data_utils.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Convolutional Gated Recurrent Networks for Algorithm Learning."""
  16. import math
  17. import random
  18. import sys
  19. import time
  20. import numpy as np
  21. import tensorflow as tf
  22. from tensorflow.python.platform import gfile
  23. FLAGS = tf.app.flags.FLAGS
  24. bins = [8, 12, 16, 20, 24, 28, 32, 36, 40, 48, 64, 128]
  25. all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
  26. "right", "left-shift", "right-shift", "bmul", "mul", "dup",
  27. "badd", "qadd", "search"]
  28. forward_max = 128
  29. log_filename = ""
  30. def pad(l):
  31. for b in bins:
  32. if b >= l: return b
  33. return forward_max
  34. train_set = {}
  35. test_set = {}
  36. for some_task in all_tasks:
  37. train_set[some_task] = []
  38. test_set[some_task] = []
  39. for all_max_len in xrange(10000):
  40. train_set[some_task].append([])
  41. test_set[some_task].append([])
  42. def add(n1, n2, base=10):
  43. """Add two numbers represented as lower-endian digit lists."""
  44. k = max(len(n1), len(n2)) + 1
  45. d1 = n1 + [0 for _ in xrange(k - len(n1))]
  46. d2 = n2 + [0 for _ in xrange(k - len(n2))]
  47. res = []
  48. carry = 0
  49. for i in xrange(k):
  50. if d1[i] + d2[i] + carry < base:
  51. res.append(d1[i] + d2[i] + carry)
  52. carry = 0
  53. else:
  54. res.append(d1[i] + d2[i] + carry - base)
  55. carry = 1
  56. while res and res[-1] == 0:
  57. res = res[:-1]
  58. if res: return res
  59. return [0]
  60. def init_data(task, length, nbr_cases, nclass):
  61. """Data initialization."""
  62. def rand_pair(l, task):
  63. """Random data pair for a task. Total length should be <= l."""
  64. k = (l-1)/2
  65. base = 10
  66. if task[0] == "b": base = 2
  67. if task[0] == "q": base = 4
  68. d1 = [np.random.randint(base) for _ in xrange(k)]
  69. d2 = [np.random.randint(base) for _ in xrange(k)]
  70. if task in ["add", "badd", "qadd"]:
  71. res = add(d1, d2, base)
  72. elif task in ["mul", "bmul"]:
  73. d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
  74. d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
  75. if task == "bmul":
  76. res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
  77. else:
  78. res = [int(x) for x in list(reversed(str(d1n * d2n)))]
  79. else:
  80. sys.exit()
  81. sep = [12]
  82. if task in ["add", "badd", "qadd"]: sep = [11]
  83. inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
  84. return inp, [r + 1 for r in res]
  85. def rand_dup_pair(l):
  86. """Random data pair for duplication task. Total length should be <= l."""
  87. k = l/2
  88. x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
  89. inp = x + [0 for _ in xrange(l - k)]
  90. res = x + x + [0 for _ in xrange(l - 2*k)]
  91. return inp, res
  92. def rand_rev2_pair(l):
  93. """Random data pair for reverse2 task. Total length should be <= l."""
  94. inp = [(np.random.randint(nclass - 1) + 1,
  95. np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)]
  96. res = [i for i in reversed(inp)]
  97. return [x for p in inp for x in p], [x for p in res for x in p]
  98. def rand_search_pair(l):
  99. """Random data pair for search task. Total length should be <= l."""
  100. inp = [(np.random.randint(nclass - 1) + 1,
  101. np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)]
  102. q = np.random.randint(nclass - 1) + 1
  103. res = 0
  104. for (k, v) in reversed(inp):
  105. if k == q:
  106. res = v
  107. return [x for p in inp for x in p] + [q], [res]
  108. def rand_kvsort_pair(l):
  109. """Random data pair for key-value sort. Total length should be <= l."""
  110. keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)]
  111. vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)]
  112. kv = [(k, vals[i]) for (k, i) in keys]
  113. sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
  114. return [x for p in kv for x in p], [x for p in sorted_kv for x in p]
  115. def spec(inp):
  116. """Return the target given the input for some tasks."""
  117. if task == "sort":
  118. return sorted(inp)
  119. elif task == "id":
  120. return inp
  121. elif task == "rev":
  122. return [i for i in reversed(inp)]
  123. elif task == "incr":
  124. carry = 1
  125. res = []
  126. for i in xrange(len(inp)):
  127. if inp[i] + carry < nclass:
  128. res.append(inp[i] + carry)
  129. carry = 0
  130. else:
  131. res.append(1)
  132. carry = 1
  133. return res
  134. elif task == "left":
  135. return [inp[0]]
  136. elif task == "right":
  137. return [inp[-1]]
  138. elif task == "left-shift":
  139. return [inp[l-1] for l in xrange(len(inp))]
  140. elif task == "right-shift":
  141. return [inp[l+1] for l in xrange(len(inp))]
  142. else:
  143. print_out("Unknown spec for task " + str(task))
  144. sys.exit()
  145. l = length
  146. cur_time = time.time()
  147. total_time = 0.0
  148. for case in xrange(nbr_cases):
  149. total_time += time.time() - cur_time
  150. cur_time = time.time()
  151. if l > 10000 and case % 100 == 1:
  152. print_out(" avg gen time %.4f s" % (total_time / float(case)))
  153. if task in ["add", "badd", "qadd", "bmul", "mul"]:
  154. i, t = rand_pair(l, task)
  155. train_set[task][len(i)].append([i, t])
  156. i, t = rand_pair(l, task)
  157. test_set[task][len(i)].append([i, t])
  158. elif task == "dup":
  159. i, t = rand_dup_pair(l)
  160. train_set[task][len(i)].append([i, t])
  161. i, t = rand_dup_pair(l)
  162. test_set[task][len(i)].append([i, t])
  163. elif task == "rev2":
  164. i, t = rand_rev2_pair(l)
  165. train_set[task][len(i)].append([i, t])
  166. i, t = rand_rev2_pair(l)
  167. test_set[task][len(i)].append([i, t])
  168. elif task == "search":
  169. i, t = rand_search_pair(l)
  170. train_set[task][len(i)].append([i, t])
  171. i, t = rand_search_pair(l)
  172. test_set[task][len(i)].append([i, t])
  173. elif task == "kvsort":
  174. i, t = rand_kvsort_pair(l)
  175. train_set[task][len(i)].append([i, t])
  176. i, t = rand_kvsort_pair(l)
  177. test_set[task][len(i)].append([i, t])
  178. else:
  179. inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
  180. target = spec(inp)
  181. train_set[task][l].append([inp, target])
  182. inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
  183. target = spec(inp)
  184. test_set[task][l].append([inp, target])
  185. def to_symbol(i):
  186. """Covert ids to text."""
  187. if i == 0: return ""
  188. if i == 11: return "+"
  189. if i == 12: return "*"
  190. return str(i-1)
  191. def to_id(s):
  192. """Covert text to ids."""
  193. if s == "+": return 11
  194. if s == "*": return 12
  195. return int(s) + 1
  196. def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
  197. """Get a batch of data, training or testing."""
  198. inputs = []
  199. targets = []
  200. length = max_length
  201. if preset is None:
  202. cur_set = test_set[task]
  203. if do_train: cur_set = train_set[task]
  204. while not cur_set[length]:
  205. length -= 1
  206. pad_length = pad(length)
  207. for b in xrange(batch_size):
  208. if preset is None:
  209. elem = random.choice(cur_set[length])
  210. if offset is not None and offset + b < len(cur_set[length]):
  211. elem = cur_set[length][offset + b]
  212. else:
  213. elem = preset
  214. inp, target = elem[0], elem[1]
  215. assert len(inp) == length
  216. inputs.append(inp + [0 for l in xrange(pad_length - len(inp))])
  217. targets.append(target + [0 for l in xrange(pad_length - len(target))])
  218. res_input = []
  219. res_target = []
  220. for l in xrange(pad_length):
  221. new_input = np.array([inputs[b][l] for b in xrange(batch_size)],
  222. dtype=np.int32)
  223. new_target = np.array([targets[b][l] for b in xrange(batch_size)],
  224. dtype=np.int32)
  225. res_input.append(new_input)
  226. res_target.append(new_target)
  227. return res_input, res_target
  228. def print_out(s, newline=True):
  229. """Print a message out and log it to file."""
  230. if log_filename:
  231. try:
  232. with gfile.GFile(log_filename, mode="a") as f:
  233. f.write(s + ("\n" if newline else ""))
  234. # pylint: disable=bare-except
  235. except:
  236. sys.stdout.write("Error appending to %s\n" % log_filename)
  237. sys.stdout.write(s + ("\n" if newline else ""))
  238. sys.stdout.flush()
  239. def decode(output):
  240. return [np.argmax(o, axis=1) for o in output]
  241. def accuracy(inpt, output, target, batch_size, nprint):
  242. """Calculate output accuracy given target."""
  243. assert nprint < batch_size + 1
  244. def task_print(inp, output, target):
  245. stop_bound = 0
  246. print_len = 0
  247. while print_len < len(target) and target[print_len] > stop_bound:
  248. print_len += 1
  249. print_out(" i: " + " ".join([str(i - 1) for i in inp if i > 0]))
  250. print_out(" o: " +
  251. " ".join([str(output[l] - 1) for l in xrange(print_len)]))
  252. print_out(" t: " +
  253. " ".join([str(target[l] - 1) for l in xrange(print_len)]))
  254. decoded_target = target
  255. decoded_output = decode(output)
  256. total = 0
  257. errors = 0
  258. seq = [0 for b in xrange(batch_size)]
  259. for l in xrange(len(decoded_output)):
  260. for b in xrange(batch_size):
  261. if decoded_target[l][b] > 0:
  262. total += 1
  263. if decoded_output[l][b] != decoded_target[l][b]:
  264. seq[b] = 1
  265. errors += 1
  266. e = 0 # Previous error index
  267. for _ in xrange(min(nprint, sum(seq))):
  268. while seq[e] == 0:
  269. e += 1
  270. task_print([inpt[l][e] for l in xrange(len(inpt))],
  271. [decoded_output[l][e] for l in xrange(len(decoded_target))],
  272. [decoded_target[l][e] for l in xrange(len(decoded_target))])
  273. e += 1
  274. for b in xrange(nprint - errors):
  275. task_print([inpt[l][b] for l in xrange(len(inpt))],
  276. [decoded_output[l][b] for l in xrange(len(decoded_target))],
  277. [decoded_target[l][b] for l in xrange(len(decoded_target))])
  278. return errors, total, sum(seq)
  279. def safe_exp(x):
  280. perp = 10000
  281. if x < 100: perp = math.exp(x)
  282. if perp > 10000: return 10000
  283. return perp