1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import json
- import os
- from tqdm import tqdm
- from datasets import load_dataset
- from django.core.management.base import BaseCommand, CommandError, CommandParser
- from commons.models import QA, Dataset
- class Command(BaseCommand):
- help = 'Import MMLU dataset from Huggingface'
- def add_arguments(self, parser: CommandParser) -> None:
- parser.add_argument('--subject', type=str, help='Subject or category of the dataset. Default is "all"', default='all')
- parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
- parser.add_argument('--dataset', type=str, help='Dataset name')
- def handle(self, *args, **options):
- dataset_name = options['dataset']
- if not dataset_name:
- raise CommandError('Dataset name is required')
- subject = options['subject']
- if not subject:
- raise CommandError('Subject is required')
- target = options['target']
- if target not in ['train', 'test', 'dev', 'validation']:
- raise CommandError('Invalid target. Must be one of train, test, dev')
-
- dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
- ds = load_dataset("cais/mmlu", subject)[target]
- for entry in tqdm(ds, desc="Importing MMLU data"):
- question = entry['question']
- category = entry['subject']
- choices = dict()
- i = 0
- for choice in entry['choices']:
- choices[chr(65+i)] = choice
- i += 1
- correct_answer_idx = chr(65 + entry['answer'])
- correct_answer = choices[correct_answer_idx]
- qa = QA(
- dataset=dataset,
- question=question,
- category=category,
- options=choices,
- correct_answer=correct_answer,
- correct_answer_idx=correct_answer_idx,
- target=target
- )
- try:
- qa.save()
- except Exception as e:
- self.stdout.write(self.style.ERROR(f'Error importing MMLU question "{question}": {e}'))
- continue
- self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))
|