|
@@ -5,8 +5,6 @@ import sys
|
|
|
import time
|
|
|
import ollama
|
|
|
import dotenv
|
|
|
-from openai import OpenAI
|
|
|
-import google.generativeai as genai
|
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
import anthropic
|
|
|
from tqdm import tqdm
|
|
@@ -69,6 +67,10 @@ class LLMClient(ABC):
|
|
|
model_parameters = self.session.llm_model.parameters
|
|
|
if self.session.parameters:
|
|
|
model_parameters.update(self.session.parameters)
|
|
|
+
|
|
|
+ if not model_parameters:
|
|
|
+ model_parameters = {}
|
|
|
+
|
|
|
return model_parameters
|
|
|
|
|
|
@abstractmethod
|
|
@@ -113,6 +115,9 @@ class LLMClient(ABC):
|
|
|
class GoogleGenAI(LLMClient):
|
|
|
def __init__(self, model, session, **kwargs):
|
|
|
super().__init__(model, session, **kwargs)
|
|
|
+
|
|
|
+ import google.generativeai as genai
|
|
|
+
|
|
|
if 'api_key' not in kwargs:
|
|
|
raise CommandError('Google Gen AI API key not found')
|
|
|
genai.configure(api_key=kwargs['api_key'])
|
|
@@ -155,6 +160,8 @@ class GoogleGenAI(LLMClient):
|
|
|
class OpenAIClient(LLMClient):
|
|
|
def __init__(self, model, session, **kwargs):
|
|
|
super().__init__(model, session, **kwargs)
|
|
|
+
|
|
|
+ from openai import OpenAI
|
|
|
self.client = OpenAI(**kwargs)
|
|
|
|
|
|
def send_question(self, question, options=[], context=None):
|
|
@@ -170,7 +177,47 @@ class OpenAIClient(LLMClient):
|
|
|
)
|
|
|
msg_content = completion.choices[0].message.content
|
|
|
return msg_content
|
|
|
+
|
|
|
+class TogheterAIClient(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ from together import Together
|
|
|
+ self.client = Together(**kwargs)
|
|
|
+
|
|
|
+ def send_question(self, question, options=[], context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ model_parameters = self.get_chat_params()
|
|
|
+
|
|
|
+ completion = self.client.chat.completions.create(
|
|
|
+ model=self.model,
|
|
|
+ messages=messages,
|
|
|
+ **model_parameters
|
|
|
+ )
|
|
|
+ msg_content = completion.choices[0].message.content
|
|
|
+ return msg_content
|
|
|
|
|
|
+class GroqClient(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ from groq import Groq
|
|
|
+ self.client = Groq(**kwargs)
|
|
|
+
|
|
|
+ def send_question(self, question, options=[], context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ model_parameters = self.get_chat_params()
|
|
|
+
|
|
|
+ completion = self.client.chat.completions.create(
|
|
|
+ model=self.model,
|
|
|
+ messages=messages,
|
|
|
+ **model_parameters
|
|
|
+ )
|
|
|
+ msg_content = completion.choices[0].message.content
|
|
|
+ return msg_content
|
|
|
+
|
|
|
class AnthropicClient(LLMClient):
|
|
|
def __init__(self, model, session, **kwargs):
|
|
|
super().__init__(model, session, **kwargs)
|
|
@@ -236,6 +283,14 @@ def get_client(llm_model, session):
|
|
|
if not llm_backend.parameters:
|
|
|
raise CommandError('Anthropic parameters not found')
|
|
|
return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
+ elif llm_backend.client_type == 'togheter':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('Togheter.ai parameters not found')
|
|
|
+ return TogheterAIClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
+ elif llm_backend.client_type == 'groq':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('Groq parameters not found')
|
|
|
+ return GroqClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
|
|
|
class Command(BaseCommand):
|
|
|
help = 'Evaluate MedQA'
|