BUILD 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Domain Separation Networks
  2. package(
  3. default_visibility = [
  4. ":internal",
  5. ],
  6. )
  7. licenses(["notice"]) # Apache 2.0
  8. exports_files(["LICENSE"])
  9. package_group(
  10. name = "internal",
  11. packages = [
  12. "//domain_adaptation/...",
  13. ],
  14. )
  15. py_library(
  16. name = "models",
  17. srcs = [
  18. "models.py",
  19. ],
  20. deps = [
  21. ":utils",
  22. ],
  23. )
  24. py_library(
  25. name = "losses",
  26. srcs = [
  27. "losses.py",
  28. ],
  29. deps = [
  30. ":grl_op_grads_py",
  31. # ":grl_op_kernels",
  32. ":grl_op_shapes_py",
  33. ":grl_ops",
  34. # ":grl_ops_py",
  35. ":utils",
  36. ],
  37. )
  38. py_test(
  39. name = "losses_test",
  40. srcs = [
  41. "losses_test.py",
  42. ],
  43. deps = [
  44. ":losses",
  45. ":utils",
  46. ],
  47. )
  48. py_library(
  49. name = "dsn",
  50. srcs = [
  51. "dsn.py",
  52. ],
  53. deps = [
  54. ":grl_op_grads_py",
  55. ":grl_op_shapes_py",
  56. ":grl_ops",
  57. ":losses",
  58. ":models",
  59. ":utils",
  60. ],
  61. )
  62. py_test(
  63. name = "dsn_test",
  64. srcs = [
  65. "dsn_test.py",
  66. ],
  67. deps = [
  68. ":dsn",
  69. ],
  70. )
  71. py_binary(
  72. name = "dsn_train",
  73. srcs = [
  74. "dsn_train.py",
  75. ],
  76. deps = [
  77. ":dsn",
  78. ":models",
  79. "//domain_adaptation/datasets:dataset_factory",
  80. ],
  81. )
  82. py_binary(
  83. name = "dsn_eval",
  84. srcs = [
  85. "dsn_eval.py",
  86. ],
  87. deps = [
  88. ":dsn",
  89. ":models",
  90. "//domain_adaptation/datasets:dataset_factory",
  91. ],
  92. )
  93. py_test(
  94. name = "models_test",
  95. srcs = [
  96. "models_test.py",
  97. ],
  98. deps = [
  99. ":models",
  100. "//domain_adaptation/datasets:dataset_factory",
  101. ],
  102. )
  103. py_library(
  104. name = "utils",
  105. srcs = [
  106. "utils.py",
  107. ],
  108. deps = [
  109. ],
  110. )
  111. py_library(
  112. name = "grl_op_grads_py",
  113. srcs = [
  114. "grl_op_grads.py",
  115. ],
  116. deps = [
  117. ":grl_ops",
  118. ],
  119. )
  120. py_library(
  121. name = "grl_op_shapes_py",
  122. srcs = [
  123. "grl_op_shapes.py",
  124. ],
  125. deps = [
  126. ],
  127. )
  128. py_library(
  129. name = "grl_ops",
  130. srcs = ["grl_ops.py"],
  131. data = ["_grl_ops.so"],
  132. )
  133. py_test(
  134. name = "grl_ops_test",
  135. size = "small",
  136. srcs = ["grl_ops_test.py"],
  137. deps = [
  138. ":grl_op_grads_py",
  139. ":grl_op_shapes_py",
  140. ":grl_ops",
  141. ],
  142. )