12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import json
- import os
- import csv
- from tqdm import tqdm
- from django.core.management.base import BaseCommand, CommandError, CommandParser
- from commons.models import QA, Dataset, EvalSession, EvalConfig, EvalAnswer
- class Command(BaseCommand):
- help = 'Export benchmark results to CSV file'
- def add_arguments(self, parser: CommandParser) -> None:
- parser.add_argument('--config', type=str, help='EvalConfig ID')
- parser.add_argument('--exclude-sessions', type=str, help='Session IDs to exclude')
- parser.add_argument('--file', type=str, help='Path to CSV file')
- def handle(self, *args, **options):
- config_id = options['config']
- if not config_id:
- raise CommandError('EvalConfig ID is required')
- exclude_session_ids = options['exclude_sessions']
-
- file_path = options['file']
- if not file_path:
- raise CommandError('File path is required')
- # check if EvalConfig exists
- config = EvalConfig.objects.filter(id=config_id).first()
- if not config:
- raise CommandError('EvalConfig not found')
- # get all sessions for this config
- sessions = EvalSession.objects.filter(config=config).filter(is_active=True)
- if exclude_session_ids:
- exclude_session_ids = exclude_session_ids.split(',')
- sessions = sessions.exclude(id__in=exclude_session_ids)
- # get all questions for the dataset in this config
- questions = dict()
- for qa in QA.objects.filter(dataset=config.dataset):
- questions[qa.id] = qa.question
- models_data = {}
- for session in sessions:
- eval_answers = EvalAnswer.objects.filter(eval_session=session).order_by('question__id')
- data = {}
- for eval_answer in eval_answers:
- data[eval_answer.question.id] = 1 if eval_answer.is_correct else 0
- models_data[session.llm_model.name] = data
-
- print("Exporting the following data:")
- for model_name, data in models_data.items():
- print(f"{model_name}: {len(data)}")
- with open(file_path, 'w') as f:
- writer = csv.writer(f)
- writer.writerow(['Question ID'] + ['Question Text'] + [model_name for model_name in models_data.keys()])
- for question_id in tqdm(models_data[list(models_data.keys())[0]].keys(), desc="Exporting data"):
- row = [question_id] + [questions[question_id]] + [data.get(question_id, '') for data in models_data.values()]
- writer.writerow(row)
- row = ['Accuracy'] + [''] + [f"{sum(data.values()) / len(data) * 100:.2f}%" for data in models_data.values()]
- writer.writerow(row)
- self.stdout.write(self.style.SUCCESS('Successfully imported MedQA data'))
|