import json import os from tqdm import tqdm from django.core.management.base import BaseCommand, CommandError, CommandParser from commons.models import QA, Dataset class Command(BaseCommand): help = 'Clean PubMedQA dataset' def handle(self, *args, **options): dataset = Dataset.objects.filter(name='pubmedqa').first() if not dataset: raise CommandError('PubMedQA dataset not found') file_path = 'datasets/pubmedqa/test_ground_truth.json' if not os.path.isfile(file_path): raise CommandError('Invalid file path') test_xids = [] with open(file_path, 'r') as f: data = json.load(f) for xid, _ in tqdm(data.items(), desc="Cleaning PubMedQA data"): test_xids.append(xid) # get all questions from the dataset with xid no in test_xids questions = QA.objects.filter(dataset=dataset).exclude(xid__in=test_xids) for question in questions: question.delete() file_path = 'datasets/pubmedqa/ori_pqal.json' with open(file_path, 'r') as f: data = json.load(f) for xid in test_xids: entry = data[xid] # get qa with xid qa = QA.objects.filter(dataset=dataset, xid=xid).first() if not qa: continue extra_info = { 'reasoning_required_pred': entry['reasoning_required_pred'], 'reasoning_free_pred': entry['reasoning_free_pred'] } qa.extra_info = extra_info qa.save() self.stdout.write(self.style.SUCCESS('Successfully cleaned PubMedQA dataset'))