mock_component.h 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
  17. #include <gmock/gmock.h>
  18. #include "dragnn/components/util/bulk_feature_extractor.h"
  19. #include "dragnn/core/index_translator.h"
  20. #include "dragnn/core/interfaces/component.h"
  21. #include "dragnn/core/interfaces/transition_state.h"
  22. #include "dragnn/protos/data.pb.h"
  23. #include "dragnn/protos/spec.pb.h"
  24. #include "syntaxnet/base.h"
  25. #include "tensorflow/core/platform/test.h"
  26. namespace syntaxnet {
  27. namespace dragnn {
  28. class MockComponent : public Component {
  29. public:
  30. MOCK_METHOD1(InitializeComponent, void(const ComponentSpec &spec));
  31. MOCK_METHOD3(
  32. InitializeData,
  33. void(const std::vector<std::vector<const TransitionState *>> &states,
  34. int max_beam_size, InputBatchCache *input_data));
  35. MOCK_CONST_METHOD0(IsReady, bool());
  36. MOCK_METHOD0(InitializeTracing, void());
  37. MOCK_METHOD0(DisableTracing, void());
  38. MOCK_CONST_METHOD0(Name, string());
  39. MOCK_CONST_METHOD0(BatchSize, int());
  40. MOCK_CONST_METHOD0(BeamSize, int());
  41. MOCK_CONST_METHOD1(StepsTaken, int(int batch_index));
  42. MOCK_CONST_METHOD3(GetBeamIndexAtStep,
  43. int(int step, int current_index, int batch));
  44. MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch));
  45. MOCK_METHOD2(AdvanceFromPrediction,
  46. void(const float transition_matrix[], int matrix_length));
  47. MOCK_METHOD0(AdvanceFromOracle, void());
  48. MOCK_CONST_METHOD0(IsTerminal, bool());
  49. MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>());
  50. MOCK_CONST_METHOD4(GetFixedFeatures,
  51. int(std::function<int32 *(int)> allocate_indices,
  52. std::function<int64 *(int)> allocate_ids,
  53. std::function<float *(int)> allocate_weights,
  54. int channel_id));
  55. MOCK_METHOD1(BulkGetFixedFeatures,
  56. int(const BulkFeatureExtractor &extractor));
  57. MOCK_CONST_METHOD1(GetRawLinkFeatures,
  58. std::vector<LinkFeatures>(int channel_id));
  59. MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
  60. MOCK_METHOD0(ResetComponent, void());
  61. MOCK_METHOD1(GetStepLookupFunction,
  62. std::function<int(int, int, int)>(const string &method));
  63. MOCK_METHOD0(FinalizeData, void());
  64. MOCK_CONST_METHOD0(GetTraceProtos,
  65. std::vector<std::vector<ComponentTrace>>());
  66. MOCK_METHOD2(AddTranslatedLinkFeaturesToTrace,
  67. void(const std::vector<LinkFeatures> &features, int channel_id));
  68. };
  69. } // namespace dragnn
  70. } // namespace syntaxnet
  71. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_