Beautified source code with Black
This commit is contained in:
+47
-22
@@ -6,6 +6,7 @@ from openai import AzureOpenAI
|
||||
from anthropic import Anthropic
|
||||
import os
|
||||
|
||||
|
||||
class AIModel(ABC):
|
||||
@abstractmethod
|
||||
def chat(self, model, messages):
|
||||
@@ -17,9 +18,9 @@ class AIModel(ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_model_client(config):
|
||||
api_provider=config["api"]
|
||||
api_provider = config["api"]
|
||||
|
||||
if api_provider == "" or api_provider==None:
|
||||
if api_provider == "" or api_provider == None:
|
||||
api_provider = "groq"
|
||||
|
||||
if api_provider == "groq":
|
||||
@@ -28,10 +29,14 @@ class AIModel(ABC):
|
||||
elif api_provider == "openai":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["openai_api_key"]
|
||||
if not api_key: #If statement to avoid "invalid filepath" error
|
||||
api_key = config["openai_api_key"]
|
||||
if not api_key: # If statement to avoid "invalid filepath" error
|
||||
home_path = os.path.expanduser("~")
|
||||
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
|
||||
api_key = (
|
||||
open(os.path.join(home_path, ".openai.apikey"), "r")
|
||||
.readline()
|
||||
.strip()
|
||||
)
|
||||
api_key = api_key
|
||||
|
||||
return OpenAIModel(api_key=api_key)
|
||||
@@ -39,65 +44,76 @@ class AIModel(ABC):
|
||||
elif api_provider == "azure":
|
||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["azure_openai_api_key"]
|
||||
api_key = config["azure_openai_api_key"]
|
||||
if not api_key:
|
||||
home_path = os.path.expanduser("~")
|
||||
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()
|
||||
api_key = (
|
||||
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
|
||||
.readline()
|
||||
.strip()
|
||||
)
|
||||
|
||||
return AzureOpenAIModel(
|
||||
api_key=api_key,
|
||||
azure_endpoint=config["azure_endpoint"],
|
||||
api_version=config["azure_api_version"])
|
||||
api_version=config["azure_api_version"],
|
||||
)
|
||||
|
||||
elif api_provider == "ollama":
|
||||
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
|
||||
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
||||
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
||||
return OllamaModel(ollama_api)
|
||||
|
||||
if api_provider == "anthropic":
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["anthropic_api_key"]
|
||||
api_key = config["anthropic_api_key"]
|
||||
return AnthropicModel(api_key=api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid AI model provider: {api_provider}")
|
||||
|
||||
|
||||
class GroqModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = Groq(api_key=api_key)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
resp = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
resp = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
return self.client.moderations.create(input=message)
|
||||
|
||||
|
||||
class OllamaModel(AIModel):
|
||||
def __init__(self, host):
|
||||
self.client = Client(host=host)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat(model=model,
|
||||
messages=messages)
|
||||
resp = self.client.chat(model=model, messages=messages)
|
||||
return resp["message"]["content"]
|
||||
|
||||
def moderate(self, message):
|
||||
@@ -106,20 +122,25 @@ class OllamaModel(AIModel):
|
||||
|
||||
class AzureOpenAIModel(AIModel):
|
||||
def __init__(self, azure_endpoint, api_key, api_version):
|
||||
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)
|
||||
self.client = AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
|
||||
)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
resp = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
return self.client.moderations.create(input=message)
|
||||
|
||||
|
||||
class AnthropicModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = Anthropic(api_key=api_key)
|
||||
@@ -129,15 +150,19 @@ class AnthropicModel(AIModel):
|
||||
## Hence extracting system prompt role from the messages
|
||||
## and then passing the messages without the system role
|
||||
## messages is not subscriptable, so we need to convert it to a list
|
||||
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")
|
||||
system_prompt = next(
|
||||
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
|
||||
)
|
||||
|
||||
# Remove system messages from the list
|
||||
user_messages = [m for m in messages if m.get("role") != "system"]
|
||||
resp = self.client.messages.create(model=model,
|
||||
resp = self.client.messages.create(
|
||||
model=model,
|
||||
system=system_prompt,
|
||||
messages=user_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
return resp.content[0].text
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ of every command or script they frequently use.
|
||||
Sources:
|
||||
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
@@ -35,6 +36,7 @@ from ai_model import AIModel
|
||||
CONFIG_FILE = "yolo.yaml"
|
||||
PROMPT_FILE = "yolo.prompt"
|
||||
|
||||
|
||||
def read_yaml_config() -> any:
|
||||
"""
|
||||
Read the configuration file from the executing directory.
|
||||
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
|
||||
prompt_path = os.path.dirname(yolo_path)
|
||||
|
||||
config_file = os.path.join(prompt_path, CONFIG_FILE)
|
||||
with open(config_file, 'r', encoding='utf-8') as file:
|
||||
with open(config_file, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
|
||||
def set_openai_api_key(config):
|
||||
"""
|
||||
Set the OpenAI API key by attempting several methods.
|
||||
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
|
||||
if not openai.api_key:
|
||||
openai.api_key = config["openai_api_key"]
|
||||
|
||||
|
||||
def print_config(config):
|
||||
"""
|
||||
Print config information.
|
||||
@@ -113,6 +117,7 @@ def print_config(config):
|
||||
print("— Color : " + str(config["suggested_command_color"]))
|
||||
print("— Shell : " + str(config["shell"]))
|
||||
|
||||
|
||||
def get_os_friendly_name():
|
||||
"""
|
||||
Returns a friendly name of the user's operating system.
|
||||
@@ -140,6 +145,7 @@ def get_os_friendly_name():
|
||||
|
||||
return os_name
|
||||
|
||||
|
||||
def get_system_prompt(shell):
|
||||
"""
|
||||
Retrieves and constructs a system prompt by replacing placeholders
|
||||
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
|
||||
|
||||
return system_prompt
|
||||
|
||||
|
||||
def chat_completion(client, query, config):
|
||||
"""
|
||||
Generate a chat-based completion for a given query using a specified model.
|
||||
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
|
||||
SystemExit: If the query is an empty string, the function will print an error message and exit.
|
||||
"""
|
||||
if query == "":
|
||||
print ("No user prompt specified.")
|
||||
print("No user prompt specified.")
|
||||
sys.exit(-1)
|
||||
|
||||
system_prompt = get_system_prompt(config["shell"])
|
||||
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
|
||||
model=config["model"],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query}
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
temperature=config["temperature"],
|
||||
max_tokens=config["max_tokens"])
|
||||
max_tokens=config["max_tokens"],
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def check_for_issue(response):
|
||||
"""
|
||||
Checks the given response for any issues and raise an error when detected.
|
||||
@@ -233,9 +242,10 @@ def check_for_issue(response):
|
||||
"""
|
||||
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
|
||||
if response.lower().startswith(prefixes):
|
||||
print(colored("There was an issue: "+response, 'red'))
|
||||
print(colored("There was an issue: " + response, "red"))
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def check_for_markdown(response):
|
||||
"""
|
||||
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
|
||||
@@ -251,12 +261,17 @@ def check_for_markdown(response):
|
||||
response : str
|
||||
A response text string that needs to be examined for markdown formatting.
|
||||
"""
|
||||
if response.count("```",2):
|
||||
print(colored(
|
||||
"The proposed command contains markdown, response not executed directly: \n", 'red'
|
||||
) + response)
|
||||
if response.count("```", 2):
|
||||
print(
|
||||
colored(
|
||||
"The proposed command contains markdown, response not executed directly: \n",
|
||||
"red",
|
||||
)
|
||||
+ response
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def missing_posix_display():
|
||||
"""
|
||||
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
|
||||
@@ -273,7 +288,8 @@ def missing_posix_display():
|
||||
"""
|
||||
display = subprocess.check_output("echo $DISPLAY", shell=True)
|
||||
|
||||
return display == b'\n'
|
||||
return display == b"\n"
|
||||
|
||||
|
||||
def prompt_user_input(config, response):
|
||||
"""
|
||||
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
|
||||
response : str
|
||||
The proposed command which is to be printed and may be executed by the user.
|
||||
"""
|
||||
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold']))
|
||||
print(
|
||||
"Command: "
|
||||
+ colored(response, config["suggested_command_color"], attrs=["bold"])
|
||||
)
|
||||
|
||||
if config["safety"]:
|
||||
modify_text = ""
|
||||
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
|
||||
if config["modify"]:
|
||||
modify_text = " [m]modify"
|
||||
|
||||
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
|
||||
prompt_text = (
|
||||
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
|
||||
)
|
||||
|
||||
if os.name == "posix" and missing_posix_display():
|
||||
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
|
||||
|
||||
print(prompt_text, end = '')
|
||||
print(prompt_text, end="")
|
||||
|
||||
user_input = input()
|
||||
else:
|
||||
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
|
||||
|
||||
return user_input
|
||||
|
||||
|
||||
def evaluate_input(client, config, user_input, command):
|
||||
"""
|
||||
Evaluate the user input to either execute, modify, or copy the command.
|
||||
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
|
||||
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
|
||||
|
||||
if user_input.upper() == "M":
|
||||
print("Modify prompt: ", end = '')
|
||||
print("Modify prompt: ", end="")
|
||||
modded_query = input()
|
||||
modded_response = chat_completion(client, modded_query, config)
|
||||
check_for_issue(modded_response)
|
||||
@@ -365,14 +387,18 @@ def main():
|
||||
Defined starting point of source code.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='AI bot that translates your question to a command.'
|
||||
description="AI bot that translates your question to a command."
|
||||
)
|
||||
parser.add_argument("text", nargs="+", help="A sequence of strings")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--safety",
|
||||
action="store_true",
|
||||
help="Enable safety mode (only useful when safety is off)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config", action="store_true", help="Print current configuration"
|
||||
)
|
||||
parser.add_argument('text', nargs='+',
|
||||
help='A sequence of strings')
|
||||
parser.add_argument("-s", "--safety", action='store_true',
|
||||
help='Enable safety mode (only useful when safety is off)')
|
||||
parser.add_argument("-c", "--config", action='store_true',
|
||||
help='Print current configuration')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configurations and set up client
|
||||
@@ -403,5 +429,6 @@ def main():
|
||||
print()
|
||||
evaluate_input(client, config, user_input, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user