finetune_retriever_distributed.sh 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/bin/bash
  2. # Finetune a BERT or pretrained ICT model using Google natural question data
  3. # Datasets can be downloaded from the following link:
  4. # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
  5. WORLD_SIZE=8
  6. DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
  7. --nnodes 1 \
  8. --node_rank 0 \
  9. --master_addr localhost \
  10. --master_port 6000"
  11. CHECKPOINT_PATH=<Specify path for the finetuned retriever model>
  12. # Load either of the below
  13. BERT_LOAD_PATH=<Path of BERT pretrained model>
  14. PRETRAINED_CHECKPOINT=<Path of Pretrained ICT model>
  15. python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
  16. --task RET-FINETUNE-NQ \
  17. --train-with-neg \
  18. --train-hard-neg 1 \
  19. --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
  20. --num-layers 12 \
  21. --hidden-size 768 \
  22. --num-attention-heads 12 \
  23. --tensor-model-parallel-size 1 \
  24. --tokenizer-type BertWordPieceLowerCase \
  25. --train-data nq-train.json \
  26. --valid-data nq-dev.json \
  27. --save ${CHECKPOINT_PATH} \
  28. --load ${CHECKPOINT_PATH} \
  29. --vocab-file bert-vocab.txt \
  30. --bert-load ${BERT_LOAD_PATH} \
  31. --save-interval 5000 \
  32. --log-interval 10 \
  33. --eval-interval 20000 \
  34. --eval-iters 100 \
  35. --indexer-log-interval 1000 \
  36. --faiss-use-gpu \
  37. --DDP-impl torch \
  38. --fp16 \
  39. --retriever-report-topk-accuracies 1 5 10 20 100 \
  40. --seq-length 512 \
  41. --retriever-seq-length 256 \
  42. --max-position-embeddings 512 \
  43. --retriever-score-scaling \
  44. --epochs 80 \
  45. --micro-batch-size 8 \
  46. --eval-micro-batch-size 16 \
  47. --indexer-batch-size 128 \
  48. --lr 2e-5 \
  49. --lr-warmup-fraction 0.01 \
  50. --weight-decay 1e-1