from abc import ABC, abstractmethod from openai import OpenAI from groq import Groq from ollama import Client from openai import AzureOpenAI from anthropic import Anthropic import os class AIModel(ABC): @abstractmethod def chat(self, model, messages): pass @abstractmethod def moderate(self, message): pass @staticmethod def get_model_client(config): api_provider = config["api"] if api_provider == "" or api_provider == None: api_provider = "groq" if api_provider == "groq": return GroqModel(api_key=os.environ.get("GROQ_API_KEY")) elif api_provider == "openai": api_key = os.getenv("OPENAI_API_KEY") if not api_key: api_key = config["openai_api_key"] if not api_key: # If statement to avoid "invalid filepath" error home_path = os.path.expanduser("~") api_key = ( open(os.path.join(home_path, ".openai.apikey"), "r") .readline() .strip() ) api_key = api_key return OpenAIModel(api_key=api_key) elif api_provider == "azure": api_key = os.getenv("AZURE_OPENAI_API_KEY") if not api_key: api_key = config["azure_openai_api_key"] if not api_key: home_path = os.path.expanduser("~") api_key = ( open(os.path.join(home_path, ".azureopenai.apikey"), "r") .readline() .strip() ) return AzureOpenAIModel( api_key=api_key, azure_endpoint=config["azure_endpoint"], api_version=config["azure_api_version"], ) elif api_provider == "ollama": ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") # ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192") return OllamaModel(ollama_api) if api_provider == "anthropic": api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: api_key = config["anthropic_api_key"] return AnthropicModel(api_key=api_key) else: raise ValueError(f"Invalid AI model provider: {api_provider}") class GroqModel(AIModel): def __init__(self, api_key): self.client = Groq(api_key=api_key) def chat(self, messages, model, temperature, max_tokens): resp = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return resp.choices[0].message.content def moderate(self, message): pass class OpenAIModel(AIModel): def __init__(self, api_key): self.client = OpenAI(api_key=api_key) def chat(self, messages, model, temperature, max_tokens): resp = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return resp.choices[0].message.content def moderate(self, message): return self.client.moderations.create(input=message) class OllamaModel(AIModel): def __init__(self, host): self.client = Client(host=host) def chat(self, messages, model, temperature, max_tokens): resp = self.client.chat(model=model, messages=messages) return resp["message"]["content"] def moderate(self, message): pass class AzureOpenAIModel(AIModel): def __init__(self, azure_endpoint, api_key, api_version): self.client = AzureOpenAI( azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version ) def chat(self, messages, model, temperature, max_tokens): resp = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return resp.choices[0].message.content def moderate(self, message): return self.client.moderations.create(input=message) class AnthropicModel(AIModel): def __init__(self, api_key): self.client = Anthropic(api_key=api_key) def chat(self, messages, model, temperature, max_tokens): ## Anthropic requires the system prompt to be passed separately ## Hence extracting system prompt role from the messages ## and then passing the messages without the system role ## messages is not subscriptable, so we need to convert it to a list system_prompt = next( (m.get("content", "") for m in messages if m.get("role") == "system"), "" ) # Remove system messages from the list user_messages = [m for m in messages if m.get("role") != "system"] resp = self.client.messages.create( model=model, system=system_prompt, messages=user_messages, temperature=temperature, max_tokens=max_tokens, ) return resp.content[0].text def moderate(self, message): pass