219 lines
7.9 KiB
Python
219 lines
7.9 KiB
Python
"""
|
|
OpenAI API client for the Edison application.
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
import dotenv
|
|
from openai import OpenAI
|
|
from edison.core import prompt_manager
|
|
from edison.utils import logging_utils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def get_api_key(config):
|
|
"""
|
|
Get the OpenAI API key from various sources.
|
|
|
|
This function first tries to grab the OpenAI API key from environment variables,
|
|
if not found, it then looks for the key in the `.openai.apikey` in the home directory,
|
|
and lastly, it will look in the provided config dictionary.
|
|
|
|
Args:
|
|
config (dict): A dictionary containing configuration values.
|
|
It may contain `openai_api_key` as one of the keys.
|
|
|
|
Returns:
|
|
str: The OpenAI API key
|
|
"""
|
|
dotenv.load_dotenv()
|
|
|
|
# Method 1: Read API key from environment variable
|
|
# The user can set their OpenAI API key by creating a ".env" file in the same
|
|
# directory as this script or by exporting it to their environment variables.
|
|
# The file or environment variable should contain the line `OPENAI_API_KEY="<yourkey>"`.
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
# Method 2: Read API key from a file in the home directory
|
|
# The user can also place a file named ".openai.apikey" in their home directory,
|
|
# which includes the API key in raw format. This method might be deprecated in future versions.
|
|
if not api_key:
|
|
home_path = os.path.expanduser("~")
|
|
api_key_path = os.path.join(home_path, ".openai.apikey")
|
|
if os.path.exists(api_key_path):
|
|
with open(api_key_path, 'r') as f:
|
|
api_key = f.read().strip()
|
|
|
|
# Method 3: Read API key from the provided config dictionary
|
|
# The final method to set the API key is by providing it in the 'config' dictionary under the
|
|
# key 'openai_api_key'. For instance, in a `edison.yaml` config file, it would appear as
|
|
# `openai_apikey: <yourkey>`.
|
|
if not api_key:
|
|
api_key = config.get("openai_api_key")
|
|
|
|
if not api_key:
|
|
logger.error("No OpenAI API key found. Please set it in your environment, .env file, or config.")
|
|
raise ValueError("No OpenAI API key found")
|
|
|
|
return api_key
|
|
|
|
def create_client(config):
|
|
"""
|
|
Create and initialize an OpenAI client.
|
|
|
|
Args:
|
|
config (dict): The configuration dictionary.
|
|
|
|
Returns:
|
|
OpenAI: An initialized OpenAI client.
|
|
"""
|
|
api_key = get_api_key(config)
|
|
return OpenAI(api_key=api_key)
|
|
|
|
def call_api(client, config, query):
|
|
"""
|
|
Call the OpenAI API with the given query.
|
|
|
|
Args:
|
|
client (OpenAI): The OpenAI client instance.
|
|
config (dict): Configuration dictionary containing model and parameters.
|
|
query (str): The user's query string.
|
|
|
|
Returns:
|
|
str: The generated command as a string.
|
|
"""
|
|
if not query:
|
|
logger.error("No user prompt specified.")
|
|
raise ValueError("No user prompt specified")
|
|
|
|
# Load the correct prompt based on shell and OS and append the user's prompt
|
|
prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash"))
|
|
|
|
# Extract the system prompt from the first line
|
|
system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt
|
|
|
|
try:
|
|
# Use the modern API pattern
|
|
response = client.chat.completions.create(
|
|
model=config.get("model", "gpt-4o-mini"),
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=config.get("temperature", 0),
|
|
max_tokens=config.get("max_tokens", 500),
|
|
)
|
|
|
|
# Extract the content from the new response structure
|
|
return response.choices[0].message.content.strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling OpenAI API: {str(e)}")
|
|
raise
|
|
|
|
def call_api_streaming(client, config, query, callback):
|
|
"""
|
|
Call the OpenAI API with streaming enabled.
|
|
|
|
Args:
|
|
client (OpenAI): The OpenAI client instance.
|
|
config (dict): Configuration dictionary containing model and parameters.
|
|
query (str): The user's query string.
|
|
callback (callable): Function to call with each token as it arrives.
|
|
|
|
Returns:
|
|
str: The complete generated command as a string.
|
|
"""
|
|
if not query:
|
|
logger.error("No user prompt specified.")
|
|
raise ValueError("No user prompt specified")
|
|
|
|
# Load the correct prompt based on shell and OS and append the user's prompt
|
|
prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash"))
|
|
|
|
# Extract the system prompt from the first line
|
|
system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt
|
|
|
|
try:
|
|
# Use the streaming API pattern
|
|
response = client.chat.completions.create(
|
|
model=config.get("model", "gpt-4o-mini"),
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=config.get("temperature", 0),
|
|
max_tokens=config.get("max_tokens", 500),
|
|
stream=True # Enable streaming
|
|
)
|
|
|
|
# Initialize an empty string to collect the full response
|
|
full_response = ""
|
|
|
|
# Process the streaming response
|
|
for chunk in response:
|
|
if hasattr(chunk.choices[0].delta, 'content'):
|
|
content = chunk.choices[0].delta.content
|
|
if content:
|
|
# Call the callback with the new content
|
|
callback(content)
|
|
# Append to the full response
|
|
full_response += content
|
|
|
|
return full_response.strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling OpenAI API: {str(e)}")
|
|
raise
|
|
|
|
def generate_command_streaming(client, config, query, callback, max_retries=3):
|
|
"""
|
|
Generate a command using the OpenAI API with streaming and retry logic.
|
|
|
|
Args:
|
|
client (OpenAI): The OpenAI client instance.
|
|
config (dict): Configuration dictionary containing model and parameters.
|
|
query (str): The user's query string.
|
|
callback (callable): Function to call with each token as it arrives.
|
|
max_retries (int): Maximum number of retry attempts.
|
|
|
|
Returns:
|
|
str: The generated command as a string.
|
|
"""
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return call_api_streaming(client, config, query, callback)
|
|
except Exception as e:
|
|
if "rate limit" in str(e).lower() and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff
|
|
logger.warning(f"Rate limited. Retrying in {wait_time} seconds...")
|
|
time.sleep(wait_time)
|
|
else:
|
|
logger.error(f"Error after {attempt+1} attempts: {str(e)}")
|
|
raise
|
|
|
|
def generate_command(client, config, query, max_retries=3):
|
|
"""
|
|
Generate a command using the OpenAI API with retry logic.
|
|
|
|
Args:
|
|
client (OpenAI): The OpenAI client instance.
|
|
config (dict): Configuration dictionary containing model and parameters.
|
|
query (str): The user's query string.
|
|
max_retries (int): Maximum number of retry attempts.
|
|
|
|
Returns:
|
|
str: The generated command as a string.
|
|
"""
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return call_api(client, config, query)
|
|
except Exception as e:
|
|
if "rate limit" in str(e).lower() and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff
|
|
logger.warning(f"Rate limited. Retrying in {wait_time} seconds...")
|
|
time.sleep(wait_time)
|
|
else:
|
|
logger.error(f"Error after {attempt+1} attempts: {str(e)}")
|
|
raise
|