syntaxnet.bzl 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. load("@protobuf//:protobuf.bzl", "cc_proto_library")
  16. load("@protobuf//:protobuf.bzl", "py_proto_library")
  17. def if_cuda(if_true, if_false = []):
  18. """Shorthand for select()'ing on whether we're building with CUDA."""
  19. return select({
  20. "@local_config_cuda//cuda:using_nvcc": if_true,
  21. "@local_config_cuda//cuda:using_clang": if_true,
  22. "//conditions:default": if_false
  23. })
  24. def tf_copts():
  25. return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] +
  26. if_cuda(["-DGOOGLE_CUDA=1"]) +
  27. select({"@org_tensorflow//tensorflow:darwin": [],
  28. "//conditions:default": ["-pthread"]}))
  29. def tf_proto_library(name, srcs=[], has_services=False,
  30. deps=[], visibility=None, testonly=0,
  31. cc_api_version=2, go_api_version=2,
  32. java_api_version=2,
  33. py_api_version=2):
  34. native.filegroup(name=name + "_proto_srcs",
  35. srcs=srcs,
  36. testonly=testonly,)
  37. cc_proto_library(name=name,
  38. srcs=srcs,
  39. deps=deps,
  40. cc_libs = ["@protobuf//:protobuf"],
  41. protoc="@protobuf//:protoc",
  42. default_runtime="@protobuf//:protobuf",
  43. testonly=testonly,
  44. visibility=visibility,)
  45. def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
  46. py_proto_library(name=name,
  47. srcs=srcs,
  48. srcs_version = "PY2AND3",
  49. deps=deps,
  50. default_runtime="@protobuf//:protobuf_python",
  51. protoc="@protobuf//:protoc",
  52. visibility=visibility,
  53. testonly=testonly,)
  54. # Given a list of "op_lib_names" (a list of files in the ops directory
  55. # without their .cc extensions), generate a library for that file.
  56. def tf_gen_op_libs(op_lib_names):
  57. # Make library out of each op so it can also be used to generate wrappers
  58. # for various languages.
  59. for n in op_lib_names:
  60. native.cc_library(name=n + "_op_lib",
  61. copts=tf_copts(),
  62. srcs=["ops/" + n + ".cc"],
  63. deps=(["@org_tensorflow//tensorflow/core:framework"]),
  64. visibility=["//visibility:public"],
  65. alwayslink=1,
  66. linkstatic=1,)
  67. # Invoke this rule in .../tensorflow/python to build the wrapper library.
  68. def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
  69. require_shape_functions=False):
  70. # Construct a cc_binary containing the specified ops.
  71. tool_name = "gen_" + name + "_py_wrappers_cc"
  72. if not deps:
  73. deps = ["//tensorflow/core:" + name + "_op_lib"]
  74. native.cc_binary(
  75. name = tool_name,
  76. linkopts = ["-lm"],
  77. copts = tf_copts(),
  78. linkstatic = 1, # Faster to link this one-time-use binary dynamically
  79. deps = (["@org_tensorflow//tensorflow/core:framework",
  80. "@org_tensorflow//tensorflow/python:python_op_gen_main"] + deps),
  81. )
  82. # Invoke the previous cc_binary to generate a python file.
  83. if not out:
  84. out = "ops/gen_" + name + ".py"
  85. native.genrule(
  86. name=name + "_pygenrule",
  87. outs=[out],
  88. tools=[tool_name],
  89. cmd=("$(location " + tool_name + ") " + ",".join(hidden)
  90. + " " + ("1" if require_shape_functions else "0") + " > $@"))
  91. # Make a py_library out of the generated python file.
  92. native.py_library(name=name,
  93. srcs=[out],
  94. srcs_version="PY2AND3",
  95. visibility=visibility,
  96. deps=[
  97. "@org_tensorflow//tensorflow/python:framework_for_generated_wrappers",
  98. ],)