11B_full_w2.yaml 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Top-level output directory
  2. output_dir: ./outputs/Llama-3.2-11B-Instruct-w2-full
  3. # Model
  4. model:
  5. _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
  6. decoder_trainable: False
  7. encoder_trainable: True
  8. fusion_trainable: True
  9. image_size: 560 # Make sure this matches the image_size in tokenizer
  10. # Tokenizer / vision transform
  11. tokenizer:
  12. _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  13. path: ./Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
  14. image_size: 560
  15. max_seq_len: 8192
  16. # Checkpointing
  17. checkpointer:
  18. _component_: torchtune.training.FullModelHFCheckpointer
  19. checkpoint_dir: ./Llama-3.2-11B-Vision-Instruct
  20. checkpoint_files:
  21. filename_format: model-{}-of-{}.safetensors
  22. max_filename: "00005"
  23. recipe_checkpoint: null
  24. output_dir: ${output_dir}
  25. model_type: LLAMA3_VISION
  26. resume_from_checkpoint: false
  27. save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
  28. # Dataset
  29. dataset:
  30. _component_: torchtune.datasets.multimodal.vqa_dataset
  31. source: arrow
  32. data_files:
  33. train: "fake_w2_us_tax_form_dataset_train30_test70/train/data-00000-of-00001.arrow"
  34. split: train
  35. column_map:
  36. input: input
  37. output: ground_truth
  38. image: image
  39. # General data handling
  40. seed: null
  41. shuffle: true
  42. collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
  43. # Training loop & hyperparams
  44. epochs: 5
  45. max_steps_per_epoch: null
  46. batch_size: 4
  47. gradient_accumulation_steps: 8 # Use to increase effective batch size
  48. # explicit optimizer / scheduler / loss
  49. optimizer:
  50. _component_: bitsandbytes.optim.PagedAdamW8bit
  51. lr: 2e-5
  52. optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
  53. loss:
  54. _component_: torchtune.modules.loss.LinearCrossEntropyLoss
  55. clip_grad_norm: 1.0
  56. compile: false
  57. # Device & memory
  58. device: cuda
  59. enable_activation_checkpointing: true
  60. dtype: bf16
  61. # Logging
  62. metric_logger:
  63. _component_: torchtune.training.metric_logging.WandBLogger
  64. project: llama3_2_w2_extraction
  65. entity: <your_wandb_entity>
  66. job_type: full_finetune_single_device
  67. group: llama-cookbook
  68. log_every_n_steps: 5
  69. save_steps: 100
  70. log_peak_memory_stats: true
  71. log_level: INFO
  72. # Profiler (off by default)
  73. profiler:
  74. _component_: torchtune.training.setup_torch_profiler
  75. enabled: false
  76. output_dir: ${output_dir}/profiling_outputs
  77. cpu: true
  78. cuda: true
  79. profile_memory: false
  80. with_stack: false
  81. record_shapes: true
  82. with_flops: false
  83. wait_steps: 5
  84. warmup_steps: 3
  85. active_steps: 2
  86. num_cycles: 1