mock_component.h 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
  3. #include <gmock/gmock.h>
  4. #include "dragnn/components/util/bulk_feature_extractor.h"
  5. #include "dragnn/core/index_translator.h"
  6. #include "dragnn/core/interfaces/component.h"
  7. #include "dragnn/core/interfaces/transition_state.h"
  8. #include "dragnn/protos/data.pb.h"
  9. #include "dragnn/protos/spec.pb.h"
  10. #include "syntaxnet/base.h"
  11. #include "tensorflow/core/platform/test.h"
  12. namespace syntaxnet {
  13. namespace dragnn {
  14. class MockComponent : public Component {
  15. public:
  16. MOCK_METHOD1(InitializeComponent, void(const ComponentSpec &spec));
  17. MOCK_METHOD3(
  18. InitializeData,
  19. void(const std::vector<std::vector<const TransitionState *>> &states,
  20. int max_beam_size, InputBatchCache *input_data));
  21. MOCK_CONST_METHOD0(IsReady, bool());
  22. MOCK_METHOD0(InitializeTracing, void());
  23. MOCK_METHOD0(DisableTracing, void());
  24. MOCK_CONST_METHOD0(Name, string());
  25. MOCK_CONST_METHOD0(BatchSize, int());
  26. MOCK_CONST_METHOD0(BeamSize, int());
  27. MOCK_CONST_METHOD1(StepsTaken, int(int batch_index));
  28. MOCK_CONST_METHOD3(GetBeamIndexAtStep,
  29. int(int step, int current_index, int batch));
  30. MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch));
  31. MOCK_METHOD2(AdvanceFromPrediction,
  32. void(const float transition_matrix[], int matrix_length));
  33. MOCK_METHOD0(AdvanceFromOracle, void());
  34. MOCK_CONST_METHOD0(IsTerminal, bool());
  35. MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>());
  36. MOCK_CONST_METHOD4(GetFixedFeatures,
  37. int(std::function<int32 *(int)> allocate_indices,
  38. std::function<int64 *(int)> allocate_ids,
  39. std::function<float *(int)> allocate_weights,
  40. int channel_id));
  41. MOCK_METHOD1(BulkGetFixedFeatures,
  42. int(const BulkFeatureExtractor &extractor));
  43. MOCK_CONST_METHOD1(GetRawLinkFeatures,
  44. std::vector<LinkFeatures>(int channel_id));
  45. MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
  46. MOCK_METHOD0(ResetComponent, void());
  47. MOCK_METHOD1(GetStepLookupFunction,
  48. std::function<int(int, int, int)>(const string &method));
  49. MOCK_METHOD0(FinalizeData, void());
  50. MOCK_CONST_METHOD0(GetTraceProtos,
  51. std::vector<std::vector<ComponentTrace>>());
  52. MOCK_METHOD2(AddTranslatedLinkFeaturesToTrace,
  53. void(const std::vector<LinkFeatures> &features, int channel_id));
  54. };
  55. } // namespace dragnn
  56. } // namespace syntaxnet
  57. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_