evaluate.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Script to evaluate a skip-thoughts model.
  16. This script can evaluate a model with a unidirectional encoder ("uni-skip" in
  17. the paper); or a model with a bidirectional encoder ("bi-skip"); or the
  18. combination of a model with a unidirectional encoder and a model with a
  19. bidirectional encoder ("combine-skip").
  20. The uni-skip model (if it exists) is specified by the flags
  21. --uni_vocab_file, --uni_embeddings_file, --uni_checkpoint_path.
  22. The bi-skip model (if it exists) is specified by the flags
  23. --bi_vocab_file, --bi_embeddings_path, --bi_checkpoint_path.
  24. The evaluation tasks have different running times. SICK may take 5-10 minutes.
  25. MSRP, TREC and CR may take 20-60 minutes. SUBJ, MPQA and MR may take 2+ hours.
  26. """
  27. from __future__ import absolute_import
  28. from __future__ import division
  29. from __future__ import print_function
  30. from skipthoughts import eval_classification
  31. from skipthoughts import eval_msrp
  32. from skipthoughts import eval_sick
  33. from skipthoughts import eval_trec
  34. import tensorflow as tf
  35. from skip_thoughts import configuration
  36. from skip_thoughts import encoder_manager
  37. FLAGS = tf.flags.FLAGS
  38. tf.flags.DEFINE_string("eval_task", "CR",
  39. "Name of the evaluation task to run. Available tasks: "
  40. "MR, CR, SUBJ, MPQA, SICK, MSRP, TREC.")
  41. tf.flags.DEFINE_string("data_dir", None, "Directory containing training data.")
  42. tf.flags.DEFINE_string("uni_vocab_file", None,
  43. "Path to vocabulary file containing a list of newline-"
  44. "separated words where the word id is the "
  45. "corresponding 0-based index in the file.")
  46. tf.flags.DEFINE_string("bi_vocab_file", None,
  47. "Path to vocabulary file containing a list of newline-"
  48. "separated words where the word id is the "
  49. "corresponding 0-based index in the file.")
  50. tf.flags.DEFINE_string("uni_embeddings_file", None,
  51. "Path to serialized numpy array of shape "
  52. "[vocab_size, embedding_dim].")
  53. tf.flags.DEFINE_string("bi_embeddings_file", None,
  54. "Path to serialized numpy array of shape "
  55. "[vocab_size, embedding_dim].")
  56. tf.flags.DEFINE_string("uni_checkpoint_path", None,
  57. "Checkpoint file or directory containing a checkpoint "
  58. "file.")
  59. tf.flags.DEFINE_string("bi_checkpoint_path", None,
  60. "Checkpoint file or directory containing a checkpoint "
  61. "file.")
  62. tf.logging.set_verbosity(tf.logging.INFO)
  63. def main(unused_argv):
  64. if not FLAGS.data_dir:
  65. raise ValueError("--data_dir is required.")
  66. encoder = encoder_manager.EncoderManager()
  67. # Maybe load unidirectional encoder.
  68. if FLAGS.uni_checkpoint_path:
  69. print("Loading unidirectional model...")
  70. uni_config = configuration.model_config()
  71. encoder.load_model(uni_config, FLAGS.uni_vocab_file,
  72. FLAGS.uni_embeddings_file, FLAGS.uni_checkpoint_path)
  73. # Maybe load bidirectional encoder.
  74. if FLAGS.bi_checkpoint_path:
  75. print("Loading bidirectional model...")
  76. bi_config = configuration.model_config(bidirectional_encoder=True)
  77. encoder.load_model(bi_config, FLAGS.bi_vocab_file, FLAGS.bi_embeddings_file,
  78. FLAGS.bi_checkpoint_path)
  79. if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
  80. eval_classification.eval_nested_kfold(
  81. encoder, FLAGS.eval_task, FLAGS.data_dir, use_nb=False)
  82. elif FLAGS.eval_task == "SICK":
  83. eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
  84. elif FLAGS.eval_task == "MSRP":
  85. eval_msrp.evaluate(
  86. encoder, evalcv=True, evaltest=True, use_feats=True, loc=FLAGS.data_dir)
  87. elif FLAGS.eval_task == "TREC":
  88. eval_trec.evaluate(encoder, evalcv=True, evaltest=True, loc=FLAGS.data_dir)
  89. else:
  90. raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)
  91. encoder.close()
  92. if __name__ == "__main__":
  93. tf.app.run()