""" OpenAI API client for the Edison application. """ import os import time import logging import dotenv from openai import OpenAI from edison.core import prompt_manager from edison.utils import logging_utils logger = logging.getLogger(__name__) def get_api_key(config): """ Get the OpenAI API key from various sources. This function first tries to grab the OpenAI API key from environment variables, if not found, it then looks for the key in the `.openai.apikey` in the home directory, and lastly, it will look in the provided config dictionary. Args: config (dict): A dictionary containing configuration values. It may contain `openai_api_key` as one of the keys. Returns: str: The OpenAI API key """ dotenv.load_dotenv() # Method 1: Read API key from environment variable # The user can set their OpenAI API key by creating a ".env" file in the same # directory as this script or by exporting it to their environment variables. # The file or environment variable should contain the line `OPENAI_API_KEY=""`. api_key = os.getenv("OPENAI_API_KEY") # Method 2: Read API key from a file in the home directory # The user can also place a file named ".openai.apikey" in their home directory, # which includes the API key in raw format. This method might be deprecated in future versions. if not api_key: home_path = os.path.expanduser("~") api_key_path = os.path.join(home_path, ".openai.apikey") if os.path.exists(api_key_path): with open(api_key_path, 'r') as f: api_key = f.read().strip() # Method 3: Read API key from the provided config dictionary # The final method to set the API key is by providing it in the 'config' dictionary under the # key 'openai_api_key'. For instance, in a `edison.yaml` config file, it would appear as # `openai_apikey: `. if not api_key: api_key = config.get("openai_api_key") if not api_key: logger.error("No OpenAI API key found. Please set it in your environment, .env file, or config.") raise ValueError("No OpenAI API key found") return api_key def create_client(config): """ Create and initialize an OpenAI client. Args: config (dict): The configuration dictionary. Returns: OpenAI: An initialized OpenAI client. """ api_key = get_api_key(config) return OpenAI(api_key=api_key) def call_api(client, config, query): """ Call the OpenAI API with the given query. Args: client (OpenAI): The OpenAI client instance. config (dict): Configuration dictionary containing model and parameters. query (str): The user's query string. Returns: str: The generated command as a string. """ if not query: logger.error("No user prompt specified.") raise ValueError("No user prompt specified") # Load the correct prompt based on shell and OS and append the user's prompt prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash")) # Extract the system prompt from the first line system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt try: # Use the modern API pattern response = client.chat.completions.create( model=config.get("model", "gpt-4o-mini"), messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ], temperature=config.get("temperature", 0), max_tokens=config.get("max_tokens", 500), ) # Extract the content from the new response structure return response.choices[0].message.content.strip() except Exception as e: logger.error(f"Error calling OpenAI API: {str(e)}") raise def call_api_streaming(client, config, query, callback): """ Call the OpenAI API with streaming enabled. Args: client (OpenAI): The OpenAI client instance. config (dict): Configuration dictionary containing model and parameters. query (str): The user's query string. callback (callable): Function to call with each token as it arrives. Returns: str: The complete generated command as a string. """ if not query: logger.error("No user prompt specified.") raise ValueError("No user prompt specified") # Load the correct prompt based on shell and OS and append the user's prompt prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash")) # Extract the system prompt from the first line system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt try: # Use the streaming API pattern response = client.chat.completions.create( model=config.get("model", "gpt-4o-mini"), messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ], temperature=config.get("temperature", 0), max_tokens=config.get("max_tokens", 500), stream=True # Enable streaming ) # Initialize an empty string to collect the full response full_response = "" # Process the streaming response for chunk in response: if hasattr(chunk.choices[0].delta, 'content'): content = chunk.choices[0].delta.content if content: # Call the callback with the new content callback(content) # Append to the full response full_response += content return full_response.strip() except Exception as e: logger.error(f"Error calling OpenAI API: {str(e)}") raise def generate_command_streaming(client, config, query, callback, max_retries=3): """ Generate a command using the OpenAI API with streaming and retry logic. Args: client (OpenAI): The OpenAI client instance. config (dict): Configuration dictionary containing model and parameters. query (str): The user's query string. callback (callable): Function to call with each token as it arrives. max_retries (int): Maximum number of retry attempts. Returns: str: The generated command as a string. """ for attempt in range(max_retries): try: return call_api_streaming(client, config, query, callback) except Exception as e: if "rate limit" in str(e).lower() and attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff logger.warning(f"Rate limited. Retrying in {wait_time} seconds...") time.sleep(wait_time) else: logger.error(f"Error after {attempt+1} attempts: {str(e)}") raise def generate_command(client, config, query, max_retries=3): """ Generate a command using the OpenAI API with retry logic. Args: client (OpenAI): The OpenAI client instance. config (dict): Configuration dictionary containing model and parameters. query (str): The user's query string. max_retries (int): Maximum number of retry attempts. Returns: str: The generated command as a string. """ for attempt in range(max_retries): try: return call_api(client, config, query) except Exception as e: if "rate limit" in str(e).lower() and attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff logger.warning(f"Rate limited. Retrying in {wait_time} seconds...") time.sleep(wait_time) else: logger.error(f"Error after {attempt+1} attempts: {str(e)}") raise