Beautified source code with Black

This commit is contained in:
2024-08-20 11:27:30 +02:00
parent 48c77403a3
commit 21ca4378df
2 changed files with 132 additions and 80 deletions
+81 -56
View File
@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from openai import OpenAI from openai import OpenAI
from groq import Groq from groq import Groq
from ollama import Client from ollama import Client
from openai import AzureOpenAI from openai import AzureOpenAI
from anthropic import Anthropic from anthropic import Anthropic
import os import os
class AIModel(ABC): class AIModel(ABC):
@abstractmethod @abstractmethod
def chat(self, model, messages): def chat(self, model, messages):
@@ -17,109 +18,129 @@ class AIModel(ABC):
@staticmethod @staticmethod
def get_model_client(config): def get_model_client(config):
api_provider=config["api"] api_provider = config["api"]
if api_provider == "" or api_provider==None: if api_provider == "" or api_provider == None:
api_provider = "groq" api_provider = "groq"
if api_provider == "groq": if api_provider == "groq":
return GroqModel(api_key=os.environ.get("GROQ_API_KEY")) return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
elif api_provider == "openai": elif api_provider == "openai":
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
api_key=config["openai_api_key"] api_key = config["openai_api_key"]
if not api_key: #If statement to avoid "invalid filepath" error if not api_key: # If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~") home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip() api_key = (
open(os.path.join(home_path, ".openai.apikey"), "r")
.readline()
.strip()
)
api_key = api_key api_key = api_key
return OpenAIModel(api_key=api_key) return OpenAIModel(api_key=api_key)
elif api_provider == "azure": elif api_provider == "azure":
api_key = os.getenv("AZURE_OPENAI_API_KEY") api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key: if not api_key:
api_key=config["azure_openai_api_key"] api_key = config["azure_openai_api_key"]
if not api_key: if not api_key:
home_path = os.path.expanduser("~") home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip() api_key = (
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
.readline()
.strip()
)
return AzureOpenAIModel( return AzureOpenAIModel(
api_key=api_key, api_key=api_key,
azure_endpoint=config["azure_endpoint"], azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"]) api_version=config["azure_api_version"],
)
elif api_provider == "ollama": elif api_provider == "ollama":
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192") # ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
return OllamaModel(ollama_api) return OllamaModel(ollama_api)
if api_provider == "anthropic": if api_provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY") api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key: if not api_key:
api_key=config["anthropic_api_key"] api_key = config["anthropic_api_key"]
return AnthropicModel(api_key=api_key) return AnthropicModel(api_key=api_key)
else: else:
raise ValueError(f"Invalid AI model provider: {api_provider}") raise ValueError(f"Invalid AI model provider: {api_provider}")
class GroqModel(AIModel): class GroqModel(AIModel):
def __init__(self, api_key): def __init__(self, api_key):
self.client = Groq(api_key=api_key) self.client = Groq(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens): def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model, resp = self.client.chat.completions.create(
messages=messages, model=model,
temperature=temperature, messages=messages,
max_tokens=max_tokens) temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content return resp.choices[0].message.content
def moderate(self, message): def moderate(self, message):
pass pass
class OpenAIModel(AIModel): class OpenAIModel(AIModel):
def __init__(self, api_key): def __init__(self, api_key):
self.client = OpenAI(api_key=api_key) self.client = OpenAI(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens): def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model, resp = self.client.chat.completions.create(
messages=messages, model=model,
temperature=temperature, messages=messages,
max_tokens=max_tokens) temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content return resp.choices[0].message.content
def moderate(self, message): def moderate(self, message):
return self.client.moderations.create(input=message) return self.client.moderations.create(input=message)
class OllamaModel(AIModel): class OllamaModel(AIModel):
def __init__(self, host): def __init__(self, host):
self.client = Client(host=host) self.client = Client(host=host)
def chat(self, messages, model, temperature, max_tokens): def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat(model=model, resp = self.client.chat(model=model, messages=messages)
messages=messages)
return resp["message"]["content"] return resp["message"]["content"]
def moderate(self, message): def moderate(self, message):
pass pass
class AzureOpenAIModel(AIModel): class AzureOpenAIModel(AIModel):
def __init__(self, azure_endpoint, api_key, api_version): def __init__(self, azure_endpoint, api_key, api_version):
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version) self.client = AzureOpenAI(
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
)
def chat(self, messages, model, temperature, max_tokens): def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model, resp = self.client.chat.completions.create(
messages=messages, model=model,
temperature=temperature, messages=messages,
max_tokens=max_tokens) temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content return resp.choices[0].message.content
def moderate(self, message): def moderate(self, message):
return self.client.moderations.create(input=message) return self.client.moderations.create(input=message)
class AnthropicModel(AIModel): class AnthropicModel(AIModel):
def __init__(self, api_key): def __init__(self, api_key):
self.client = Anthropic(api_key=api_key) self.client = Anthropic(api_key=api_key)
@@ -127,19 +148,23 @@ class AnthropicModel(AIModel):
def chat(self, messages, model, temperature, max_tokens): def chat(self, messages, model, temperature, max_tokens):
## Anthropic requires the system prompt to be passed separately ## Anthropic requires the system prompt to be passed separately
## Hence extracting system prompt role from the messages ## Hence extracting system prompt role from the messages
## and then passing the messages without the system role ## and then passing the messages without the system role
## messages is not subscriptable, so we need to convert it to a list ## messages is not subscriptable, so we need to convert it to a list
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "") system_prompt = next(
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
)
# Remove system messages from the list # Remove system messages from the list
user_messages = [m for m in messages if m.get("role") != "system"] user_messages = [m for m in messages if m.get("role") != "system"]
resp = self.client.messages.create(model=model, resp = self.client.messages.create(
system=system_prompt, model=model,
messages=user_messages, system=system_prompt,
temperature=temperature, messages=user_messages,
max_tokens=max_tokens) temperature=temperature,
max_tokens=max_tokens,
)
return resp.content[0].text return resp.content[0].text
def moderate(self, message): def moderate(self, message):
pass pass
+51 -24
View File
@@ -15,6 +15,7 @@ of every command or script they frequently use.
Sources: Sources:
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot — https://github.com/wunderwuzzi23/yolo-ai-cmdbot
""" """
import os import os
import platform import platform
import subprocess import subprocess
@@ -35,6 +36,7 @@ from ai_model import AIModel
CONFIG_FILE = "yolo.yaml" CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt" PROMPT_FILE = "yolo.prompt"
def read_yaml_config() -> any: def read_yaml_config() -> any:
""" """
Read the configuration file from the executing directory. Read the configuration file from the executing directory.
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
prompt_path = os.path.dirname(yolo_path) prompt_path = os.path.dirname(yolo_path)
config_file = os.path.join(prompt_path, CONFIG_FILE) config_file = os.path.join(prompt_path, CONFIG_FILE)
with open(config_file, 'r', encoding='utf-8') as file: with open(config_file, "r", encoding="utf-8") as file:
return yaml.safe_load(file) return yaml.safe_load(file)
def set_openai_api_key(config): def set_openai_api_key(config):
""" """
Set the OpenAI API key by attempting several methods. Set the OpenAI API key by attempting several methods.
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
if not openai.api_key: if not openai.api_key:
openai.api_key = config["openai_api_key"] openai.api_key = config["openai_api_key"]
def print_config(config): def print_config(config):
""" """
Print config information. Print config information.
@@ -113,6 +117,7 @@ def print_config(config):
print("— Color : " + str(config["suggested_command_color"])) print("— Color : " + str(config["suggested_command_color"]))
print("— Shell : " + str(config["shell"])) print("— Shell : " + str(config["shell"]))
def get_os_friendly_name(): def get_os_friendly_name():
""" """
Returns a friendly name of the user's operating system. Returns a friendly name of the user's operating system.
@@ -140,6 +145,7 @@ def get_os_friendly_name():
return os_name return os_name
def get_system_prompt(shell): def get_system_prompt(shell):
""" """
Retrieves and constructs a system prompt by replacing placeholders Retrieves and constructs a system prompt by replacing placeholders
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
return system_prompt return system_prompt
def chat_completion(client, query, config): def chat_completion(client, query, config):
""" """
Generate a chat-based completion for a given query using a specified model. Generate a chat-based completion for a given query using a specified model.
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
SystemExit: If the query is an empty string, the function will print an error message and exit. SystemExit: If the query is an empty string, the function will print an error message and exit.
""" """
if query == "": if query == "":
print ("No user prompt specified.") print("No user prompt specified.")
sys.exit(-1) sys.exit(-1)
system_prompt = get_system_prompt(config["shell"]) system_prompt = get_system_prompt(config["shell"])
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
model=config["model"], model=config["model"],
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": query} {"role": "user", "content": query},
], ],
temperature=config["temperature"], temperature=config["temperature"],
max_tokens=config["max_tokens"]) max_tokens=config["max_tokens"],
)
return response return response
def check_for_issue(response): def check_for_issue(response):
""" """
Checks the given response for any issues and raise an error when detected. Checks the given response for any issues and raise an error when detected.
@@ -233,9 +242,10 @@ def check_for_issue(response):
""" """
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am") prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
if response.lower().startswith(prefixes): if response.lower().startswith(prefixes):
print(colored("There was an issue: "+response, 'red')) print(colored("There was an issue: " + response, "red"))
sys.exit(-1) sys.exit(-1)
def check_for_markdown(response): def check_for_markdown(response):
""" """
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
@@ -251,12 +261,17 @@ def check_for_markdown(response):
response : str response : str
A response text string that needs to be examined for markdown formatting. A response text string that needs to be examined for markdown formatting.
""" """
if response.count("```",2): if response.count("```", 2):
print(colored( print(
"The proposed command contains markdown, response not executed directly: \n", 'red' colored(
) + response) "The proposed command contains markdown, response not executed directly: \n",
"red",
)
+ response
)
sys.exit(-1) sys.exit(-1)
def missing_posix_display(): def missing_posix_display():
""" """
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell. Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
@@ -273,7 +288,8 @@ def missing_posix_display():
""" """
display = subprocess.check_output("echo $DISPLAY", shell=True) display = subprocess.check_output("echo $DISPLAY", shell=True)
return display == b'\n' return display == b"\n"
def prompt_user_input(config, response): def prompt_user_input(config, response):
""" """
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
response : str response : str
The proposed command which is to be printed and may be executed by the user. The proposed command which is to be printed and may be executed by the user.
""" """
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold'])) print(
"Command: "
+ colored(response, config["suggested_command_color"], attrs=["bold"])
)
if config["safety"]: if config["safety"]:
modify_text = "" modify_text = ""
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
if config["modify"]: if config["modify"]:
modify_text = " [m]modify" modify_text = " [m]modify"
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": " prompt_text = (
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
)
if os.name == "posix" and missing_posix_display(): if os.name == "posix" and missing_posix_display():
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": " prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
print(prompt_text, end = '') print(prompt_text, end="")
user_input = input() user_input = input()
else: else:
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
return user_input return user_input
def evaluate_input(client, config, user_input, command): def evaluate_input(client, config, user_input, command):
""" """
Evaluate the user input to either execute, modify, or copy the command. Evaluate the user input to either execute, modify, or copy the command.
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
subprocess.run([config["shell"], "-c", command], shell=False, check=True) subprocess.run([config["shell"], "-c", command], shell=False, check=True)
if user_input.upper() == "M": if user_input.upper() == "M":
print("Modify prompt: ", end = '') print("Modify prompt: ", end="")
modded_query = input() modded_query = input()
modded_response = chat_completion(client, modded_query, config) modded_response = chat_completion(client, modded_query, config)
check_for_issue(modded_response) check_for_issue(modded_response)
@@ -365,14 +387,18 @@ def main():
Defined starting point of source code. Defined starting point of source code.
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='AI bot that translates your question to a command.' description="AI bot that translates your question to a command."
)
parser.add_argument("text", nargs="+", help="A sequence of strings")
parser.add_argument(
"-s",
"--safety",
action="store_true",
help="Enable safety mode (only useful when safety is off)",
)
parser.add_argument(
"-c", "--config", action="store_true", help="Print current configuration"
) )
parser.add_argument('text', nargs='+',
help='A sequence of strings')
parser.add_argument("-s", "--safety", action='store_true',
help='Enable safety mode (only useful when safety is off)')
parser.add_argument("-c", "--config", action='store_true',
help='Print current configuration')
args = parser.parse_args() args = parser.parse_args()
# Load configurations and set up client # Load configurations and set up client
@@ -403,5 +429,6 @@ def main():
print() print()
evaluate_input(client, config, user_input, result) evaluate_input(client, config, user_input, result)
if __name__ == "__main__": if __name__ == "__main__":
main() main()