Files
yolo-ai-cmdbot/yolo.py
T

435 lines
15 KiB
Python
Executable File

"""
AI Chatbot to generate shell commands.
This script allows the user to ask their question in plain English and translates
that question into a command that can be run in the shell. The functionalities
include leveraging OpenAI's GPT models to generate command, verifying newly generated
commands, checking commands for any unsafe attributes, and allowing the user to
execute or modify the generated command.
This program is an implementation of an AI model used to assist users in
generating Unix/shell commands or other scripts, based on their natural language
input. The objective is to aid those users who might not remember the exact syntax
of every command or script they frequently use.
Sources:
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
"""
import os
import platform
import subprocess
import sys
import argparse
import distro
import dotenv
import openai
import pyperclip
import yaml
from termcolor import colored
from colorama import init
from ai_model import AIModel
CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt"
def read_yaml_config() -> any:
"""
Read the configuration file from the executing directory.
This function determines the execution folder (which may vary if an alias is set) in order to
find the configuration file. It reads the file and returns its content in a Python data
structure.
Returns:
The content of the configuration file. Could be dictionary, list, etc. depending on
the YAML file structure.
"""
yolo_path = os.path.abspath(__file__)
prompt_path = os.path.dirname(yolo_path)
config_file = os.path.join(prompt_path, CONFIG_FILE)
with open(config_file, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def set_openai_api_key(config):
"""
Set the OpenAI API key by attempting several methods.
This function first tries to grab the OpenAI API key from environment variables,
if not found, it then looks for the key in the `.openai.apikey` in the home directory,
and lastly, it will look in the provided config dictionary. It sets the `openai.api_key`
with the retrieved key.
Parameters:
config (dict): A dictionary containing configuration values.
It may contain `openai_api_key` as one of the keys.
"""
dotenv.load_dotenv()
# Method 1: Read API key from environment variable
# The user can set their OpenAI API key by creating a ".env" file in the same
# directory as this script or by exporting it to their environment variables.
# The file or environment variable should contain the line `OPENAI_API_KEY="<yourkey>"`.
config["openai_api_key"] = os.getenv("OPENAI_API_KEY")
# Method 2: Read API key from a file in the home directory
# The user can also place a file named ".openai.apikey" in their home directory,
# which includes the API key in raw format. This method might be deprecated in future versions.
if not openai.api_key: # Check this to avoid potential "invalid filepath" error.
home_path = os.path.expanduser("~")
openai.api_key_path = os.path.join(home_path, ".openai.apikey")
# Method 3: Read API key from the provided config dictionary
# The final method to set the API key is by providing it in the 'config' dictionary under the
# key 'openai_api_key'. For instance, in a `yolo.yaml` config file, it would appear as
# `openai_apikey: <yourkey>`.
if not openai.api_key:
openai.api_key = config["openai_api_key"]
def print_config(config):
"""
Print config information.
Given an input configuration dictionary, this function prints out the
current configurations per yolo.yaml. This includes details on "model",
"temperature", "max_tokens", "safety", and "shell".
Parameters
----------
config : dict
A dictionary containing the various configuration parameters. It should have
the following keys: "model", "temperature", "max_tokens", "safety", "shell".
"""
print("Current configuration per yolo.yaml:")
print("— API : " + str(config["api"]))
print("— Model : " + str(config["model"]))
print("— Temperature : " + str(config["temperature"]))
print("— Max. Tokens : " + str(config["max_tokens"]))
print("— Safety : " + str(bool(config["safety"])))
print("— Modify : " + str(bool(config["modify"])))
print("— Color : " + str(config["suggested_command_color"]))
print("— Shell : " + str(config["shell"]))
def get_os_friendly_name():
"""
Returns a friendly name of the user's operating system.
The function retrieves the current system platform name using the `platform.system()` function.
For Linux, it appends the distribution name retrieved from `distro.name(pretty=True)` to give a
more descriptive representation. For Darwin (Apple's macOS), it appends "macOS" to "Darwin" to
make the output clearer to the user.
Returns
-------
str
A friendly name for the user's operating system. It will be one of the following:
- "Linux/<distribution name>"
- "Darwin/macOS"
- The system string returned by `platform.system()` if it's not Linux or Darwin.
"""
os_name = platform.system()
if os_name == "Linux":
os_name = "Linux/" + distro.name(pretty=True)
elif os_name == "Darwin":
os_name = "Darwin/macOS"
return os_name
def get_system_prompt(shell):
"""
Retrieves and constructs a system prompt by replacing placeholders
in a predefined template with specific values.
The function finds the absolute path of the currently executing file
and, based on this path, identifies the directory of PROMPT_FILE.
It reads the file and replaces the placeholders {shell} and {os}
with the provided shell parameter and the friendly name of the operating system, respectively.
Parameters
----------
user_prompt : str
The user's prompt (not used in this function, included for context in future use).
shell : str
The shell information to be inserted in place of the {shell} placeholder in PROMPT_FILE.
Returns
-------
str
The system prompt, constructed from the template prompt in PROMPT_FILE
with the shell and OS placeholders replaced with actual values.
"""
yolo_path = os.path.abspath(__file__)
prompt_path = os.path.dirname(yolo_path)
## Load the prompt and prep it
prompt_file = os.path.join(prompt_path, PROMPT_FILE)
with open(prompt_file, "r", encoding="utf-8") as file:
system_prompt = file.read()
system_prompt = system_prompt.replace("{shell}", shell)
system_prompt = system_prompt.replace("{os}", get_os_friendly_name())
return system_prompt
def chat_completion(client, query, config):
"""
Generate a chat-based completion for a given query using a specified model.
This function sends a user query to a chat model and returns the generated response.
Parameters:
client (object): The client object to interact with the chat service.
query (str): The user's query to send to the chat model.
config (dict): Configuration settings for the chat service, which should include:
- "shell" (str): Type of shell to use in the system prompt.
- "model" (str): The specific model to use for the chat completion.
- "temperature" (float): Sampling temperature to use for the response generation (higher
values mean the model will take more risks).
- "max_tokens" (int): Maximum number of tokens to generate in the chat response.
Returns:
dict: The response from the chat model.
Raises:
SystemExit: If the query is an empty string, the function will print an error message and exit.
"""
if query == "":
print("No user prompt specified.")
sys.exit(-1)
system_prompt = get_system_prompt(config["shell"])
# Ensure query is a question
if query[-1:] != "?" and query[-1:] != ".":
query += "?"
response = client.chat(
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
return response
def check_for_issue(response):
"""
Checks the given response for any issues and raise an error when detected.
The function checks if the supplied text response begins with any of a set of predefined
prefixes, which indicate a problem with the response. If such a prefix is found, an error
message is printed to the console in red, and the program exits with a -1 status code.
Parameters
----------
response : str
A response text string that needs to be examined for any issues.
"""
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
if response.lower().startswith(prefixes):
print(colored("There was an issue: " + response, "red"))
sys.exit(-1)
def check_for_markdown(response):
"""
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
provided response.
This function considers the presence of markdown formatting (specifically, code block
formatting marked by ```) in the `response` as an "odd corner case". If such a case is
detected, it prints an error message in red, along with the markdown-contained response, and
then terminates the program with a -1 status code.
Parameters
----------
response : str
A response text string that needs to be examined for markdown formatting.
"""
if response.count("```", 2):
print(
colored(
"The proposed command contains markdown, response not executed directly: \n",
"red",
)
+ response
)
sys.exit(-1)
def missing_posix_display():
"""
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
This function runs a shell subprocess that outputs the value of the DISPLAY environment
variable. It then checks if this value is unset (i.e., equals a newline 'b'\\n'') in the
current shell environment. If the DISPLAY variable is unset, the function returns `True`
indicating a "missing" display; otherwise, it returns `False`.
Returns
-------
bool
`True` if the DISPLAY environment variable is unset or empty, `False` otherwise.
"""
display = subprocess.check_output("echo $DISPLAY", shell=True)
return display == b"\n"
def prompt_user_input(config, response):
"""
Print the command proposal in blue and prompt the user for next action based on the safety
configuration.
The user is given options to execute, modify, or copy the command to clipboard if the safety
configuration is enabled (config["safety"] = True). If the safety configuration is off
(config["safety"] = False), the function automatically assumes an execution action ('Y' for
Yes). In a POSIX-compliant shell with no display available (checked using
`missing_posix_display()`), the 'copy to clipboard' option is omitted.
Parameters
----------
config : dict
The system configurations dictionary which contains a "safety" key
to determine user prompt options.
response : str
The proposed command which is to be printed and may be executed by the user.
"""
print(
"Command: "
+ colored(response, config["suggested_command_color"], attrs=["bold"])
)
if config["safety"]:
modify_text = ""
if config["modify"]:
modify_text = " [m]modify"
prompt_text = (
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
)
if os.name == "posix" and missing_posix_display():
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
print(prompt_text, end="")
user_input = input()
else:
user_input = "Y"
return user_input
def evaluate_input(client, config, user_input, command):
"""
Evaluate the user input to either execute, modify, or copy the command.
Based on the user's response, this function takes action:
- If the user response is 'Y' or blank, the given command gets executed in the shell.
- If the user response is 'M', user can modify the command and the modified command is executed
recursively.
- If the user response is 'C', the command is copied to the clipboard.
Parameters
----------
config : dict
The system configurations dictionary. It should contain a "shell" key specifying the shell
environment.
user_input : str
The user response which determines the course of action. It can be 'Y', 'n', 'm', 'c',
or '' (empty string).
command : str
The command which is either executed, modified, or copied to clipboard.
"""
if user_input.upper() == "Y" or user_input == "":
if config["shell"] == "powershell.exe":
subprocess.run([config["shell"], "/c", command], shell=False, check=True)
else:
# Unix: /bin/bash /bin/zsh: uses -c both Ubuntu and macOS should work, others might not
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
if user_input.upper() == "M":
print("Modify prompt: ", end="")
modded_query = input()
modded_response = chat_completion(client, modded_query, config)
check_for_issue(modded_response)
check_for_markdown(modded_response)
modded_user_input = prompt_user_input(config, modded_response)
print()
evaluate_input(client, config, modded_user_input, modded_response)
if user_input.upper() == "C":
if os.name == "posix" and missing_posix_display():
return
pyperclip.copy(command)
print("Copied command to clipboard.")
def main():
"""
Defined starting point of source code.
"""
parser = argparse.ArgumentParser(
description="AI bot that translates your question to a command."
)
parser.add_argument("text", nargs="+", help="A sequence of strings")
parser.add_argument(
"-s",
"--safety",
action="store_true",
help="Enable safety mode (only useful when safety is off)",
)
parser.add_argument(
"-c", "--config", action="store_true", help="Print current configuration"
)
args = parser.parse_args()
# Load configurations and set up client
config = read_yaml_config()
client = AIModel.get_model_client(config)
set_openai_api_key(config)
# Process parameters
user_prompt = " ".join(args.text)
if args.safety:
config["safety"] = args.safety
# Unix based SHELL (/bin/bash, /bin/zsh), otherwise assuming it's Windows
config["shell"] = os.environ.get("SHELL", "powershell.exe")
if args.config:
print_config(config)
# Enable color output on Windows using colorama
init()
result = chat_completion(client, user_prompt, config)
check_for_issue(result)
check_for_markdown(result)
user_input = prompt_user_input(config, result)
print()
evaluate_input(client, config, user_input, result)
if __name__ == "__main__":
main()