171 lines
5.2 KiB
Python
171 lines
5.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from openai import OpenAI
|
|
from groq import Groq
|
|
from ollama import Client
|
|
from openai import AzureOpenAI
|
|
from anthropic import Anthropic
|
|
import os
|
|
|
|
|
|
class AIModel(ABC):
|
|
@abstractmethod
|
|
def chat(self, model, messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def moderate(self, message):
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_model_client(config):
|
|
api_provider = config["api"]
|
|
|
|
if api_provider == "" or api_provider == None:
|
|
api_provider = "groq"
|
|
|
|
if api_provider == "groq":
|
|
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
|
|
|
|
elif api_provider == "openai":
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
api_key = config["openai_api_key"]
|
|
if not api_key: # If statement to avoid "invalid filepath" error
|
|
home_path = os.path.expanduser("~")
|
|
api_key = (
|
|
open(os.path.join(home_path, ".openai.apikey"), "r")
|
|
.readline()
|
|
.strip()
|
|
)
|
|
api_key = api_key
|
|
|
|
return OpenAIModel(api_key=api_key)
|
|
|
|
elif api_provider == "azure":
|
|
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
|
if not api_key:
|
|
api_key = config["azure_openai_api_key"]
|
|
if not api_key:
|
|
home_path = os.path.expanduser("~")
|
|
api_key = (
|
|
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
|
|
.readline()
|
|
.strip()
|
|
)
|
|
|
|
return AzureOpenAIModel(
|
|
api_key=api_key,
|
|
azure_endpoint=config["azure_endpoint"],
|
|
api_version=config["azure_api_version"],
|
|
)
|
|
|
|
elif api_provider == "ollama":
|
|
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
|
|
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
|
return OllamaModel(ollama_api)
|
|
|
|
if api_provider == "anthropic":
|
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
api_key = config["anthropic_api_key"]
|
|
return AnthropicModel(api_key=api_key)
|
|
else:
|
|
raise ValueError(f"Invalid AI model provider: {api_provider}")
|
|
|
|
|
|
class GroqModel(AIModel):
|
|
def __init__(self, api_key):
|
|
self.client = Groq(api_key=api_key)
|
|
|
|
def chat(self, messages, model, temperature, max_tokens):
|
|
resp = self.client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return resp.choices[0].message.content
|
|
|
|
def moderate(self, message):
|
|
pass
|
|
|
|
|
|
class OpenAIModel(AIModel):
|
|
def __init__(self, api_key):
|
|
self.client = OpenAI(api_key=api_key)
|
|
|
|
def chat(self, messages, model, temperature, max_tokens):
|
|
resp = self.client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
return resp.choices[0].message.content
|
|
|
|
def moderate(self, message):
|
|
return self.client.moderations.create(input=message)
|
|
|
|
|
|
class OllamaModel(AIModel):
|
|
def __init__(self, host):
|
|
self.client = Client(host=host)
|
|
|
|
def chat(self, messages, model, temperature, max_tokens):
|
|
resp = self.client.chat(model=model, messages=messages)
|
|
return resp["message"]["content"]
|
|
|
|
def moderate(self, message):
|
|
pass
|
|
|
|
|
|
class AzureOpenAIModel(AIModel):
|
|
def __init__(self, azure_endpoint, api_key, api_version):
|
|
self.client = AzureOpenAI(
|
|
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
|
|
)
|
|
|
|
def chat(self, messages, model, temperature, max_tokens):
|
|
|
|
resp = self.client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
return resp.choices[0].message.content
|
|
|
|
def moderate(self, message):
|
|
return self.client.moderations.create(input=message)
|
|
|
|
|
|
class AnthropicModel(AIModel):
|
|
def __init__(self, api_key):
|
|
self.client = Anthropic(api_key=api_key)
|
|
|
|
def chat(self, messages, model, temperature, max_tokens):
|
|
## Anthropic requires the system prompt to be passed separately
|
|
## Hence extracting system prompt role from the messages
|
|
## and then passing the messages without the system role
|
|
## messages is not subscriptable, so we need to convert it to a list
|
|
system_prompt = next(
|
|
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
|
|
)
|
|
|
|
# Remove system messages from the list
|
|
user_messages = [m for m in messages if m.get("role") != "system"]
|
|
resp = self.client.messages.create(
|
|
model=model,
|
|
system=system_prompt,
|
|
messages=user_messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
return resp.content[0].text
|
|
|
|
def moderate(self, message):
|
|
pass
|