11B_lora_w2.yaml 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Top-level output directory
  2. output_dir: ./outputs/Llama-3.2-11B-Instruct-w2-lora-80
  3. # Model + LoRA settings
  4. model:
  5. _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b
  6. # preserve your hyperparams
  7. lora_rank: 8 # higher increases accuracy and memory
  8. lora_alpha: 16 # usually alpha=2*rank
  9. lora_dropout: 0.05
  10. image_size: 560 # Make sure this matches the image_size in tokenizer
  11. # example’s fixed settings
  12. decoder_trainable: "frozen"
  13. encoder_trainable: "lora"
  14. fusion_trainable: "lora"
  15. lora_attn_modules:
  16. - 'q_proj'
  17. - 'v_proj'
  18. - 'output_proj'
  19. apply_lora_to_mlp: true
  20. apply_lora_to_output: false
  21. # Tokenizer / vision transform
  22. tokenizer:
  23. _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  24. path: ./Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
  25. image_size: 560
  26. max_seq_len: 8192
  27. # Checkpointing
  28. checkpointer:
  29. _component_: torchtune.training.FullModelHFCheckpointer
  30. checkpoint_dir: ./Llama-3.2-11B-Vision-Instruct
  31. checkpoint_files:
  32. filename_format: model-{}-of-{}.safetensors
  33. max_filename: "00005"
  34. recipe_checkpoint: null
  35. output_dir: ${output_dir}
  36. model_type: LLAMA3_VISION
  37. resume_from_checkpoint: false
  38. save_adapter_weights_only: false # PeFT formatting not available yet. This will save it in torchtune format only.
  39. # Dataset
  40. dataset:
  41. _component_: torchtune.datasets.multimodal.vqa_dataset
  42. source: arrow
  43. data_files:
  44. # train: "w2_with_input/train/data-00000-of-00001.arrow"
  45. train: "fake_w2_us_tax_form_dataset_train80_test20/train/data-00000-of-00001.arrow"
  46. split: train
  47. column_map:
  48. input: input
  49. output: ground_truth
  50. image: image
  51. # General data handling
  52. seed: null
  53. shuffle: true
  54. collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
  55. # Training loop & hyperparams
  56. # example’s train-control
  57. epochs: 10
  58. max_steps_per_epoch: null
  59. batch_size: 4
  60. gradient_accumulation_steps: 8 # Use to increase effective batch size
  61. # explicit optimizer / scheduler / loss
  62. optimizer:
  63. _component_: torch.optim.AdamW
  64. fused: true
  65. weight_decay: 0.01
  66. lr: 1e-4
  67. lr_scheduler:
  68. _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  69. num_warmup_steps: 100
  70. loss:
  71. _component_: torchtune.modules.loss.LinearCrossEntropyLoss
  72. clip_grad_norm: 1.0
  73. compile: false
  74. # Device & memory
  75. device: cuda
  76. enable_activation_checkpointing: true
  77. dtype: bf16
  78. # Logging
  79. metric_logger:
  80. _component_: torchtune.training.metric_logging.WandBLogger
  81. project: llama3_2_w2_extraction
  82. entity: <your_wandb_entity>
  83. job_type: lora_finetune_single_device
  84. group: llama-cookbook
  85. log_every_n_steps: 5
  86. save_steps: 100
  87. log_peak_memory_stats: true
  88. log_level: INFO
  89. # Profiler (off by default)
  90. profiler:
  91. _component_: torchtune.training.setup_torch_profiler
  92. enabled: false
  93. output_dir: ${output_dir}/profiling_outputs
  94. cpu: true
  95. cuda: true
  96. profile_memory: false
  97. with_stack: false
  98. record_shapes: true
  99. with_flops: false
  100. wait_steps: 5
  101. warmup_steps: 3
  102. active_steps: 2
  103. num_cycles: 1