prepare_w2_dataset.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #!/usr/bin/env python3
  2. """
  3. Script to modify the dataset by removing the top-level 'gt_parse' attribute from the ground_truth column
  4. and keeping all the keys under it. Also supports custom train-test splits.
  5. """
  6. import argparse
  7. import json
  8. import logging
  9. from datasets import load_dataset
  10. # Configure logging
  11. logging.basicConfig(
  12. level=logging.INFO,
  13. format="%(asctime)s - %(levelname)s - %(message)s",
  14. datefmt="%Y-%m-%d %H:%M:%S",
  15. )
  16. logger = logging.getLogger(__name__)
  17. def parse_args():
  18. parser = argparse.ArgumentParser(
  19. description="Prepare W2 dataset with custom train-test splits"
  20. )
  21. parser.add_argument(
  22. "--train-ratio",
  23. type=float,
  24. default=0.8,
  25. help="Ratio of data to use for training (default: 0.8, i.e., 80%% train, 20%% test)",
  26. )
  27. parser.add_argument(
  28. "--output-dir",
  29. type=str,
  30. default=None,
  31. help="Custom output directory name. If not provided, will use 'fake_w2_us_tax_form_dataset_train{train_ratio}_test{1 - train_ratio}'",
  32. )
  33. parser.add_argument(
  34. "--seed",
  35. type=int,
  36. default=42,
  37. help="Random seed for dataset splitting (default: 42)",
  38. )
  39. parser.add_argument(
  40. "--prompt",
  41. type=str,
  42. default="Parse this W-2 form and extract all fields into a single level json.",
  43. help="Custom prompt to use for the input field (default: Parse this W-2 form...)",
  44. )
  45. parser.add_argument(
  46. "--dataset-name",
  47. type=str,
  48. default="singhsays/fake-w2-us-tax-form-dataset",
  49. help="Dataset name from HuggingFace Hub (default: singhsays/fake-w2-us-tax-form-dataset)",
  50. )
  51. parser.add_argument(
  52. "--skip-validation",
  53. action="store_true",
  54. help="Skip validation split loading (useful if dataset doesn't have validation split)",
  55. )
  56. return parser.parse_args()
  57. # Define a function to modify the ground_truth column
  58. def remove_gt_parse_wrapper(example):
  59. try:
  60. # Parse the ground_truth JSON
  61. ground_truth = json.loads(example["ground_truth"])
  62. # Check if gt_parse exists in the ground_truth
  63. if "gt_parse" in ground_truth:
  64. # Replace the ground_truth with just the contents of gt_parse
  65. example["ground_truth"] = json.dumps(ground_truth["gt_parse"])
  66. else:
  67. logger.warning("No 'gt_parse' key found in ground_truth, keeping original")
  68. return example
  69. except json.JSONDecodeError as e:
  70. logger.error(f"Failed to parse ground_truth JSON: {e}")
  71. logger.error(f"Problematic data: {example.get('ground_truth', 'N/A')}")
  72. # Return the example unchanged if we can't parse it
  73. return example
  74. except Exception as e:
  75. logger.error(f"Unexpected error in remove_gt_parse_wrapper: {e}")
  76. return example
  77. def validate_dataset(dataset):
  78. """Validate the loaded dataset has required columns."""
  79. required_columns = ["ground_truth", "image"]
  80. missing_columns = [
  81. col for col in required_columns if col not in dataset.column_names
  82. ]
  83. if missing_columns:
  84. raise ValueError(f"Dataset missing required columns: {missing_columns}")
  85. logger.info(f"Dataset validation passed. Columns: {dataset.column_names}")
  86. def validate_train_ratio(train_ratio):
  87. """Validate that train ratio is between 0 and 1 (exclusive)."""
  88. if train_ratio <= 0 or train_ratio >= 1:
  89. raise ValueError("Train ratio must be between 0 and 1 (exclusive)")
  90. return True
  91. def create_output_directory_name(train_ratio, test_ratio, output_dir=None):
  92. """Create output directory name based on the split ratio if not provided."""
  93. if output_dir is None:
  94. # Round to 2 decimal places before converting to int to avoid floating point precision issues
  95. train_pct = int(round(train_ratio * 100, 2))
  96. test_pct = int(round(test_ratio * 100, 2))
  97. return f"fake_w2_us_tax_form_dataset_train{train_pct}_test{test_pct}"
  98. return output_dir
  99. def load_dataset_safely(dataset_name, split="train+test"):
  100. """Load dataset with proper error handling."""
  101. try:
  102. return load_dataset(dataset_name, split=split)
  103. except Exception as e:
  104. logger.error(f"Failed to load dataset '{dataset_name}': {e}")
  105. raise
  106. def create_splits(all_data, train_ratio, seed):
  107. """Create train-test splits from the dataset."""
  108. logger.info(f"Creating new splits with train ratio: {train_ratio}")
  109. return all_data.train_test_split(train_size=train_ratio, seed=seed)
  110. def load_validation_split(dataset_name, split_ds, skip_validation=False):
  111. """Load validation split if not skipped."""
  112. if skip_validation:
  113. logger.info("Skipping validation split as requested")
  114. return split_ds
  115. try:
  116. split_ds["validation"] = load_dataset(dataset_name, split="validation")
  117. logger.info(
  118. f"Loaded validation split with {len(split_ds['validation'])} examples"
  119. )
  120. except Exception as e:
  121. logger.warning(
  122. f"Could not load validation split: {e}. Continuing without validation split."
  123. )
  124. return split_ds
  125. def apply_transformations(split_ds, prompt):
  126. """Apply data transformations to the dataset."""
  127. logger.info("Modifying dataset...")
  128. modified_ds = split_ds.map(remove_gt_parse_wrapper)
  129. logger.info(f"Adding custom prompt: {prompt}")
  130. modified_ds = modified_ds.map(lambda _: {"input": prompt})
  131. return modified_ds
  132. def log_dataset_statistics(all_data, modified_ds):
  133. """Log comprehensive dataset statistics."""
  134. logger.info("\n=== Dataset Statistics ===")
  135. logger.info(f"Total examples: {len(all_data)}")
  136. logger.info(
  137. f"Train split: {len(modified_ds['train'])} examples ({len(modified_ds['train'])/len(all_data)*100:.1f}%)"
  138. )
  139. logger.info(
  140. f"Test split: {len(modified_ds['test'])} examples ({len(modified_ds['test'])/len(all_data)*100:.1f}%)"
  141. )
  142. if "validation" in modified_ds:
  143. logger.info(f"Validation split: {len(modified_ds['validation'])} examples")
  144. def save_dataset(modified_ds, output_dir):
  145. """Save the modified dataset to disk."""
  146. logger.info(f"Saving modified dataset to '{output_dir}'...")
  147. modified_ds.save_to_disk(output_dir)
  148. logger.info(f"Done! Modified dataset saved to '{output_dir}'")
  149. def main():
  150. try:
  151. args = parse_args()
  152. # Reconfigure logging with user-specified level
  153. global logger
  154. # Validate train ratio
  155. validate_train_ratio(args.train_ratio)
  156. train_ratio = args.train_ratio
  157. test_ratio = 1 - train_ratio
  158. # Create output directory name
  159. output_dir = create_output_directory_name(
  160. train_ratio, test_ratio, args.output_dir
  161. )
  162. logger.info(f"Using train-test split: {train_ratio:.2f}-{test_ratio:.2f}")
  163. logger.info(f"Output directory will be: {output_dir}")
  164. logger.info(f"Dataset: {args.dataset_name}")
  165. # Load the dataset with error handling
  166. logger.info("Loading dataset...")
  167. all_data = load_dataset_safely(args.dataset_name, "train+test")
  168. validate_dataset(all_data)
  169. logger.info(f"Loaded {len(all_data)} examples from dataset")
  170. # Create splits
  171. split_ds = create_splits(all_data, train_ratio, args.seed)
  172. # Load validation split
  173. split_ds = load_validation_split(
  174. args.dataset_name, split_ds, args.skip_validation
  175. )
  176. # Apply transformations
  177. modified_ds = apply_transformations(split_ds, args.prompt)
  178. # Log statistics
  179. log_dataset_statistics(all_data, modified_ds)
  180. # Save the modified dataset
  181. save_dataset(modified_ds, output_dir)
  182. except Exception as e:
  183. logger.error(f"Script failed with error: {e}")
  184. raise
  185. if __name__ == "__main__":
  186. main()