sentence_batch.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef SYNTAXNET_SENTENCE_BATCH_H_
  13. #define SYNTAXNET_SENTENCE_BATCH_H_
  14. #include <memory>
  15. #include <string>
  16. #include <vector>
  17. #include "syntaxnet/embedding_feature_extractor.h"
  18. #include "syntaxnet/feature_extractor.h"
  19. #include "syntaxnet/parser_state.h"
  20. #include "syntaxnet/parser_transitions.h"
  21. #include "syntaxnet/sentence.pb.h"
  22. #include "syntaxnet/sparse.pb.h"
  23. #include "syntaxnet/task_context.h"
  24. #include "syntaxnet/task_spec.pb.h"
  25. #include "syntaxnet/term_frequency_map.h"
  26. namespace syntaxnet {
  27. // Helper class to manage generating batches of preprocessed ParserState objects
  28. // by reading in multiple sentences in parallel.
  29. class SentenceBatch {
  30. public:
  31. SentenceBatch(int batch_size, string input_name)
  32. : batch_size_(batch_size),
  33. input_name_(input_name),
  34. sentences_(batch_size) {}
  35. // Initializes all resources and opens the corpus file.
  36. void Init(TaskContext *context);
  37. // Advances the index'th sentence in the batch to the next sentence. This will
  38. // create and preprocess a new ParserState for that element. Returns false if
  39. // EOF is reached (if EOF, also sets the state to be nullptr.)
  40. bool AdvanceSentence(int index);
  41. // Rewinds the corpus reader.
  42. void Rewind() { reader_->Reset(); }
  43. int size() const { return size_; }
  44. Sentence *sentence(int index) { return sentences_[index].get(); }
  45. private:
  46. // Running tally of non-nullptr states in the batch.
  47. int size_;
  48. // Maximum number of states in the batch.
  49. int batch_size_;
  50. // Input to read from the TaskContext.
  51. string input_name_;
  52. // Reader for the corpus.
  53. std::unique_ptr<TextReader> reader_;
  54. // Batch: Sentence objects.
  55. std::vector<std::unique_ptr<Sentence>> sentences_;
  56. };
  57. } // namespace syntaxnet
  58. #endif // SYNTAXNET_SENTENCE_BATCH_H_