Fir changes after version was increased to 0.5
This commit is contained in:
@@ -30,6 +30,8 @@ import yaml
|
||||
from termcolor import colored
|
||||
from colorama import init
|
||||
|
||||
from ai_model import AIModel, GroqModel, OpenAIModel, OllamaModel, AnthropicModel, AzureOpenAIModel
|
||||
|
||||
CONFIG_FILE = "yolo.yaml"
|
||||
PROMPT_FILE = "yolo.prompt"
|
||||
|
||||
@@ -87,6 +89,7 @@ def set_openai_api_key(config):
|
||||
if not openai.api_key:
|
||||
openai.api_key = config["openai_api_key"]
|
||||
|
||||
# TODO: Add new configuration paramters
|
||||
def print_config(config):
|
||||
"""
|
||||
Print config information.
|
||||
@@ -135,7 +138,8 @@ def get_os_friendly_name():
|
||||
|
||||
return os_name
|
||||
|
||||
def get_full_prompt(user_prompt, shell):
|
||||
# TODO: Change comment
|
||||
def get_system_prompt(user_prompt, shell):
|
||||
"""
|
||||
Constructs a full prompt string by appending the user's prompt to a predefined prompt template
|
||||
located in the PROMPT_FILE file.
|
||||
@@ -164,44 +168,53 @@ def get_full_prompt(user_prompt, shell):
|
||||
|
||||
## Load the prompt and prep it
|
||||
prompt_file = os.path.join(prompt_path, PROMPT_FILE)
|
||||
pre_prompt = open(prompt_file,"r").read()
|
||||
pre_prompt = pre_prompt.replace("{shell}", shell)
|
||||
pre_prompt = pre_prompt.replace("{os}", get_os_friendly_name())
|
||||
prompt = pre_prompt + user_prompt
|
||||
system_prompt = open(prompt_file,"r").read()
|
||||
system_prompt = system_prompt.replace("{shell}", shell)
|
||||
system_prompt = system_prompt.replace("{os}", get_os_friendly_name())
|
||||
|
||||
return system_prompt
|
||||
|
||||
# Be nice and make it a question.
|
||||
if prompt[-1:] != "?" and prompt[-1:] != ".":
|
||||
prompt+="?"
|
||||
|
||||
return prompt
|
||||
|
||||
def call_open_ai(config, query):
|
||||
def chat_completion(client, query, config):
|
||||
"""
|
||||
Do we have a prompt from the user?
|
||||
Generate a chat-based completion for a given query using a specified model.
|
||||
|
||||
This function sends a user query to a chat model and returns the generated response.
|
||||
|
||||
Parameters:
|
||||
client (object): The client object to interact with the chat service.
|
||||
query (str): The user's query to send to the chat model.
|
||||
config (dict): Configuration settings for the chat service, which should include:
|
||||
- "shell" (str): Type of shell to use in the system prompt.
|
||||
- "model" (str): The specific model to use for the chat completion.
|
||||
- "temperature" (float): Sampling temperature to use for the response generation (higher values mean the model will take more risks).
|
||||
- "max_tokens" (int): Maximum number of tokens to generate in the chat response.
|
||||
|
||||
Returns:
|
||||
dict: The response from the chat model.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the query is an empty string, the function will print an error message and exit.
|
||||
"""
|
||||
if query == "":
|
||||
print ("No user prompt specified.")
|
||||
sys.exit(-1)
|
||||
|
||||
system_prompt = get_system_prompt(query, config["shell"])
|
||||
|
||||
# Load the correct prompt based on shell and OS and append the user's prompt.
|
||||
prompt = get_full_prompt(query, config["shell"])
|
||||
# Ensure query is a question
|
||||
if query[-1:] != "?" and query[-1:] != ".":
|
||||
query += "?"
|
||||
|
||||
# Make the first line also the system prompt
|
||||
system_prompt = prompt[1]
|
||||
#print(prompt)
|
||||
|
||||
# Call the ChatGPT API
|
||||
response = openai.ChatCompletion.create(
|
||||
response = client.chat(
|
||||
model=config["model"],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config["temperature"],
|
||||
max_tokens=config["max_tokens"],
|
||||
)
|
||||
|
||||
return response.choices[0].message.content.strip()
|
||||
{"role": "user", "content": query}
|
||||
],
|
||||
temperature=config["temperature"],
|
||||
max_tokens=config["max_tokens"])
|
||||
|
||||
return response
|
||||
|
||||
def check_for_issue(response):
|
||||
"""
|
||||
@@ -355,8 +368,9 @@ def main():
|
||||
help='Print current configuration')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
# Load configurations and set up client
|
||||
config = read_yaml_config()
|
||||
client = AIModel.get_model_client(config)
|
||||
set_openai_api_key(config)
|
||||
|
||||
# Process parameters
|
||||
@@ -374,12 +388,13 @@ def main():
|
||||
# Enable color output on Windows using colorama
|
||||
init()
|
||||
|
||||
res_command = call_open_ai(config, user_prompt)
|
||||
check_for_issue(res_command)
|
||||
check_for_markdown(res_command)
|
||||
user_input = prompt_user_input(config, res_command)
|
||||
result = chat_completion(client, user_prompt, config)
|
||||
check_for_issue(result)
|
||||
check_for_markdown(result)
|
||||
|
||||
user_input = prompt_user_input(config, result)
|
||||
print()
|
||||
evaluate_input(config, user_input, res_command)
|
||||
evaluate_input(config, user_input, result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user