Initial commit
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
OpenAI API client for the Edison application.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import dotenv
|
||||
from openai import OpenAI
|
||||
from edison.core import prompt_manager
|
||||
from edison.utils import logging_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_api_key(config):
|
||||
"""
|
||||
Get the OpenAI API key from various sources.
|
||||
|
||||
This function first tries to grab the OpenAI API key from environment variables,
|
||||
if not found, it then looks for the key in the `.openai.apikey` in the home directory,
|
||||
and lastly, it will look in the provided config dictionary.
|
||||
|
||||
Args:
|
||||
config (dict): A dictionary containing configuration values.
|
||||
It may contain `openai_api_key` as one of the keys.
|
||||
|
||||
Returns:
|
||||
str: The OpenAI API key
|
||||
"""
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Method 1: Read API key from environment variable
|
||||
# The user can set their OpenAI API key by creating a ".env" file in the same
|
||||
# directory as this script or by exporting it to their environment variables.
|
||||
# The file or environment variable should contain the line `OPENAI_API_KEY="<yourkey>"`.
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Method 2: Read API key from a file in the home directory
|
||||
# The user can also place a file named ".openai.apikey" in their home directory,
|
||||
# which includes the API key in raw format. This method might be deprecated in future versions.
|
||||
if not api_key:
|
||||
home_path = os.path.expanduser("~")
|
||||
api_key_path = os.path.join(home_path, ".openai.apikey")
|
||||
if os.path.exists(api_key_path):
|
||||
with open(api_key_path, 'r') as f:
|
||||
api_key = f.read().strip()
|
||||
|
||||
# Method 3: Read API key from the provided config dictionary
|
||||
# The final method to set the API key is by providing it in the 'config' dictionary under the
|
||||
# key 'openai_api_key'. For instance, in a `edison.yaml` config file, it would appear as
|
||||
# `openai_apikey: <yourkey>`.
|
||||
if not api_key:
|
||||
api_key = config.get("openai_api_key")
|
||||
|
||||
if not api_key:
|
||||
logger.error("No OpenAI API key found. Please set it in your environment, .env file, or config.")
|
||||
raise ValueError("No OpenAI API key found")
|
||||
|
||||
return api_key
|
||||
|
||||
def create_client(config):
|
||||
"""
|
||||
Create and initialize an OpenAI client.
|
||||
|
||||
Args:
|
||||
config (dict): The configuration dictionary.
|
||||
|
||||
Returns:
|
||||
OpenAI: An initialized OpenAI client.
|
||||
"""
|
||||
api_key = get_api_key(config)
|
||||
return OpenAI(api_key=api_key)
|
||||
|
||||
def call_api(client, config, query):
|
||||
"""
|
||||
Call the OpenAI API with the given query.
|
||||
|
||||
Args:
|
||||
client (OpenAI): The OpenAI client instance.
|
||||
config (dict): Configuration dictionary containing model and parameters.
|
||||
query (str): The user's query string.
|
||||
|
||||
Returns:
|
||||
str: The generated command as a string.
|
||||
"""
|
||||
if not query:
|
||||
logger.error("No user prompt specified.")
|
||||
raise ValueError("No user prompt specified")
|
||||
|
||||
# Load the correct prompt based on shell and OS and append the user's prompt
|
||||
prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash"))
|
||||
|
||||
# Extract the system prompt from the first line
|
||||
system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt
|
||||
|
||||
try:
|
||||
# Use the modern API pattern
|
||||
response = client.chat.completions.create(
|
||||
model=config.get("model", "gpt-4o-mini"),
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0),
|
||||
max_tokens=config.get("max_tokens", 500),
|
||||
)
|
||||
|
||||
# Extract the content from the new response structure
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise
|
||||
|
||||
def call_api_streaming(client, config, query, callback):
|
||||
"""
|
||||
Call the OpenAI API with streaming enabled.
|
||||
|
||||
Args:
|
||||
client (OpenAI): The OpenAI client instance.
|
||||
config (dict): Configuration dictionary containing model and parameters.
|
||||
query (str): The user's query string.
|
||||
callback (callable): Function to call with each token as it arrives.
|
||||
|
||||
Returns:
|
||||
str: The complete generated command as a string.
|
||||
"""
|
||||
if not query:
|
||||
logger.error("No user prompt specified.")
|
||||
raise ValueError("No user prompt specified")
|
||||
|
||||
# Load the correct prompt based on shell and OS and append the user's prompt
|
||||
prompt = prompt_manager.get_full_prompt(query, config.get("shell", "bash"))
|
||||
|
||||
# Extract the system prompt from the first line
|
||||
system_prompt = prompt.split('\n')[0] if '\n' in prompt else prompt
|
||||
|
||||
try:
|
||||
# Use the streaming API pattern
|
||||
response = client.chat.completions.create(
|
||||
model=config.get("model", "gpt-4o-mini"),
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0),
|
||||
max_tokens=config.get("max_tokens", 500),
|
||||
stream=True # Enable streaming
|
||||
)
|
||||
|
||||
# Initialize an empty string to collect the full response
|
||||
full_response = ""
|
||||
|
||||
# Process the streaming response
|
||||
for chunk in response:
|
||||
if hasattr(chunk.choices[0].delta, 'content'):
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
# Call the callback with the new content
|
||||
callback(content)
|
||||
# Append to the full response
|
||||
full_response += content
|
||||
|
||||
return full_response.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_command_streaming(client, config, query, callback, max_retries=3):
|
||||
"""
|
||||
Generate a command using the OpenAI API with streaming and retry logic.
|
||||
|
||||
Args:
|
||||
client (OpenAI): The OpenAI client instance.
|
||||
config (dict): Configuration dictionary containing model and parameters.
|
||||
query (str): The user's query string.
|
||||
callback (callable): Function to call with each token as it arrives.
|
||||
max_retries (int): Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
str: The generated command as a string.
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return call_api_streaming(client, config, query, callback)
|
||||
except Exception as e:
|
||||
if "rate limit" in str(e).lower() and attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
logger.warning(f"Rate limited. Retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Error after {attempt+1} attempts: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_command(client, config, query, max_retries=3):
|
||||
"""
|
||||
Generate a command using the OpenAI API with retry logic.
|
||||
|
||||
Args:
|
||||
client (OpenAI): The OpenAI client instance.
|
||||
config (dict): Configuration dictionary containing model and parameters.
|
||||
query (str): The user's query string.
|
||||
max_retries (int): Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
str: The generated command as a string.
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return call_api(client, config, query)
|
||||
except Exception as e:
|
||||
if "rate limit" in str(e).lower() and attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
logger.warning(f"Rate limited. Retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Error after {attempt+1} attempts: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Command execution for the Edison application.
|
||||
"""
|
||||
import subprocess
|
||||
import logging
|
||||
from edison.utils import validation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def execute_command(shell, command):
|
||||
"""
|
||||
Execute a shell command.
|
||||
|
||||
Args:
|
||||
shell (str): The shell to use.
|
||||
command (str): The command to execute.
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess: The result of the command execution.
|
||||
|
||||
Raises:
|
||||
subprocess.CalledProcessError: If the command execution fails.
|
||||
"""
|
||||
if validation.is_dangerous_command(command):
|
||||
logger.warning(f"Potentially dangerous command detected: {command}")
|
||||
# We still allow execution but log a warning
|
||||
|
||||
try:
|
||||
if shell == "powershell.exe":
|
||||
result = subprocess.run([shell, "/c", command], shell=False, check=True)
|
||||
else:
|
||||
# Unix: /bin/bash /bin/zsh: uses -c both Ubuntu and macOS should work, others might not
|
||||
result = subprocess.run([shell, "-c", command], shell=False, check=True)
|
||||
|
||||
logger.debug(f"Command executed successfully: {command}")
|
||||
return result
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Command execution failed: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Prompt management for the Edison application.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from edison.utils import os_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_prompt_template_path():
|
||||
"""
|
||||
Get the path to the prompt template file.
|
||||
|
||||
Returns:
|
||||
str: The path to the prompt template file.
|
||||
"""
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(script_dir, "edison.prompt")
|
||||
|
||||
def get_full_prompt(user_prompt, shell):
|
||||
"""
|
||||
Construct a full prompt from the template and user input.
|
||||
|
||||
Args:
|
||||
user_prompt (str): The user's prompt.
|
||||
shell (str): The shell to use.
|
||||
|
||||
Returns:
|
||||
str: The full prompt.
|
||||
"""
|
||||
# Get the path to the prompt template
|
||||
prompt_file = get_prompt_template_path()
|
||||
|
||||
try:
|
||||
# Load the prompt template
|
||||
with open(prompt_file, "r") as f:
|
||||
pre_prompt = f.read()
|
||||
|
||||
# Replace placeholders
|
||||
pre_prompt = pre_prompt.replace("{shell}", shell)
|
||||
pre_prompt = pre_prompt.replace("{os}", os_utils.get_os_friendly_name())
|
||||
|
||||
# Append the user prompt
|
||||
prompt = pre_prompt + user_prompt
|
||||
|
||||
# Make it a question if it's not already
|
||||
if prompt[-1:] != "?" and prompt[-1:] != ".":
|
||||
prompt += "?"
|
||||
|
||||
return prompt
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt template: {e}")
|
||||
# Fallback to a simple prompt
|
||||
return f"Act as a natural language to {shell} command translation engine on {os_utils.get_os_friendly_name()}. {user_prompt}"
|
||||
Reference in New Issue
Block a user