Beautified source code with Black

This commit is contained in:
2024-08-20 11:27:30 +02:00
parent 48c77403a3
commit 21ca4378df
2 changed files with 132 additions and 80 deletions
+81 -56
View File
@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from openai import OpenAI
from groq import Groq
from ollama import Client
from openai import AzureOpenAI
from openai import AzureOpenAI
from anthropic import Anthropic
import os
class AIModel(ABC):
@abstractmethod
def chat(self, model, messages):
@@ -17,109 +18,129 @@ class AIModel(ABC):
@staticmethod
def get_model_client(config):
api_provider=config["api"]
api_provider = config["api"]
if api_provider == "" or api_provider==None:
if api_provider == "" or api_provider == None:
api_provider = "groq"
if api_provider == "groq":
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
elif api_provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
api_key=config["openai_api_key"]
if not api_key: #If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
if not api_key:
api_key = config["openai_api_key"]
if not api_key: # If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key = (
open(os.path.join(home_path, ".openai.apikey"), "r")
.readline()
.strip()
)
api_key = api_key
return OpenAIModel(api_key=api_key)
elif api_provider == "azure":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
api_key=config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()
if not api_key:
api_key = config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key = (
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
.readline()
.strip()
)
return AzureOpenAIModel(
api_key=api_key,
azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"])
api_key=api_key,
azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"],
)
elif api_provider == "ollama":
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
return OllamaModel(ollama_api)
if api_provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
api_key=config["anthropic_api_key"]
return AnthropicModel(api_key=api_key)
if not api_key:
api_key = config["anthropic_api_key"]
return AnthropicModel(api_key=api_key)
else:
raise ValueError(f"Invalid AI model provider: {api_provider}")
class GroqModel(AIModel):
def __init__(self, api_key):
self.client = Groq(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
pass
class OpenAIModel(AIModel):
def __init__(self, api_key):
self.client = OpenAI(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class OllamaModel(AIModel):
def __init__(self, host):
self.client = Client(host=host)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat(model=model,
messages=messages)
resp = self.client.chat(model=model, messages=messages)
return resp["message"]["content"]
def moderate(self, message):
pass
class AzureOpenAIModel(AIModel):
def __init__(self, azure_endpoint, api_key, api_version):
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)
self.client = AzureOpenAI(
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class AnthropicModel(AIModel):
def __init__(self, api_key):
self.client = Anthropic(api_key=api_key)
@@ -127,19 +148,23 @@ class AnthropicModel(AIModel):
def chat(self, messages, model, temperature, max_tokens):
## Anthropic requires the system prompt to be passed separately
## Hence extracting system prompt role from the messages
## and then passing the messages without the system role
## and then passing the messages without the system role
## messages is not subscriptable, so we need to convert it to a list
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")
system_prompt = next(
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
)
# Remove system messages from the list
user_messages = [m for m in messages if m.get("role") != "system"]
resp = self.client.messages.create(model=model,
system=system_prompt,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens)
resp = self.client.messages.create(
model=model,
system=system_prompt,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.content[0].text
def moderate(self, message):
pass
+51 -24
View File
@@ -15,6 +15,7 @@ of every command or script they frequently use.
Sources:
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
"""
import os
import platform
import subprocess
@@ -35,6 +36,7 @@ from ai_model import AIModel
CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt"
def read_yaml_config() -> any:
"""
Read the configuration file from the executing directory.
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
prompt_path = os.path.dirname(yolo_path)
config_file = os.path.join(prompt_path, CONFIG_FILE)
with open(config_file, 'r', encoding='utf-8') as file:
with open(config_file, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def set_openai_api_key(config):
"""
Set the OpenAI API key by attempting several methods.
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
if not openai.api_key:
openai.api_key = config["openai_api_key"]
def print_config(config):
"""
Print config information.
@@ -113,6 +117,7 @@ def print_config(config):
print("— Color : " + str(config["suggested_command_color"]))
print("— Shell : " + str(config["shell"]))
def get_os_friendly_name():
"""
Returns a friendly name of the user's operating system.
@@ -140,6 +145,7 @@ def get_os_friendly_name():
return os_name
def get_system_prompt(shell):
"""
Retrieves and constructs a system prompt by replacing placeholders
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
return system_prompt
def chat_completion(client, query, config):
"""
Generate a chat-based completion for a given query using a specified model.
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
SystemExit: If the query is an empty string, the function will print an error message and exit.
"""
if query == "":
print ("No user prompt specified.")
print("No user prompt specified.")
sys.exit(-1)
system_prompt = get_system_prompt(config["shell"])
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
],
temperature=config["temperature"],
max_tokens=config["max_tokens"])
{"role": "user", "content": query},
],
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
return response
def check_for_issue(response):
"""
Checks the given response for any issues and raise an error when detected.
@@ -233,9 +242,10 @@ def check_for_issue(response):
"""
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
if response.lower().startswith(prefixes):
print(colored("There was an issue: "+response, 'red'))
print(colored("There was an issue: " + response, "red"))
sys.exit(-1)
def check_for_markdown(response):
"""
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
@@ -251,12 +261,17 @@ def check_for_markdown(response):
response : str
A response text string that needs to be examined for markdown formatting.
"""
if response.count("```",2):
print(colored(
"The proposed command contains markdown, response not executed directly: \n", 'red'
) + response)
if response.count("```", 2):
print(
colored(
"The proposed command contains markdown, response not executed directly: \n",
"red",
)
+ response
)
sys.exit(-1)
def missing_posix_display():
"""
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
@@ -273,7 +288,8 @@ def missing_posix_display():
"""
display = subprocess.check_output("echo $DISPLAY", shell=True)
return display == b'\n'
return display == b"\n"
def prompt_user_input(config, response):
"""
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
response : str
The proposed command which is to be printed and may be executed by the user.
"""
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold']))
print(
"Command: "
+ colored(response, config["suggested_command_color"], attrs=["bold"])
)
if config["safety"]:
modify_text = ""
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
if config["modify"]:
modify_text = " [m]modify"
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
prompt_text = (
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
)
if os.name == "posix" and missing_posix_display():
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
print(prompt_text, end = '')
print(prompt_text, end="")
user_input = input()
else:
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
return user_input
def evaluate_input(client, config, user_input, command):
"""
Evaluate the user input to either execute, modify, or copy the command.
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
if user_input.upper() == "M":
print("Modify prompt: ", end = '')
print("Modify prompt: ", end="")
modded_query = input()
modded_response = chat_completion(client, modded_query, config)
check_for_issue(modded_response)
@@ -365,14 +387,18 @@ def main():
Defined starting point of source code.
"""
parser = argparse.ArgumentParser(
description='AI bot that translates your question to a command.'
description="AI bot that translates your question to a command."
)
parser.add_argument("text", nargs="+", help="A sequence of strings")
parser.add_argument(
"-s",
"--safety",
action="store_true",
help="Enable safety mode (only useful when safety is off)",
)
parser.add_argument(
"-c", "--config", action="store_true", help="Print current configuration"
)
parser.add_argument('text', nargs='+',
help='A sequence of strings')
parser.add_argument("-s", "--safety", action='store_true',
help='Enable safety mode (only useful when safety is off)')
parser.add_argument("-c", "--config", action='store_true',
help='Print current configuration')
args = parser.parse_args()
# Load configurations and set up client
@@ -403,5 +429,6 @@ def main():
print()
evaluate_input(client, config, user_input, result)
if __name__ == "__main__":
main()