eval_qa.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import json
  2. import re
  3. import os
  4. import sys
  5. import time
  6. import ollama
  7. import dotenv
  8. from google.generativeai.types import HarmCategory, HarmBlockThreshold
  9. import anthropic
  10. from tqdm import tqdm
  11. from abc import ABC, abstractmethod
  12. from django.core.management.base import (
  13. BaseCommand,
  14. CommandError,
  15. CommandParser
  16. )
  17. from commons.models import (
  18. QA, Dataset, EvalAnswer,
  19. EvalConfig, RoleMessage, EvalSession
  20. )
  21. dotenv.load_dotenv()
  22. class LLMClient(ABC):
  23. def __init__(self, model, session=None, **kwargs):
  24. self.messages = []
  25. self.model = model
  26. self.session = session
  27. self.final_answer_pattern = None
  28. self.stats = {
  29. "instruction": "",
  30. "answer": ""
  31. }
  32. if session and isinstance(session, EvalSession):
  33. self.final_answer_pattern = re.compile(session.config.final_answer_pattern)
  34. self.messages.append({
  35. "role": "system",
  36. "content": self.session.config.sys_prompt
  37. })
  38. for role_message in RoleMessage.objects.filter(eval_config=session.config):
  39. self.messages.append({
  40. "role": role_message.role,
  41. "content": role_message.content
  42. })
  43. def make_messages(self, question, options, context):
  44. messages = self.messages.copy()
  45. option_str = self.option_str(options)
  46. content = f"Question: {question}"
  47. if option_str:
  48. content = f"{content}\nChoices:\n {option_str}"
  49. if context:
  50. content = f"{content}\nContext:\n{context}"
  51. messages.append(
  52. {"role": "user",
  53. "content": content
  54. })
  55. return messages
  56. def get_chat_params(self):
  57. model_parameters = self.session.llm_model.parameters
  58. if self.session.parameters:
  59. model_parameters.update(self.session.parameters)
  60. if not model_parameters:
  61. model_parameters = {}
  62. return model_parameters
  63. @abstractmethod
  64. def send_question(self, question, options, context):
  65. pass
  66. def llm_eval(self, q):
  67. question = q.question
  68. options = q.options
  69. context = q.context
  70. correct_answer_idx = q.correct_answer_idx
  71. result = self.send_question(question, options, context)
  72. self.stats['answer'] = result
  73. match = re.search(self.final_answer_pattern, result)
  74. if not match:
  75. if self.session.answer_interpreter:
  76. interpreter_client = get_client(self.session.answer_interpreter.llm_model, self.session)
  77. question = self.session.answer_interpreter.prompt.replace("$QUESTION", self.stats['answer'])
  78. result = interpreter_client.send_question(question)
  79. self.stats['answer'] += f"\nInterpreter: {result}"
  80. match = re.search(self.final_answer_pattern, result)
  81. if match:
  82. final_answer = match.group(1)
  83. if final_answer.upper() == correct_answer_idx.upper():
  84. return True
  85. return False
  86. def messages_2_instruction(self, messages):
  87. instruction = ""
  88. for message in messages:
  89. instruction += f"{message['role'].upper()}: {message['content']}\n\n"
  90. return instruction
  91. def option_str(self, options=[]):
  92. options_str = ""
  93. if not options:
  94. return None
  95. for i, option in options.items():
  96. options_str += f"{i}) {option}\n"
  97. return options_str
  98. class GoogleGenAI(LLMClient):
  99. def __init__(self, model, session, **kwargs):
  100. super().__init__(model, session, **kwargs)
  101. import google.generativeai as genai
  102. if 'api_key' not in kwargs:
  103. raise CommandError('Google Gen AI API key not found')
  104. genai.configure(api_key=kwargs['api_key'])
  105. self.client = genai.GenerativeModel(
  106. model_name = model,
  107. system_instruction=self.session.config.sys_prompt
  108. )
  109. def send_question(self, question, options=[], context=None):
  110. messages = self.make_messages(question, options, context)
  111. self.stats['instruction'] = self.messages_2_instruction(messages)
  112. prompt = messages[-1]['content']
  113. messages = messages[:-1]
  114. history = []
  115. for message in messages:
  116. if message['role'] == 'system':
  117. continue
  118. if message['role'] == 'assistant':
  119. history.append({
  120. "role": "model",
  121. "parts": message['content']
  122. })
  123. continue
  124. history.append({
  125. "role": message['role'],
  126. "parts": message['content']
  127. })
  128. chat = self.client.start_chat(history=history)
  129. response = chat.send_message(prompt, safety_settings={
  130. HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
  131. HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
  132. HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
  133. HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
  134. })
  135. return response.text
  136. class OpenAIClient(LLMClient):
  137. def __init__(self, model, session, **kwargs):
  138. super().__init__(model, session, **kwargs)
  139. from openai import OpenAI
  140. self.client = OpenAI(**kwargs)
  141. def send_question(self, question, options=[], context=None):
  142. messages = self.make_messages(question, options, context)
  143. self.stats['instruction'] = self.messages_2_instruction(messages)
  144. model_parameters = self.get_chat_params()
  145. completion = self.client.chat.completions.create(
  146. model=self.model,
  147. messages=messages,
  148. **model_parameters
  149. )
  150. msg_content = completion.choices[0].message.content
  151. return msg_content
  152. class TogheterAIClient(LLMClient):
  153. def __init__(self, model, session, **kwargs):
  154. super().__init__(model, session, **kwargs)
  155. from together import Together
  156. self.client = Together(**kwargs)
  157. def send_question(self, question, options=[], context=None):
  158. messages = self.make_messages(question, options, context)
  159. self.stats['instruction'] = self.messages_2_instruction(messages)
  160. model_parameters = self.get_chat_params()
  161. completion = self.client.chat.completions.create(
  162. model=self.model,
  163. messages=messages,
  164. **model_parameters
  165. )
  166. msg_content = completion.choices[0].message.content
  167. return msg_content
  168. class GroqClient(LLMClient):
  169. def __init__(self, model, session, **kwargs):
  170. super().__init__(model, session, **kwargs)
  171. from groq import Groq
  172. self.client = Groq(**kwargs)
  173. def send_question(self, question, options=[], context=None):
  174. messages = self.make_messages(question, options, context)
  175. self.stats['instruction'] = self.messages_2_instruction(messages)
  176. model_parameters = self.get_chat_params()
  177. completion = self.client.chat.completions.create(
  178. model=self.model,
  179. messages=messages,
  180. **model_parameters
  181. )
  182. msg_content = completion.choices[0].message.content
  183. return msg_content
  184. class AnthropicClient(LLMClient):
  185. def __init__(self, model, session, **kwargs):
  186. super().__init__(model, session, **kwargs)
  187. if 'api_key' not in kwargs:
  188. raise CommandError('Anthropic API key not found')
  189. self.client = anthropic.Anthropic(api_key=kwargs['api_key'])
  190. def send_question(self, question, options=[], context=None):
  191. messages = self.make_messages(question, options, context)
  192. self.stats['instruction'] = self.messages_2_instruction(messages)
  193. sys_prompt = messages[0]['content']
  194. messages = messages[1:]
  195. # iterate the messages and concatenate the successive messages that have the same role
  196. new_messages = []
  197. for message in messages:
  198. if new_messages and new_messages[-1]['role'] == message['role']:
  199. new_messages[-1]['content'] += "\n" + message['content']
  200. else:
  201. new_messages.append(message)
  202. model_parameters =self.get_chat_params()
  203. response = self.client.messages.create(
  204. model=self.model,
  205. system=sys_prompt,
  206. messages=new_messages,
  207. **model_parameters
  208. )
  209. return response.content[0].text
  210. class OllamaClient(LLMClient):
  211. def __init__(self, model, session, **kwargs):
  212. super().__init__(model, session, **kwargs)
  213. self.client = ollama.Client(**kwargs)
  214. def send_question(self, question, options, context=None):
  215. messages = self.make_messages(question, options, context)
  216. self.stats['instruction'] = self.messages_2_instruction(messages)
  217. model_parameters = self.get_chat_params()
  218. response = self.client.chat(model=self.model, messages=messages, **model_parameters)
  219. self.stats['answer'] = response['message']['content']
  220. return response['message']['content']
  221. def get_client(llm_model, session):
  222. llm_backend = llm_model.backend
  223. if llm_backend.client_type == 'openai':
  224. if not llm_backend.parameters:
  225. raise CommandError('OpenAI parameters not found')
  226. return OpenAIClient(llm_model.name, session, **llm_backend.parameters)
  227. elif llm_backend.client_type == 'ollama':
  228. if not llm_backend.parameters:
  229. raise CommandError('Ollama parameters not found')
  230. return OllamaClient(llm_model.name, session, **llm_backend.parameters)
  231. elif llm_backend.client_type == 'genai':
  232. if not llm_backend.parameters:
  233. raise CommandError('Google GenAI parameters not found')
  234. return GoogleGenAI(llm_model.name, session, **llm_backend.parameters)
  235. elif llm_backend.client_type == 'anthropic':
  236. if not llm_backend.parameters:
  237. raise CommandError('Anthropic parameters not found')
  238. return AnthropicClient(llm_model.name, session, **llm_backend.parameters)
  239. elif llm_backend.client_type == 'togheter':
  240. if not llm_backend.parameters:
  241. raise CommandError('Togheter.ai parameters not found')
  242. return TogheterAIClient(llm_model.name, session, **llm_backend.parameters)
  243. elif llm_backend.client_type == 'groq':
  244. if not llm_backend.parameters:
  245. raise CommandError('Groq parameters not found')
  246. return GroqClient(llm_model.name, session, **llm_backend.parameters)
  247. class Command(BaseCommand):
  248. help = 'Evaluate MedQA'
  249. def add_arguments(self, parser: CommandParser) -> None:
  250. parser.add_argument('--sample-size', type=int, help='Sample size')
  251. parser.add_argument('--randomize', action='store_true', help='Randomize questions')
  252. parser.add_argument('--session-id', type=int, help='Session ID')
  253. parser.add_argument('--continue', action='store_true', help='Continue from last question')
  254. def handle(self, *args, **options):
  255. sample_size = options['sample_size']
  256. randomize = options['randomize']
  257. session_id = options['session_id']
  258. continue_from_last = options['continue']
  259. session = EvalSession.objects.filter(id=session_id).first()
  260. if not session:
  261. raise CommandError('Session not found')
  262. dataset = session.config.dataset
  263. target = session.dataset_target
  264. llm_model = session.llm_model
  265. llm_backend = llm_model.backend
  266. qa_set = QA.objects.filter(dataset=dataset).order_by('id')
  267. if target:
  268. qa_set = qa_set.filter(target=target)
  269. if randomize:
  270. qa_set = qa_set.order_by('?')
  271. if sample_size:
  272. qa_set = qa_set[:sample_size]
  273. if not qa_set.exists():
  274. raise CommandError('No questions found')
  275. if continue_from_last:
  276. for q in qa_set:
  277. if EvalAnswer.objects.filter(eval_session=session, question=q).exists():
  278. qa_set = qa_set.exclude(id=q.id)
  279. print(f"Questions to evaluate: {qa_set.count()}")
  280. client = get_client(llm_model, session)
  281. stats = {
  282. "correct": 0,
  283. "total": qa_set.count()
  284. }
  285. for q in tqdm(qa_set, desc="Evaluating dataset"):
  286. correct = client.llm_eval(q)
  287. eval_answer = EvalAnswer(
  288. eval_session=session,
  289. question=q,
  290. is_correct=correct,
  291. instruction=client.stats['instruction'],
  292. assistant_answer = client.stats['answer'],
  293. hash=q.hash,
  294. llm_backend=llm_backend,
  295. llm_model=llm_model
  296. )
  297. if correct:
  298. stats["correct"] += 1
  299. eval_answer.save()
  300. if session.request_delay:
  301. time.sleep(session.request_delay)
  302. if stats['total']>0:
  303. print(f"Accuracy: {stats['correct']}/{stats['total']} ({stats['correct']/stats['total']:.2f})")
  304. self.stdout.write(self.style.SUCCESS('Successfully evaluated MedQA'))