BUILD 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Description:
  2. # Contains files for loading, training and evaluating TF-Slim-based models.
  3. package(default_visibility = [":internal"])
  4. licenses(["notice"]) # Apache 2.0
  5. exports_files(["LICENSE"])
  6. package_group(name = "internal")
  7. py_library(
  8. name = "dataset_utils",
  9. srcs = ["datasets/dataset_utils.py"],
  10. )
  11. py_library(
  12. name = "download_and_convert_cifar10",
  13. srcs = ["datasets/download_and_convert_cifar10.py"],
  14. deps = [":dataset_utils"],
  15. )
  16. py_library(
  17. name = "download_and_convert_flowers",
  18. srcs = ["datasets/download_and_convert_flowers.py"],
  19. deps = [":dataset_utils"],
  20. )
  21. py_library(
  22. name = "download_and_convert_mnist",
  23. srcs = ["datasets/download_and_convert_mnist.py"],
  24. deps = [":dataset_utils"],
  25. )
  26. py_binary(
  27. name = "download_and_convert_data",
  28. srcs = ["download_and_convert_data.py"],
  29. deps = [
  30. ":download_and_convert_cifar10",
  31. ":download_and_convert_flowers",
  32. ":download_and_convert_mnist",
  33. ],
  34. )
  35. py_binary(
  36. name = "cifar10",
  37. srcs = ["datasets/cifar10.py"],
  38. deps = [":dataset_utils"],
  39. )
  40. py_binary(
  41. name = "flowers",
  42. srcs = ["datasets/flowers.py"],
  43. deps = [":dataset_utils"],
  44. )
  45. py_binary(
  46. name = "imagenet",
  47. srcs = ["datasets/imagenet.py"],
  48. deps = [":dataset_utils"],
  49. )
  50. py_binary(
  51. name = "mnist",
  52. srcs = ["datasets/mnist.py"],
  53. deps = [":dataset_utils"],
  54. )
  55. py_library(
  56. name = "dataset_factory",
  57. srcs = ["datasets/dataset_factory.py"],
  58. deps = [
  59. ":cifar10",
  60. ":flowers",
  61. ":imagenet",
  62. ":mnist",
  63. ],
  64. )
  65. py_library(
  66. name = "model_deploy",
  67. srcs = ["deployment/model_deploy.py"],
  68. )
  69. py_test(
  70. name = "model_deploy_test",
  71. srcs = ["deployment/model_deploy_test.py"],
  72. srcs_version = "PY2AND3",
  73. deps = [":model_deploy"],
  74. )
  75. py_library(
  76. name = "cifarnet_preprocessing",
  77. srcs = ["preprocessing/cifarnet_preprocessing.py"],
  78. )
  79. py_library(
  80. name = "inception_preprocessing",
  81. srcs = ["preprocessing/inception_preprocessing.py"],
  82. )
  83. py_library(
  84. name = "lenet_preprocessing",
  85. srcs = ["preprocessing/lenet_preprocessing.py"],
  86. )
  87. py_library(
  88. name = "vgg_preprocessing",
  89. srcs = ["preprocessing/vgg_preprocessing.py"],
  90. )
  91. py_library(
  92. name = "preprocessing_factory",
  93. srcs = ["preprocessing/preprocessing_factory.py"],
  94. deps = [
  95. ":cifarnet_preprocessing",
  96. ":inception_preprocessing",
  97. ":lenet_preprocessing",
  98. ":vgg_preprocessing",
  99. ],
  100. )
  101. # Typical networks definitions.
  102. py_library(
  103. name = "nets",
  104. deps = [
  105. ":alexnet",
  106. ":cifarnet",
  107. ":inception",
  108. ":lenet",
  109. ":overfeat",
  110. ":resnet_v1",
  111. ":resnet_v2",
  112. ":vgg",
  113. ],
  114. )
  115. py_library(
  116. name = "alexnet",
  117. srcs = ["nets/alexnet.py"],
  118. srcs_version = "PY2AND3",
  119. )
  120. py_test(
  121. name = "alexnet_test",
  122. size = "medium",
  123. srcs = ["nets/alexnet_test.py"],
  124. srcs_version = "PY2AND3",
  125. deps = [":alexnet"],
  126. )
  127. py_library(
  128. name = "cifarnet",
  129. srcs = ["nets/cifarnet.py"],
  130. )
  131. py_library(
  132. name = "inception",
  133. srcs = ["nets/inception.py"],
  134. srcs_version = "PY2AND3",
  135. deps = [
  136. ":inception_resnet_v2",
  137. ":inception_v1",
  138. ":inception_v2",
  139. ":inception_v3",
  140. ],
  141. )
  142. py_library(
  143. name = "inception_v1",
  144. srcs = ["nets/inception_v1.py"],
  145. srcs_version = "PY2AND3",
  146. )
  147. py_library(
  148. name = "inception_v2",
  149. srcs = ["nets/inception_v2.py"],
  150. srcs_version = "PY2AND3",
  151. )
  152. py_library(
  153. name = "inception_v3",
  154. srcs = ["nets/inception_v3.py"],
  155. srcs_version = "PY2AND3",
  156. )
  157. py_library(
  158. name = "inception_resnet_v2",
  159. srcs = ["nets/inception_resnet_v2.py"],
  160. srcs_version = "PY2AND3",
  161. )
  162. py_test(
  163. name = "inception_v1_test",
  164. size = "large",
  165. srcs = ["nets/inception_v1_test.py"],
  166. shard_count = 3,
  167. srcs_version = "PY2AND3",
  168. deps = [":inception"],
  169. )
  170. py_test(
  171. name = "inception_v2_test",
  172. size = "large",
  173. srcs = ["nets/inception_v2_test.py"],
  174. shard_count = 3,
  175. srcs_version = "PY2AND3",
  176. deps = [":inception"],
  177. )
  178. py_test(
  179. name = "inception_v3_test",
  180. size = "large",
  181. srcs = ["nets/inception_v3_test.py"],
  182. shard_count = 3,
  183. srcs_version = "PY2AND3",
  184. deps = [":inception"],
  185. )
  186. py_test(
  187. name = "inception_resnet_v2_test",
  188. size = "large",
  189. srcs = ["nets/inception_resnet_v2_test.py"],
  190. shard_count = 3,
  191. srcs_version = "PY2AND3",
  192. deps = [":inception"],
  193. )
  194. py_library(
  195. name = "lenet",
  196. srcs = ["nets/lenet.py"],
  197. )
  198. py_library(
  199. name = "overfeat",
  200. srcs = ["nets/overfeat.py"],
  201. srcs_version = "PY2AND3",
  202. )
  203. py_test(
  204. name = "overfeat_test",
  205. size = "medium",
  206. srcs = ["nets/overfeat_test.py"],
  207. srcs_version = "PY2AND3",
  208. deps = [":overfeat"],
  209. )
  210. py_library(
  211. name = "resnet_utils",
  212. srcs = ["nets/resnet_utils.py"],
  213. srcs_version = "PY2AND3",
  214. )
  215. py_library(
  216. name = "resnet_v1",
  217. srcs = ["nets/resnet_v1.py"],
  218. srcs_version = "PY2AND3",
  219. deps = [
  220. ":resnet_utils",
  221. ],
  222. )
  223. py_test(
  224. name = "resnet_v1_test",
  225. size = "medium",
  226. srcs = ["nets/resnet_v1_test.py"],
  227. srcs_version = "PY2AND3",
  228. deps = [":resnet_v1"],
  229. )
  230. py_library(
  231. name = "resnet_v2",
  232. srcs = ["nets/resnet_v2.py"],
  233. srcs_version = "PY2AND3",
  234. deps = [
  235. ":resnet_utils",
  236. ],
  237. )
  238. py_test(
  239. name = "resnet_v2_test",
  240. size = "medium",
  241. srcs = ["nets/resnet_v2_test.py"],
  242. srcs_version = "PY2AND3",
  243. deps = [":resnet_v2"],
  244. )
  245. py_library(
  246. name = "vgg",
  247. srcs = ["nets/vgg.py"],
  248. srcs_version = "PY2AND3",
  249. )
  250. py_test(
  251. name = "vgg_test",
  252. size = "medium",
  253. srcs = ["nets/vgg_test.py"],
  254. srcs_version = "PY2AND3",
  255. deps = [":vgg"],
  256. )
  257. py_library(
  258. name = "nets_factory",
  259. srcs = ["nets/nets_factory.py"],
  260. deps = [":nets"],
  261. )
  262. py_test(
  263. name = "nets_factory_test",
  264. size = "medium",
  265. srcs = ["nets/nets_factory_test.py"],
  266. srcs_version = "PY2AND3",
  267. deps = [":nets_factory"],
  268. )
  269. py_binary(
  270. name = "train_image_classifier",
  271. srcs = ["train_image_classifier.py"],
  272. deps = [
  273. ":dataset_factory",
  274. ":model_deploy",
  275. ":nets_factory",
  276. ":preprocessing_factory",
  277. ],
  278. )
  279. py_binary(
  280. name = "eval_image_classifier",
  281. srcs = ["eval_image_classifier.py"],
  282. deps = [
  283. ":dataset_factory",
  284. ":model_deploy",
  285. ":nets_factory",
  286. ":preprocessing_factory",
  287. ],
  288. )