data_utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """Convolutional Gated Recurrent Networks for Algorithm Learning."""
  2. import math
  3. import random
  4. import sys
  5. import time
  6. import numpy as np
  7. import tensorflow as tf
  8. from tensorflow.python.platform import gfile
  9. FLAGS = tf.app.flags.FLAGS
  10. bins = [8, 16, 32, 64, 128]
  11. all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add",
  12. "right-shift", "bmul", "dup", "badd", "qadd"]
  13. forward_max = 128
  14. log_filename = ""
  15. def pad(l):
  16. for b in bins:
  17. if b >= l: return b
  18. return forward_max
  19. train_set = {}
  20. test_set = {}
  21. for some_task in all_tasks:
  22. train_set[some_task] = []
  23. test_set[some_task] = []
  24. for all_max_len in xrange(10000):
  25. train_set[some_task].append([])
  26. test_set[some_task].append([])
  27. def add(n1, n2, base=10):
  28. """Add two numbers represented as lower-endian digit lists."""
  29. k = max(len(n1), len(n2)) + 1
  30. d1 = n1 + [0 for _ in xrange(k - len(n1))]
  31. d2 = n2 + [0 for _ in xrange(k - len(n2))]
  32. res = []
  33. carry = 0
  34. for i in xrange(k):
  35. if d1[i] + d2[i] + carry < base:
  36. res.append(d1[i] + d2[i] + carry)
  37. carry = 0
  38. else:
  39. res.append(d1[i] + d2[i] + carry - base)
  40. carry = 1
  41. while res and res[-1] == 0:
  42. res = res[:-1]
  43. if res: return res
  44. return [0]
  45. def init_data(task, length, nbr_cases, nclass):
  46. """Data initialization."""
  47. def rand_pair(l, task):
  48. """Random data pair for a task. Total length should be <= l."""
  49. k = (l-1)/2
  50. base = 10
  51. if task[0] == "b": base = 2
  52. if task[0] == "q": base = 4
  53. d1 = [np.random.randint(base) for _ in xrange(k)]
  54. d2 = [np.random.randint(base) for _ in xrange(k)]
  55. if task in ["add", "badd", "qadd"]:
  56. res = add(d1, d2, base)
  57. elif task in ["bmul"]:
  58. d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
  59. d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
  60. res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
  61. else:
  62. sys.exit()
  63. sep = [12]
  64. if task in ["add", "badd", "qadd"]: sep = [11]
  65. inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
  66. return inp, [r + 1 for r in res]
  67. def rand_dup_pair(l):
  68. """Random data pair for duplication task. Total length should be <= l."""
  69. k = l/2
  70. x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
  71. inp = x + [0 for _ in xrange(l - k)]
  72. res = x + x + [0 for _ in xrange(l - 2*k)]
  73. return inp, res
  74. def spec(inp):
  75. """Return the target given the input for some tasks."""
  76. if task == "sort":
  77. return sorted(inp)
  78. elif task == "id":
  79. return inp
  80. elif task == "rev":
  81. return [i for i in reversed(inp)]
  82. elif task == "incr":
  83. carry = 1
  84. res = []
  85. for i in xrange(len(inp)):
  86. if inp[i] + carry < nclass:
  87. res.append(inp[i] + carry)
  88. carry = 0
  89. else:
  90. res.append(1)
  91. carry = 1
  92. return res
  93. elif task == "left":
  94. return [inp[0]]
  95. elif task == "right":
  96. return [inp[-1]]
  97. elif task == "left-shift":
  98. return [inp[l-1] for l in xrange(len(inp))]
  99. elif task == "right-shift":
  100. return [inp[l+1] for l in xrange(len(inp))]
  101. else:
  102. print_out("Unknown spec for task " + str(task))
  103. sys.exit()
  104. l = length
  105. cur_time = time.time()
  106. total_time = 0.0
  107. for case in xrange(nbr_cases):
  108. total_time += time.time() - cur_time
  109. cur_time = time.time()
  110. if l > 10000 and case % 100 == 1:
  111. print_out(" avg gen time %.4f s" % (total_time / float(case)))
  112. if task in ["add", "badd", "qadd", "bmul"]:
  113. i, t = rand_pair(l, task)
  114. train_set[task][len(i)].append([i, t])
  115. i, t = rand_pair(l, task)
  116. test_set[task][len(i)].append([i, t])
  117. elif task == "dup":
  118. i, t = rand_dup_pair(l)
  119. train_set[task][len(i)].append([i, t])
  120. i, t = rand_dup_pair(l)
  121. test_set[task][len(i)].append([i, t])
  122. else:
  123. inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
  124. target = spec(inp)
  125. train_set[task][l].append([inp, target])
  126. inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
  127. target = spec(inp)
  128. test_set[task][l].append([inp, target])
  129. def to_symbol(i):
  130. """Covert ids to text."""
  131. if i == 0: return ""
  132. if i == 11: return "+"
  133. if i == 12: return "*"
  134. return str(i-1)
  135. def to_id(s):
  136. """Covert text to ids."""
  137. if s == "+": return 11
  138. if s == "*": return 12
  139. return int(s) + 1
  140. def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
  141. """Get a batch of data, training or testing."""
  142. inputs = []
  143. targets = []
  144. length = max_length
  145. if preset is None:
  146. cur_set = test_set[task]
  147. if do_train: cur_set = train_set[task]
  148. while not cur_set[length]:
  149. length -= 1
  150. pad_length = pad(length)
  151. for b in xrange(batch_size):
  152. if preset is None:
  153. elem = random.choice(cur_set[length])
  154. if offset is not None and offset + b < len(cur_set[length]):
  155. elem = cur_set[length][offset + b]
  156. else:
  157. elem = preset
  158. inp, target = elem[0], elem[1]
  159. assert len(inp) == length
  160. inputs.append(inp + [0 for l in xrange(pad_length - len(inp))])
  161. targets.append(target + [0 for l in xrange(pad_length - len(target))])
  162. res_input = []
  163. res_target = []
  164. for l in xrange(pad_length):
  165. new_input = np.array([inputs[b][l] for b in xrange(batch_size)],
  166. dtype=np.int32)
  167. new_target = np.array([targets[b][l] for b in xrange(batch_size)],
  168. dtype=np.int32)
  169. res_input.append(new_input)
  170. res_target.append(new_target)
  171. return res_input, res_target
  172. def print_out(s, newline=True):
  173. """Print a message out and log it to file."""
  174. if log_filename:
  175. try:
  176. with gfile.GFile(log_filename, mode="a") as f:
  177. f.write(s + ("\n" if newline else ""))
  178. # pylint: disable=bare-except
  179. except:
  180. sys.stdout.write("Error appending to %s\n" % log_filename)
  181. sys.stdout.write(s + ("\n" if newline else ""))
  182. sys.stdout.flush()
  183. def decode(output):
  184. return [np.argmax(o, axis=1) for o in output]
  185. def accuracy(inpt, output, target, batch_size, nprint):
  186. """Calculate output accuracy given target."""
  187. assert nprint < batch_size + 1
  188. def task_print(inp, output, target):
  189. stop_bound = 0
  190. print_len = 0
  191. while print_len < len(target) and target[print_len] > stop_bound:
  192. print_len += 1
  193. print_out(" i: " + " ".join([str(i - 1) for i in inp if i > 0]))
  194. print_out(" o: " +
  195. " ".join([str(output[l] - 1) for l in xrange(print_len)]))
  196. print_out(" t: " +
  197. " ".join([str(target[l] - 1) for l in xrange(print_len)]))
  198. decoded_target = target
  199. decoded_output = decode(output)
  200. total = 0
  201. errors = 0
  202. seq = [0 for b in xrange(batch_size)]
  203. for l in xrange(len(decoded_output)):
  204. for b in xrange(batch_size):
  205. if decoded_target[l][b] > 0:
  206. total += 1
  207. if decoded_output[l][b] != decoded_target[l][b]:
  208. seq[b] = 1
  209. errors += 1
  210. e = 0 # Previous error index
  211. for _ in xrange(min(nprint, sum(seq))):
  212. while seq[e] == 0:
  213. e += 1
  214. task_print([inpt[l][e] for l in xrange(len(inpt))],
  215. [decoded_output[l][e] for l in xrange(len(decoded_target))],
  216. [decoded_target[l][e] for l in xrange(len(decoded_target))])
  217. e += 1
  218. for b in xrange(nprint - errors):
  219. task_print([inpt[l][b] for l in xrange(len(inpt))],
  220. [decoded_output[l][b] for l in xrange(len(decoded_target))],
  221. [decoded_target[l][b] for l in xrange(len(decoded_target))])
  222. return errors, total, sum(seq)
  223. def safe_exp(x):
  224. perp = 10000
  225. if x < 100: perp = math.exp(x)
  226. if perp > 10000: return 10000
  227. return perp