12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import json
- import os
- from tqdm import tqdm
- from django.core.management.base import BaseCommand, CommandError, CommandParser
- from commons.models import QA, Dataset
- class Command(BaseCommand):
- help = 'Import PubMedQA data from JSON file'
- def add_arguments(self, parser: CommandParser) -> None:
- parser.add_argument('--file', type=str, help='Path to JSON file')
- parser.add_argument('--file-test-ground-truth', type=str, help='Path to test ground truth JSON file', default=None)
- parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
- parser.add_argument('--dataset', type=str, help='Dataset name')
- def handle(self, *args, **options):
- dataset_name = options['dataset']
- if not dataset_name:
- raise CommandError('Dataset name is required')
- file_path = options['file']
- if not os.path.isfile(file_path):
- raise CommandError('Invalid file path')
-
- test_xids = []
- if 'file_test_ground_truth' in options and options['file_test_ground_truth'] is not None:
- with open(options['file_test_ground_truth'], 'r') as f:
- data = json.load(f)
- for xid, _ in data.items():
- test_xids.append(xid)
- target = options['target']
- if target not in ['train', 'test', 'dev']:
- raise CommandError('Invalid target. Must be one of train, test, dev')
-
- dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
- with open(file_path, 'r') as f:
- data = json.load(f)
-
- for xid, entry in tqdm(data.items(), desc="Importing PubMedQA data"):
- if test_xids and xid not in test_xids:
- continue
- question = entry['QUESTION']
- correct_answer = entry['LONG_ANSWER']
- correct_answer_idx = entry['final_decision']
- context = "\n".join(entry['CONTEXTS'])
- extra_info = {
- 'reasoning_required_pred': entry['reasoning_required_pred'],
- 'reasoning_free_pred': entry['reasoning_free_pred']
- }
- qa = QA(
- dataset=dataset,
- xid=xid,
- question=question,
- context=context,
- correct_answer=correct_answer,
- correct_answer_idx=correct_answer_idx,
- target=target,
- extra_info=extra_info
- )
- try:
- qa.save()
- except Exception as e:
- self.stdout.write(self.style.ERROR(f'Error importing PubMedQA data: {e}'))
- continue
- self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))
|