Beautified source code with Black

This commit is contained in:
2024-08-20 11:27:30 +02:00
parent 48c77403a3
commit 21ca4378df
2 changed files with 132 additions and 80 deletions
+47 -22
View File
@@ -6,6 +6,7 @@ from openai import AzureOpenAI
from anthropic import Anthropic
import os
class AIModel(ABC):
@abstractmethod
def chat(self, model, messages):
@@ -17,9 +18,9 @@ class AIModel(ABC):
@staticmethod
def get_model_client(config):
api_provider=config["api"]
api_provider = config["api"]
if api_provider == "" or api_provider==None:
if api_provider == "" or api_provider == None:
api_provider = "groq"
if api_provider == "groq":
@@ -28,10 +29,14 @@ class AIModel(ABC):
elif api_provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
api_key=config["openai_api_key"]
if not api_key: #If statement to avoid "invalid filepath" error
api_key = config["openai_api_key"]
if not api_key: # If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
api_key = (
open(os.path.join(home_path, ".openai.apikey"), "r")
.readline()
.strip()
)
api_key = api_key
return OpenAIModel(api_key=api_key)
@@ -39,65 +44,76 @@ class AIModel(ABC):
elif api_provider == "azure":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
api_key=config["azure_openai_api_key"]
api_key = config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()
api_key = (
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
.readline()
.strip()
)
return AzureOpenAIModel(
api_key=api_key,
azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"])
api_version=config["azure_api_version"],
)
elif api_provider == "ollama":
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
return OllamaModel(ollama_api)
if api_provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
api_key=config["anthropic_api_key"]
api_key = config["anthropic_api_key"]
return AnthropicModel(api_key=api_key)
else:
raise ValueError(f"Invalid AI model provider: {api_provider}")
class GroqModel(AIModel):
def __init__(self, api_key):
self.client = Groq(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
pass
class OpenAIModel(AIModel):
def __init__(self, api_key):
self.client = OpenAI(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class OllamaModel(AIModel):
def __init__(self, host):
self.client = Client(host=host)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat(model=model,
messages=messages)
resp = self.client.chat(model=model, messages=messages)
return resp["message"]["content"]
def moderate(self, message):
@@ -106,20 +122,25 @@ class OllamaModel(AIModel):
class AzureOpenAIModel(AIModel):
def __init__(self, azure_endpoint, api_key, api_version):
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)
self.client = AzureOpenAI(
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class AnthropicModel(AIModel):
def __init__(self, api_key):
self.client = Anthropic(api_key=api_key)
@@ -129,15 +150,19 @@ class AnthropicModel(AIModel):
## Hence extracting system prompt role from the messages
## and then passing the messages without the system role
## messages is not subscriptable, so we need to convert it to a list
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")
system_prompt = next(
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
)
# Remove system messages from the list
user_messages = [m for m in messages if m.get("role") != "system"]
resp = self.client.messages.create(model=model,
resp = self.client.messages.create(
model=model,
system=system_prompt,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
)
return resp.content[0].text
+48 -21
View File
@@ -15,6 +15,7 @@ of every command or script they frequently use.
Sources:
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
"""
import os
import platform
import subprocess
@@ -35,6 +36,7 @@ from ai_model import AIModel
CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt"
def read_yaml_config() -> any:
"""
Read the configuration file from the executing directory.
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
prompt_path = os.path.dirname(yolo_path)
config_file = os.path.join(prompt_path, CONFIG_FILE)
with open(config_file, 'r', encoding='utf-8') as file:
with open(config_file, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def set_openai_api_key(config):
"""
Set the OpenAI API key by attempting several methods.
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
if not openai.api_key:
openai.api_key = config["openai_api_key"]
def print_config(config):
"""
Print config information.
@@ -113,6 +117,7 @@ def print_config(config):
print("— Color : " + str(config["suggested_command_color"]))
print("— Shell : " + str(config["shell"]))
def get_os_friendly_name():
"""
Returns a friendly name of the user's operating system.
@@ -140,6 +145,7 @@ def get_os_friendly_name():
return os_name
def get_system_prompt(shell):
"""
Retrieves and constructs a system prompt by replacing placeholders
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
return system_prompt
def chat_completion(client, query, config):
"""
Generate a chat-based completion for a given query using a specified model.
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
SystemExit: If the query is an empty string, the function will print an error message and exit.
"""
if query == "":
print ("No user prompt specified.")
print("No user prompt specified.")
sys.exit(-1)
system_prompt = get_system_prompt(config["shell"])
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
{"role": "user", "content": query},
],
temperature=config["temperature"],
max_tokens=config["max_tokens"])
max_tokens=config["max_tokens"],
)
return response
def check_for_issue(response):
"""
Checks the given response for any issues and raise an error when detected.
@@ -233,9 +242,10 @@ def check_for_issue(response):
"""
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
if response.lower().startswith(prefixes):
print(colored("There was an issue: "+response, 'red'))
print(colored("There was an issue: " + response, "red"))
sys.exit(-1)
def check_for_markdown(response):
"""
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
@@ -251,12 +261,17 @@ def check_for_markdown(response):
response : str
A response text string that needs to be examined for markdown formatting.
"""
if response.count("```",2):
print(colored(
"The proposed command contains markdown, response not executed directly: \n", 'red'
) + response)
if response.count("```", 2):
print(
colored(
"The proposed command contains markdown, response not executed directly: \n",
"red",
)
+ response
)
sys.exit(-1)
def missing_posix_display():
"""
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
@@ -273,7 +288,8 @@ def missing_posix_display():
"""
display = subprocess.check_output("echo $DISPLAY", shell=True)
return display == b'\n'
return display == b"\n"
def prompt_user_input(config, response):
"""
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
response : str
The proposed command which is to be printed and may be executed by the user.
"""
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold']))
print(
"Command: "
+ colored(response, config["suggested_command_color"], attrs=["bold"])
)
if config["safety"]:
modify_text = ""
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
if config["modify"]:
modify_text = " [m]modify"
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
prompt_text = (
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
)
if os.name == "posix" and missing_posix_display():
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
print(prompt_text, end = '')
print(prompt_text, end="")
user_input = input()
else:
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
return user_input
def evaluate_input(client, config, user_input, command):
"""
Evaluate the user input to either execute, modify, or copy the command.
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
if user_input.upper() == "M":
print("Modify prompt: ", end = '')
print("Modify prompt: ", end="")
modded_query = input()
modded_response = chat_completion(client, modded_query, config)
check_for_issue(modded_response)
@@ -365,14 +387,18 @@ def main():
Defined starting point of source code.
"""
parser = argparse.ArgumentParser(
description='AI bot that translates your question to a command.'
description="AI bot that translates your question to a command."
)
parser.add_argument("text", nargs="+", help="A sequence of strings")
parser.add_argument(
"-s",
"--safety",
action="store_true",
help="Enable safety mode (only useful when safety is off)",
)
parser.add_argument(
"-c", "--config", action="store_true", help="Print current configuration"
)
parser.add_argument('text', nargs='+',
help='A sequence of strings')
parser.add_argument("-s", "--safety", action='store_true',
help='Enable safety mode (only useful when safety is off)')
parser.add_argument("-c", "--config", action='store_true',
help='Print current configuration')
args = parser.parse_args()
# Load configurations and set up client
@@ -403,5 +429,6 @@ def main():
print()
evaluate_input(client, config, user_input, result)
if __name__ == "__main__":
main()