Updated to version 0.5
This commit is contained in:
+149
@@ -0,0 +1,149 @@
|
||||
# MIT License
|
||||
# Copyright (c) 2023-2024 wunderwuzzi23
|
||||
# Greetings from Seattle!
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
from groq import Groq
|
||||
from ollama import Client
|
||||
from openai import AzureOpenAI
|
||||
from anthropic import Anthropic
|
||||
import os
|
||||
|
||||
class AIModel(ABC):
|
||||
@abstractmethod
|
||||
def chat(self, model, messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def moderate(self, message):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_model_client(config):
|
||||
api_provider=config["api"]
|
||||
|
||||
if api_provider == "" or api_provider==None:
|
||||
api_provider = "groq"
|
||||
|
||||
if api_provider == "groq":
|
||||
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))
|
||||
|
||||
elif api_provider == "openai":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["openai_api_key"]
|
||||
if not api_key: #If statement to avoid "invalid filepath" error
|
||||
home_path = os.path.expanduser("~")
|
||||
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
|
||||
api_key = api_key
|
||||
|
||||
return OpenAIModel(api_key=api_key)
|
||||
|
||||
elif api_provider == "azure":
|
||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["azure_openai_api_key"]
|
||||
if not api_key:
|
||||
home_path = os.path.expanduser("~")
|
||||
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()
|
||||
|
||||
return AzureOpenAIModel(
|
||||
api_key=api_key,
|
||||
azure_endpoint=config["azure_endpoint"],
|
||||
api_version=config["azure_api_version"])
|
||||
|
||||
elif api_provider == "ollama":
|
||||
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
|
||||
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
|
||||
return OllamaModel(ollama_api)
|
||||
|
||||
if api_provider == "anthropic":
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
api_key=config["anthropic_api_key"]
|
||||
return AnthropicModel(api_key=api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid AI model provider: {api_provider}")
|
||||
|
||||
class GroqModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = Groq(api_key=api_key)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
pass
|
||||
|
||||
class OpenAIModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
return self.client.moderations.create(input=message)
|
||||
|
||||
class OllamaModel(AIModel):
|
||||
def __init__(self, host):
|
||||
self.client = Client(host=host)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
resp = self.client.chat(model=model,
|
||||
messages=messages)
|
||||
return resp["message"]["content"]
|
||||
|
||||
def moderate(self, message):
|
||||
pass
|
||||
|
||||
|
||||
class AzureOpenAIModel(AIModel):
|
||||
def __init__(self, azure_endpoint, api_key, api_version):
|
||||
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
|
||||
resp = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
return resp.choices[0].message.content
|
||||
|
||||
def moderate(self, message):
|
||||
return self.client.moderations.create(input=message)
|
||||
|
||||
class AnthropicModel(AIModel):
|
||||
def __init__(self, api_key):
|
||||
self.client = Anthropic(api_key=api_key)
|
||||
|
||||
def chat(self, messages, model, temperature, max_tokens):
|
||||
## Anthropic requires the system prompt to be passed separately
|
||||
## Hence extracting system prompt role from the messages
|
||||
## and then passing the messages without the system role
|
||||
## messages is not subscriptable, so we need to convert it to a list
|
||||
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")
|
||||
|
||||
# Remove system messages from the list
|
||||
user_messages = [m for m in messages if m.get("role") != "system"]
|
||||
resp = self.client.messages.create(model=model,
|
||||
system=system_prompt,
|
||||
messages=user_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
return resp.content[0].text
|
||||
|
||||
def moderate(self, message):
|
||||
pass
|
||||
Reference in New Issue
Block a user