import json import os from tqdm import tqdm from datasets import load_dataset from django.core.management.base import BaseCommand, CommandError, CommandParser from commons.models import QA, Dataset class Command(BaseCommand): help = 'Import MMLU dataset from Huggingface' def add_arguments(self, parser: CommandParser) -> None: parser.add_argument('--subject', type=str, help='Subject or category of the dataset. Default is "all"', default='all') parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test') parser.add_argument('--dataset', type=str, help='Dataset name') def handle(self, *args, **options): dataset_name = options['dataset'] if not dataset_name: raise CommandError('Dataset name is required') subject = options['subject'] if not subject: raise CommandError('Subject is required') target = options['target'] if target not in ['train', 'test', 'dev', 'validation']: raise CommandError('Invalid target. Must be one of train, test, dev') dataset, _ = Dataset.objects.get_or_create(name=dataset_name) ds = load_dataset("cais/mmlu", subject)[target] for entry in tqdm(ds, desc="Importing MMLU data"): question = entry['question'] category = entry['subject'] choices = dict() i = 0 for choice in entry['choices']: choices[chr(65+i)] = choice i += 1 correct_answer_idx = chr(65 + entry['answer']) correct_answer = choices[correct_answer_idx] qa = QA( dataset=dataset, question=question, category=category, options=choices, correct_answer=correct_answer, correct_answer_idx=correct_answer_idx, target=target ) try: qa.save() except Exception as e: self.stdout.write(self.style.ERROR(f'Error importing MMLU question "{question}": {e}')) continue self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))