import_medqa.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import json
  2. import os
  3. from tqdm import tqdm
  4. from django.core.management.base import BaseCommand, CommandError, CommandParser
  5. from commons.models import QA, Dataset
  6. class Command(BaseCommand):
  7. help = 'Import MedQA data from JSON file'
  8. def add_arguments(self, parser: CommandParser) -> None:
  9. parser.add_argument('--file', type=str, help='Path to JSON file')
  10. parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
  11. parser.add_argument('--dataset', type=str, help='Dataset name')
  12. def handle(self, *args, **options):
  13. dataset_name = options['dataset']
  14. if not dataset_name:
  15. raise CommandError('Dataset name is required')
  16. file_path = options['file']
  17. if not os.path.isfile(file_path):
  18. raise CommandError('Invalid file path')
  19. target = options['target']
  20. if target not in ['train', 'test', 'dev']:
  21. raise CommandError('Invalid target. Must be one of train, test, dev')
  22. # check if dataset exists
  23. dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
  24. with open(file_path, 'r', encoding='utf-8') as f:
  25. lines = f.readlines()
  26. for line in tqdm(lines, desc="Importing MedQA data"):
  27. data = json.loads(line)
  28. question = data['question']
  29. options = data['options']
  30. correct_answer = data['answer']
  31. correct_answer_idx = data['answer_idx']
  32. qa = QA(
  33. dataset=dataset,
  34. question=question,
  35. options=options,
  36. correct_answer=correct_answer,
  37. correct_answer_idx=correct_answer_idx,
  38. target=target
  39. )
  40. try:
  41. qa.save()
  42. except Exception as e:
  43. self.stdout.write(self.style.ERROR(f'Error importing MedQA data: {e}'))
  44. continue
  45. self.stdout.write(self.style.SUCCESS('Successfully imported MedQA data'))