Beautified source code with Black
This commit is contained in:
+81
-56
@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
from anthropic import Anthropic
|
from anthropic import Anthropic
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
class AIModel(ABC):
|
class AIModel(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def chat(self, model, messages):
|
def chat(self, model, messages):
|
||||||
@@ -17,109 +18,129 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_client(config):
|
def get_model_client(config):
|
||||||
api_provider=config["api"]
|
api_provider = config["api"]
|
||||||
|
|
||||||
if api_provider == "" or api_provider==None:
|
if api_provider == "" or api_provider == None:
|
||||||
api_provider = "groq"
|
api_provider = "groq"
|
||||||
|
|
||||||
if api_provider == "groq":
|
if api_provider == "groq":
|
||||||
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
|
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
|
||||||
|
|
||||||
elif api_provider == "openai":
|
elif api_provider == "openai":
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
api_key=config["openai_api_key"]
|
api_key = config["openai_api_key"]
|
||||||
if not api_key: #If statement to avoid "invalid filepath" error
|
if not api_key: # If statement to avoid "invalid filepath" error
|
||||||
home_path = os.path.expanduser("~")
|
home_path = os.path.expanduser("~")
|
||||||
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
|
api_key = (
|
||||||
|
open(os.path.join(home_path, ".openai.apikey"), "r")
|
||||||
|
.readline()
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
api_key = api_key
|
api_key = api_key
|
||||||
|
|
||||||
return OpenAIModel(api_key=api_key)
|
return OpenAIModel(api_key=api_key)
|
||||||
|
|
||||||
elif api_provider == "azure":
|
elif api_provider == "azure":
|
||||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
api_key=config["azure_openai_api_key"]
|
api_key = config["azure_openai_api_key"]
|
||||||
if not api_key:
|
if not api_key:
|
||||||
home_path = os.path.expanduser("~")
|
home_path = os.path.expanduser("~")
|
||||||
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()
|
api_key = (
|
||||||
|
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
|
||||||
|
.readline()
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return AzureOpenAIModel(
|
return AzureOpenAIModel(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_endpoint=config["azure_endpoint"],
|
azure_endpoint=config["azure_endpoint"],
|
||||||
api_version=config["azure_api_version"])
|
api_version=config["azure_api_version"],
|
||||||
|
)
|
||||||
|
|
||||||
elif api_provider == "ollama":
|
elif api_provider == "ollama":
|
||||||
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
|
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
|
||||||
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
||||||
return OllamaModel(ollama_api)
|
return OllamaModel(ollama_api)
|
||||||
|
|
||||||
if api_provider == "anthropic":
|
if api_provider == "anthropic":
|
||||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
api_key=config["anthropic_api_key"]
|
api_key = config["anthropic_api_key"]
|
||||||
return AnthropicModel(api_key=api_key)
|
return AnthropicModel(api_key=api_key)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid AI model provider: {api_provider}")
|
raise ValueError(f"Invalid AI model provider: {api_provider}")
|
||||||
|
|
||||||
|
|
||||||
class GroqModel(AIModel):
|
class GroqModel(AIModel):
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
self.client = Groq(api_key=api_key)
|
self.client = Groq(api_key=api_key)
|
||||||
|
|
||||||
def chat(self, messages, model, temperature, max_tokens):
|
def chat(self, messages, model, temperature, max_tokens):
|
||||||
resp = self.client.chat.completions.create(model=model,
|
resp = self.client.chat.completions.create(
|
||||||
messages=messages,
|
model=model,
|
||||||
temperature=temperature,
|
messages=messages,
|
||||||
max_tokens=max_tokens)
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
return resp.choices[0].message.content
|
return resp.choices[0].message.content
|
||||||
|
|
||||||
def moderate(self, message):
|
def moderate(self, message):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModel(AIModel):
|
class OpenAIModel(AIModel):
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
self.client = OpenAI(api_key=api_key)
|
self.client = OpenAI(api_key=api_key)
|
||||||
|
|
||||||
def chat(self, messages, model, temperature, max_tokens):
|
def chat(self, messages, model, temperature, max_tokens):
|
||||||
resp = self.client.chat.completions.create(model=model,
|
resp = self.client.chat.completions.create(
|
||||||
messages=messages,
|
model=model,
|
||||||
temperature=temperature,
|
messages=messages,
|
||||||
max_tokens=max_tokens)
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return resp.choices[0].message.content
|
return resp.choices[0].message.content
|
||||||
|
|
||||||
def moderate(self, message):
|
def moderate(self, message):
|
||||||
return self.client.moderations.create(input=message)
|
return self.client.moderations.create(input=message)
|
||||||
|
|
||||||
|
|
||||||
class OllamaModel(AIModel):
|
class OllamaModel(AIModel):
|
||||||
def __init__(self, host):
|
def __init__(self, host):
|
||||||
self.client = Client(host=host)
|
self.client = Client(host=host)
|
||||||
|
|
||||||
def chat(self, messages, model, temperature, max_tokens):
|
def chat(self, messages, model, temperature, max_tokens):
|
||||||
resp = self.client.chat(model=model,
|
resp = self.client.chat(model=model, messages=messages)
|
||||||
messages=messages)
|
|
||||||
return resp["message"]["content"]
|
return resp["message"]["content"]
|
||||||
|
|
||||||
def moderate(self, message):
|
def moderate(self, message):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIModel(AIModel):
|
class AzureOpenAIModel(AIModel):
|
||||||
def __init__(self, azure_endpoint, api_key, api_version):
|
def __init__(self, azure_endpoint, api_key, api_version):
|
||||||
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)
|
self.client = AzureOpenAI(
|
||||||
|
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
|
||||||
|
)
|
||||||
|
|
||||||
def chat(self, messages, model, temperature, max_tokens):
|
def chat(self, messages, model, temperature, max_tokens):
|
||||||
|
|
||||||
resp = self.client.chat.completions.create(model=model,
|
resp = self.client.chat.completions.create(
|
||||||
messages=messages,
|
model=model,
|
||||||
temperature=temperature,
|
messages=messages,
|
||||||
max_tokens=max_tokens)
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return resp.choices[0].message.content
|
return resp.choices[0].message.content
|
||||||
|
|
||||||
def moderate(self, message):
|
def moderate(self, message):
|
||||||
return self.client.moderations.create(input=message)
|
return self.client.moderations.create(input=message)
|
||||||
|
|
||||||
|
|
||||||
class AnthropicModel(AIModel):
|
class AnthropicModel(AIModel):
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
self.client = Anthropic(api_key=api_key)
|
self.client = Anthropic(api_key=api_key)
|
||||||
@@ -127,19 +148,23 @@ class AnthropicModel(AIModel):
|
|||||||
def chat(self, messages, model, temperature, max_tokens):
|
def chat(self, messages, model, temperature, max_tokens):
|
||||||
## Anthropic requires the system prompt to be passed separately
|
## Anthropic requires the system prompt to be passed separately
|
||||||
## Hence extracting system prompt role from the messages
|
## Hence extracting system prompt role from the messages
|
||||||
## and then passing the messages without the system role
|
## and then passing the messages without the system role
|
||||||
## messages is not subscriptable, so we need to convert it to a list
|
## messages is not subscriptable, so we need to convert it to a list
|
||||||
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")
|
system_prompt = next(
|
||||||
|
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
|
||||||
|
)
|
||||||
|
|
||||||
# Remove system messages from the list
|
# Remove system messages from the list
|
||||||
user_messages = [m for m in messages if m.get("role") != "system"]
|
user_messages = [m for m in messages if m.get("role") != "system"]
|
||||||
resp = self.client.messages.create(model=model,
|
resp = self.client.messages.create(
|
||||||
system=system_prompt,
|
model=model,
|
||||||
messages=user_messages,
|
system=system_prompt,
|
||||||
temperature=temperature,
|
messages=user_messages,
|
||||||
max_tokens=max_tokens)
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return resp.content[0].text
|
return resp.content[0].text
|
||||||
|
|
||||||
def moderate(self, message):
|
def moderate(self, message):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ of every command or script they frequently use.
|
|||||||
Sources:
|
Sources:
|
||||||
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
|
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -35,6 +36,7 @@ from ai_model import AIModel
|
|||||||
CONFIG_FILE = "yolo.yaml"
|
CONFIG_FILE = "yolo.yaml"
|
||||||
PROMPT_FILE = "yolo.prompt"
|
PROMPT_FILE = "yolo.prompt"
|
||||||
|
|
||||||
|
|
||||||
def read_yaml_config() -> any:
|
def read_yaml_config() -> any:
|
||||||
"""
|
"""
|
||||||
Read the configuration file from the executing directory.
|
Read the configuration file from the executing directory.
|
||||||
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
|
|||||||
prompt_path = os.path.dirname(yolo_path)
|
prompt_path = os.path.dirname(yolo_path)
|
||||||
|
|
||||||
config_file = os.path.join(prompt_path, CONFIG_FILE)
|
config_file = os.path.join(prompt_path, CONFIG_FILE)
|
||||||
with open(config_file, 'r', encoding='utf-8') as file:
|
with open(config_file, "r", encoding="utf-8") as file:
|
||||||
return yaml.safe_load(file)
|
return yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
||||||
def set_openai_api_key(config):
|
def set_openai_api_key(config):
|
||||||
"""
|
"""
|
||||||
Set the OpenAI API key by attempting several methods.
|
Set the OpenAI API key by attempting several methods.
|
||||||
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
|
|||||||
if not openai.api_key:
|
if not openai.api_key:
|
||||||
openai.api_key = config["openai_api_key"]
|
openai.api_key = config["openai_api_key"]
|
||||||
|
|
||||||
|
|
||||||
def print_config(config):
|
def print_config(config):
|
||||||
"""
|
"""
|
||||||
Print config information.
|
Print config information.
|
||||||
@@ -113,6 +117,7 @@ def print_config(config):
|
|||||||
print("— Color : " + str(config["suggested_command_color"]))
|
print("— Color : " + str(config["suggested_command_color"]))
|
||||||
print("— Shell : " + str(config["shell"]))
|
print("— Shell : " + str(config["shell"]))
|
||||||
|
|
||||||
|
|
||||||
def get_os_friendly_name():
|
def get_os_friendly_name():
|
||||||
"""
|
"""
|
||||||
Returns a friendly name of the user's operating system.
|
Returns a friendly name of the user's operating system.
|
||||||
@@ -140,6 +145,7 @@ def get_os_friendly_name():
|
|||||||
|
|
||||||
return os_name
|
return os_name
|
||||||
|
|
||||||
|
|
||||||
def get_system_prompt(shell):
|
def get_system_prompt(shell):
|
||||||
"""
|
"""
|
||||||
Retrieves and constructs a system prompt by replacing placeholders
|
Retrieves and constructs a system prompt by replacing placeholders
|
||||||
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
|
|||||||
|
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|
||||||
|
|
||||||
def chat_completion(client, query, config):
|
def chat_completion(client, query, config):
|
||||||
"""
|
"""
|
||||||
Generate a chat-based completion for a given query using a specified model.
|
Generate a chat-based completion for a given query using a specified model.
|
||||||
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
|
|||||||
SystemExit: If the query is an empty string, the function will print an error message and exit.
|
SystemExit: If the query is an empty string, the function will print an error message and exit.
|
||||||
"""
|
"""
|
||||||
if query == "":
|
if query == "":
|
||||||
print ("No user prompt specified.")
|
print("No user prompt specified.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
system_prompt = get_system_prompt(config["shell"])
|
system_prompt = get_system_prompt(config["shell"])
|
||||||
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
|
|||||||
model=config["model"],
|
model=config["model"],
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": query}
|
{"role": "user", "content": query},
|
||||||
],
|
],
|
||||||
temperature=config["temperature"],
|
temperature=config["temperature"],
|
||||||
max_tokens=config["max_tokens"])
|
max_tokens=config["max_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def check_for_issue(response):
|
def check_for_issue(response):
|
||||||
"""
|
"""
|
||||||
Checks the given response for any issues and raise an error when detected.
|
Checks the given response for any issues and raise an error when detected.
|
||||||
@@ -233,9 +242,10 @@ def check_for_issue(response):
|
|||||||
"""
|
"""
|
||||||
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
|
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
|
||||||
if response.lower().startswith(prefixes):
|
if response.lower().startswith(prefixes):
|
||||||
print(colored("There was an issue: "+response, 'red'))
|
print(colored("There was an issue: " + response, "red"))
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
def check_for_markdown(response):
|
def check_for_markdown(response):
|
||||||
"""
|
"""
|
||||||
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
|
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
|
||||||
@@ -251,12 +261,17 @@ def check_for_markdown(response):
|
|||||||
response : str
|
response : str
|
||||||
A response text string that needs to be examined for markdown formatting.
|
A response text string that needs to be examined for markdown formatting.
|
||||||
"""
|
"""
|
||||||
if response.count("```",2):
|
if response.count("```", 2):
|
||||||
print(colored(
|
print(
|
||||||
"The proposed command contains markdown, response not executed directly: \n", 'red'
|
colored(
|
||||||
) + response)
|
"The proposed command contains markdown, response not executed directly: \n",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
+ response
|
||||||
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
def missing_posix_display():
|
def missing_posix_display():
|
||||||
"""
|
"""
|
||||||
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
|
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
|
||||||
@@ -273,7 +288,8 @@ def missing_posix_display():
|
|||||||
"""
|
"""
|
||||||
display = subprocess.check_output("echo $DISPLAY", shell=True)
|
display = subprocess.check_output("echo $DISPLAY", shell=True)
|
||||||
|
|
||||||
return display == b'\n'
|
return display == b"\n"
|
||||||
|
|
||||||
|
|
||||||
def prompt_user_input(config, response):
|
def prompt_user_input(config, response):
|
||||||
"""
|
"""
|
||||||
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
|
|||||||
response : str
|
response : str
|
||||||
The proposed command which is to be printed and may be executed by the user.
|
The proposed command which is to be printed and may be executed by the user.
|
||||||
"""
|
"""
|
||||||
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold']))
|
print(
|
||||||
|
"Command: "
|
||||||
|
+ colored(response, config["suggested_command_color"], attrs=["bold"])
|
||||||
|
)
|
||||||
|
|
||||||
if config["safety"]:
|
if config["safety"]:
|
||||||
modify_text = ""
|
modify_text = ""
|
||||||
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
|
|||||||
if config["modify"]:
|
if config["modify"]:
|
||||||
modify_text = " [m]modify"
|
modify_text = " [m]modify"
|
||||||
|
|
||||||
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
|
prompt_text = (
|
||||||
|
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
|
||||||
|
)
|
||||||
|
|
||||||
if os.name == "posix" and missing_posix_display():
|
if os.name == "posix" and missing_posix_display():
|
||||||
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
|
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
|
||||||
|
|
||||||
print(prompt_text, end = '')
|
print(prompt_text, end="")
|
||||||
|
|
||||||
user_input = input()
|
user_input = input()
|
||||||
else:
|
else:
|
||||||
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
|
|||||||
|
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
def evaluate_input(client, config, user_input, command):
|
def evaluate_input(client, config, user_input, command):
|
||||||
"""
|
"""
|
||||||
Evaluate the user input to either execute, modify, or copy the command.
|
Evaluate the user input to either execute, modify, or copy the command.
|
||||||
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
|
|||||||
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
|
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
|
||||||
|
|
||||||
if user_input.upper() == "M":
|
if user_input.upper() == "M":
|
||||||
print("Modify prompt: ", end = '')
|
print("Modify prompt: ", end="")
|
||||||
modded_query = input()
|
modded_query = input()
|
||||||
modded_response = chat_completion(client, modded_query, config)
|
modded_response = chat_completion(client, modded_query, config)
|
||||||
check_for_issue(modded_response)
|
check_for_issue(modded_response)
|
||||||
@@ -365,14 +387,18 @@ def main():
|
|||||||
Defined starting point of source code.
|
Defined starting point of source code.
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='AI bot that translates your question to a command.'
|
description="AI bot that translates your question to a command."
|
||||||
|
)
|
||||||
|
parser.add_argument("text", nargs="+", help="A sequence of strings")
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--safety",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable safety mode (only useful when safety is off)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c", "--config", action="store_true", help="Print current configuration"
|
||||||
)
|
)
|
||||||
parser.add_argument('text', nargs='+',
|
|
||||||
help='A sequence of strings')
|
|
||||||
parser.add_argument("-s", "--safety", action='store_true',
|
|
||||||
help='Enable safety mode (only useful when safety is off)')
|
|
||||||
parser.add_argument("-c", "--config", action='store_true',
|
|
||||||
help='Print current configuration')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load configurations and set up client
|
# Load configurations and set up client
|
||||||
@@ -403,5 +429,6 @@ def main():
|
|||||||
print()
|
print()
|
||||||
evaluate_input(client, config, user_input, result)
|
evaluate_input(client, config, user_input, result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user