import json import os from tqdm import tqdm from django.core.management.base import BaseCommand, CommandError, CommandParser from commons.models import QA, Dataset class Command(BaseCommand): help = 'Import PubMedQA data from JSON file' def add_arguments(self, parser: CommandParser) -> None: parser.add_argument('--file', type=str, help='Path to JSON file') parser.add_argument('--file-test-ground-truth', type=str, help='Path to test ground truth JSON file', default=None) parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test') parser.add_argument('--dataset', type=str, help='Dataset name') def handle(self, *args, **options): dataset_name = options['dataset'] if not dataset_name: raise CommandError('Dataset name is required') file_path = options['file'] if not os.path.isfile(file_path): raise CommandError('Invalid file path') test_xids = [] if 'file_test_ground_truth' in options and options['file_test_ground_truth'] is not None: with open(options['file_test_ground_truth'], 'r') as f: data = json.load(f) for xid, _ in data.items(): test_xids.append(xid) target = options['target'] if target not in ['train', 'test', 'dev']: raise CommandError('Invalid target. Must be one of train, test, dev') dataset, _ = Dataset.objects.get_or_create(name=dataset_name) with open(file_path, 'r') as f: data = json.load(f) for xid, entry in tqdm(data.items(), desc="Importing PubMedQA data"): if test_xids and xid not in test_xids: continue question = entry['QUESTION'] correct_answer = entry['LONG_ANSWER'] correct_answer_idx = entry['final_decision'] context = "\n".join(entry['CONTEXTS']) extra_info = { 'reasoning_required_pred': entry['reasoning_required_pred'], 'reasoning_free_pred': entry['reasoning_free_pred'] } qa = QA( dataset=dataset, xid=xid, question=question, context=context, correct_answer=correct_answer, correct_answer_idx=correct_answer_idx, target=target, extra_info=extra_info ) try: qa.save() except Exception as e: self.stdout.write(self.style.ERROR(f'Error importing PubMedQA data: {e}')) continue self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))