import_pubmedqa.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import json
  2. import os
  3. from tqdm import tqdm
  4. from django.core.management.base import BaseCommand, CommandError, CommandParser
  5. from commons.models import QA, Dataset
  6. class Command(BaseCommand):
  7. help = 'Import PubMedQA data from JSON file'
  8. def add_arguments(self, parser: CommandParser) -> None:
  9. parser.add_argument('--file', type=str, help='Path to JSON file')
  10. parser.add_argument('--file-test-ground-truth', type=str, help='Path to test ground truth JSON file', default=None)
  11. parser.add_argument('--target', type=str, help='Dataset target (train, test, dev)', default='test')
  12. parser.add_argument('--dataset', type=str, help='Dataset name')
  13. def handle(self, *args, **options):
  14. dataset_name = options['dataset']
  15. if not dataset_name:
  16. raise CommandError('Dataset name is required')
  17. file_path = options['file']
  18. if not os.path.isfile(file_path):
  19. raise CommandError('Invalid file path')
  20. test_xids = []
  21. if 'file_test_ground_truth' in options and options['file_test_ground_truth'] is not None:
  22. with open(options['file_test_ground_truth'], 'r') as f:
  23. data = json.load(f)
  24. for xid, _ in data.items():
  25. test_xids.append(xid)
  26. target = options['target']
  27. if target not in ['train', 'test', 'dev']:
  28. raise CommandError('Invalid target. Must be one of train, test, dev')
  29. dataset, _ = Dataset.objects.get_or_create(name=dataset_name)
  30. with open(file_path, 'r') as f:
  31. data = json.load(f)
  32. for xid, entry in tqdm(data.items(), desc="Importing PubMedQA data"):
  33. if test_xids and xid not in test_xids:
  34. continue
  35. question = entry['QUESTION']
  36. correct_answer = entry['LONG_ANSWER']
  37. correct_answer_idx = entry['final_decision']
  38. context = "\n".join(entry['CONTEXTS'])
  39. extra_info = {
  40. 'reasoning_required_pred': entry['reasoning_required_pred'],
  41. 'reasoning_free_pred': entry['reasoning_free_pred']
  42. }
  43. qa = QA(
  44. dataset=dataset,
  45. xid=xid,
  46. question=question,
  47. context=context,
  48. correct_answer=correct_answer,
  49. correct_answer_idx=correct_answer_idx,
  50. target=target,
  51. extra_info=extra_info
  52. )
  53. try:
  54. qa.save()
  55. except Exception as e:
  56. self.stdout.write(self.style.ERROR(f'Error importing PubMedQA data: {e}'))
  57. continue
  58. self.stdout.write(self.style.SUCCESS('Successfully imported PubMedQA data'))