1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import json
- import os
- from tqdm import tqdm
- from django.core.management.base import BaseCommand, CommandError, CommandParser
- from commons.models import QA, Dataset
- class Command(BaseCommand):
- help = 'Clean PubMedQA dataset'
- def handle(self, *args, **options):
- dataset = Dataset.objects.filter(name='pubmedqa').first()
- if not dataset:
- raise CommandError('PubMedQA dataset not found')
-
- file_path = 'datasets/pubmedqa/test_ground_truth.json'
- if not os.path.isfile(file_path):
- raise CommandError('Invalid file path')
-
- test_xids = []
- with open(file_path, 'r') as f:
- data = json.load(f)
- for xid, _ in tqdm(data.items(), desc="Cleaning PubMedQA data"):
- test_xids.append(xid)
-
- questions = QA.objects.filter(dataset=dataset).exclude(xid__in=test_xids)
- for question in questions:
- question.delete()
- file_path = 'datasets/pubmedqa/ori_pqal.json'
- with open(file_path, 'r') as f:
- data = json.load(f)
- for xid in test_xids:
- entry = data[xid]
-
- qa = QA.objects.filter(dataset=dataset, xid=xid).first()
- if not qa:
- continue
- extra_info = {
- 'reasoning_required_pred': entry['reasoning_required_pred'],
- 'reasoning_free_pred': entry['reasoning_free_pred']
- }
- qa.extra_info = extra_info
- qa.save()
- self.stdout.write(self.style.SUCCESS('Successfully cleaned PubMedQA dataset'))
|