proto_io.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef SYNTAXNET_PROTO_IO_H_
  13. #define SYNTAXNET_PROTO_IO_H_
  14. #include <iostream>
  15. #include <memory>
  16. #include <string>
  17. #include <vector>
  18. #include "syntaxnet/document_format.h"
  19. #include "syntaxnet/feature_types.h"
  20. #include "syntaxnet/registry.h"
  21. #include "syntaxnet/sentence.pb.h"
  22. #include "syntaxnet/task_context.h"
  23. #include "syntaxnet/utils.h"
  24. #include "syntaxnet/workspace.h"
  25. #include "tensorflow/core/lib/core/errors.h"
  26. #include "tensorflow/core/lib/core/status.h"
  27. #include "tensorflow/core/lib/core/stringpiece.h"
  28. #include "tensorflow/core/lib/io/buffered_inputstream.h"
  29. #include "tensorflow/core/lib/io/random_inputstream.h"
  30. #include "tensorflow/core/lib/io/record_reader.h"
  31. #include "tensorflow/core/lib/io/record_writer.h"
  32. #include "tensorflow/core/lib/strings/strcat.h"
  33. #include "tensorflow/core/platform/env.h"
  34. namespace syntaxnet {
  35. // A convenience wrapper to read protos with a RecordReader.
  36. class ProtoRecordReader {
  37. public:
  38. explicit ProtoRecordReader(tensorflow::RandomAccessFile *file) {
  39. file_.reset(file);
  40. reader_.reset(new tensorflow::io::RecordReader(file_.get()));
  41. }
  42. explicit ProtoRecordReader(const string &filename) {
  43. TF_CHECK_OK(
  44. tensorflow::Env::Default()->NewRandomAccessFile(filename, &file_));
  45. reader_.reset(new tensorflow::io::RecordReader(file_.get()));
  46. }
  47. ~ProtoRecordReader() {
  48. reader_.reset();
  49. }
  50. template <typename T>
  51. tensorflow::Status Read(T *proto) {
  52. string buffer;
  53. tensorflow::Status status = reader_->ReadRecord(&offset_, &buffer);
  54. if (status.ok()) {
  55. CHECK(proto->ParseFromString(buffer));
  56. return tensorflow::Status::OK();
  57. } else {
  58. return status;
  59. }
  60. }
  61. private:
  62. uint64 offset_ = 0;
  63. std::unique_ptr<tensorflow::io::RecordReader> reader_;
  64. std::unique_ptr<tensorflow::RandomAccessFile> file_;
  65. };
  66. // A convenience wrapper to write protos with a RecordReader.
  67. class ProtoRecordWriter {
  68. public:
  69. explicit ProtoRecordWriter(const string &filename) {
  70. TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(filename, &file_));
  71. writer_.reset(new tensorflow::io::RecordWriter(file_.get()));
  72. }
  73. ~ProtoRecordWriter() {
  74. writer_.reset();
  75. file_.reset();
  76. }
  77. template <typename T>
  78. void Write(const T &proto) {
  79. TF_CHECK_OK(writer_->WriteRecord(proto.SerializeAsString()));
  80. }
  81. private:
  82. std::unique_ptr<tensorflow::io::RecordWriter> writer_;
  83. std::unique_ptr<tensorflow::WritableFile> file_;
  84. };
  85. // A file implementation to read from stdin.
  86. class StdIn : public tensorflow::RandomAccessFile {
  87. public:
  88. StdIn() {}
  89. ~StdIn() override {}
  90. // Reads up to n bytes from standard input. Returns `OUT_OF_RANGE` if fewer
  91. // than n bytes were stored in `*result` because of EOF.
  92. tensorflow::Status Read(uint64 offset, size_t n,
  93. tensorflow::StringPiece *result,
  94. char *scratch) const override {
  95. CHECK_EQ(expected_offset_, offset);
  96. if (!eof_) {
  97. string line;
  98. eof_ = !std::getline(std::cin, line);
  99. buffer_.append(line);
  100. buffer_.append("\n");
  101. }
  102. CopyFromBuffer(std::min(buffer_.size(), n), result, scratch);
  103. if (eof_) {
  104. return tensorflow::errors::OutOfRange("End of file reached");
  105. } else {
  106. return tensorflow::Status::OK();
  107. }
  108. }
  109. private:
  110. void CopyFromBuffer(size_t n, tensorflow::StringPiece *result,
  111. char *scratch) const {
  112. memcpy(scratch, buffer_.data(), buffer_.size());
  113. buffer_ = buffer_.substr(n);
  114. *result = tensorflow::StringPiece(scratch, n);
  115. expected_offset_ += n;
  116. }
  117. mutable bool eof_ = false;
  118. mutable int64 expected_offset_ = 0;
  119. mutable string buffer_;
  120. TF_DISALLOW_COPY_AND_ASSIGN(StdIn);
  121. };
  122. // Reads sentence protos from a text file.
  123. class TextReader {
  124. public:
  125. explicit TextReader(const TaskInput &input, TaskContext *context) {
  126. CHECK_EQ(input.record_format_size(), 1)
  127. << "TextReader only supports inputs with one record format: "
  128. << input.DebugString();
  129. CHECK_EQ(input.part_size(), 1)
  130. << "TextReader only supports inputs with one part: "
  131. << input.DebugString();
  132. filename_ = TaskContext::InputFile(input);
  133. format_.reset(DocumentFormat::Create(input.record_format(0)));
  134. format_->Setup(context);
  135. Reset();
  136. }
  137. Sentence *Read() {
  138. // Skips emtpy sentences, e.g., blank lines at the beginning of a file or
  139. // commented out blocks.
  140. std::vector<Sentence *> sentences;
  141. string key, value;
  142. while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) {
  143. key = tensorflow::strings::StrCat(filename_, ":", sentence_count_);
  144. format_->ConvertFromString(key, value, &sentences);
  145. CHECK_LE(sentences.size(), 1);
  146. }
  147. if (sentences.empty()) {
  148. // End of file reached.
  149. return nullptr;
  150. } else {
  151. ++sentence_count_;
  152. return sentences[0];
  153. }
  154. }
  155. void Reset() {
  156. sentence_count_ = 0;
  157. if (filename_ == "-") {
  158. static const int kInputBufferSize = 8 * 1024; /* bytes */
  159. file_.reset(new StdIn());
  160. stream_.reset(new tensorflow::io::RandomAccessInputStream(file_.get()));
  161. buffer_.reset(new tensorflow::io::BufferedInputStream(file_.get(),
  162. kInputBufferSize));
  163. } else {
  164. static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
  165. TF_CHECK_OK(
  166. tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file_));
  167. stream_.reset(new tensorflow::io::RandomAccessInputStream(file_.get()));
  168. buffer_.reset(new tensorflow::io::BufferedInputStream(file_.get(),
  169. kInputBufferSize));
  170. }
  171. }
  172. private:
  173. string filename_;
  174. int sentence_count_ = 0;
  175. std::unique_ptr<tensorflow::RandomAccessFile>
  176. file_; // must outlive buffer_, stream_
  177. std::unique_ptr<tensorflow::io::RandomAccessInputStream>
  178. stream_; // Must outlive buffer_
  179. std::unique_ptr<tensorflow::io::BufferedInputStream> buffer_;
  180. std::unique_ptr<DocumentFormat> format_;
  181. };
  182. // Writes sentence protos to a text conll file.
  183. class TextWriter {
  184. public:
  185. explicit TextWriter(const TaskInput &input, TaskContext *context) {
  186. CHECK_EQ(input.record_format_size(), 1)
  187. << "TextWriter only supports files with one record format: "
  188. << input.DebugString();
  189. CHECK_EQ(input.part_size(), 1)
  190. << "TextWriter only supports files with one part: "
  191. << input.DebugString();
  192. filename_ = TaskContext::InputFile(input);
  193. format_.reset(DocumentFormat::Create(input.record_format(0)));
  194. format_->Setup(context);
  195. if (filename_ != "-") {
  196. TF_CHECK_OK(
  197. tensorflow::Env::Default()->NewWritableFile(filename_, &file_));
  198. }
  199. }
  200. ~TextWriter() {
  201. if (file_) {
  202. TF_CHECK_OK(file_->Close());
  203. }
  204. }
  205. void Write(const Sentence &sentence) {
  206. string key, value;
  207. format_->ConvertToString(sentence, &key, &value);
  208. if (file_) {
  209. TF_CHECK_OK(file_->Append(value));
  210. } else {
  211. std::cout << value;
  212. }
  213. }
  214. private:
  215. string filename_;
  216. std::unique_ptr<DocumentFormat> format_;
  217. std::unique_ptr<tensorflow::WritableFile> file_;
  218. };
  219. } // namespace syntaxnet
  220. #endif // SYNTAXNET_PROTO_IO_H_