import json import re import os import sys import time import ollama import dotenv from google.generativeai.types import HarmCategory, HarmBlockThreshold import anthropic from tqdm import tqdm from abc import ABC, abstractmethod from django.core.management.base import ( BaseCommand, CommandError, CommandParser ) from commons.models import ( QA, Dataset, EvalAnswer, EvalConfig, RoleMessage, EvalSession ) dotenv.load_dotenv() class LLMClient(ABC): def __init__(self, model, session=None, **kwargs): self.messages = [] self.model = model self.session = session self.final_answer_pattern = None self.stats = { "instruction": "", "answer": "" } if session and isinstance(session, EvalSession): self.final_answer_pattern = re.compile(session.config.final_answer_pattern) self.messages.append({ "role": "system", "content": self.session.config.sys_prompt }) for role_message in RoleMessage.objects.filter(eval_config=session.config): self.messages.append({ "role": role_message.role, "content": role_message.content }) def make_messages(self, question, options, context): messages = self.messages.copy() option_str = self.option_str(options) content = f"Question: {question}" if option_str: content = f"{content}\nChoices:\n {option_str}" if context: content = f"{content}\nContext:\n{context}" messages.append( {"role": "user", "content": content }) return messages def get_chat_params(self): model_parameters = self.session.llm_model.parameters if self.session.parameters: model_parameters.update(self.session.parameters) if not model_parameters: model_parameters = {} return model_parameters @abstractmethod def send_question(self, question, options, context): pass def llm_eval(self, q): question = q.question options = q.options context = q.context correct_answer_idx = q.correct_answer_idx result = self.send_question(question, options, context) self.stats['answer'] = result match = re.search(self.final_answer_pattern, result) if not match: if self.session.answer_interpreter: interpreter_client = get_client(self.session.answer_interpreter.llm_model, self.session) question = self.session.answer_interpreter.prompt.replace("$QUESTION", self.stats['answer']) result = interpreter_client.send_question(question) self.stats['answer'] += f"\nInterpreter: {result}" match = re.search(self.final_answer_pattern, result) if match: final_answer = match.group(1) if final_answer.upper() == correct_answer_idx.upper(): return True return False def messages_2_instruction(self, messages): instruction = "" for message in messages: instruction += f"{message['role'].upper()}: {message['content']}\n\n" return instruction def option_str(self, options=[]): options_str = "" if not options: return None for i, option in options.items(): options_str += f"{i}) {option}\n" return options_str class GoogleGenAI(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) import google.generativeai as genai if 'api_key' not in kwargs: raise CommandError('Google Gen AI API key not found') genai.configure(api_key=kwargs['api_key']) self.client = genai.GenerativeModel( model_name = model, system_instruction=self.session.config.sys_prompt ) def send_question(self, question, options=[], context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) prompt = messages[-1]['content'] messages = messages[:-1] history = [] for message in messages: if message['role'] == 'system': continue if message['role'] == 'assistant': history.append({ "role": "model", "parts": message['content'] }) continue history.append({ "role": message['role'], "parts": message['content'] }) chat = self.client.start_chat(history=history) response = chat.send_message(prompt, safety_settings={ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, }) return response.text class OpenAIClient(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) from openai import OpenAI self.client = OpenAI(**kwargs) def send_question(self, question, options=[], context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) model_parameters = self.get_chat_params() completion = self.client.chat.completions.create( model=self.model, messages=messages, **model_parameters ) msg_content = completion.choices[0].message.content return msg_content class TogheterAIClient(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) from together import Together self.client = Together(**kwargs) def send_question(self, question, options=[], context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) model_parameters = self.get_chat_params() completion = self.client.chat.completions.create( model=self.model, messages=messages, **model_parameters ) msg_content = completion.choices[0].message.content return msg_content class GroqClient(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) from groq import Groq self.client = Groq(**kwargs) def send_question(self, question, options=[], context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) model_parameters = self.get_chat_params() completion = self.client.chat.completions.create( model=self.model, messages=messages, **model_parameters ) msg_content = completion.choices[0].message.content return msg_content class AnthropicClient(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) if 'api_key' not in kwargs: raise CommandError('Anthropic API key not found') self.client = anthropic.Anthropic(api_key=kwargs['api_key']) def send_question(self, question, options=[], context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) sys_prompt = messages[0]['content'] messages = messages[1:] # iterate the messages and concatenate the successive messages that have the same role new_messages = [] for message in messages: if new_messages and new_messages[-1]['role'] == message['role']: new_messages[-1]['content'] += "\n" + message['content'] else: new_messages.append(message) model_parameters =self.get_chat_params() response = self.client.messages.create( model=self.model, system=sys_prompt, messages=new_messages, **model_parameters ) return response.content[0].text class OllamaClient(LLMClient): def __init__(self, model, session, **kwargs): super().__init__(model, session, **kwargs) self.client = ollama.Client(**kwargs) def send_question(self, question, options, context=None): messages = self.make_messages(question, options, context) self.stats['instruction'] = self.messages_2_instruction(messages) model_parameters = self.get_chat_params() response = self.client.chat(model=self.model, messages=messages, **model_parameters) self.stats['answer'] = response['message']['content'] return response['message']['content'] def get_client(llm_model, session): llm_backend = llm_model.backend if llm_backend.client_type == 'openai': if not llm_backend.parameters: raise CommandError('OpenAI parameters not found') return OpenAIClient(llm_model.name, session, **llm_backend.parameters) elif llm_backend.client_type == 'ollama': if not llm_backend.parameters: raise CommandError('Ollama parameters not found') return OllamaClient(llm_model.name, session, **llm_backend.parameters) elif llm_backend.client_type == 'genai': if not llm_backend.parameters: raise CommandError('Google GenAI parameters not found') return GoogleGenAI(llm_model.name, session, **llm_backend.parameters) elif llm_backend.client_type == 'anthropic': if not llm_backend.parameters: raise CommandError('Anthropic parameters not found') return AnthropicClient(llm_model.name, session, **llm_backend.parameters) elif llm_backend.client_type == 'togheter': if not llm_backend.parameters: raise CommandError('Togheter.ai parameters not found') return TogheterAIClient(llm_model.name, session, **llm_backend.parameters) elif llm_backend.client_type == 'groq': if not llm_backend.parameters: raise CommandError('Groq parameters not found') return GroqClient(llm_model.name, session, **llm_backend.parameters) class Command(BaseCommand): help = 'Evaluate MedQA' def add_arguments(self, parser: CommandParser) -> None: parser.add_argument('--sample-size', type=int, help='Sample size') parser.add_argument('--randomize', action='store_true', help='Randomize questions') parser.add_argument('--session-id', type=int, help='Session ID') parser.add_argument('--continue', action='store_true', help='Continue from last question') def handle(self, *args, **options): sample_size = options['sample_size'] randomize = options['randomize'] session_id = options['session_id'] continue_from_last = options['continue'] session = EvalSession.objects.filter(id=session_id).first() if not session: raise CommandError('Session not found') dataset = session.config.dataset target = session.dataset_target llm_model = session.llm_model llm_backend = llm_model.backend qa_set = QA.objects.filter(dataset=dataset).order_by('id') if target: qa_set = qa_set.filter(target=target) if randomize: qa_set = qa_set.order_by('?') if sample_size: qa_set = qa_set[:sample_size] if not qa_set.exists(): raise CommandError('No questions found') if continue_from_last: for q in qa_set: if EvalAnswer.objects.filter(eval_session=session, question=q).exists(): qa_set = qa_set.exclude(id=q.id) print(f"Questions to evaluate: {qa_set.count()}") client = get_client(llm_model, session) stats = { "correct": 0, "total": qa_set.count() } for q in tqdm(qa_set, desc="Evaluating dataset"): correct = client.llm_eval(q) eval_answer = EvalAnswer( eval_session=session, question=q, is_correct=correct, instruction=client.stats['instruction'], assistant_answer = client.stats['answer'], hash=q.hash, llm_backend=llm_backend, llm_model=llm_model ) if correct: stats["correct"] += 1 eval_answer.save() if session.request_delay: time.sleep(session.request_delay) print(f"Accuracy: {stats['correct']}/{stats['total']} ({stats['correct']/stats['total']:.2f})") self.stdout.write(self.style.SUCCESS('Successfully evaluated MedQA'))