123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- # Copyright 2015 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Neural GPU -- data generation and batching utilities."""
- import math
- import os
- import random
- import sys
- import time
- import numpy as np
- import tensorflow as tf
- import program_utils
- FLAGS = tf.app.flags.FLAGS
- bins = [2 + bin_idx_i for bin_idx_i in xrange(256)]
- all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
- "right", "left-shift", "right-shift", "bmul", "mul", "dup",
- "badd", "qadd", "search", "progeval", "progsynth"]
- log_filename = ""
- vocab, rev_vocab = None, None
- def pad(l):
- for b in bins:
- if b >= l: return b
- return bins[-1]
- def bin_for(l):
- for i, b in enumerate(bins):
- if b >= l: return i
- return len(bins) - 1
- train_set = {}
- test_set = {}
- for some_task in all_tasks:
- train_set[some_task] = []
- test_set[some_task] = []
- for all_max_len in xrange(10000):
- train_set[some_task].append([])
- test_set[some_task].append([])
- def read_tmp_file(name):
- """Read from a file with the given name in our log directory or above."""
- dirname = os.path.dirname(log_filename)
- fname = os.path.join(dirname, name + ".txt")
- if not tf.gfile.Exists(fname):
- print_out("== not found file: " + fname)
- fname = os.path.join(dirname, "../" + name + ".txt")
- if not tf.gfile.Exists(fname):
- print_out("== not found file: " + fname)
- fname = os.path.join(dirname, "../../" + name + ".txt")
- if not tf.gfile.Exists(fname):
- print_out("== not found file: " + fname)
- return None
- print_out("== found file: " + fname)
- res = []
- with tf.gfile.GFile(fname, mode="r") as f:
- for line in f:
- res.append(line.strip())
- return res
- def write_tmp_file(name, lines):
- dirname = os.path.dirname(log_filename)
- fname = os.path.join(dirname, name + ".txt")
- with tf.gfile.GFile(fname, mode="w") as f:
- for line in lines:
- f.write(line + "\n")
- def add(n1, n2, base=10):
- """Add two numbers represented as lower-endian digit lists."""
- k = max(len(n1), len(n2)) + 1
- d1 = n1 + [0 for _ in xrange(k - len(n1))]
- d2 = n2 + [0 for _ in xrange(k - len(n2))]
- res = []
- carry = 0
- for i in xrange(k):
- if d1[i] + d2[i] + carry < base:
- res.append(d1[i] + d2[i] + carry)
- carry = 0
- else:
- res.append(d1[i] + d2[i] + carry - base)
- carry = 1
- while res and res[-1] == 0:
- res = res[:-1]
- if res: return res
- return [0]
- def init_data(task, length, nbr_cases, nclass):
- """Data initialization."""
- def rand_pair(l, task):
- """Random data pair for a task. Total length should be <= l."""
- k = (l-1)/2
- base = 10
- if task[0] == "b": base = 2
- if task[0] == "q": base = 4
- d1 = [np.random.randint(base) for _ in xrange(k)]
- d2 = [np.random.randint(base) for _ in xrange(k)]
- if task in ["add", "badd", "qadd"]:
- res = add(d1, d2, base)
- elif task in ["mul", "bmul"]:
- d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
- d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
- if task == "bmul":
- res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
- else:
- res = [int(x) for x in list(reversed(str(d1n * d2n)))]
- else:
- sys.exit()
- sep = [12]
- if task in ["add", "badd", "qadd"]: sep = [11]
- inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
- return inp, [r + 1 for r in res]
- def rand_dup_pair(l):
- """Random data pair for duplication task. Total length should be <= l."""
- k = l/2
- x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
- inp = x + [0 for _ in xrange(l - k)]
- res = x + x + [0 for _ in xrange(l - 2*k)]
- return inp, res
- def rand_rev2_pair(l):
- """Random data pair for reverse2 task. Total length should be <= l."""
- inp = [(np.random.randint(nclass - 1) + 1,
- np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)]
- res = [i for i in reversed(inp)]
- return [x for p in inp for x in p], [x for p in res for x in p]
- def rand_search_pair(l):
- """Random data pair for search task. Total length should be <= l."""
- inp = [(np.random.randint(nclass - 1) + 1,
- np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)]
- q = np.random.randint(nclass - 1) + 1
- res = 0
- for (k, v) in reversed(inp):
- if k == q:
- res = v
- return [x for p in inp for x in p] + [q], [res]
- def rand_kvsort_pair(l):
- """Random data pair for key-value sort. Total length should be <= l."""
- keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)]
- vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)]
- kv = [(k, vals[i]) for (k, i) in keys]
- sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
- return [x for p in kv for x in p], [x for p in sorted_kv for x in p]
- def prog_io_pair(prog, max_len, counter=0):
- try:
- ilen = np.random.randint(max_len - 3) + 1
- bound = max(15 - (counter / 20), 1)
- inp = [random.choice(range(-bound, bound)) for _ in range(ilen)]
- inp_toks = [program_utils.prog_rev_vocab[t]
- for t in program_utils.tokenize(str(inp)) if t != ","]
- out = program_utils.evaluate(prog, {"a": inp})
- out_toks = [program_utils.prog_rev_vocab[t]
- for t in program_utils.tokenize(str(out)) if t != ","]
- if counter > 400:
- out_toks = []
- if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and
- len(out_toks) != len([o for o in out if o == ","]) + 3):
- raise ValueError("generated list with too long ints")
- if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and
- len(out_toks) > 1):
- raise ValueError("generated one int but tokenized it to many")
- if len(out_toks) > max_len:
- raise ValueError("output too long")
- return (inp_toks, out_toks)
- except ValueError:
- return prog_io_pair(prog, max_len, counter+1)
- def spec(inp):
- """Return the target given the input for some tasks."""
- if task == "sort":
- return sorted(inp)
- elif task == "id":
- return inp
- elif task == "rev":
- return [i for i in reversed(inp)]
- elif task == "incr":
- carry = 1
- res = []
- for i in xrange(len(inp)):
- if inp[i] + carry < nclass:
- res.append(inp[i] + carry)
- carry = 0
- else:
- res.append(1)
- carry = 1
- return res
- elif task == "left":
- return [inp[0]]
- elif task == "right":
- return [inp[-1]]
- elif task == "left-shift":
- return [inp[l-1] for l in xrange(len(inp))]
- elif task == "right-shift":
- return [inp[l+1] for l in xrange(len(inp))]
- else:
- print_out("Unknown spec for task " + str(task))
- sys.exit()
- l = length
- cur_time = time.time()
- total_time = 0.0
- is_prog = task in ["progeval", "progsynth"]
- if is_prog:
- inputs_per_prog = 5
- program_utils.make_vocab()
- progs = read_tmp_file("programs_len%d" % (l / 10))
- if not progs:
- progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog)
- write_tmp_file("programs_len%d" % (l / 10), progs)
- prog_ios = read_tmp_file("programs_len%d_io" % (l / 10))
- nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
- if not prog_ios:
- # Generate program io data.
- prog_ios = []
- for pidx, prog in enumerate(progs):
- if pidx % 500 == 0:
- print_out("== generating io pairs for program %d" % pidx)
- if pidx * inputs_per_prog > nbr_cases * 1.2:
- break
- ptoks = [program_utils.prog_rev_vocab[t]
- for t in program_utils.tokenize(prog)]
- ptoks.append(program_utils.prog_rev_vocab["_EOS"])
- plen = len(ptoks)
- for _ in xrange(inputs_per_prog):
- if task == "progeval":
- inp, out = prog_io_pair(prog, plen)
- prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
- elif task == "progsynth":
- plen = max(len(ptoks), 8)
- for _ in xrange(3):
- inp, out = prog_io_pair(prog, plen / 2)
- prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
- write_tmp_file("programs_len%d_io" % (l / 10), prog_ios)
- prog_ios_dict = {}
- for s in prog_ios:
- i, o, p = s.split("\t")
- i_clean = "".join([c for c in i if c.isdigit() or c == " "])
- o_clean = "".join([c for c in o if c.isdigit() or c == " "])
- inp = [int(x) for x in i_clean.split()]
- out = [int(x) for x in o_clean.split()]
- if inp and out:
- if p in prog_ios_dict:
- prog_ios_dict[p].append([inp, out])
- else:
- prog_ios_dict[p] = [[inp, out]]
- # Use prog_ios_dict to create data.
- progs = []
- for prog in prog_ios_dict:
- if len([c for c in prog if c == ";"]) <= (l / 10):
- progs.append(prog)
- nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
- print_out("== %d training cases on %d progs" % (nbr_cases, len(progs)))
- for pidx, prog in enumerate(progs):
- if pidx * inputs_per_prog > nbr_cases * 1.2:
- break
- ptoks = [program_utils.prog_rev_vocab[t]
- for t in program_utils.tokenize(prog)]
- ptoks.append(program_utils.prog_rev_vocab["_EOS"])
- plen = len(ptoks)
- dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set
- for _ in xrange(inputs_per_prog):
- if task == "progeval":
- inp, out = prog_ios_dict[prog].pop()
- dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]])
- elif task == "progsynth":
- plen, ilist = max(len(ptoks), 8), [[]]
- for _ in xrange(3):
- inp, out = prog_ios_dict[prog].pop()
- ilist.append(inp + out)
- dset[task][bin_for(plen)].append([ilist, [ptoks]])
- for case in xrange(0 if is_prog else nbr_cases):
- total_time += time.time() - cur_time
- cur_time = time.time()
- if l > 10000 and case % 100 == 1:
- print_out(" avg gen time %.4f s" % (total_time / float(case)))
- if task in ["add", "badd", "qadd", "bmul", "mul"]:
- i, t = rand_pair(l, task)
- train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
- i, t = rand_pair(l, task)
- test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
- elif task == "dup":
- i, t = rand_dup_pair(l)
- train_set[task][bin_for(len(i))].append([[i], [t]])
- i, t = rand_dup_pair(l)
- test_set[task][bin_for(len(i))].append([[i], [t]])
- elif task == "rev2":
- i, t = rand_rev2_pair(l)
- train_set[task][bin_for(len(i))].append([[i], [t]])
- i, t = rand_rev2_pair(l)
- test_set[task][bin_for(len(i))].append([[i], [t]])
- elif task == "search":
- i, t = rand_search_pair(l)
- train_set[task][bin_for(len(i))].append([[i], [t]])
- i, t = rand_search_pair(l)
- test_set[task][bin_for(len(i))].append([[i], [t]])
- elif task == "kvsort":
- i, t = rand_kvsort_pair(l)
- train_set[task][bin_for(len(i))].append([[i], [t]])
- i, t = rand_kvsort_pair(l)
- test_set[task][bin_for(len(i))].append([[i], [t]])
- elif task not in ["progeval", "progsynth"]:
- inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
- target = spec(inp)
- train_set[task][bin_for(l)].append([[inp], [target]])
- inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
- target = spec(inp)
- test_set[task][bin_for(l)].append([[inp], [target]])
- def to_symbol(i):
- """Covert ids to text."""
- if i == 0: return ""
- if i == 11: return "+"
- if i == 12: return "*"
- return str(i-1)
- def to_id(s):
- """Covert text to ids."""
- if s == "+": return 11
- if s == "*": return 12
- return int(s) + 1
- def get_batch(bin_id, batch_size, data_set, height, offset=None, preset=None):
- """Get a batch of data, training or testing."""
- inputs, targets = [], []
- pad_length = bins[bin_id]
- for b in xrange(batch_size):
- if preset is None:
- elem = random.choice(data_set[bin_id])
- if offset is not None and offset + b < len(data_set[bin_id]):
- elem = data_set[bin_id][offset + b]
- else:
- elem = preset
- inpt, targett, inpl, targetl = elem[0], elem[1], [], []
- for inp in inpt:
- inpl.append(inp + [0 for _ in xrange(pad_length - len(inp))])
- if len(inpl) == 1:
- for _ in xrange(height - 1):
- inpl.append([0 for _ in xrange(pad_length)])
- for target in targett:
- targetl.append(target + [0 for _ in xrange(pad_length - len(target))])
- inputs.append(inpl)
- targets.append(targetl)
- res_input = np.array(inputs, dtype=np.int32)
- res_target = np.array(targets, dtype=np.int32)
- assert list(res_input.shape) == [batch_size, height, pad_length]
- assert list(res_target.shape) == [batch_size, 1, pad_length]
- return res_input, res_target
- def print_out(s, newline=True):
- """Print a message out and log it to file."""
- if log_filename:
- try:
- with tf.gfile.GFile(log_filename, mode="a") as f:
- f.write(s + ("\n" if newline else ""))
- # pylint: disable=bare-except
- except:
- sys.stderr.write("Error appending to %s\n" % log_filename)
- sys.stdout.write(s + ("\n" if newline else ""))
- sys.stdout.flush()
- def decode(output):
- return [np.argmax(o, axis=1) for o in output]
- def accuracy(inpt_t, output, target_t, batch_size, nprint,
- beam_out=None, beam_scores=None):
- """Calculate output accuracy given target."""
- assert nprint < batch_size + 1
- inpt = []
- for h in xrange(inpt_t.shape[1]):
- inpt.extend([inpt_t[:, h, l] for l in xrange(inpt_t.shape[2])])
- target = [target_t[:, 0, l] for l in xrange(target_t.shape[2])]
- def tok(i):
- if rev_vocab and i < len(rev_vocab):
- return rev_vocab[i]
- return str(i - 1)
- def task_print(inp, output, target):
- stop_bound = 0
- print_len = 0
- while print_len < len(target) and target[print_len] > stop_bound:
- print_len += 1
- print_out(" i: " + " ".join([tok(i) for i in inp if i > 0]))
- print_out(" o: " +
- " ".join([tok(output[l]) for l in xrange(print_len)]))
- print_out(" t: " +
- " ".join([tok(target[l]) for l in xrange(print_len)]))
- decoded_target = target
- decoded_output = decode(output)
- # Use beam output if given and score is high enough.
- if beam_out is not None:
- for b in xrange(batch_size):
- if beam_scores[b] >= 10.0:
- for l in xrange(min(len(decoded_output), beam_out.shape[2])):
- decoded_output[l][b] = int(beam_out[b, 0, l])
- total = 0
- errors = 0
- seq = [0 for b in xrange(batch_size)]
- for l in xrange(len(decoded_output)):
- for b in xrange(batch_size):
- if decoded_target[l][b] > 0:
- total += 1
- if decoded_output[l][b] != decoded_target[l][b]:
- seq[b] = 1
- errors += 1
- e = 0 # Previous error index
- for _ in xrange(min(nprint, sum(seq))):
- while seq[e] == 0:
- e += 1
- task_print([inpt[l][e] for l in xrange(len(inpt))],
- [decoded_output[l][e] for l in xrange(len(decoded_target))],
- [decoded_target[l][e] for l in xrange(len(decoded_target))])
- e += 1
- for b in xrange(nprint - errors):
- task_print([inpt[l][b] for l in xrange(len(inpt))],
- [decoded_output[l][b] for l in xrange(len(decoded_target))],
- [decoded_target[l][b] for l in xrange(len(decoded_target))])
- return errors, total, sum(seq)
- def safe_exp(x):
- perp = 10000
- x = float(x)
- if x < 100: perp = math.exp(x)
- if perp > 10000: return 10000
- return perp
|