BUILD 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Description:
  2. # Contains files for loading, training and evaluating TF-Slim-based models.
  3. package(default_visibility = [
  4. ":internal",
  5. "//domain_adaptation:__subpackages__",
  6. ])
  7. licenses(["notice"]) # Apache 2.0
  8. exports_files(["LICENSE"])
  9. package_group(name = "internal")
  10. py_library(
  11. name = "dataset_utils",
  12. srcs = ["datasets/dataset_utils.py"],
  13. )
  14. py_library(
  15. name = "download_and_convert_cifar10",
  16. srcs = ["datasets/download_and_convert_cifar10.py"],
  17. deps = [":dataset_utils"],
  18. )
  19. py_library(
  20. name = "download_and_convert_flowers",
  21. srcs = ["datasets/download_and_convert_flowers.py"],
  22. deps = [":dataset_utils"],
  23. )
  24. py_library(
  25. name = "download_and_convert_mnist",
  26. srcs = ["datasets/download_and_convert_mnist.py"],
  27. deps = [":dataset_utils"],
  28. )
  29. py_binary(
  30. name = "download_and_convert_data",
  31. srcs = ["download_and_convert_data.py"],
  32. deps = [
  33. ":download_and_convert_cifar10",
  34. ":download_and_convert_flowers",
  35. ":download_and_convert_mnist",
  36. ],
  37. )
  38. py_binary(
  39. name = "cifar10",
  40. srcs = ["datasets/cifar10.py"],
  41. deps = [":dataset_utils"],
  42. )
  43. py_binary(
  44. name = "flowers",
  45. srcs = ["datasets/flowers.py"],
  46. deps = [":dataset_utils"],
  47. )
  48. py_binary(
  49. name = "imagenet",
  50. srcs = ["datasets/imagenet.py"],
  51. deps = [":dataset_utils"],
  52. )
  53. py_binary(
  54. name = "mnist",
  55. srcs = ["datasets/mnist.py"],
  56. deps = [":dataset_utils"],
  57. )
  58. py_library(
  59. name = "dataset_factory",
  60. srcs = ["datasets/dataset_factory.py"],
  61. deps = [
  62. ":cifar10",
  63. ":flowers",
  64. ":imagenet",
  65. ":mnist",
  66. ],
  67. )
  68. py_library(
  69. name = "model_deploy",
  70. srcs = ["deployment/model_deploy.py"],
  71. )
  72. py_test(
  73. name = "model_deploy_test",
  74. srcs = ["deployment/model_deploy_test.py"],
  75. srcs_version = "PY2AND3",
  76. deps = [":model_deploy"],
  77. )
  78. py_library(
  79. name = "cifarnet_preprocessing",
  80. srcs = ["preprocessing/cifarnet_preprocessing.py"],
  81. )
  82. py_library(
  83. name = "inception_preprocessing",
  84. srcs = ["preprocessing/inception_preprocessing.py"],
  85. )
  86. py_library(
  87. name = "lenet_preprocessing",
  88. srcs = ["preprocessing/lenet_preprocessing.py"],
  89. )
  90. py_library(
  91. name = "vgg_preprocessing",
  92. srcs = ["preprocessing/vgg_preprocessing.py"],
  93. )
  94. py_library(
  95. name = "preprocessing_factory",
  96. srcs = ["preprocessing/preprocessing_factory.py"],
  97. deps = [
  98. ":cifarnet_preprocessing",
  99. ":inception_preprocessing",
  100. ":lenet_preprocessing",
  101. ":vgg_preprocessing",
  102. ],
  103. )
  104. # Typical networks definitions.
  105. py_library(
  106. name = "nets",
  107. deps = [
  108. ":alexnet",
  109. ":cifarnet",
  110. ":inception",
  111. ":lenet",
  112. ":overfeat",
  113. ":resnet_v1",
  114. ":resnet_v2",
  115. ":vgg",
  116. ],
  117. )
  118. py_library(
  119. name = "alexnet",
  120. srcs = ["nets/alexnet.py"],
  121. srcs_version = "PY2AND3",
  122. )
  123. py_test(
  124. name = "alexnet_test",
  125. size = "medium",
  126. srcs = ["nets/alexnet_test.py"],
  127. srcs_version = "PY2AND3",
  128. deps = [":alexnet"],
  129. )
  130. py_library(
  131. name = "cifarnet",
  132. srcs = ["nets/cifarnet.py"],
  133. )
  134. py_library(
  135. name = "inception",
  136. srcs = ["nets/inception.py"],
  137. srcs_version = "PY2AND3",
  138. deps = [
  139. ":inception_resnet_v2",
  140. ":inception_v1",
  141. ":inception_v2",
  142. ":inception_v3",
  143. ":inception_v4",
  144. ],
  145. )
  146. py_library(
  147. name = "inception_utils",
  148. srcs = ["nets/inception_utils.py"],
  149. srcs_version = "PY2AND3",
  150. )
  151. py_library(
  152. name = "inception_v1",
  153. srcs = ["nets/inception_v1.py"],
  154. srcs_version = "PY2AND3",
  155. deps = [
  156. ":inception_utils",
  157. ],
  158. )
  159. py_library(
  160. name = "inception_v2",
  161. srcs = ["nets/inception_v2.py"],
  162. srcs_version = "PY2AND3",
  163. deps = [
  164. ":inception_utils",
  165. ],
  166. )
  167. py_library(
  168. name = "inception_v3",
  169. srcs = ["nets/inception_v3.py"],
  170. srcs_version = "PY2AND3",
  171. deps = [
  172. ":inception_utils",
  173. ],
  174. )
  175. py_library(
  176. name = "inception_v4",
  177. srcs = ["nets/inception_v4.py"],
  178. srcs_version = "PY2AND3",
  179. deps = [
  180. ":inception_utils",
  181. ],
  182. )
  183. py_library(
  184. name = "inception_resnet_v2",
  185. srcs = ["nets/inception_resnet_v2.py"],
  186. srcs_version = "PY2AND3",
  187. )
  188. py_test(
  189. name = "inception_v1_test",
  190. size = "large",
  191. srcs = ["nets/inception_v1_test.py"],
  192. shard_count = 3,
  193. srcs_version = "PY2AND3",
  194. deps = [":inception"],
  195. )
  196. py_test(
  197. name = "inception_v2_test",
  198. size = "large",
  199. srcs = ["nets/inception_v2_test.py"],
  200. shard_count = 3,
  201. srcs_version = "PY2AND3",
  202. deps = [":inception"],
  203. )
  204. py_test(
  205. name = "inception_v3_test",
  206. size = "large",
  207. srcs = ["nets/inception_v3_test.py"],
  208. shard_count = 3,
  209. srcs_version = "PY2AND3",
  210. deps = [":inception"],
  211. )
  212. py_test(
  213. name = "inception_v4_test",
  214. size = "large",
  215. srcs = ["nets/inception_v4_test.py"],
  216. shard_count = 3,
  217. srcs_version = "PY2AND3",
  218. deps = [":inception"],
  219. )
  220. py_test(
  221. name = "inception_resnet_v2_test",
  222. size = "large",
  223. srcs = ["nets/inception_resnet_v2_test.py"],
  224. shard_count = 3,
  225. srcs_version = "PY2AND3",
  226. deps = [":inception"],
  227. )
  228. py_library(
  229. name = "lenet",
  230. srcs = ["nets/lenet.py"],
  231. )
  232. py_library(
  233. name = "overfeat",
  234. srcs = ["nets/overfeat.py"],
  235. srcs_version = "PY2AND3",
  236. )
  237. py_test(
  238. name = "overfeat_test",
  239. size = "medium",
  240. srcs = ["nets/overfeat_test.py"],
  241. srcs_version = "PY2AND3",
  242. deps = [":overfeat"],
  243. )
  244. py_library(
  245. name = "resnet_utils",
  246. srcs = ["nets/resnet_utils.py"],
  247. srcs_version = "PY2AND3",
  248. )
  249. py_library(
  250. name = "resnet_v1",
  251. srcs = ["nets/resnet_v1.py"],
  252. srcs_version = "PY2AND3",
  253. deps = [
  254. ":resnet_utils",
  255. ],
  256. )
  257. py_test(
  258. name = "resnet_v1_test",
  259. size = "medium",
  260. srcs = ["nets/resnet_v1_test.py"],
  261. srcs_version = "PY2AND3",
  262. deps = [":resnet_v1"],
  263. )
  264. py_library(
  265. name = "resnet_v2",
  266. srcs = ["nets/resnet_v2.py"],
  267. srcs_version = "PY2AND3",
  268. deps = [
  269. ":resnet_utils",
  270. ],
  271. )
  272. py_test(
  273. name = "resnet_v2_test",
  274. size = "medium",
  275. srcs = ["nets/resnet_v2_test.py"],
  276. srcs_version = "PY2AND3",
  277. deps = [":resnet_v2"],
  278. )
  279. py_library(
  280. name = "vgg",
  281. srcs = ["nets/vgg.py"],
  282. srcs_version = "PY2AND3",
  283. )
  284. py_test(
  285. name = "vgg_test",
  286. size = "medium",
  287. srcs = ["nets/vgg_test.py"],
  288. srcs_version = "PY2AND3",
  289. deps = [":vgg"],
  290. )
  291. py_library(
  292. name = "nets_factory",
  293. srcs = ["nets/nets_factory.py"],
  294. deps = [":nets"],
  295. )
  296. py_test(
  297. name = "nets_factory_test",
  298. size = "medium",
  299. srcs = ["nets/nets_factory_test.py"],
  300. srcs_version = "PY2AND3",
  301. deps = [":nets_factory"],
  302. )
  303. py_binary(
  304. name = "train_image_classifier",
  305. srcs = ["train_image_classifier.py"],
  306. deps = [
  307. ":dataset_factory",
  308. ":model_deploy",
  309. ":nets_factory",
  310. ":preprocessing_factory",
  311. ],
  312. )
  313. py_binary(
  314. name = "eval_image_classifier",
  315. srcs = ["eval_image_classifier.py"],
  316. deps = [
  317. ":dataset_factory",
  318. ":model_deploy",
  319. ":nets_factory",
  320. ":preprocessing_factory",
  321. ],
  322. )