train_lenet_on_mnist.sh 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #!/bin/bash
  2. #
  3. # This script performs the following operations:
  4. # 1. Downloads the MNIST dataset
  5. # 2. Trains a LeNet model on the MNIST training set.
  6. # 3. Evaluates the model on the MNIST testing set.
  7. #
  8. # Usage:
  9. # cd slim
  10. # ./slim/scripts/train_lenet_on_mnist.sh
  11. # Where the checkpoint and logs will be saved to.
  12. TRAIN_DIR=/tmp/lenet-model
  13. # Where the dataset is saved to.
  14. DATASET_DIR=/tmp/mnist
  15. # Download the dataset
  16. python download_and_convert_data.py \
  17. --dataset_name=mnist \
  18. --dataset_dir=${DATASET_DIR}
  19. # Run training.
  20. python train_image_classifier.py \
  21. --train_dir=${TRAIN_DIR} \
  22. --dataset_name=mnist \
  23. --dataset_split_name=train \
  24. --dataset_dir=${DATASET_DIR} \
  25. --model_name=lenet \
  26. --preprocessing_name=lenet \
  27. --max_number_of_steps=20000 \
  28. --batch_size=50 \
  29. --learning_rate=0.01 \
  30. --save_interval_secs=60 \
  31. --save_summaries_secs=60 \
  32. --log_every_n_steps=100 \
  33. --optimizer=sgd \
  34. --learning_rate_decay_type=fixed \
  35. --weight_decay=0
  36. # Run evaluation.
  37. python eval_image_classifier.py \
  38. --checkpoint_path=${TRAIN_DIR} \
  39. --eval_dir=${TRAIN_DIR} \
  40. --dataset_name=mnist \
  41. --dataset_split_name=test \
  42. --dataset_dir=${DATASET_DIR} \
  43. --model_name=lenet