clean_pubmedqa_dataset.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import json
  2. import os
  3. from tqdm import tqdm
  4. from django.core.management.base import BaseCommand, CommandError, CommandParser
  5. from commons.models import QA, Dataset
  6. class Command(BaseCommand):
  7. help = 'Clean PubMedQA dataset'
  8. def handle(self, *args, **options):
  9. dataset = Dataset.objects.filter(name='pubmedqa').first()
  10. if not dataset:
  11. raise CommandError('PubMedQA dataset not found')
  12. file_path = 'datasets/pubmedqa/test_ground_truth.json'
  13. if not os.path.isfile(file_path):
  14. raise CommandError('Invalid file path')
  15. test_xids = []
  16. with open(file_path, 'r') as f:
  17. data = json.load(f)
  18. for xid, _ in tqdm(data.items(), desc="Cleaning PubMedQA data"):
  19. test_xids.append(xid)
  20. # get all questions from the dataset with xid no in test_xids
  21. questions = QA.objects.filter(dataset=dataset).exclude(xid__in=test_xids)
  22. for question in questions:
  23. question.delete()
  24. file_path = 'datasets/pubmedqa/ori_pqal.json'
  25. with open(file_path, 'r') as f:
  26. data = json.load(f)
  27. for xid in test_xids:
  28. entry = data[xid]
  29. # get qa with xid
  30. qa = QA.objects.filter(dataset=dataset, xid=xid).first()
  31. if not qa:
  32. continue
  33. extra_info = {
  34. 'reasoning_required_pred': entry['reasoning_required_pred'],
  35. 'reasoning_free_pred': entry['reasoning_free_pred']
  36. }
  37. qa.extra_info = extra_info
  38. qa.save()
  39. self.stdout.write(self.style.SUCCESS('Successfully cleaned PubMedQA dataset'))