|
@@ -0,0 +1,314 @@
|
|
|
+import json
|
|
|
+import re
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import time
|
|
|
+import ollama
|
|
|
+import dotenv
|
|
|
+from openai import OpenAI
|
|
|
+import google.generativeai as genai
|
|
|
+from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
+import anthropic
|
|
|
+from tqdm import tqdm
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from django.core.management.base import (
|
|
|
+ BaseCommand,
|
|
|
+ CommandError,
|
|
|
+ CommandParser
|
|
|
+)
|
|
|
+from commons.models import (
|
|
|
+ QA, Dataset, EvalAnswer,
|
|
|
+ EvalConfig, RoleMessage, EvalSession
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+dotenv.load_dotenv()
|
|
|
+
|
|
|
+
|
|
|
+class LLMClient(ABC):
|
|
|
+ def __init__(self, model, session=None, **kwargs):
|
|
|
+ self.messages = []
|
|
|
+ self.model = model
|
|
|
+ self.session = session
|
|
|
+ self.final_answer_pattern = None
|
|
|
+ self.stats = {
|
|
|
+ "instruction": "",
|
|
|
+ "answer": ""
|
|
|
+ }
|
|
|
+
|
|
|
+ if session and isinstance(session, EvalSession):
|
|
|
+ self.final_answer_pattern = re.compile(session.config.final_answer_pattern)
|
|
|
+ self.messages.append({
|
|
|
+ "role": "system",
|
|
|
+ "content": self.session.config.sys_prompt
|
|
|
+ })
|
|
|
+
|
|
|
+ for role_message in RoleMessage.objects.filter(eval_config=session.config):
|
|
|
+ self.messages.append({
|
|
|
+ "role": role_message.role,
|
|
|
+ "content": role_message.content
|
|
|
+ })
|
|
|
+
|
|
|
+ def make_messages(self, question, options, context):
|
|
|
+ messages = self.messages.copy()
|
|
|
+ option_str = self.option_str(options)
|
|
|
+
|
|
|
+ content = f"Question: {question}"
|
|
|
+ if option_str:
|
|
|
+ content = f"{content}\nChoices:\n {option_str}"
|
|
|
+ if context:
|
|
|
+ content = f"{content}\nContext:\n{context}"
|
|
|
+
|
|
|
+ messages.append(
|
|
|
+ {"role": "user",
|
|
|
+ "content": content
|
|
|
+ })
|
|
|
+ return messages
|
|
|
+
|
|
|
+ def get_chat_params(self):
|
|
|
+ model_parameters = self.session.llm_model.parameters
|
|
|
+ if self.session.parameters:
|
|
|
+ model_parameters.update(self.session.parameters)
|
|
|
+ return model_parameters
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def send_question(self, question, options, context):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def llm_eval(self, q):
|
|
|
+ question = q.question
|
|
|
+ options = q.options
|
|
|
+ context = q.context
|
|
|
+ correct_answer_idx = q.correct_answer_idx
|
|
|
+ result = self.send_question(question, options, context)
|
|
|
+ self.stats['answer'] = result
|
|
|
+ match = re.search(self.final_answer_pattern, result)
|
|
|
+ if not match:
|
|
|
+ if self.session.answer_interpreter:
|
|
|
+ interpreter_client = get_client(self.session.answer_interpreter.llm_model)
|
|
|
+ question = self.session.answer_interpreter.prompt.replace("$QUESTION", self.stats['answer'])
|
|
|
+ result = interpreter_client.send_question(question)
|
|
|
+ self.stats['answer'] += f"\nInterpreter: {result}"
|
|
|
+ match = re.search(self.final_answer_pattern, result)
|
|
|
+ if match:
|
|
|
+ final_answer = match.group(1)
|
|
|
+ if final_answer.upper() == correct_answer_idx.upper():
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def messages_2_instruction(self, messages):
|
|
|
+ instruction = ""
|
|
|
+ for message in messages:
|
|
|
+ instruction += f"{message['role'].upper()}: {message['content']}\n\n"
|
|
|
+ return instruction
|
|
|
+
|
|
|
+ def option_str(self, options=[]):
|
|
|
+ options_str = ""
|
|
|
+ if not options:
|
|
|
+ return None
|
|
|
+ for i, option in options.items():
|
|
|
+ options_str += f"{i}) {option}\n"
|
|
|
+ return options_str
|
|
|
+
|
|
|
+class GoogleGenAI(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ if 'api_key' not in kwargs:
|
|
|
+ raise CommandError('Google Gen AI API key not found')
|
|
|
+ genai.configure(api_key=kwargs['api_key'])
|
|
|
+ self.client = genai.GenerativeModel(
|
|
|
+ model_name = model,
|
|
|
+ system_instruction=self.session.config.sys_prompt
|
|
|
+ )
|
|
|
+
|
|
|
+ def send_question(self, question, options=[], context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ prompt = messages[-1]['content']
|
|
|
+ messages = messages[:-1]
|
|
|
+
|
|
|
+ history = []
|
|
|
+ for message in messages:
|
|
|
+ if message['role'] == 'system':
|
|
|
+ continue
|
|
|
+ if message['role'] == 'assistant':
|
|
|
+ history.append({
|
|
|
+ "role": "model",
|
|
|
+ "parts": message['content']
|
|
|
+ })
|
|
|
+ continue
|
|
|
+ history.append({
|
|
|
+ "role": message['role'],
|
|
|
+ "parts": message['content']
|
|
|
+ })
|
|
|
+ chat = self.client.start_chat(history=history)
|
|
|
+ response = chat.send_message(prompt, safety_settings={
|
|
|
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
|
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
+ })
|
|
|
+
|
|
|
+ return response.text
|
|
|
+
|
|
|
+class OpenAIClient(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ self.client = OpenAI(**kwargs)
|
|
|
+
|
|
|
+ def send_question(self, question, options=[], context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ model_parameters = self.get_chat_params()
|
|
|
+
|
|
|
+ completion = self.client.chat.completions.create(
|
|
|
+ model=self.model,
|
|
|
+ messages=messages,
|
|
|
+ **model_parameters
|
|
|
+ )
|
|
|
+ msg_content = completion.choices[0].message.content
|
|
|
+ return msg_content
|
|
|
+
|
|
|
+class AnthropicClient(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ if 'api_key' not in kwargs:
|
|
|
+ raise CommandError('Anthropic API key not found')
|
|
|
+ self.client = anthropic.Anthropic(api_key=kwargs['api_key'])
|
|
|
+
|
|
|
+ def send_question(self, question, options=[], context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ sys_prompt = messages[0]['content']
|
|
|
+ messages = messages[1:]
|
|
|
+
|
|
|
+ # iterate the messages and concatenate the successive messages that have the same role
|
|
|
+ new_messages = []
|
|
|
+ for message in messages:
|
|
|
+ if new_messages and new_messages[-1]['role'] == message['role']:
|
|
|
+ new_messages[-1]['content'] += "\n" + message['content']
|
|
|
+ else:
|
|
|
+ new_messages.append(message)
|
|
|
+
|
|
|
+ model_parameters =self.get_chat_params()
|
|
|
+
|
|
|
+ response = self.client.messages.create(
|
|
|
+ model=self.model,
|
|
|
+ system=sys_prompt,
|
|
|
+ messages=new_messages,
|
|
|
+ **model_parameters
|
|
|
+ )
|
|
|
+ return response.content[0].text
|
|
|
+
|
|
|
+class OllamaClient(LLMClient):
|
|
|
+ def __init__(self, model, session, **kwargs):
|
|
|
+ super().__init__(model, session, **kwargs)
|
|
|
+ self.client = ollama.Client(**kwargs)
|
|
|
+
|
|
|
+ def send_question(self, question, options, context=None):
|
|
|
+ messages = self.make_messages(question, options, context)
|
|
|
+ self.stats['instruction'] = self.messages_2_instruction(messages)
|
|
|
+
|
|
|
+ model_parameters = self.get_chat_params()
|
|
|
+
|
|
|
+ response = self.client.chat(model=self.model, messages=messages, **model_parameters)
|
|
|
+
|
|
|
+ self.stats['answer'] = response['message']['content']
|
|
|
+ return response['message']['content']
|
|
|
+
|
|
|
+def get_client(llm_model, session):
|
|
|
+ llm_backend = llm_model.backend
|
|
|
+ if llm_backend.client_type == 'openai':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('OpenAI parameters not found')
|
|
|
+ return OpenAIClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
+ elif llm_backend.client_type == 'ollama':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('Ollama parameters not found')
|
|
|
+ return OllamaClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
+ elif llm_backend.client_type == 'genai':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('Google GenAI parameters not found')
|
|
|
+ return GoogleGenAI(llm_model.name, session, **llm_backend.parameters)
|
|
|
+ elif llm_backend.client_type == 'anthropic':
|
|
|
+ if not llm_backend.parameters:
|
|
|
+ raise CommandError('Anthropic parameters not found')
|
|
|
+ return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
|
|
|
+
|
|
|
+class Command(BaseCommand):
|
|
|
+ help = 'Evaluate MedQA'
|
|
|
+
|
|
|
+ def add_arguments(self, parser: CommandParser) -> None:
|
|
|
+ parser.add_argument('--sample-size', type=int, help='Sample size')
|
|
|
+ parser.add_argument('--randomize', action='store_true', help='Randomize questions')
|
|
|
+ parser.add_argument('--session-id', type=int, help='Session ID')
|
|
|
+ parser.add_argument('--continue', action='store_true', help='Continue from last question')
|
|
|
+
|
|
|
+ def handle(self, *args, **options):
|
|
|
+ sample_size = options['sample_size']
|
|
|
+ randomize = options['randomize']
|
|
|
+ session_id = options['session_id']
|
|
|
+ continue_from_last = options['continue']
|
|
|
+
|
|
|
+ session = EvalSession.objects.filter(id=session_id).first()
|
|
|
+ if not session:
|
|
|
+ raise CommandError('Session not found')
|
|
|
+
|
|
|
+ dataset = session.config.dataset
|
|
|
+ target = session.dataset_target
|
|
|
+ llm_model = session.llm_model
|
|
|
+ llm_backend = llm_model.backend
|
|
|
+
|
|
|
+
|
|
|
+ qa_set = QA.objects.filter(dataset=dataset).order_by('id')
|
|
|
+
|
|
|
+ if target:
|
|
|
+ qa_set = qa_set.filter(target=target)
|
|
|
+
|
|
|
+ if randomize:
|
|
|
+ qa_set = qa_set.order_by('?')
|
|
|
+
|
|
|
+ if sample_size:
|
|
|
+ qa_set = qa_set[:sample_size]
|
|
|
+
|
|
|
+ if not qa_set.exists():
|
|
|
+ raise CommandError('No questions found')
|
|
|
+
|
|
|
+ if continue_from_last:
|
|
|
+ for q in qa_set:
|
|
|
+ if EvalAnswer.objects.filter(eval_session=session, question=q).exists():
|
|
|
+ qa_set = qa_set.exclude(id=q.id)
|
|
|
+
|
|
|
+ print(f"Questions to evaluate: {qa_set.count()}")
|
|
|
+
|
|
|
+ client = get_client(llm_model, session)
|
|
|
+
|
|
|
+ stats = {
|
|
|
+ "correct": 0,
|
|
|
+ "total": qa_set.count()
|
|
|
+ }
|
|
|
+ for q in tqdm(qa_set, desc="Evaluating dataset"):
|
|
|
+ correct = client.llm_eval(q)
|
|
|
+ eval_answer = EvalAnswer(
|
|
|
+ eval_session=session,
|
|
|
+ question=q,
|
|
|
+ is_correct=correct,
|
|
|
+ instruction=client.stats['instruction'],
|
|
|
+ assistant_answer = client.stats['answer'],
|
|
|
+ hash=q.hash,
|
|
|
+ llm_backend=llm_backend,
|
|
|
+ llm_model=llm_model
|
|
|
+ )
|
|
|
+ if correct:
|
|
|
+ stats["correct"] += 1
|
|
|
+
|
|
|
+ eval_answer.save()
|
|
|
+ if session.request_delay:
|
|
|
+ time.sleep(session.request_delay)
|
|
|
+
|
|
|
+ print(f"Accuracy: {stats['correct']}/{stats['total']} ({stats['correct']/stats['total']:.2f})")
|
|
|
+
|
|
|
+ self.stdout.write(self.style.SUCCESS('Successfully evaluated MedQA'))
|