import json import os import csv from tqdm import tqdm from django.core.management.base import BaseCommand, CommandError, CommandParser from commons.models import QA, Dataset, EvalSession, EvalConfig, EvalAnswer class Command(BaseCommand): help = 'Export benchmark results to CSV file' def add_arguments(self, parser: CommandParser) -> None: parser.add_argument('--config', type=str, help='EvalConfig ID') parser.add_argument('--exclude-sessions', type=str, help='Session IDs to exclude') parser.add_argument('--file', type=str, help='Path to CSV file') def handle(self, *args, **options): config_id = options['config'] if not config_id: raise CommandError('EvalConfig ID is required') exclude_session_ids = options['exclude_sessions'] file_path = options['file'] if not file_path: raise CommandError('File path is required') # check if EvalConfig exists config = EvalConfig.objects.filter(id=config_id).first() if not config: raise CommandError('EvalConfig not found') # get all sessions for this config sessions = EvalSession.objects.filter(config=config).filter(is_active=True) if exclude_session_ids: exclude_session_ids = exclude_session_ids.split(',') sessions = sessions.exclude(id__in=exclude_session_ids) # get all questions for the dataset in this config questions = dict() for qa in QA.objects.filter(dataset=config.dataset): questions[qa.id] = qa.question models_data = {} for session in sessions: eval_answers = EvalAnswer.objects.filter(eval_session=session).order_by('question__id') data = {} for eval_answer in eval_answers: data[eval_answer.question.id] = 1 if eval_answer.is_correct else 0 models_data[session.llm_model.name] = data print("Exporting the following data:") for model_name, data in models_data.items(): print(f"{model_name}: {len(data)}") with open(file_path, 'w') as f: writer = csv.writer(f) writer.writerow(['Question ID'] + ['Question Text'] + [model_name for model_name in models_data.keys()]) for question_id in tqdm(models_data[list(models_data.keys())[0]].keys(), desc="Exporting data"): row = [question_id] + [questions[question_id]] + [data.get(question_id, '') for data in models_data.values()] writer.writerow(row) row = ['Accuracy'] + [''] + [f"{sum(data.values()) / len(data) * 100:.2f}%" for data in models_data.values()] writer.writerow(row) self.stdout.write(self.style.SUCCESS('Successfully imported MedQA data'))