config.yaml 7.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. task1:
  2. dataset: singhsays/fake-w2-us-tax-form-dataset
  3. is_local: false
  4. system_prompt: null
  5. user_prompt: "You are an expert document information extraction system. I will show you an image of a W-2 tax form. Please extract all the information from this form and return it in a JSON format. Include all fields such as employee details, employer details, wages, federal income tax withheld, social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, and any other information present on the form. Return ONLY the JSON output without any additional text or explanations following this schema {'properties': {'box_b_employer_identification_number': {'title': 'Box B Employer Identification Number', 'type': 'string'}, 'box_c_employer_name': {'title': 'Box C Employer Name', 'type': 'string'}, 'box_c_employer_street_address': {'title': 'Box C Employer Street Address', 'type': 'string'}, 'box_c_employer_city_state_zip': {'title': 'Box C Employer City State Zip', 'type': 'string'}, 'box_a_employee_ssn': {'title': 'Box A Employee Ssn', 'type': 'string'}, 'box_e_employee_name': {'title': 'Box E Employee Name', 'type': 'string'}, 'box_e_employee_street_address': {'title': 'Box E Employee Street Address', 'type': 'string'}, 'box_e_employee_city_state_zip': {'title': 'Box E Employee City State Zip', 'type': 'string'}, 'box_d_control_number': {'title': 'Box D Control Number', 'type': 'integer'}, 'box_1_wages': {'title': 'Box 1 Wages', 'type': 'number'}, 'box_2_federal_tax_withheld': {'title': 'Box 2 Federal Tax Withheld', 'type': 'number'}, 'box_3_social_security_wages': {'title': 'Box 3 Social Security Wages', 'type': 'number'}, 'box_4_social_security_tax_withheld': {'title': 'Box 4 Social Security Tax Withheld', 'type': 'number'}, 'box_5_medicare_wages': {'title': 'Box 5 Medicare Wages', 'type': 'number'}, 'box_6_medicare_wages_tax_withheld': {'title': 'Box 6 Medicare Wages Tax Withheld', 'type': 'number'}, 'box_7_social_security_tips': {'title': 'Box 7 Social Security Tips', 'type': 'number'}, 'box_8_allocated_tips': {'title': 'Box 8 Allocated Tips', 'type': 'number'}, 'box_9_advance_eic_payment': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 9 Advance Eic Payment'}, 'box_10_dependent_care_benefits': {'title': 'Box 10 Dependent Care Benefits', 'type': 'number'}, 'box_11_nonqualified_plans': {'title': 'Box 11 Nonqualified Plans', 'type': 'number'}, 'box_12a_code': {'title': 'Box 12A Code', 'type': 'string'}, 'box_12a_value': {'title': 'Box 12A Value', 'type': 'number'}, 'box_12b_code': {'title': 'Box 12B Code', 'type': 'string'}, 'box_12b_value': {'title': 'Box 12B Value', 'type': 'number'}, 'box_12c_code': {'title': 'Box 12C Code', 'type': 'string'}, 'box_12c_value': {'title': 'Box 12C Value', 'type': 'number'}, 'box_12d_code': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 12D Code'}, 'box_12d_value': {'title': 'Box 12D Value', 'type': 'number'}, 'box_13_statutary_employee': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Statutary Employee'}, 'box_13_retirement_plan': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Retirement Plan'}, 'box_13_third_part_sick_pay': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Third Part Sick Pay'}, 'box_15_1_state': {'title': 'Box 15 1 State', 'type': 'string'}, 'box_15_1_employee_state_id': {'title': 'Box 15 1 Employee State Id', 'type': 'string'}, 'box_16_1_state_wages': {'title': 'Box 16 1 State Wages', 'type': 'number'}, 'box_17_1_state_income_tax': {'title': 'Box 17 1 State Income Tax', 'type': 'number'}, 'box_18_1_local_wages': {'title': 'Box 18 1 Local Wages', 'type': 'number'}, 'box_19_1_local_income_tax': {'title': 'Box 19 1 Local Income Tax', 'type': 'number'}, 'box_20_1_locality': {'title': 'Box 20 1 Locality', 'type': 'string'}, 'box_15_2_state': {'title': 'Box 15 2 State', 'type': 'string'}, 'box_15_2_employee_state_id': {'title': 'Box 15 2 Employee State Id', 'type': 'string'}, 'box_16_2_state_wages': {'title': 'Box 16 2 State Wages', 'type': 'number'}, 'box_17_2_state_income_tax': {'title': 'Box 17 2 State Income Tax', 'type': 'number'}, 'box_18_2_local_wages': {'title': 'Box 18 2 Local Wages', 'type': 'number'}, 'box_19_2_local_income_tax': {'title': 'Box 19 2 Local Income Tax', 'type': 'number'}, 'box_20_2_locality': {'title': 'Box 20 2 Locality', 'type': 'string'}}, 'required': ['box_b_employer_identification_number', 'box_c_employer_name', 'box_c_employer_street_address', 'box_c_employer_city_state_zip', 'box_a_employee_ssn', 'box_e_employee_name', 'box_e_employee_street_address', 'box_e_employee_city_state_zip', 'box_d_control_number', 'box_1_wages', 'box_2_federal_tax_withheld', 'box_3_social_security_wages', 'box_4_social_security_tax_withheld', 'box_5_medicare_wages', 'box_6_medicare_wages_tax_withheld', 'box_7_social_security_tips', 'box_8_allocated_tips', 'box_9_advance_eic_payment', 'box_10_dependent_care_benefits', 'box_11_nonqualified_plans', 'box_12a_code', 'box_12a_value', 'box_12b_code', 'box_12b_value', 'box_12c_code', 'box_12c_value', 'box_12d_code', 'box_12d_value', 'box_13_statutary_employee', 'box_13_retirement_plan', 'box_13_third_part_sick_pay', 'box_15_1_state', 'box_15_1_employee_state_id', 'box_16_1_state_wages', 'box_17_1_state_income_tax', 'box_18_1_local_wages', 'box_19_1_local_income_tax', 'box_20_1_locality', 'box_15_2_state', 'box_15_2_employee_state_id', 'box_16_2_state_wages', 'box_17_2_state_income_tax', 'box_18_2_local_wages', 'box_19_2_local_income_tax', 'box_20_2_locality'], 'title': 'W2Form', 'type': 'object'}"
  6. sample_percent: 1 # % of the dataset to use; 1.0 means use the entire dataset
  7. resplit_train_percent: 0.3 # % of the sampled dataset to use for training; the rest is used for validation
  8. image_column: image
  9. user_text_column: null
  10. assistant_text_column: ground_truth
  11. grader: JSONGrader # Task-specific grader
  12. task2:
  13. dataset: getomni-ai/ocr-benchmark
  14. is_local: false
  15. system_prompt: You are a helpful assistant, you will always respond only in JSON following the provided JSON schema.
  16. user_prompt: "Extract the data in this image as a JSON. Use the following JSON schema:\n"
  17. sample_percent: 0.6
  18. resplit_train_percent: 0.0
  19. image_column: image
  20. user_text_column: json_schema
  21. assistant_text_column: true_json_output
  22. grader: JSONGrader
  23. finetuning:
  24. ### FFT LAYERS TO TRAIN - ALL FALSE FOR NO FFT
  25. fusion: true
  26. fusion+encoder: false
  27. fusion+decoder: false
  28. fusion+encoder+decoder: true
  29. ### LORA RANKS TO TRAIN - EMPTY LIST FOR NO LORA
  30. lora_ranks: [8, 64]
  31. ### TORCHTUNE CONFIG
  32. fft_torchtune_config: transferability/finetune/8b_full.yaml
  33. lora_torchtune_config: transferability/finetune/8b_lora.yaml
  34. ### TORCHTUNE ARGS
  35. model_path: /path/to/llama31/ckpt
  36. tokenizer_path: /path/to/llama31/ckpt/tokenizer.model
  37. epochs: 5 # Number of training epochs
  38. batch_size: 8 # Batch size per device for training
  39. ngpu: 4
  40. distributed: true # Whether to use distributed training
  41. evals:
  42. nb_eval_samples: null # Number of samples to use for evaluation; null means use the entire dataset.
  43. checkpoint_to_eval: -1
  44. model_server_args:
  45. tensor_parallel_size: 2
  46. max_model_len: 8192
  47. max_num_seqs: 128
  48. enforce_eager: true
  49. inference_params:
  50. temperature: 0
  51. top_p: 1.0
  52. max_completion_tokens: 4096
  53. seed: 42