Fir changes after version was increased to 0.5

This commit is contained in:
2024-08-19 19:09:47 +02:00
parent 84a8876016
commit 256b616f02
5 changed files with 244 additions and 46 deletions
+49 -34
View File
@@ -30,6 +30,8 @@ import yaml
from termcolor import colored
from colorama import init
from ai_model import AIModel, GroqModel, OpenAIModel, OllamaModel, AnthropicModel, AzureOpenAIModel
CONFIG_FILE = "yolo.yaml"
PROMPT_FILE = "yolo.prompt"
@@ -87,6 +89,7 @@ def set_openai_api_key(config):
if not openai.api_key:
openai.api_key = config["openai_api_key"]
# TODO: Add new configuration paramters
def print_config(config):
"""
Print config information.
@@ -135,7 +138,8 @@ def get_os_friendly_name():
return os_name
def get_full_prompt(user_prompt, shell):
# TODO: Change comment
def get_system_prompt(user_prompt, shell):
"""
Constructs a full prompt string by appending the user's prompt to a predefined prompt template
located in the PROMPT_FILE file.
@@ -164,44 +168,53 @@ def get_full_prompt(user_prompt, shell):
## Load the prompt and prep it
prompt_file = os.path.join(prompt_path, PROMPT_FILE)
pre_prompt = open(prompt_file,"r").read()
pre_prompt = pre_prompt.replace("{shell}", shell)
pre_prompt = pre_prompt.replace("{os}", get_os_friendly_name())
prompt = pre_prompt + user_prompt
system_prompt = open(prompt_file,"r").read()
system_prompt = system_prompt.replace("{shell}", shell)
system_prompt = system_prompt.replace("{os}", get_os_friendly_name())
return system_prompt
# Be nice and make it a question.
if prompt[-1:] != "?" and prompt[-1:] != ".":
prompt+="?"
return prompt
def call_open_ai(config, query):
def chat_completion(client, query, config):
"""
Do we have a prompt from the user?
Generate a chat-based completion for a given query using a specified model.
This function sends a user query to a chat model and returns the generated response.
Parameters:
client (object): The client object to interact with the chat service.
query (str): The user's query to send to the chat model.
config (dict): Configuration settings for the chat service, which should include:
- "shell" (str): Type of shell to use in the system prompt.
- "model" (str): The specific model to use for the chat completion.
- "temperature" (float): Sampling temperature to use for the response generation (higher values mean the model will take more risks).
- "max_tokens" (int): Maximum number of tokens to generate in the chat response.
Returns:
dict: The response from the chat model.
Raises:
SystemExit: If the query is an empty string, the function will print an error message and exit.
"""
if query == "":
print ("No user prompt specified.")
sys.exit(-1)
system_prompt = get_system_prompt(query, config["shell"])
# Load the correct prompt based on shell and OS and append the user's prompt.
prompt = get_full_prompt(query, config["shell"])
# Ensure query is a question
if query[-1:] != "?" and query[-1:] != ".":
query += "?"
# Make the first line also the system prompt
system_prompt = prompt[1]
#print(prompt)
# Call the ChatGPT API
response = openai.ChatCompletion.create(
response = client.chat(
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
return response.choices[0].message.content.strip()
{"role": "user", "content": query}
],
temperature=config["temperature"],
max_tokens=config["max_tokens"])
return response
def check_for_issue(response):
"""
@@ -355,8 +368,9 @@ def main():
help='Print current configuration')
args = parser.parse_args()
# Load configuration
# Load configurations and set up client
config = read_yaml_config()
client = AIModel.get_model_client(config)
set_openai_api_key(config)
# Process parameters
@@ -374,12 +388,13 @@ def main():
# Enable color output on Windows using colorama
init()
res_command = call_open_ai(config, user_prompt)
check_for_issue(res_command)
check_for_markdown(res_command)
user_input = prompt_user_input(config, res_command)
result = chat_completion(client, user_prompt, config)
check_for_issue(result)
check_for_markdown(result)
user_input = prompt_user_input(config, result)
print()
evaluate_input(config, user_input, res_command)
evaluate_input(config, user_input, result)
if __name__ == "__main__":
main()