export_results.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import json
  2. import os
  3. import csv
  4. from tqdm import tqdm
  5. from django.core.management.base import BaseCommand, CommandError, CommandParser
  6. from commons.models import QA, Dataset, EvalSession, EvalConfig, EvalAnswer
  7. class Command(BaseCommand):
  8. help = 'Export benchmark results to CSV file'
  9. def add_arguments(self, parser: CommandParser) -> None:
  10. parser.add_argument('--config', type=str, help='EvalConfig ID')
  11. parser.add_argument('--exclude-sessions', type=str, help='Session IDs to exclude')
  12. parser.add_argument('--file', type=str, help='Path to CSV file')
  13. def handle(self, *args, **options):
  14. config_id = options['config']
  15. if not config_id:
  16. raise CommandError('EvalConfig ID is required')
  17. exclude_session_ids = options['exclude_sessions']
  18. file_path = options['file']
  19. if not file_path:
  20. raise CommandError('File path is required')
  21. # check if EvalConfig exists
  22. config = EvalConfig.objects.filter(id=config_id).first()
  23. if not config:
  24. raise CommandError('EvalConfig not found')
  25. # get all sessions for this config
  26. sessions = EvalSession.objects.filter(config=config).filter(is_active=True)
  27. if exclude_session_ids:
  28. exclude_session_ids = exclude_session_ids.split(',')
  29. sessions = sessions.exclude(id__in=exclude_session_ids)
  30. # get all questions for the dataset in this config
  31. questions = dict()
  32. for qa in QA.objects.filter(dataset=config.dataset):
  33. questions[qa.id] = qa.question
  34. models_data = {}
  35. for session in sessions:
  36. eval_answers = EvalAnswer.objects.filter(eval_session=session).order_by('question__id')
  37. data = {}
  38. for eval_answer in eval_answers:
  39. data[eval_answer.question.id] = 1 if eval_answer.is_correct else 0
  40. models_data[session.llm_model.name] = data
  41. print("Exporting the following data:")
  42. for model_name, data in models_data.items():
  43. print(f"{model_name}: {len(data)}")
  44. with open(file_path, 'w') as f:
  45. writer = csv.writer(f)
  46. writer.writerow(['Question ID'] + ['Question Text'] + [model_name for model_name in models_data.keys()])
  47. for question_id in tqdm(models_data[list(models_data.keys())[0]].keys(), desc="Exporting data"):
  48. row = [question_id] + [questions[question_id]] + [data.get(question_id, '') for data in models_data.values()]
  49. writer.writerow(row)
  50. row = ['Accuracy'] + [''] + [f"{sum(data.values()) / len(data) * 100:.2f}%" for data in models_data.values()]
  51. writer.writerow(row)
  52. self.stdout.write(self.style.SUCCESS('Successfully imported MedQA data'))