12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import json
- import os
- from tqdm import tqdm
- from django.core.management.base import BaseCommand, CommandError, CommandParser
- from commons.models import QA, Dataset
- class Command(BaseCommand):
- help = 'Import MedQA data from JSON file'
- def add_arguments(self, parser: CommandParser) -> None:
- parser.add_argument('--file', type=str, help='Path to JSON file')
- parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
- parser.add_argument('--dataset', type=str, help='Dataset name')
- def handle(self, *args, **options):
- dataset_name = options['dataset']
- if not dataset_name:
- raise CommandError('Dataset name is required')
- file_path = options['file']
- if not os.path.isfile(file_path):
- raise CommandError('Invalid file path')
- target = options['target']
- if target not in ['train', 'test', 'dev']:
- raise CommandError('Invalid target. Must be one of train, test, dev')
-
- # check if dataset exists
- dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
-
- with open(file_path, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for line in tqdm(lines, desc="Importing MedQA data"):
- data = json.loads(line)
- question = data['question']
- options = data['options']
- correct_answer = data['answer']
- correct_answer_idx = data['answer_idx']
- qa = QA(
- dataset=dataset,
- question=question,
- options=options,
- correct_answer=correct_answer,
- correct_answer_idx=correct_answer_idx,
- target=target
- )
- try:
- qa.save()
- except Exception as e:
- self.stdout.write(self.style.ERROR(f'Error importing MedQA data: {e}'))
- continue
-
- self.stdout.write(self.style.SUCCESS('Successfully imported MedQA data'))
|