import_mmlu.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import json
  2. import os
  3. from tqdm import tqdm
  4. from datasets import load_dataset
  5. from django.core.management.base import BaseCommand, CommandError, CommandParser
  6. from commons.models import QA, Dataset
  7. class Command(BaseCommand):
  8. help = 'Import MMLU dataset from Huggingface'
  9. def add_arguments(self, parser: CommandParser) -> None:
  10. parser.add_argument('--subject', type=str, help='Subject or category of the dataset. Default is "all"', default='all')
  11. parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
  12. parser.add_argument('--dataset', type=str, help='Dataset name')
  13. def handle(self, *args, **options):
  14. dataset_name = options['dataset']
  15. if not dataset_name:
  16. raise CommandError('Dataset name is required')
  17. subject = options['subject']
  18. if not subject:
  19. raise CommandError('Subject is required')
  20. target = options['target']
  21. if target not in ['train', 'test', 'dev', 'validation']:
  22. raise CommandError('Invalid target. Must be one of train, test, dev')
  23. dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
  24. ds = load_dataset("cais/mmlu", subject)[target]
  25. for entry in tqdm(ds, desc="Importing MMLU data"):
  26. question = entry['question']
  27. category = entry['subject']
  28. choices = dict()
  29. i = 0
  30. for choice in entry['choices']:
  31. choices[chr(65+i)] = choice
  32. i += 1
  33. correct_answer_idx = chr(65 + entry['answer'])
  34. correct_answer = choices[correct_answer_idx]
  35. qa = QA(
  36. dataset=dataset,
  37. question=question,
  38. category=category,
  39. options=choices,
  40. correct_answer=correct_answer,
  41. correct_answer_idx=correct_answer_idx,
  42. target=target
  43. )
  44. try:
  45. qa.save()
  46. except Exception as e:
  47. self.stdout.write(self.style.ERROR(f'Error importing MMLU question "{question}": {e}'))
  48. continue
  49. self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))