Beautified source code with Black

This commit is contained in:
2024-08-20 11:27:30 +02:00
parent 48c77403a3
commit 21ca4378df
2 changed files with 132 additions and 80 deletions
+51 -24
View File
@@ -15,6 +15,7 @@ of every command or script they frequently use.
Sources:
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
"""
import os
import platform
import subprocess
@@ -35,6 +36,7 @@ from ai_model import AIModel
CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt"
def read_yaml_config() -> any:
"""
Read the configuration file from the executing directory.
@@ -51,9 +53,10 @@ def read_yaml_config() -> any:
prompt_path = os.path.dirname(yolo_path)
config_file = os.path.join(prompt_path, CONFIG_FILE)
with open(config_file, 'r', encoding='utf-8') as file:
with open(config_file, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def set_openai_api_key(config):
"""
Set the OpenAI API key by attempting several methods.
@@ -89,6 +92,7 @@ def set_openai_api_key(config):
if not openai.api_key:
openai.api_key = config["openai_api_key"]
def print_config(config):
"""
Print config information.
@@ -113,6 +117,7 @@ def print_config(config):
print("— Color : " + str(config["suggested_command_color"]))
print("— Shell : " + str(config["shell"]))
def get_os_friendly_name():
"""
Returns a friendly name of the user's operating system.
@@ -140,6 +145,7 @@ def get_os_friendly_name():
return os_name
def get_system_prompt(shell):
"""
Retrieves and constructs a system prompt by replacing placeholders
@@ -175,6 +181,7 @@ def get_system_prompt(shell):
return system_prompt
def chat_completion(client, query, config):
"""
Generate a chat-based completion for a given query using a specified model.
@@ -198,7 +205,7 @@ def chat_completion(client, query, config):
SystemExit: If the query is an empty string, the function will print an error message and exit.
"""
if query == "":
print ("No user prompt specified.")
print("No user prompt specified.")
sys.exit(-1)
system_prompt = get_system_prompt(config["shell"])
@@ -211,13 +218,15 @@ def chat_completion(client, query, config):
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
],
temperature=config["temperature"],
max_tokens=config["max_tokens"])
{"role": "user", "content": query},
],
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
return response
def check_for_issue(response):
"""
Checks the given response for any issues and raise an error when detected.
@@ -233,9 +242,10 @@ def check_for_issue(response):
"""
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
if response.lower().startswith(prefixes):
print(colored("There was an issue: "+response, 'red'))
print(colored("There was an issue: " + response, "red"))
sys.exit(-1)
def check_for_markdown(response):
"""
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
@@ -251,12 +261,17 @@ def check_for_markdown(response):
response : str
A response text string that needs to be examined for markdown formatting.
"""
if response.count("```",2):
print(colored(
"The proposed command contains markdown, response not executed directly: \n", 'red'
) + response)
if response.count("```", 2):
print(
colored(
"The proposed command contains markdown, response not executed directly: \n",
"red",
)
+ response
)
sys.exit(-1)
def missing_posix_display():
"""
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
@@ -273,7 +288,8 @@ def missing_posix_display():
"""
display = subprocess.check_output("echo $DISPLAY", shell=True)
return display == b'\n'
return display == b"\n"
def prompt_user_input(config, response):
"""
@@ -294,7 +310,10 @@ def prompt_user_input(config, response):
response : str
The proposed command which is to be printed and may be executed by the user.
"""
print("Command: " + colored(response, config["suggested_command_color"], attrs=['bold']))
print(
"Command: "
+ colored(response, config["suggested_command_color"], attrs=["bold"])
)
if config["safety"]:
modify_text = ""
@@ -302,12 +321,14 @@ def prompt_user_input(config, response):
if config["modify"]:
modify_text = " [m]modify"
prompt_text = "Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
prompt_text = (
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
)
if os.name == "posix" and missing_posix_display():
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
print(prompt_text, end = '')
print(prompt_text, end="")
user_input = input()
else:
@@ -315,6 +336,7 @@ def prompt_user_input(config, response):
return user_input
def evaluate_input(client, config, user_input, command):
"""
Evaluate the user input to either execute, modify, or copy the command.
@@ -344,7 +366,7 @@ def evaluate_input(client, config, user_input, command):
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
if user_input.upper() == "M":
print("Modify prompt: ", end = '')
print("Modify prompt: ", end="")
modded_query = input()
modded_response = chat_completion(client, modded_query, config)
check_for_issue(modded_response)
@@ -365,14 +387,18 @@ def main():
Defined starting point of source code.
"""
parser = argparse.ArgumentParser(
description='AI bot that translates your question to a command.'
description="AI bot that translates your question to a command."
)
parser.add_argument("text", nargs="+", help="A sequence of strings")
parser.add_argument(
"-s",
"--safety",
action="store_true",
help="Enable safety mode (only useful when safety is off)",
)
parser.add_argument(
"-c", "--config", action="store_true", help="Print current configuration"
)
parser.add_argument('text', nargs='+',
help='A sequence of strings')
parser.add_argument("-s", "--safety", action='store_true',
help='Enable safety mode (only useful when safety is off)')
parser.add_argument("-c", "--config", action='store_true',
help='Print current configuration')
args = parser.parse_args()
# Load configurations and set up client
@@ -403,5 +429,6 @@ def main():
print()
evaluate_input(client, config, user_input, result)
if __name__ == "__main__":
main()