From 21ca4378dfd46949a44a720d05916e96258f0cee Mon Sep 17 00:00:00 2001 From: Heiko Joerg Schick Date: Tue, 20 Aug 2024 11:27:30 +0200 Subject: [PATCH] Beautified source code with Black --- ai_model.py | 137 +++++++++++++++++++++++++++++++--------------------- yolo.py | 75 +++++++++++++++++++--------- 2 files changed, 132 insertions(+), 80 deletions(-) diff --git a/ai_model.py b/ai_model.py index f946cf5..a0d2eed 100644 --- a/ai_model.py +++ b/ai_model.py @@ -2,10 +2,11 @@ from abc import ABC, abstractmethod from openai import OpenAI from groq import Groq from ollama import Client -from openai import AzureOpenAI +from openai import AzureOpenAI from anthropic import Anthropic import os + class AIModel(ABC): @abstractmethod def chat(self, model, messages): @@ -17,109 +18,129 @@ class AIModel(ABC): @staticmethod def get_model_client(config): - api_provider=config["api"] + api_provider = config["api"] - if api_provider == "" or api_provider==None: + if api_provider == "" or api_provider == None: api_provider = "groq" - + if api_provider == "groq": return GroqModel(api_key=os.environ.get("GROQ_API_KEY")) - + elif api_provider == "openai": api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - api_key=config["openai_api_key"] - if not api_key: #If statement to avoid "invalid filepath" error - home_path = os.path.expanduser("~") - api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip() + if not api_key: + api_key = config["openai_api_key"] + if not api_key: # If statement to avoid "invalid filepath" error + home_path = os.path.expanduser("~") + api_key = ( + open(os.path.join(home_path, ".openai.apikey"), "r") + .readline() + .strip() + ) api_key = api_key return OpenAIModel(api_key=api_key) - + elif api_provider == "azure": api_key = os.getenv("AZURE_OPENAI_API_KEY") - if not api_key: - api_key=config["azure_openai_api_key"] - if not api_key: - home_path = os.path.expanduser("~") - api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip() + if not api_key: + api_key = config["azure_openai_api_key"] + if not api_key: + home_path = os.path.expanduser("~") + api_key = ( + open(os.path.join(home_path, ".azureopenai.apikey"), "r") + .readline() + .strip() + ) return AzureOpenAIModel( - api_key=api_key, - azure_endpoint=config["azure_endpoint"], - api_version=config["azure_api_version"]) - + api_key=api_key, + azure_endpoint=config["azure_endpoint"], + api_version=config["azure_api_version"], + ) + elif api_provider == "ollama": - ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") - #ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192") + ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") + # ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192") return OllamaModel(ollama_api) if api_provider == "anthropic": api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - api_key=config["anthropic_api_key"] - return AnthropicModel(api_key=api_key) + if not api_key: + api_key = config["anthropic_api_key"] + return AnthropicModel(api_key=api_key) else: raise ValueError(f"Invalid AI model provider: {api_provider}") + class GroqModel(AIModel): def __init__(self, api_key): self.client = Groq(api_key=api_key) def chat(self, messages, model, temperature, max_tokens): - resp = self.client.chat.completions.create(model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens) + resp = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) return resp.choices[0].message.content - + def moderate(self, message): pass + class OpenAIModel(AIModel): def __init__(self, api_key): self.client = OpenAI(api_key=api_key) def chat(self, messages, model, temperature, max_tokens): - resp = self.client.chat.completions.create(model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens) - + resp = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + return resp.choices[0].message.content - + def moderate(self, message): return self.client.moderations.create(input=message) + class OllamaModel(AIModel): def __init__(self, host): self.client = Client(host=host) - + def chat(self, messages, model, temperature, max_tokens): - resp = self.client.chat(model=model, - messages=messages) + resp = self.client.chat(model=model, messages=messages) return resp["message"]["content"] - + def moderate(self, message): pass class AzureOpenAIModel(AIModel): def __init__(self, azure_endpoint, api_key, api_version): - self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version) + self.client = AzureOpenAI( + azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version + ) def chat(self, messages, model, temperature, max_tokens): - resp = self.client.chat.completions.create(model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens) - + resp = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + return resp.choices[0].message.content - + def moderate(self, message): return self.client.moderations.create(input=message) + class AnthropicModel(AIModel): def __init__(self, api_key): self.client = Anthropic(api_key=api_key) @@ -127,19 +148,23 @@ class AnthropicModel(AIModel): def chat(self, messages, model, temperature, max_tokens): ## Anthropic requires the system prompt to be passed separately ## Hence extracting system prompt role from the messages - ## and then passing the messages without the system role + ## and then passing the messages without the system role ## messages is not subscriptable, so we need to convert it to a list - system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "") + system_prompt = next( + (m.get("content", "") for m in messages if m.get("role") == "system"), "" + ) # Remove system messages from the list user_messages = [m for m in messages if m.get("role") != "system"] - resp = self.client.messages.create(model=model, - system=system_prompt, - messages=user_messages, - temperature=temperature, - max_tokens=max_tokens) - + resp = self.client.messages.create( + model=model, + system=system_prompt, + messages=user_messages, + temperature=temperature, + max_tokens=max_tokens, + ) + return resp.content[0].text - + def moderate(self, message): pass diff --git a/yolo.py b/yolo.py index 0e87d4f..656da5d 100755 --- a/yolo.py +++ b/yolo.py @@ -15,6 +15,7 @@ of every command or script they frequently use. Sources: — https://github.com/wunderwuzzi23/yolo-ai-cmdbot """ + import os import platform import subprocess @@ -35,6 +36,7 @@ from ai_model import AIModel CONFIG_FILE = "yolo.yaml" PROMPT_FILE = "yolo.prompt" + def read_yaml_config() -> any: """ Read the configuration file from the executing directory. @@ -51,9 +53,10 @@ def read_yaml_config() -> any: prompt_path = os.path.dirname(yolo_path) config_file = os.path.join(prompt_path, CONFIG_FILE) - with open(config_file, 'r', encoding='utf-8') as file: + with open(config_file, "r", encoding="utf-8") as file: return yaml.safe_load(file) + def set_openai_api_key(config): """ Set the OpenAI API key by attempting several methods. @@ -89,6 +92,7 @@ def set_openai_api_key(config): if not openai.api_key: openai.api_key = config["openai_api_key"] + def print_config(config): """ Print config information. @@ -113,6 +117,7 @@ def print_config(config): print("— Color : " + str(config["suggested_command_color"])) print("— Shell : " + str(config["shell"])) + def get_os_friendly_name(): """ Returns a friendly name of the user's operating system. @@ -140,6 +145,7 @@ def get_os_friendly_name(): return os_name + def get_system_prompt(shell): """ Retrieves and constructs a system prompt by replacing placeholders @@ -175,6 +181,7 @@ def get_system_prompt(shell): return system_prompt + def chat_completion(client, query, config): """ Generate a chat-based completion for a given query using a specified model. @@ -198,7 +205,7 @@ def chat_completion(client, query, config): SystemExit: If the query is an empty string, the function will print an error message and exit. """ if query == "": - print ("No user prompt specified.") + print("No user prompt specified.") sys.exit(-1) system_prompt = get_system_prompt(config["shell"]) @@ -211,13 +218,15 @@ def chat_completion(client, query, config): model=config["model"], messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": query} - ], - temperature=config["temperature"], - max_tokens=config["max_tokens"]) + {"role": "user", "content": query}, + ], + temperature=config["temperature"], + max_tokens=config["max_tokens"], + ) return response + def check_for_issue(response): """ Checks the given response for any issues and raise an error when detected. @@ -233,9 +242,10 @@ def check_for_issue(response): """ prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am") if response.lower().startswith(prefixes): - print(colored("There was an issue: "+response, 'red')) + print(colored("There was an issue: " + response, "red")) sys.exit(-1) + def check_for_markdown(response): """ Checks for the presence of markdown formatting (specifically, code snippet markdown) in the @@ -251,12 +261,17 @@ def check_for_markdown(response): response : str A response text string that needs to be examined for markdown formatting. """ - if response.count("```",2): - print(colored( - "The proposed command contains markdown, response not executed directly: \n", 'red' - ) + response) + if response.count("```", 2): + print( + colored( + "The proposed command contains markdown, response not executed directly: \n", + "red", + ) + + response + ) sys.exit(-1) + def missing_posix_display(): """ Checks if the DISPLAY environment variable is set in a POSIX-compliant shell. @@ -273,7 +288,8 @@ def missing_posix_display(): """ display = subprocess.check_output("echo $DISPLAY", shell=True) - return display == b'\n' + return display == b"\n" + def prompt_user_input(config, response): """ @@ -294,7 +310,10 @@ def prompt_user_input(config, response): response : str The proposed command which is to be printed and may be executed by the user. """ - print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold'])) + print( + "Command: " + + colored(response, config["suggested_command_color"], attrs=["bold"]) + ) if config["safety"]: modify_text = "" @@ -302,12 +321,14 @@ def prompt_user_input(config, response): if config["modify"]: modify_text = " [m]modify" - prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": " + prompt_text = ( + "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": " + ) if os.name == "posix" and missing_posix_display(): - prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": " + prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": " - print(prompt_text, end = '') + print(prompt_text, end="") user_input = input() else: @@ -315,6 +336,7 @@ def prompt_user_input(config, response): return user_input + def evaluate_input(client, config, user_input, command): """ Evaluate the user input to either execute, modify, or copy the command. @@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command): subprocess.run([config["shell"], "-c", command], shell=False, check=True) if user_input.upper() == "M": - print("Modify prompt: ", end = '') + print("Modify prompt: ", end="") modded_query = input() modded_response = chat_completion(client, modded_query, config) check_for_issue(modded_response) @@ -365,14 +387,18 @@ def main(): Defined starting point of source code. """ parser = argparse.ArgumentParser( - description='AI bot that translates your question to a command.' + description="AI bot that translates your question to a command." + ) + parser.add_argument("text", nargs="+", help="A sequence of strings") + parser.add_argument( + "-s", + "--safety", + action="store_true", + help="Enable safety mode (only useful when safety is off)", + ) + parser.add_argument( + "-c", "--config", action="store_true", help="Print current configuration" ) - parser.add_argument('text', nargs='+', - help='A sequence of strings') - parser.add_argument("-s", "--safety", action='store_true', - help='Enable safety mode (only useful when safety is off)') - parser.add_argument("-c", "--config", action='store_true', - help='Print current configuration') args = parser.parse_args() # Load configurations and set up client @@ -403,5 +429,6 @@ def main(): print() evaluate_input(client, config, user_input, result) + if __name__ == "__main__": main()