eval_qa.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import json
  2. import re
  3. import os
  4. import sys
  5. import time
  6. import ollama
  7. import dotenv
  8. from openai import OpenAI
  9. import google.generativeai as genai
  10. from google.generativeai.types import HarmCategory, HarmBlockThreshold
  11. import anthropic
  12. from tqdm import tqdm
  13. from abc import ABC, abstractmethod
  14. from django.core.management.base import (
  15. BaseCommand,
  16. CommandError,
  17. CommandParser
  18. )
  19. from commons.models import (
  20. QA, Dataset, EvalAnswer,
  21. EvalConfig, RoleMessage, EvalSession
  22. )
  23. dotenv.load_dotenv()
  24. class LLMClient(ABC):
  25. def __init__(self, model, session=None, **kwargs):
  26. self.messages = []
  27. self.model = model
  28. self.session = session
  29. self.final_answer_pattern = None
  30. self.stats = {
  31. "instruction": "",
  32. "answer": ""
  33. }
  34. if session and isinstance(session, EvalSession):
  35. self.final_answer_pattern = re.compile(session.config.final_answer_pattern)
  36. self.messages.append({
  37. "role": "system",
  38. "content": self.session.config.sys_prompt
  39. })
  40. for role_message in RoleMessage.objects.filter(eval_config=session.config):
  41. self.messages.append({
  42. "role": role_message.role,
  43. "content": role_message.content
  44. })
  45. def make_messages(self, question, options, context):
  46. messages = self.messages.copy()
  47. option_str = self.option_str(options)
  48. content = f"Question: {question}"
  49. if option_str:
  50. content = f"{content}\nChoices:\n {option_str}"
  51. if context:
  52. content = f"{content}\nContext:\n{context}"
  53. messages.append(
  54. {"role": "user",
  55. "content": content
  56. })
  57. return messages
  58. def get_chat_params(self):
  59. model_parameters = self.session.llm_model.parameters
  60. if self.session.parameters:
  61. model_parameters.update(self.session.parameters)
  62. return model_parameters
  63. @abstractmethod
  64. def send_question(self, question, options, context):
  65. pass
  66. def llm_eval(self, q):
  67. question = q.question
  68. options = q.options
  69. context = q.context
  70. correct_answer_idx = q.correct_answer_idx
  71. result = self.send_question(question, options, context)
  72. self.stats['answer'] = result
  73. match = re.search(self.final_answer_pattern, result)
  74. if not match:
  75. if self.session.answer_interpreter:
  76. interpreter_client = get_client(self.session.answer_interpreter.llm_model, self.session)
  77. question = self.session.answer_interpreter.prompt.replace("$QUESTION", self.stats['answer'])
  78. result = interpreter_client.send_question(question)
  79. self.stats['answer'] += f"\nInterpreter: {result}"
  80. match = re.search(self.final_answer_pattern, result)
  81. if match:
  82. final_answer = match.group(1)
  83. if final_answer.upper() == correct_answer_idx.upper():
  84. return True
  85. return False
  86. def messages_2_instruction(self, messages):
  87. instruction = ""
  88. for message in messages:
  89. instruction += f"{message['role'].upper()}: {message['content']}\n\n"
  90. return instruction
  91. def option_str(self, options=[]):
  92. options_str = ""
  93. if not options:
  94. return None
  95. for i, option in options.items():
  96. options_str += f"{i}) {option}\n"
  97. return options_str
  98. class GoogleGenAI(LLMClient):
  99. def __init__(self, model, session, **kwargs):
  100. super().__init__(model, session, **kwargs)
  101. if 'api_key' not in kwargs:
  102. raise CommandError('Google Gen AI API key not found')
  103. genai.configure(api_key=kwargs['api_key'])
  104. self.client = genai.GenerativeModel(
  105. model_name = model,
  106. system_instruction=self.session.config.sys_prompt
  107. )
  108. def send_question(self, question, options=[], context=None):
  109. messages = self.make_messages(question, options, context)
  110. self.stats['instruction'] = self.messages_2_instruction(messages)
  111. prompt = messages[-1]['content']
  112. messages = messages[:-1]
  113. history = []
  114. for message in messages:
  115. if message['role'] == 'system':
  116. continue
  117. if message['role'] == 'assistant':
  118. history.append({
  119. "role": "model",
  120. "parts": message['content']
  121. })
  122. continue
  123. history.append({
  124. "role": message['role'],
  125. "parts": message['content']
  126. })
  127. chat = self.client.start_chat(history=history)
  128. response = chat.send_message(prompt, safety_settings={
  129. HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
  130. HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
  131. HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
  132. HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
  133. })
  134. return response.text
  135. class OpenAIClient(LLMClient):
  136. def __init__(self, model, session, **kwargs):
  137. super().__init__(model, session, **kwargs)
  138. self.client = OpenAI(**kwargs)
  139. def send_question(self, question, options=[], context=None):
  140. messages = self.make_messages(question, options, context)
  141. self.stats['instruction'] = self.messages_2_instruction(messages)
  142. model_parameters = self.get_chat_params()
  143. completion = self.client.chat.completions.create(
  144. model=self.model,
  145. messages=messages,
  146. **model_parameters
  147. )
  148. msg_content = completion.choices[0].message.content
  149. return msg_content
  150. class AnthropicClient(LLMClient):
  151. def __init__(self, model, session, **kwargs):
  152. super().__init__(model, session, **kwargs)
  153. if 'api_key' not in kwargs:
  154. raise CommandError('Anthropic API key not found')
  155. self.client = anthropic.Anthropic(api_key=kwargs['api_key'])
  156. def send_question(self, question, options=[], context=None):
  157. messages = self.make_messages(question, options, context)
  158. self.stats['instruction'] = self.messages_2_instruction(messages)
  159. sys_prompt = messages[0]['content']
  160. messages = messages[1:]
  161. # iterate the messages and concatenate the successive messages that have the same role
  162. new_messages = []
  163. for message in messages:
  164. if new_messages and new_messages[-1]['role'] == message['role']:
  165. new_messages[-1]['content'] += "\n" + message['content']
  166. else:
  167. new_messages.append(message)
  168. model_parameters =self.get_chat_params()
  169. response = self.client.messages.create(
  170. model=self.model,
  171. system=sys_prompt,
  172. messages=new_messages,
  173. **model_parameters
  174. )
  175. return response.content[0].text
  176. class OllamaClient(LLMClient):
  177. def __init__(self, model, session, **kwargs):
  178. super().__init__(model, session, **kwargs)
  179. self.client = ollama.Client(**kwargs)
  180. def send_question(self, question, options, context=None):
  181. messages = self.make_messages(question, options, context)
  182. self.stats['instruction'] = self.messages_2_instruction(messages)
  183. model_parameters = self.get_chat_params()
  184. response = self.client.chat(model=self.model, messages=messages, **model_parameters)
  185. self.stats['answer'] = response['message']['content']
  186. return response['message']['content']
  187. def get_client(llm_model, session):
  188. llm_backend = llm_model.backend
  189. if llm_backend.client_type == 'openai':
  190. if not llm_backend.parameters:
  191. raise CommandError('OpenAI parameters not found')
  192. return OpenAIClient(llm_model.name, session, **llm_backend.parameters)
  193. elif llm_backend.client_type == 'ollama':
  194. if not llm_backend.parameters:
  195. raise CommandError('Ollama parameters not found')
  196. return OllamaClient(llm_model.name, session, **llm_backend.parameters)
  197. elif llm_backend.client_type == 'genai':
  198. if not llm_backend.parameters:
  199. raise CommandError('Google GenAI parameters not found')
  200. return GoogleGenAI(llm_model.name, session, **llm_backend.parameters)
  201. elif llm_backend.client_type == 'anthropic':
  202. if not llm_backend.parameters:
  203. raise CommandError('Anthropic parameters not found')
  204. return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
  205. class Command(BaseCommand):
  206. help = 'Evaluate MedQA'
  207. def add_arguments(self, parser: CommandParser) -> None:
  208. parser.add_argument('--sample-size', type=int, help='Sample size')
  209. parser.add_argument('--randomize', action='store_true', help='Randomize questions')
  210. parser.add_argument('--session-id', type=int, help='Session ID')
  211. parser.add_argument('--continue', action='store_true', help='Continue from last question')
  212. def handle(self, *args, **options):
  213. sample_size = options['sample_size']
  214. randomize = options['randomize']
  215. session_id = options['session_id']
  216. continue_from_last = options['continue']
  217. session = EvalSession.objects.filter(id=session_id).first()
  218. if not session:
  219. raise CommandError('Session not found')
  220. dataset = session.config.dataset
  221. target = session.dataset_target
  222. llm_model = session.llm_model
  223. llm_backend = llm_model.backend
  224. qa_set = QA.objects.filter(dataset=dataset).order_by('id')
  225. if target:
  226. qa_set = qa_set.filter(target=target)
  227. if randomize:
  228. qa_set = qa_set.order_by('?')
  229. if sample_size:
  230. qa_set = qa_set[:sample_size]
  231. if not qa_set.exists():
  232. raise CommandError('No questions found')
  233. if continue_from_last:
  234. for q in qa_set:
  235. if EvalAnswer.objects.filter(eval_session=session, question=q).exists():
  236. qa_set = qa_set.exclude(id=q.id)
  237. print(f"Questions to evaluate: {qa_set.count()}")
  238. client = get_client(llm_model, session)
  239. stats = {
  240. "correct": 0,
  241. "total": qa_set.count()
  242. }
  243. for q in tqdm(qa_set, desc="Evaluating dataset"):
  244. correct = client.llm_eval(q)
  245. eval_answer = EvalAnswer(
  246. eval_session=session,
  247. question=q,
  248. is_correct=correct,
  249. instruction=client.stats['instruction'],
  250. assistant_answer = client.stats['answer'],
  251. hash=q.hash,
  252. llm_backend=llm_backend,
  253. llm_model=llm_model
  254. )
  255. if correct:
  256. stats["correct"] += 1
  257. eval_answer.save()
  258. if session.request_delay:
  259. time.sleep(session.request_delay)
  260. print(f"Accuracy: {stats['correct']}/{stats['total']} ({stats['correct']/stats['total']:.2f})")
  261. self.stdout.write(self.style.SUCCESS('Successfully evaluated MedQA'))