435 lines
15 KiB
Python
Executable File
435 lines
15 KiB
Python
Executable File
"""
|
|
AI Chatbot to generate shell commands.
|
|
|
|
This script allows the user to ask their question in plain English and translates
|
|
that question into a command that can be run in the shell. The functionalities
|
|
include leveraging OpenAI's GPT models to generate command, verifying newly generated
|
|
commands, checking commands for any unsafe attributes, and allowing the user to
|
|
execute or modify the generated command.
|
|
|
|
This program is an implementation of an AI model used to assist users in
|
|
generating Unix/shell commands or other scripts, based on their natural language
|
|
input. The objective is to aid those users who might not remember the exact syntax
|
|
of every command or script they frequently use.
|
|
|
|
Sources:
|
|
— https://github.com/wunderwuzzi23/yolo-ai-cmdbot
|
|
"""
|
|
|
|
import os
|
|
import platform
|
|
import subprocess
|
|
import sys
|
|
|
|
import argparse
|
|
import distro
|
|
import dotenv
|
|
import openai
|
|
import pyperclip
|
|
import yaml
|
|
|
|
from termcolor import colored
|
|
from colorama import init
|
|
|
|
from ai_model import AIModel
|
|
|
|
CONFIG_FILE = "yolo.yaml"
|
|
PROMPT_FILE = "yolo.prompt"
|
|
|
|
|
|
def read_yaml_config() -> any:
|
|
"""
|
|
Read the configuration file from the executing directory.
|
|
|
|
This function determines the execution folder (which may vary if an alias is set) in order to
|
|
find the configuration file. It reads the file and returns its content in a Python data
|
|
structure.
|
|
|
|
Returns:
|
|
The content of the configuration file. Could be dictionary, list, etc. depending on
|
|
the YAML file structure.
|
|
"""
|
|
yolo_path = os.path.abspath(__file__)
|
|
prompt_path = os.path.dirname(yolo_path)
|
|
|
|
config_file = os.path.join(prompt_path, CONFIG_FILE)
|
|
with open(config_file, "r", encoding="utf-8") as file:
|
|
return yaml.safe_load(file)
|
|
|
|
|
|
def set_openai_api_key(config):
|
|
"""
|
|
Set the OpenAI API key by attempting several methods.
|
|
|
|
This function first tries to grab the OpenAI API key from environment variables,
|
|
if not found, it then looks for the key in the `.openai.apikey` in the home directory,
|
|
and lastly, it will look in the provided config dictionary. It sets the `openai.api_key`
|
|
with the retrieved key.
|
|
|
|
Parameters:
|
|
config (dict): A dictionary containing configuration values.
|
|
It may contain `openai_api_key` as one of the keys.
|
|
"""
|
|
dotenv.load_dotenv()
|
|
|
|
# Method 1: Read API key from environment variable
|
|
# The user can set their OpenAI API key by creating a ".env" file in the same
|
|
# directory as this script or by exporting it to their environment variables.
|
|
# The file or environment variable should contain the line `OPENAI_API_KEY="<yourkey>"`.
|
|
config["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
|
|
|
# Method 2: Read API key from a file in the home directory
|
|
# The user can also place a file named ".openai.apikey" in their home directory,
|
|
# which includes the API key in raw format. This method might be deprecated in future versions.
|
|
if not openai.api_key: # Check this to avoid potential "invalid filepath" error.
|
|
home_path = os.path.expanduser("~")
|
|
openai.api_key_path = os.path.join(home_path, ".openai.apikey")
|
|
|
|
# Method 3: Read API key from the provided config dictionary
|
|
# The final method to set the API key is by providing it in the 'config' dictionary under the
|
|
# key 'openai_api_key'. For instance, in a `yolo.yaml` config file, it would appear as
|
|
# `openai_apikey: <yourkey>`.
|
|
if not openai.api_key:
|
|
openai.api_key = config["openai_api_key"]
|
|
|
|
|
|
def print_config(config):
|
|
"""
|
|
Print config information.
|
|
|
|
Given an input configuration dictionary, this function prints out the
|
|
current configurations per yolo.yaml. This includes details on "model",
|
|
"temperature", "max_tokens", "safety", and "shell".
|
|
|
|
Parameters
|
|
----------
|
|
config : dict
|
|
A dictionary containing the various configuration parameters. It should have
|
|
the following keys: "model", "temperature", "max_tokens", "safety", "shell".
|
|
"""
|
|
print("Current configuration per yolo.yaml:")
|
|
print("— API : " + str(config["api"]))
|
|
print("— Model : " + str(config["model"]))
|
|
print("— Temperature : " + str(config["temperature"]))
|
|
print("— Max. Tokens : " + str(config["max_tokens"]))
|
|
print("— Safety : " + str(bool(config["safety"])))
|
|
print("— Modify : " + str(bool(config["modify"])))
|
|
print("— Color : " + str(config["suggested_command_color"]))
|
|
print("— Shell : " + str(config["shell"]))
|
|
|
|
|
|
def get_os_friendly_name():
|
|
"""
|
|
Returns a friendly name of the user's operating system.
|
|
|
|
The function retrieves the current system platform name using the `platform.system()` function.
|
|
For Linux, it appends the distribution name retrieved from `distro.name(pretty=True)` to give a
|
|
more descriptive representation. For Darwin (Apple's macOS), it appends "macOS" to "Darwin" to
|
|
make the output clearer to the user.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
A friendly name for the user's operating system. It will be one of the following:
|
|
|
|
- "Linux/<distribution name>"
|
|
- "Darwin/macOS"
|
|
- The system string returned by `platform.system()` if it's not Linux or Darwin.
|
|
"""
|
|
os_name = platform.system()
|
|
|
|
if os_name == "Linux":
|
|
os_name = "Linux/" + distro.name(pretty=True)
|
|
elif os_name == "Darwin":
|
|
os_name = "Darwin/macOS"
|
|
|
|
return os_name
|
|
|
|
|
|
def get_system_prompt(shell):
|
|
"""
|
|
Retrieves and constructs a system prompt by replacing placeholders
|
|
in a predefined template with specific values.
|
|
|
|
The function finds the absolute path of the currently executing file
|
|
and, based on this path, identifies the directory of PROMPT_FILE.
|
|
It reads the file and replaces the placeholders {shell} and {os}
|
|
with the provided shell parameter and the friendly name of the operating system, respectively.
|
|
|
|
Parameters
|
|
----------
|
|
user_prompt : str
|
|
The user's prompt (not used in this function, included for context in future use).
|
|
shell : str
|
|
The shell information to be inserted in place of the {shell} placeholder in PROMPT_FILE.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The system prompt, constructed from the template prompt in PROMPT_FILE
|
|
with the shell and OS placeholders replaced with actual values.
|
|
"""
|
|
yolo_path = os.path.abspath(__file__)
|
|
prompt_path = os.path.dirname(yolo_path)
|
|
|
|
## Load the prompt and prep it
|
|
prompt_file = os.path.join(prompt_path, PROMPT_FILE)
|
|
with open(prompt_file, "r", encoding="utf-8") as file:
|
|
system_prompt = file.read()
|
|
system_prompt = system_prompt.replace("{shell}", shell)
|
|
system_prompt = system_prompt.replace("{os}", get_os_friendly_name())
|
|
|
|
return system_prompt
|
|
|
|
|
|
def chat_completion(client, query, config):
|
|
"""
|
|
Generate a chat-based completion for a given query using a specified model.
|
|
|
|
This function sends a user query to a chat model and returns the generated response.
|
|
|
|
Parameters:
|
|
client (object): The client object to interact with the chat service.
|
|
query (str): The user's query to send to the chat model.
|
|
config (dict): Configuration settings for the chat service, which should include:
|
|
- "shell" (str): Type of shell to use in the system prompt.
|
|
- "model" (str): The specific model to use for the chat completion.
|
|
- "temperature" (float): Sampling temperature to use for the response generation (higher
|
|
values mean the model will take more risks).
|
|
- "max_tokens" (int): Maximum number of tokens to generate in the chat response.
|
|
|
|
Returns:
|
|
dict: The response from the chat model.
|
|
|
|
Raises:
|
|
SystemExit: If the query is an empty string, the function will print an error message and exit.
|
|
"""
|
|
if query == "":
|
|
print("No user prompt specified.")
|
|
sys.exit(-1)
|
|
|
|
system_prompt = get_system_prompt(config["shell"])
|
|
|
|
# Ensure query is a question
|
|
if query[-1:] != "?" and query[-1:] != ".":
|
|
query += "?"
|
|
|
|
response = client.chat(
|
|
model=config["model"],
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": query},
|
|
],
|
|
temperature=config["temperature"],
|
|
max_tokens=config["max_tokens"],
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
def check_for_issue(response):
|
|
"""
|
|
Checks the given response for any issues and raise an error when detected.
|
|
|
|
The function checks if the supplied text response begins with any of a set of predefined
|
|
prefixes, which indicate a problem with the response. If such a prefix is found, an error
|
|
message is printed to the console in red, and the program exits with a -1 status code.
|
|
|
|
Parameters
|
|
----------
|
|
response : str
|
|
A response text string that needs to be examined for any issues.
|
|
"""
|
|
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
|
|
if response.lower().startswith(prefixes):
|
|
print(colored("There was an issue: " + response, "red"))
|
|
sys.exit(-1)
|
|
|
|
|
|
def check_for_markdown(response):
|
|
"""
|
|
Checks for the presence of markdown formatting (specifically, code snippet markdown) in the
|
|
provided response.
|
|
|
|
This function considers the presence of markdown formatting (specifically, code block
|
|
formatting marked by ```) in the `response` as an "odd corner case". If such a case is
|
|
detected, it prints an error message in red, along with the markdown-contained response, and
|
|
then terminates the program with a -1 status code.
|
|
|
|
Parameters
|
|
----------
|
|
response : str
|
|
A response text string that needs to be examined for markdown formatting.
|
|
"""
|
|
if response.count("```", 2):
|
|
print(
|
|
colored(
|
|
"The proposed command contains markdown, response not executed directly: \n",
|
|
"red",
|
|
)
|
|
+ response
|
|
)
|
|
sys.exit(-1)
|
|
|
|
|
|
def missing_posix_display():
|
|
"""
|
|
Checks if the DISPLAY environment variable is set in a POSIX-compliant shell.
|
|
|
|
This function runs a shell subprocess that outputs the value of the DISPLAY environment
|
|
variable. It then checks if this value is unset (i.e., equals a newline 'b'\\n'') in the
|
|
current shell environment. If the DISPLAY variable is unset, the function returns `True`
|
|
indicating a "missing" display; otherwise, it returns `False`.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
`True` if the DISPLAY environment variable is unset or empty, `False` otherwise.
|
|
"""
|
|
display = subprocess.check_output("echo $DISPLAY", shell=True)
|
|
|
|
return display == b"\n"
|
|
|
|
|
|
def prompt_user_input(config, response):
|
|
"""
|
|
Print the command proposal in blue and prompt the user for next action based on the safety
|
|
configuration.
|
|
|
|
The user is given options to execute, modify, or copy the command to clipboard if the safety
|
|
configuration is enabled (config["safety"] = True). If the safety configuration is off
|
|
(config["safety"] = False), the function automatically assumes an execution action ('Y' for
|
|
Yes). In a POSIX-compliant shell with no display available (checked using
|
|
`missing_posix_display()`), the 'copy to clipboard' option is omitted.
|
|
|
|
Parameters
|
|
----------
|
|
config : dict
|
|
The system configurations dictionary which contains a "safety" key
|
|
to determine user prompt options.
|
|
response : str
|
|
The proposed command which is to be printed and may be executed by the user.
|
|
"""
|
|
print(
|
|
"Command: "
|
|
+ colored(response, config["suggested_command_color"], attrs=["bold"])
|
|
)
|
|
|
|
if config["safety"]:
|
|
modify_text = ""
|
|
|
|
if config["modify"]:
|
|
modify_text = " [m]modify"
|
|
|
|
prompt_text = (
|
|
"Execute command? [Y]es [n]o [c]opy to clipboard" + modify_text + ": "
|
|
)
|
|
|
|
if os.name == "posix" and missing_posix_display():
|
|
prompt_text = "Execute command? [Y]es [n]o" + modify_text + ": "
|
|
|
|
print(prompt_text, end="")
|
|
|
|
user_input = input()
|
|
else:
|
|
user_input = "Y"
|
|
|
|
return user_input
|
|
|
|
|
|
def evaluate_input(client, config, user_input, command):
|
|
"""
|
|
Evaluate the user input to either execute, modify, or copy the command.
|
|
|
|
Based on the user's response, this function takes action:
|
|
- If the user response is 'Y' or blank, the given command gets executed in the shell.
|
|
- If the user response is 'M', user can modify the command and the modified command is executed
|
|
recursively.
|
|
- If the user response is 'C', the command is copied to the clipboard.
|
|
|
|
Parameters
|
|
----------
|
|
config : dict
|
|
The system configurations dictionary. It should contain a "shell" key specifying the shell
|
|
environment.
|
|
user_input : str
|
|
The user response which determines the course of action. It can be 'Y', 'n', 'm', 'c',
|
|
or '' (empty string).
|
|
command : str
|
|
The command which is either executed, modified, or copied to clipboard.
|
|
"""
|
|
if user_input.upper() == "Y" or user_input == "":
|
|
if config["shell"] == "powershell.exe":
|
|
subprocess.run([config["shell"], "/c", command], shell=False, check=True)
|
|
else:
|
|
# Unix: /bin/bash /bin/zsh: uses -c both Ubuntu and macOS should work, others might not
|
|
subprocess.run([config["shell"], "-c", command], shell=False, check=True)
|
|
|
|
if user_input.upper() == "M":
|
|
print("Modify prompt: ", end="")
|
|
modded_query = input()
|
|
modded_response = chat_completion(client, modded_query, config)
|
|
check_for_issue(modded_response)
|
|
check_for_markdown(modded_response)
|
|
modded_user_input = prompt_user_input(config, modded_response)
|
|
print()
|
|
evaluate_input(client, config, modded_user_input, modded_response)
|
|
|
|
if user_input.upper() == "C":
|
|
if os.name == "posix" and missing_posix_display():
|
|
return
|
|
pyperclip.copy(command)
|
|
print("Copied command to clipboard.")
|
|
|
|
|
|
def main():
|
|
"""
|
|
Defined starting point of source code.
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description="AI bot that translates your question to a command."
|
|
)
|
|
parser.add_argument("text", nargs="+", help="A sequence of strings")
|
|
parser.add_argument(
|
|
"-s",
|
|
"--safety",
|
|
action="store_true",
|
|
help="Enable safety mode (only useful when safety is off)",
|
|
)
|
|
parser.add_argument(
|
|
"-c", "--config", action="store_true", help="Print current configuration"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load configurations and set up client
|
|
config = read_yaml_config()
|
|
client = AIModel.get_model_client(config)
|
|
set_openai_api_key(config)
|
|
|
|
# Process parameters
|
|
user_prompt = " ".join(args.text)
|
|
|
|
if args.safety:
|
|
config["safety"] = args.safety
|
|
|
|
# Unix based SHELL (/bin/bash, /bin/zsh), otherwise assuming it's Windows
|
|
config["shell"] = os.environ.get("SHELL", "powershell.exe")
|
|
|
|
if args.config:
|
|
print_config(config)
|
|
|
|
# Enable color output on Windows using colorama
|
|
init()
|
|
|
|
result = chat_completion(client, user_prompt, config)
|
|
check_for_issue(result)
|
|
check_for_markdown(result)
|
|
|
|
user_input = prompt_user_input(config, result)
|
|
print()
|
|
evaluate_input(client, config, user_input, result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|