Bladeren bron

chore: Add support for Togheter.ai and Groq clients

Radu Boncea 8 maanden geleden
bovenliggende
commit
7887d98fd0
2 gewijzigde bestanden met toevoegingen van 59 en 2 verwijderingen
  1. 57 2
      llmeval/commons/management/commands/eval_qa.py
  2. 2 0
      llmeval/commons/models.py

+ 57 - 2
llmeval/commons/management/commands/eval_qa.py

@@ -5,8 +5,6 @@ import sys
 import time
 import ollama
 import dotenv
-from openai import OpenAI
-import google.generativeai as genai
 from google.generativeai.types import HarmCategory, HarmBlockThreshold
 import anthropic
 from tqdm import tqdm
@@ -69,6 +67,10 @@ class LLMClient(ABC):
         model_parameters = self.session.llm_model.parameters
         if self.session.parameters:
             model_parameters.update(self.session.parameters)
+        
+        if not model_parameters:
+            model_parameters = {}
+
         return model_parameters
 
     @abstractmethod
@@ -113,6 +115,9 @@ class LLMClient(ABC):
 class GoogleGenAI(LLMClient):
     def __init__(self, model, session, **kwargs):
         super().__init__(model, session, **kwargs)
+
+        import google.generativeai as genai
+
         if 'api_key' not in kwargs:
             raise CommandError('Google Gen AI API key not found')
         genai.configure(api_key=kwargs['api_key'])
@@ -155,6 +160,8 @@ class GoogleGenAI(LLMClient):
 class OpenAIClient(LLMClient):
     def __init__(self, model, session, **kwargs):
         super().__init__(model, session, **kwargs)
+
+        from openai import OpenAI
         self.client = OpenAI(**kwargs)
 
     def send_question(self, question, options=[], context=None):
@@ -170,7 +177,47 @@ class OpenAIClient(LLMClient):
         )
         msg_content = completion.choices[0].message.content
         return msg_content
+
+class TogheterAIClient(LLMClient):
+    def __init__(self, model, session, **kwargs):
+        super().__init__(model, session, **kwargs)
+        from together import Together
+        self.client = Together(**kwargs)
+
+    def send_question(self, question, options=[], context=None):
+        messages = self.make_messages(question, options, context)
+        self.stats['instruction'] = self.messages_2_instruction(messages)
+
+        model_parameters = self.get_chat_params()
+
+        completion = self.client.chat.completions.create(
+            model=self.model,
+            messages=messages,
+            **model_parameters
+        )
+        msg_content = completion.choices[0].message.content
+        return msg_content
     
+class GroqClient(LLMClient):
+    def __init__(self, model, session, **kwargs):
+        super().__init__(model, session, **kwargs)
+        from groq import Groq
+        self.client = Groq(**kwargs)
+
+    def send_question(self, question, options=[], context=None):
+        messages = self.make_messages(question, options, context)
+        self.stats['instruction'] = self.messages_2_instruction(messages)
+
+        model_parameters = self.get_chat_params()
+
+        completion = self.client.chat.completions.create(
+            model=self.model,
+            messages=messages,
+            **model_parameters
+        )
+        msg_content = completion.choices[0].message.content
+        return msg_content
+
 class AnthropicClient(LLMClient):
     def __init__(self, model, session, **kwargs):
         super().__init__(model, session, **kwargs)
@@ -236,6 +283,14 @@ def get_client(llm_model, session):
         if not llm_backend.parameters:
             raise CommandError('Anthropic parameters not found')
         return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
+    elif llm_backend.client_type == 'togheter':
+        if not llm_backend.parameters:
+            raise CommandError('Togheter.ai parameters not found')        
+        return TogheterAIClient(llm_model.name, session, **llm_backend.parameters)
+    elif llm_backend.client_type == 'groq':
+        if not llm_backend.parameters:
+            raise CommandError('Groq parameters not found')
+        return GroqClient(llm_model.name, session, **llm_backend.parameters)
 
 class Command(BaseCommand):
     help = 'Evaluate MedQA'

+ 2 - 0
llmeval/commons/models.py

@@ -24,6 +24,8 @@ CLIENT_CHOICES = [
     ('ollama', 'Ollama'),
     ('genai', 'Google GenAI'),
     ('anthropic', 'Anthropic'),
+    ('togheter', 'Togheter.ai'),
+    ('groq', 'Groq'),
 ]
 
 class Dataset(models.Model):