syntaxnet.bzl 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. load("@tf//google/protobuf:protobuf.bzl", "cc_proto_library")
  16. load("@tf//google/protobuf:protobuf.bzl", "py_proto_library")
  17. def if_cuda(a, b=[]):
  18. return select({
  19. "@tf//third_party/gpus/cuda:cuda_crosstool_condition": a,
  20. "//conditions:default": b,
  21. })
  22. def tf_copts():
  23. return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] +
  24. if_cuda(["-DGOOGLE_CUDA=1"]) +
  25. select({"@tf//tensorflow:darwin": [],
  26. "//conditions:default": ["-pthread"]}))
  27. def tf_proto_library(name, srcs=[], has_services=False,
  28. deps=[], visibility=None, testonly=0,
  29. cc_api_version=2, go_api_version=2,
  30. java_api_version=2,
  31. py_api_version=2):
  32. native.filegroup(name=name + "_proto_srcs",
  33. srcs=srcs,
  34. testonly=testonly,)
  35. cc_proto_library(name=name,
  36. srcs=srcs,
  37. deps=deps,
  38. cc_libs = ["@tf//google/protobuf:protobuf"],
  39. protoc="@tf//google/protobuf:protoc",
  40. default_runtime="@tf//google/protobuf:protobuf",
  41. testonly=testonly,
  42. visibility=visibility,)
  43. def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
  44. py_proto_library(name=name,
  45. srcs=srcs,
  46. srcs_version = "PY2AND3",
  47. deps=deps,
  48. default_runtime="@tf//google/protobuf:protobuf_python",
  49. protoc="@tf//google/protobuf:protoc",
  50. visibility=visibility,
  51. testonly=testonly,)
  52. # Given a list of "op_lib_names" (a list of files in the ops directory
  53. # without their .cc extensions), generate a library for that file.
  54. def tf_gen_op_libs(op_lib_names):
  55. # Make library out of each op so it can also be used to generate wrappers
  56. # for various languages.
  57. for n in op_lib_names:
  58. native.cc_library(name=n + "_op_lib",
  59. copts=tf_copts(),
  60. srcs=["ops/" + n + ".cc"],
  61. deps=(["@tf//tensorflow/core:framework"]),
  62. visibility=["//visibility:public"],
  63. alwayslink=1,
  64. linkstatic=1,)
  65. # Invoke this rule in .../tensorflow/python to build the wrapper library.
  66. def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
  67. require_shape_functions=False):
  68. # Construct a cc_binary containing the specified ops.
  69. tool_name = "gen_" + name + "_py_wrappers_cc"
  70. if not deps:
  71. deps = ["//tensorflow/core:" + name + "_op_lib"]
  72. native.cc_binary(
  73. name = tool_name,
  74. linkopts = ["-lm"],
  75. copts = tf_copts(),
  76. linkstatic = 1, # Faster to link this one-time-use binary dynamically
  77. deps = (["@tf//tensorflow/core:framework",
  78. "@tf//tensorflow/python:python_op_gen_main"] + deps),
  79. )
  80. # Invoke the previous cc_binary to generate a python file.
  81. if not out:
  82. out = "ops/gen_" + name + ".py"
  83. native.genrule(
  84. name=name + "_pygenrule",
  85. outs=[out],
  86. tools=[tool_name],
  87. cmd=("$(location " + tool_name + ") " + ",".join(hidden)
  88. + " " + ("1" if require_shape_functions else "0") + " > $@"))
  89. # Make a py_library out of the generated python file.
  90. native.py_library(name=name,
  91. srcs=[out],
  92. srcs_version="PY2AND3",
  93. visibility=visibility,
  94. deps=[
  95. "@tf//tensorflow/python:framework_for_generated_wrappers",
  96. ],)