BUILD 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Domain Separation Networks
  2. package(
  3. default_visibility = [
  4. ":internal",
  5. ],
  6. )
  7. licenses(["notice"]) # Apache 2.0
  8. exports_files(["LICENSE"])
  9. package_group(
  10. name = "internal",
  11. packages = [
  12. "//domain_adaptation/...",
  13. ],
  14. )
  15. py_library(
  16. name = "models",
  17. srcs = [
  18. "models.py",
  19. ],
  20. deps = [
  21. ":utils",
  22. ],
  23. )
  24. py_library(
  25. name = "losses",
  26. srcs = [
  27. "losses.py",
  28. ],
  29. deps = [
  30. ":grl_op_grads_py",
  31. # ":grl_op_kernels",
  32. ":grl_op_shapes_py",
  33. ":grl_ops",
  34. # ":grl_ops_py",
  35. ":utils",
  36. ],
  37. )
  38. py_test(
  39. name = "losses_test",
  40. srcs = [
  41. "losses_test.py",
  42. ],
  43. deps = [
  44. ":losses",
  45. ":utils",
  46. ],
  47. )
  48. py_library(
  49. name = "dsn",
  50. srcs = [
  51. "dsn.py",
  52. ],
  53. deps = [
  54. ":grl_op_grads_py",
  55. #":grl_op_kernels",
  56. ":grl_op_shapes_py",
  57. ":grl_ops",
  58. #":grl_ops_py",
  59. ":losses",
  60. ":models",
  61. ":utils",
  62. ],
  63. )
  64. py_test(
  65. name = "dsn_test",
  66. srcs = [
  67. "dsn_test.py",
  68. ],
  69. deps = [
  70. ":dsn",
  71. ],
  72. )
  73. py_binary(
  74. name = "dsn_train",
  75. srcs = [
  76. "dsn_train.py",
  77. ],
  78. deps = [
  79. ":dsn",
  80. ":models",
  81. "//domain_adaptation/datasets:dataset_factory",
  82. ],
  83. )
  84. py_binary(
  85. name = "dsn_eval",
  86. srcs = [
  87. "dsn_eval.py",
  88. ],
  89. deps = [
  90. ":dsn",
  91. ":models",
  92. "//domain_adaptation/datasets:dataset_factory",
  93. ],
  94. )
  95. py_test(
  96. name = "models_test",
  97. srcs = [
  98. "models_test.py",
  99. ],
  100. deps = [
  101. ":models",
  102. "//domain_adaptation/datasets:dataset_factory",
  103. ],
  104. )
  105. py_library(
  106. name = "utils",
  107. srcs = [
  108. "utils.py",
  109. ],
  110. deps = [
  111. ],
  112. )
  113. py_library(
  114. name = "grl_op_grads_py",
  115. srcs = [
  116. "grl_op_grads.py",
  117. ],
  118. deps = [
  119. ":grl_ops",
  120. ],
  121. )
  122. py_library(
  123. name = "grl_op_shapes_py",
  124. srcs = [
  125. "grl_op_shapes.py",
  126. ],
  127. deps = [
  128. ],
  129. )
  130. py_library(
  131. name = "grl_ops",
  132. srcs = ["grl_ops.py"],
  133. data = ["_grl_ops.so"],
  134. )
  135. #cc_library(
  136. # name = "grl_ops",
  137. # srcs = ["grl_ops.cc"],
  138. # deps = ["//tensorflow/core:framework"],
  139. # alwayslink = 1,
  140. #)
  141. #tf_gen_op_wrapper_py(
  142. # name = "grl_ops_py",
  143. # out = "grl_ops.py",
  144. # deps = [":grl_ops"],
  145. #)
  146. #cc_library(
  147. # name = "grl_op_kernels",
  148. # srcs = ["grl_op_kernels.cc"],
  149. # deps = [
  150. # "//tensorflow/core:framework",
  151. # "//tensorflow/core:protos_all",
  152. # ],
  153. # alwayslink = 1,
  154. #)
  155. py_test(
  156. name = "grl_ops_test",
  157. size = "small",
  158. srcs = ["grl_ops_test.py"],
  159. deps = [
  160. ":grl_op_grads_py",
  161. # ":grl_op_kernels",
  162. ":grl_op_shapes_py",
  163. ":grl_ops",
  164. #":grl_ops_py",
  165. ],
  166. )