123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- import json
- import re
- import os
- import sys
- import time
- import ollama
- import dotenv
- from google.generativeai.types import HarmCategory, HarmBlockThreshold
- import anthropic
- from tqdm import tqdm
- from abc import ABC, abstractmethod
- from django.core.management.base import (
- BaseCommand,
- CommandError,
- CommandParser
- )
- from commons.models import (
- QA, Dataset, EvalAnswer,
- EvalConfig, RoleMessage, EvalSession
- )
- dotenv.load_dotenv()
- class LLMClient(ABC):
- def __init__(self, model, session=None, **kwargs):
- self.messages = []
- self.model = model
- self.session = session
- self.final_answer_pattern = None
- self.stats = {
- "instruction": "",
- "answer": ""
- }
- if session and isinstance(session, EvalSession):
- self.final_answer_pattern = re.compile(session.config.final_answer_pattern)
- self.messages.append({
- "role": "system",
- "content": self.session.config.sys_prompt
- })
- for role_message in RoleMessage.objects.filter(eval_config=session.config):
- self.messages.append({
- "role": role_message.role,
- "content": role_message.content
- })
- def make_messages(self, question, options, context):
- messages = self.messages.copy()
- option_str = self.option_str(options)
- content = f"Question: {question}"
- if option_str:
- content = f"{content}\nChoices:\n {option_str}"
- if context:
- content = f"{content}\nContext:\n{context}"
- messages.append(
- {"role": "user",
- "content": content
- })
- return messages
-
- def get_chat_params(self):
- model_parameters = self.session.llm_model.parameters
- if self.session.parameters:
- model_parameters.update(self.session.parameters)
-
- if not model_parameters:
- model_parameters = {}
- return model_parameters
- @abstractmethod
- def send_question(self, question, options, context):
- pass
-
- def llm_eval(self, q):
- question = q.question
- options = q.options
- context = q.context
- correct_answer_idx = q.correct_answer_idx
- result = self.send_question(question, options, context)
- self.stats['answer'] = result
- match = re.search(self.final_answer_pattern, result)
- if not match:
- if self.session.answer_interpreter:
- interpreter_client = get_client(self.session.answer_interpreter.llm_model, self.session)
- question = self.session.answer_interpreter.prompt.replace("$QUESTION", self.stats['answer'])
- result = interpreter_client.send_question(question)
- self.stats['answer'] += f"\nInterpreter: {result}"
- match = re.search(self.final_answer_pattern, result)
- if match:
- final_answer = match.group(1)
- if final_answer.upper() == correct_answer_idx.upper():
- return True
- return False
- def messages_2_instruction(self, messages):
- instruction = ""
- for message in messages:
- instruction += f"{message['role'].upper()}: {message['content']}\n\n"
- return instruction
-
- def option_str(self, options=[]):
- options_str = ""
- if not options:
- return None
- for i, option in options.items():
- options_str += f"{i}) {option}\n"
- return options_str
- class GoogleGenAI(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- import google.generativeai as genai
- if 'api_key' not in kwargs:
- raise CommandError('Google Gen AI API key not found')
- genai.configure(api_key=kwargs['api_key'])
- self.client = genai.GenerativeModel(
- model_name = model,
- system_instruction=self.session.config.sys_prompt
- )
- def send_question(self, question, options=[], context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- prompt = messages[-1]['content']
- messages = messages[:-1]
- history = []
- for message in messages:
- if message['role'] == 'system':
- continue
- if message['role'] == 'assistant':
- history.append({
- "role": "model",
- "parts": message['content']
- })
- continue
- history.append({
- "role": message['role'],
- "parts": message['content']
- })
- chat = self.client.start_chat(history=history)
- response = chat.send_message(prompt, safety_settings={
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
- })
- return response.text
-
- class OpenAIClient(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- from openai import OpenAI
- self.client = OpenAI(**kwargs)
- def send_question(self, question, options=[], context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- model_parameters = self.get_chat_params()
- completion = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- **model_parameters
- )
- msg_content = completion.choices[0].message.content
- return msg_content
- class TogheterAIClient(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- from together import Together
- self.client = Together(**kwargs)
- def send_question(self, question, options=[], context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- model_parameters = self.get_chat_params()
- completion = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- **model_parameters
- )
- msg_content = completion.choices[0].message.content
- return msg_content
-
- class GroqClient(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- from groq import Groq
- self.client = Groq(**kwargs)
- def send_question(self, question, options=[], context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- model_parameters = self.get_chat_params()
- completion = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- **model_parameters
- )
- msg_content = completion.choices[0].message.content
- return msg_content
- class AnthropicClient(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- if 'api_key' not in kwargs:
- raise CommandError('Anthropic API key not found')
- self.client = anthropic.Anthropic(api_key=kwargs['api_key'])
- def send_question(self, question, options=[], context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- sys_prompt = messages[0]['content']
- messages = messages[1:]
- # iterate the messages and concatenate the successive messages that have the same role
- new_messages = []
- for message in messages:
- if new_messages and new_messages[-1]['role'] == message['role']:
- new_messages[-1]['content'] += "\n" + message['content']
- else:
- new_messages.append(message)
- model_parameters =self.get_chat_params()
-
- response = self.client.messages.create(
- model=self.model,
- system=sys_prompt,
- messages=new_messages,
- **model_parameters
- )
- return response.content[0].text
- class OllamaClient(LLMClient):
- def __init__(self, model, session, **kwargs):
- super().__init__(model, session, **kwargs)
- self.client = ollama.Client(**kwargs)
- def send_question(self, question, options, context=None):
- messages = self.make_messages(question, options, context)
- self.stats['instruction'] = self.messages_2_instruction(messages)
- model_parameters = self.get_chat_params()
- response = self.client.chat(model=self.model, messages=messages, **model_parameters)
- self.stats['answer'] = response['message']['content']
- return response['message']['content']
- def get_client(llm_model, session):
- llm_backend = llm_model.backend
- if llm_backend.client_type == 'openai':
- if not llm_backend.parameters:
- raise CommandError('OpenAI parameters not found')
- return OpenAIClient(llm_model.name, session, **llm_backend.parameters)
- elif llm_backend.client_type == 'ollama':
- if not llm_backend.parameters:
- raise CommandError('Ollama parameters not found')
- return OllamaClient(llm_model.name, session, **llm_backend.parameters)
- elif llm_backend.client_type == 'genai':
- if not llm_backend.parameters:
- raise CommandError('Google GenAI parameters not found')
- return GoogleGenAI(llm_model.name, session, **llm_backend.parameters)
- elif llm_backend.client_type == 'anthropic':
- if not llm_backend.parameters:
- raise CommandError('Anthropic parameters not found')
- return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
- elif llm_backend.client_type == 'togheter':
- if not llm_backend.parameters:
- raise CommandError('Togheter.ai parameters not found')
- return TogheterAIClient(llm_model.name, session, **llm_backend.parameters)
- elif llm_backend.client_type == 'groq':
- if not llm_backend.parameters:
- raise CommandError('Groq parameters not found')
- return GroqClient(llm_model.name, session, **llm_backend.parameters)
- class Command(BaseCommand):
- help = 'Evaluate MedQA'
- def add_arguments(self, parser: CommandParser) -> None:
- parser.add_argument('--sample-size', type=int, help='Sample size')
- parser.add_argument('--randomize', action='store_true', help='Randomize questions')
- parser.add_argument('--session-id', type=int, help='Session ID')
- parser.add_argument('--continue', action='store_true', help='Continue from last question')
- def handle(self, *args, **options):
- sample_size = options['sample_size']
- randomize = options['randomize']
- session_id = options['session_id']
- continue_from_last = options['continue']
-
- session = EvalSession.objects.filter(id=session_id).first()
- if not session:
- raise CommandError('Session not found')
-
- dataset = session.config.dataset
- target = session.dataset_target
- llm_model = session.llm_model
- llm_backend = llm_model.backend
-
- qa_set = QA.objects.filter(dataset=dataset).order_by('id')
-
- if target:
- qa_set = qa_set.filter(target=target)
- if randomize:
- qa_set = qa_set.order_by('?')
- if sample_size:
- qa_set = qa_set[:sample_size]
- if not qa_set.exists():
- raise CommandError('No questions found')
-
- if continue_from_last:
- for q in qa_set:
- if EvalAnswer.objects.filter(eval_session=session, question=q).exists():
- qa_set = qa_set.exclude(id=q.id)
- print(f"Questions to evaluate: {qa_set.count()}")
- client = get_client(llm_model, session)
-
- stats = {
- "correct": 0,
- "total": qa_set.count()
- }
- for q in tqdm(qa_set, desc="Evaluating dataset"):
- correct = client.llm_eval(q)
- eval_answer = EvalAnswer(
- eval_session=session,
- question=q,
- is_correct=correct,
- instruction=client.stats['instruction'],
- assistant_answer = client.stats['answer'],
- hash=q.hash,
- llm_backend=llm_backend,
- llm_model=llm_model
- )
- if correct:
- stats["correct"] += 1
-
- eval_answer.save()
- if session.request_delay:
- time.sleep(session.request_delay)
- if stats['total']>0:
- print(f"Accuracy: {stats['correct']}/{stats['total']} ({stats['correct']/stats['total']:.2f})")
- self.stdout.write(self.style.SUCCESS('Successfully evaluated MedQA'))
|