program_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utilities for generating program synthesis and evaluation data."""
  16. import contextlib
  17. import sys
  18. import StringIO
  19. import random
  20. import os
  21. class ListType(object):
  22. def __init__(self, arg):
  23. self.arg = arg
  24. def __str__(self):
  25. return "[" + str(self.arg) + "]"
  26. def __eq__(self, other):
  27. if not isinstance(other, ListType):
  28. return False
  29. return self.arg == other.arg
  30. def __hash__(self):
  31. return hash(self.arg)
  32. class VarType(object):
  33. def __init__(self, arg):
  34. self.arg = arg
  35. def __str__(self):
  36. return str(self.arg)
  37. def __eq__(self, other):
  38. if not isinstance(other, VarType):
  39. return False
  40. return self.arg == other.arg
  41. def __hash__(self):
  42. return hash(self.arg)
  43. class FunctionType(object):
  44. def __init__(self, args):
  45. self.args = args
  46. def __str__(self):
  47. return str(self.args[0]) + " -> " + str(self.args[1])
  48. def __eq__(self, other):
  49. if not isinstance(other, FunctionType):
  50. return False
  51. return self.args == other.args
  52. def __hash__(self):
  53. return hash(tuple(self.args))
  54. class Function(object):
  55. def __init__(self, name, arg_types, output_type, fn_arg_types = None):
  56. self.name = name
  57. self.arg_types = arg_types
  58. self.fn_arg_types = fn_arg_types or []
  59. self.output_type = output_type
  60. Null = 100
  61. ## Functions
  62. f_head = Function("c_head", [ListType("Int")], "Int")
  63. def c_head(xs): return xs[0] if len(xs) > 0 else Null
  64. f_last = Function("c_last", [ListType("Int")], "Int")
  65. def c_last(xs): return xs[-1] if len(xs) > 0 else Null
  66. f_take = Function("c_take", ["Int", ListType("Int")], ListType("Int"))
  67. def c_take(n, xs): return xs[:n]
  68. f_drop = Function("c_drop", ["Int", ListType("Int")], ListType("Int"))
  69. def c_drop(n, xs): return xs[n:]
  70. f_access = Function("c_access", ["Int", ListType("Int")], "Int")
  71. def c_access(n, xs): return xs[n] if n >= 0 and len(xs) > n else Null
  72. f_max = Function("c_max", [ListType("Int")], "Int")
  73. def c_max(xs): return max(xs) if len(xs) > 0 else Null
  74. f_min = Function("c_min", [ListType("Int")], "Int")
  75. def c_min(xs): return min(xs) if len(xs) > 0 else Null
  76. f_reverse = Function("c_reverse", [ListType("Int")], ListType("Int"))
  77. def c_reverse(xs): return list(reversed(xs))
  78. f_sort = Function("sorted", [ListType("Int")], ListType("Int"))
  79. # def c_sort(xs): return sorted(xs)
  80. f_sum = Function("sum", [ListType("Int")], "Int")
  81. # def c_sum(xs): return sum(xs)
  82. ## Lambdas
  83. # Int -> Int
  84. def plus_one(x): return x + 1
  85. def minus_one(x): return x - 1
  86. def times_two(x): return x * 2
  87. def neg(x): return x * (-1)
  88. def div_two(x): return int(x/2)
  89. def sq(x): return x**2
  90. def times_three(x): return x * 3
  91. def div_three(x): return int(x/3)
  92. def times_four(x): return x * 4
  93. def div_four(x): return int(x/4)
  94. # Int -> Bool
  95. def pos(x): return x > 0
  96. def neg(x): return x < 0
  97. def even(x): return x%2 == 0
  98. def odd(x): return x%2 == 1
  99. # Int -> Int -> Int
  100. def add(x, y): return x + y
  101. def sub(x, y): return x - y
  102. def mul(x, y): return x * y
  103. # HOFs
  104. f_map = Function("map", [ListType("Int")],
  105. ListType("Int"),
  106. [FunctionType(["Int", "Int"])])
  107. f_filter = Function("filter", [ListType("Int")],
  108. ListType("Int"),
  109. [FunctionType(["Int", "Bool"])])
  110. f_count = Function("c_count", [ListType("Int")],
  111. "Int",
  112. [FunctionType(["Int", "Bool"])])
  113. def c_count(f, xs): return len([x for x in xs if f(x)])
  114. f_zipwith = Function("c_zipwith", [ListType("Int"), ListType("Int")],
  115. ListType("Int"),
  116. [FunctionType(["Int", "Int", "Int"])]) #FIX
  117. def c_zipwith(f, xs, ys): return [f(x, y) for (x, y) in zip(xs, ys)]
  118. f_scan = Function("c_scan", [ListType("Int")],
  119. ListType("Int"),
  120. [FunctionType(["Int", "Int", "Int"])])
  121. def c_scan(f, xs):
  122. out = xs
  123. for i in range(1, len(xs)):
  124. out[i] = f(xs[i], xs[i -1])
  125. return out
  126. @contextlib.contextmanager
  127. def stdoutIO(stdout=None):
  128. old = sys.stdout
  129. if stdout is None:
  130. stdout = StringIO.StringIO()
  131. sys.stdout = stdout
  132. yield stdout
  133. sys.stdout = old
  134. def evaluate(program_str, input_names_to_vals, default="ERROR"):
  135. exec_str = []
  136. for name, val in input_names_to_vals.iteritems():
  137. exec_str += name + " = " + str(val) + "; "
  138. exec_str += program_str
  139. if type(exec_str) is list:
  140. exec_str = "".join(exec_str)
  141. with stdoutIO() as s:
  142. # pylint: disable=bare-except
  143. try:
  144. exec exec_str + " print(out)"
  145. return s.getvalue()[:-1]
  146. except:
  147. return default
  148. # pylint: enable=bare-except
  149. class Statement(object):
  150. """Statement class."""
  151. def __init__(self, fn, output_var, arg_vars, fn_args=None):
  152. self.fn = fn
  153. self.output_var = output_var
  154. self.arg_vars = arg_vars
  155. self.fn_args = fn_args or []
  156. def __str__(self):
  157. return "%s = %s(%s%s%s)"%(self.output_var,
  158. self.fn.name,
  159. ", ".join(self.fn_args),
  160. ", " if self.fn_args else "",
  161. ", ".join(self.arg_vars))
  162. def substitute(self, env):
  163. self.output_var = env.get(self.output_var, self.output_var)
  164. self.arg_vars = [env.get(v, v) for v in self.arg_vars]
  165. class ProgramGrower(object):
  166. """Grow programs."""
  167. def __init__(self, functions, types_to_lambdas):
  168. self.functions = functions
  169. self.types_to_lambdas = types_to_lambdas
  170. def grow_body(self, new_var_name, dependencies, types_to_vars):
  171. """Grow the program body."""
  172. choices = []
  173. for f in self.functions:
  174. if all([a in types_to_vars.keys() for a in f.arg_types]):
  175. choices.append(f)
  176. f = random.choice(choices)
  177. args = []
  178. for t in f.arg_types:
  179. possible_vars = random.choice(types_to_vars[t])
  180. var = random.choice(possible_vars)
  181. args.append(var)
  182. dependencies.setdefault(new_var_name, []).extend(
  183. [var] + (dependencies[var]))
  184. fn_args = [random.choice(self.types_to_lambdas[t]) for t in f.fn_arg_types]
  185. types_to_vars.setdefault(f.output_type, []).append(new_var_name)
  186. return Statement(f, new_var_name, args, fn_args)
  187. def grow(self, program_len, input_types):
  188. """Grow the program."""
  189. var_names = list(reversed(map(chr, range(97, 123))))
  190. dependencies = dict()
  191. types_to_vars = dict()
  192. input_names = []
  193. for t in input_types:
  194. var = var_names.pop()
  195. dependencies[var] = []
  196. types_to_vars.setdefault(t, []).append(var)
  197. input_names.append(var)
  198. statements = []
  199. for _ in range(program_len - 1):
  200. var = var_names.pop()
  201. statements.append(self.grow_body(var, dependencies, types_to_vars))
  202. statements.append(self.grow_body("out", dependencies, types_to_vars))
  203. new_var_names = [c for c in map(chr, range(97, 123))
  204. if c not in input_names]
  205. new_var_names.reverse()
  206. keep_statements = []
  207. env = dict()
  208. for s in statements:
  209. if s.output_var in dependencies["out"]:
  210. keep_statements.append(s)
  211. env[s.output_var] = new_var_names.pop()
  212. if s.output_var == "out":
  213. keep_statements.append(s)
  214. for k in keep_statements:
  215. k.substitute(env)
  216. return Program(input_names, input_types, ";".join(
  217. [str(k) for k in keep_statements]))
  218. class Program(object):
  219. """The program class."""
  220. def __init__(self, input_names, input_types, body):
  221. self.input_names = input_names
  222. self.input_types = input_types
  223. self.body = body
  224. def evaluate(self, inputs):
  225. """Evaluate this program."""
  226. if len(inputs) != len(self.input_names):
  227. raise AssertionError("inputs and input_names have to"
  228. "have the same len. inp: %s , names: %s" %
  229. (str(inputs), str(self.input_names)))
  230. inp_str = ""
  231. for (name, inp) in zip(self.input_names, inputs):
  232. inp_str += name + " = " + str(inp) + "; "
  233. with stdoutIO() as s:
  234. # pylint: disable=exec-used
  235. exec inp_str + self.body + "; print(out)"
  236. # pylint: enable=exec-used
  237. return s.getvalue()[:-1]
  238. def flat_str(self):
  239. out = ""
  240. for s in self.body.split(";"):
  241. out += s + ";"
  242. return out
  243. def __str__(self):
  244. out = ""
  245. for (n, t) in zip(self.input_names, self.input_types):
  246. out += n + " = " + str(t) + "\n"
  247. for s in self.body.split(";"):
  248. out += s + "\n"
  249. return out
  250. prog_vocab = []
  251. prog_rev_vocab = {}
  252. def tokenize(string, tokens=None):
  253. """Tokenize the program string."""
  254. if tokens is None:
  255. tokens = prog_vocab
  256. tokens = sorted(tokens, key=len, reverse=True)
  257. out = []
  258. string = string.strip()
  259. while string:
  260. found = False
  261. for t in tokens:
  262. if string.startswith(t):
  263. out.append(t)
  264. string = string[len(t):]
  265. found = True
  266. break
  267. if not found:
  268. raise ValueError("Couldn't tokenize this: " + string)
  269. string = string.strip()
  270. return out
  271. def clean_up(output, max_val=100):
  272. o = eval(str(output))
  273. if isinstance(o, bool):
  274. return o
  275. if isinstance(o, int):
  276. if o >= 0:
  277. return min(o, max_val)
  278. else:
  279. return max(o, -1 * max_val)
  280. if isinstance(o, list):
  281. return [clean_up(l) for l in o]
  282. def make_vocab():
  283. gen(2, 0)
  284. def gen(max_len, how_many):
  285. """Generate some programs."""
  286. functions = [f_head, f_last, f_take, f_drop, f_access, f_max, f_min,
  287. f_reverse, f_sort, f_sum, f_map, f_filter, f_count, f_zipwith,
  288. f_scan]
  289. types_to_lambdas = {
  290. FunctionType(["Int", "Int"]): ["plus_one", "minus_one", "times_two",
  291. "div_two", "sq", "times_three",
  292. "div_three", "times_four", "div_four"],
  293. FunctionType(["Int", "Bool"]): ["pos", "neg", "even", "odd"],
  294. FunctionType(["Int", "Int", "Int"]): ["add", "sub", "mul"]
  295. }
  296. tokens = []
  297. for f in functions:
  298. tokens.append(f.name)
  299. for v in types_to_lambdas.values():
  300. tokens.extend(v)
  301. tokens.extend(["=", ";", ",", "(", ")", "[", "]", "Int", "out"])
  302. tokens.extend(map(chr, range(97, 123)))
  303. io_tokens = map(str, range(-220, 220))
  304. if not prog_vocab:
  305. prog_vocab.extend(["_PAD", "_EOS"] + tokens + io_tokens)
  306. for i, t in enumerate(prog_vocab):
  307. prog_rev_vocab[t] = i
  308. io_tokens += [",", "[", "]", ")", "(", "None"]
  309. grower = ProgramGrower(functions=functions,
  310. types_to_lambdas=types_to_lambdas)
  311. def mk_inp(l):
  312. return [random.choice(range(-5, 5)) for _ in range(l)]
  313. tar = [ListType("Int")]
  314. inps = [[mk_inp(3)], [mk_inp(5)], [mk_inp(7)], [mk_inp(15)]]
  315. save_prefix = None
  316. outcomes_to_programs = dict()
  317. tried = set()
  318. counter = 0
  319. choices = [0] if max_len == 0 else range(max_len)
  320. while counter < 100 * how_many and len(outcomes_to_programs) < how_many:
  321. counter += 1
  322. length = random.choice(choices)
  323. t = grower.grow(length, tar)
  324. while t in tried:
  325. length = random.choice(choices)
  326. t = grower.grow(length, tar)
  327. # print(t.flat_str())
  328. tried.add(t)
  329. outcomes = [clean_up(t.evaluate(i)) for i in inps]
  330. outcome_str = str(zip(inps, outcomes))
  331. if outcome_str in outcomes_to_programs:
  332. outcomes_to_programs[outcome_str] = min(
  333. [t.flat_str(), outcomes_to_programs[outcome_str]],
  334. key=lambda x: len(tokenize(x, tokens)))
  335. else:
  336. outcomes_to_programs[outcome_str] = t.flat_str()
  337. if counter % 5000 == 0:
  338. print "== proggen: tried: " + str(counter)
  339. print "== proggen: kept: " + str(len(outcomes_to_programs))
  340. if counter % 250000 == 0 and save_prefix is not None:
  341. print "saving..."
  342. save_counter = 0
  343. progfilename = os.path.join(save_prefix, "prog_" + str(counter) + ".txt")
  344. iofilename = os.path.join(save_prefix, "io_" + str(counter) + ".txt")
  345. prog_token_filename = os.path.join(save_prefix,
  346. "prog_tokens_" + str(counter) + ".txt")
  347. io_token_filename = os.path.join(save_prefix,
  348. "io_tokens_" + str(counter) + ".txt")
  349. with open(progfilename, "a+") as fp, \
  350. open(iofilename, "a+") as fi, \
  351. open(prog_token_filename, "a+") as ftp, \
  352. open(io_token_filename, "a+") as fti:
  353. for (o, p) in outcomes_to_programs.iteritems():
  354. save_counter += 1
  355. if save_counter % 500 == 0:
  356. print "saving %d of %d" % (save_counter, len(outcomes_to_programs))
  357. fp.write(p+"\n")
  358. fi.write(o+"\n")
  359. ftp.write(str(tokenize(p, tokens))+"\n")
  360. fti.write(str(tokenize(o, io_tokens))+"\n")
  361. return list(outcomes_to_programs.values())