Files
yolo-ai-cmdbot/ai_model.py

171 lines
5.2 KiB
Python

from abc import ABC, abstractmethod
from openai import OpenAI
from groq import Groq
from ollama import Client
from openai import AzureOpenAI
from anthropic import Anthropic
import os
class AIModel(ABC):
@abstractmethod
def chat(self, model, messages):
pass
@abstractmethod
def moderate(self, message):
pass
@staticmethod
def get_model_client(config):
api_provider = config["api"]
if api_provider == "" or api_provider == None:
api_provider = "groq"
if api_provider == "groq":
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
elif api_provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
api_key = config["openai_api_key"]
if not api_key: # If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key = (
open(os.path.join(home_path, ".openai.apikey"), "r")
.readline()
.strip()
)
api_key = api_key
return OpenAIModel(api_key=api_key)
elif api_provider == "azure":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
api_key = config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key = (
open(os.path.join(home_path, ".azureopenai.apikey"), "r")
.readline()
.strip()
)
return AzureOpenAIModel(
api_key=api_key,
azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"],
)
elif api_provider == "ollama":
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
# ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
return OllamaModel(ollama_api)
if api_provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
api_key = config["anthropic_api_key"]
return AnthropicModel(api_key=api_key)
else:
raise ValueError(f"Invalid AI model provider: {api_provider}")
class GroqModel(AIModel):
def __init__(self, api_key):
self.client = Groq(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
pass
class OpenAIModel(AIModel):
def __init__(self, api_key):
self.client = OpenAI(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class OllamaModel(AIModel):
def __init__(self, host):
self.client = Client(host=host)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat(model=model, messages=messages)
return resp["message"]["content"]
def moderate(self, message):
pass
class AzureOpenAIModel(AIModel):
def __init__(self, azure_endpoint, api_key, api_version):
self.client = AzureOpenAI(
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
)
def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
def moderate(self, message):
return self.client.moderations.create(input=message)
class AnthropicModel(AIModel):
def __init__(self, api_key):
self.client = Anthropic(api_key=api_key)
def chat(self, messages, model, temperature, max_tokens):
## Anthropic requires the system prompt to be passed separately
## Hence extracting system prompt role from the messages
## and then passing the messages without the system role
## messages is not subscriptable, so we need to convert it to a list
system_prompt = next(
(m.get("content", "") for m in messages if m.get("role") == "system"), ""
)
# Remove system messages from the list
user_messages = [m for m in messages if m.get("role") != "system"]
resp = self.client.messages.create(
model=model,
system=system_prompt,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.content[0].text
def moderate(self, message):
pass