Prechádzať zdrojové kódy

add the namignizer model (#147)

K. Nathaniel Tucker 9 rokov pred
rodič
commit
76f567df5f

+ 6 - 0
namignizer/.gitignore

@@ -0,0 +1,6 @@
+# Remove the pyc files
+*.pyc
+
+# Ignore the model and the data
+model/
+data/

+ 82 - 0
namignizer/README.md

@@ -0,0 +1,82 @@
+# Namignizer
+
+Use a variation of the [PTB](https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html#recurrent-neural-networks) model to recognize and generate names using the [Kaggle Baby Name Database](https://www.kaggle.com/kaggle/us-baby-names).
+
+### API
+Namignizer is implemented in Tensorflow 0.8r and uses the python package `pandas` for some data processing.
+
+#### How to use
+Download the data from Kaggle and place it in your data directory (or use the small training data provided). The example data looks like so:
+
+```
+Id,Name,Year,Gender,Count
+1,Mary,1880,F,7065
+2,Anna,1880,F,2604
+3,Emma,1880,F,2003
+4,Elizabeth,1880,F,1939
+5,Minnie,1880,F,1746
+6,Margaret,1880,F,1578
+7,Ida,1880,F,1472
+8,Alice,1880,F,1414
+9,Bertha,1880,F,1320
+```
+
+But any data with the two columns: `Name` and `Count` will work.
+
+With the data, we can then train the model:
+
+```python
+train("data/SmallNames.txt", "model/namignizer", SmallConfig)
+```
+
+And you will get the output:
+
+```
+Reading Name data in data/SmallNames.txt
+Epoch: 1 Learning rate: 1.000
+0.090 perplexity: 18.539 speed: 282 lps
+...
+0.890 perplexity: 1.478 speed: 285 lps
+0.990 perplexity: 1.477 speed: 284 lps
+Epoch: 13 Train Perplexity: 1.477
+```
+
+This will as a side effect write model checkpoints to the `model` directory. With this you will be able to determine the perplexity your model will give you for any arbitrary set of names like so:
+
+```python
+namignize(["mary", "ida", "gazorpazorp", "houyhnhnms", "bob"],
+  tf.train.latest_checkpoint("model"), SmallConfig)
+```
+You will provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a perplexity output for each name like so:
+
+```
+Name mary gives us a perplexity of 1.03105580807
+Name ida gives us a perplexity of 1.07770049572
+Name gazorpazorp gives us a perplexity of 175.940353394
+Name houyhnhnms gives us a perplexity of 9.53870773315
+Name bob gives us a perplexity of 6.03938627243
+```
+
+Finally, you will also be able generate names using the model like so:
+
+```python
+namignator(tf.train.latest_checkpoint("model"), SmallConfig)
+```
+
+Again, you will need to provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a single generated name. Examples of output that I got when using the provided data are:
+
+```
+['b', 'e', 'r', 't', 'h', 'a', '`']
+['m', 'a', 'r', 'y', '`']
+['a', 'n', 'n', 'a', '`']
+['m', 'a', 'r', 'y', '`']
+['b', 'e', 'r', 't', 'h', 'a', '`']
+['a', 'n', 'n', 'a', '`']
+['e', 'l', 'i', 'z', 'a', 'b', 'e', 't', 'h', '`']
+```
+
+Notice that each name ends with a backtick. This marks the end of the name.
+
+### Contact Info
+
+Feel free to reach out to me at knt(at google) or k.nathaniel.tucker(at gmail)

+ 119 - 0
namignizer/data_utils.py

@@ -0,0 +1,119 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for parsing Kaggle baby names files."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+import numpy as np
+import tensorflow as tf
+import pandas as pd
+
+# the default end of name rep will be zero
+_EON = 0
+
+
+def read_names(names_path):
+    """read data from downloaded file. See SmallNames.txt for example format
+    or go to https://www.kaggle.com/kaggle/us-baby-names for full lists
+
+    Args:
+        names_path: path to the csv file similar to the example type
+    Returns:
+        Dataset: a namedtuple of two elements: deduped names and their associated
+            counts. The names contain only 26 chars and are all lower case
+    """
+    names_data = pd.read_csv(names_path)
+    names_data.Name = names_data.Name.str.lower()
+
+    name_data = names_data.groupby(by=["Name"])["Count"].sum()
+    name_counts = np.array(name_data.tolist())
+    names_deduped = np.array(name_data.index.tolist())
+
+    Dataset = collections.namedtuple('Dataset', ['Name', 'Count'])
+    return Dataset(names_deduped, name_counts)
+
+
+def _letter_to_number(letter):
+    """converts letters to numbers between 1 and 27"""
+    # ord of lower case 'a' is 97
+    return ord(letter) - 96
+
+
+def namignizer_iterator(names, counts, batch_size, num_steps, epoch_size):
+    """Takes a list of names and counts like those output from read_names, and
+    makes an iterator yielding a batch_size by num_steps array of random names
+    separated by an end of name token. The names are choosen randomly according
+    to their counts. The batch may end mid-name
+
+    Args:
+        names: a set of lowercase names composed of 26 characters
+        counts: a list of the frequency of those names
+        batch_size: int
+        num_steps: int
+        epoch_size: number of batches to yield
+    Yields:
+        (x, y): a batch_size by num_steps array of ints representing letters, where
+            x will be the input and y will be the target
+    """
+    name_distribution = counts / counts.sum()
+
+    for i in range(epoch_size):
+        data = np.zeros(batch_size * num_steps + 1)
+        samples = np.random.choice(names, size=batch_size * num_steps // 2,
+                                   replace=True, p=name_distribution)
+
+        data_index = 0
+        for sample in samples:
+            if data_index >= batch_size * num_steps:
+                break
+            for letter in map(_letter_to_number, sample) + [_EON]:
+                if data_index >= batch_size * num_steps:
+                    break
+                data[data_index] = letter
+                data_index += 1
+
+        x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
+        y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))
+
+        yield (x, y)
+
+
+def name_to_batch(name, batch_size, num_steps):
+    """ Takes a single name and fills a batch with it
+
+    Args:
+        name: lowercase composed of 26 characters
+        batch_size: int
+        num_steps: int
+    Returns:
+        x, y: a batch_size by num_steps array of ints representing letters, where
+            x will be the input and y will be the target. The array is filled up
+            to the length of the string, the rest is filled with zeros
+    """
+    data = np.zeros(batch_size * num_steps + 1)
+
+    data_index = 0
+    for letter in map(_letter_to_number, name) + [_EON]:
+        data[data_index] = letter
+        data_index += 1
+
+    x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
+    y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))
+
+    return x, y

+ 133 - 0
namignizer/model.py

@@ -0,0 +1,133 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RNN model with embeddings"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class NamignizerModel(object):
+    """The Namignizer model ~ strongly based on PTB"""
+
+    def __init__(self, is_training, config):
+        self.batch_size = batch_size = config.batch_size
+        self.num_steps = num_steps = config.num_steps
+        size = config.hidden_size
+        # will always be 27
+        vocab_size = config.vocab_size
+
+        # placeholders for inputs
+        self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
+        self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
+        # weights for the loss function
+        self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
+
+        # lstm for our RNN cell (GRU supported too)
+        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0)
+        if is_training and config.keep_prob < 1:
+            lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
+                lstm_cell, output_keep_prob=config.keep_prob)
+        cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
+
+        self._initial_state = cell.zero_state(batch_size, tf.float32)
+
+        with tf.device("/cpu:0"):
+            embedding = tf.get_variable("embedding", [vocab_size, size])
+            inputs = tf.nn.embedding_lookup(embedding, self._input_data)
+
+        if is_training and config.keep_prob < 1:
+            inputs = tf.nn.dropout(inputs, config.keep_prob)
+
+        outputs = []
+        state = self._initial_state
+        with tf.variable_scope("RNN"):
+            for time_step in range(num_steps):
+                if time_step > 0:
+                    tf.get_variable_scope().reuse_variables()
+                (cell_output, state) = cell(inputs[:, time_step, :], state)
+                outputs.append(cell_output)
+
+        output = tf.reshape(tf.concat(1, outputs), [-1, size])
+        softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
+        softmax_b = tf.get_variable("softmax_b", [vocab_size])
+        logits = tf.matmul(output, softmax_w) + softmax_b
+        loss = tf.nn.seq2seq.sequence_loss_by_example(
+            [logits],
+            [tf.reshape(self._targets, [-1])],
+            [self._weights])
+        self._loss = loss
+        self._cost = cost = tf.reduce_sum(loss) / batch_size
+        self._final_state = state
+
+        # probabilities of each letter
+        self._activations = tf.nn.softmax(logits)
+
+        # ability to save the model
+        self.saver = tf.train.Saver(tf.all_variables())
+
+        if not is_training:
+            return
+
+        self._lr = tf.Variable(0.0, trainable=False)
+        tvars = tf.trainable_variables()
+        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
+                                          config.max_grad_norm)
+        optimizer = tf.train.GradientDescentOptimizer(self.lr)
+        self._train_op = optimizer.apply_gradients(zip(grads, tvars))
+
+    def assign_lr(self, session, lr_value):
+        session.run(tf.assign(self.lr, lr_value))
+
+    @property
+    def input_data(self):
+        return self._input_data
+
+    @property
+    def targets(self):
+        return self._targets
+
+    @property
+    def activations(self):
+        return self._activations
+
+    @property
+    def weights(self):
+        return self._weights
+
+    @property
+    def initial_state(self):
+        return self._initial_state
+
+    @property
+    def cost(self):
+        return self._cost
+
+    @property
+    def loss(self):
+        return self._loss
+
+    @property
+    def final_state(self):
+        return self._final_state
+
+    @property
+    def lr(self):
+        return self._lr
+
+    @property
+    def train_op(self):
+        return self._train_op

+ 262 - 0
namignizer/names.py

@@ -0,0 +1,262 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A library showing off sequence recognition and generation with the simple
+example of names.
+
+We use recurrent neural nets to learn complex functions able to recogize and
+generate sequences of a given form. This can be used for natural language
+syntax recognition, dynamically generating maps or puzzles and of course
+baby name generation.
+
+Before using this module, it is recommended to read the Tensorflow tutorial on
+recurrent neural nets, as it explains the basic concepts of this model, and
+will show off another module, the PTB module on which this model bases itself.
+
+Here is an overview of the functions available in this module:
+
+* RNN Module for sequence functions based on PTB
+
+* Name recognition specifically for recognizing names, but can be adapted to
+    recognizing sequence patterns
+
+* Name generations specifically for generating names, but can be adapted to
+    generating arbitrary sequence patterns
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import tensorflow as tf
+import numpy as np
+
+from model import NamignizerModel
+import data_utils
+
+
+class SmallConfig(object):
+    """Small config."""
+    init_scale = 0.1
+    learning_rate = 1.0
+    max_grad_norm = 5
+    num_layers = 2
+    num_steps = 20
+    hidden_size = 200
+    max_epoch = 4
+    max_max_epoch = 13
+    keep_prob = 1.0
+    lr_decay = 0.5
+    batch_size = 20
+    vocab_size = 27
+    epoch_size = 100
+
+
+class LargeConfig(object):
+    """Medium config."""
+    init_scale = 0.05
+    learning_rate = 1.0
+    max_grad_norm = 5
+    num_layers = 2
+    num_steps = 35
+    hidden_size = 650
+    max_epoch = 6
+    max_max_epoch = 39
+    keep_prob = 0.5
+    lr_decay = 0.8
+    batch_size = 20
+    vocab_size = 27
+    epoch_size = 100
+
+
+class TestConfig(object):
+    """Tiny config, for testing."""
+    init_scale = 0.1
+    learning_rate = 1.0
+    max_grad_norm = 1
+    num_layers = 1
+    num_steps = 2
+    hidden_size = 2
+    max_epoch = 1
+    max_max_epoch = 1
+    keep_prob = 1.0
+    lr_decay = 0.5
+    batch_size = 20
+    vocab_size = 27
+    epoch_size = 100
+
+
+def run_epoch(session, m, names, counts, epoch_size, eval_op, verbose=False):
+    """Runs the model on the given data for one epoch
+
+    Args:
+        session: the tf session holding the model graph
+        m: an instance of the NamignizerModel
+        names: a set of lowercase names of 26 characters
+        counts: a list of the frequency of the above names
+        epoch_size: the number of batches to run
+        eval_op: whether to change the params or not, and how to do it
+    Kwargs:
+        verbose: whether to print out state of training during the epoch
+    Returns:
+        cost: the average cost during the last stage of the epoch
+    """
+    start_time = time.time()
+    costs = 0.0
+    iters = 0
+    for step, (x, y) in enumerate(data_utils.namignizer_iterator(names, counts,
+                                                                 m.batch_size, m.num_steps, epoch_size)):
+
+        cost, _ = session.run([m.cost, eval_op],
+                              {m.input_data: x,
+                               m.targets: y,
+                               m.initial_state: m.initial_state.eval(),
+                               m.weights: np.ones(m.batch_size * m.num_steps)})
+        costs += cost
+        iters += m.num_steps
+
+        if verbose and step % (epoch_size // 10) == 9:
+            print("%.3f perplexity: %.3f speed: %.0f lps" %
+                  (step * 1.0 / epoch_size, np.exp(costs / iters),
+                   iters * m.batch_size / (time.time() - start_time)))
+
+        if step >= epoch_size:
+            break
+
+    return np.exp(costs / iters)
+
+
+def train(data_dir, checkpoint_path, config):
+    """Trains the model with the given data
+
+    Args:
+        data_dir: path to the data for the model (see data_utils for data
+            format)
+        checkpoint_path: the path to save the trained model checkpoints
+        config: one of the above configs that specify the model and how it
+            should be run and trained
+    Returns:
+        None
+    """
+    # Prepare Name data.
+    print("Reading Name data in %s" % data_dir)
+    names, counts = data_utils.read_names(data_dir)
+
+    with tf.Graph().as_default(), tf.Session() as session:
+        initializer = tf.random_uniform_initializer(-config.init_scale,
+                                                    config.init_scale)
+        with tf.variable_scope("model", reuse=None, initializer=initializer):
+            m = NamignizerModel(is_training=True, config=config)
+
+        tf.initialize_all_variables().run()
+
+        for i in range(config.max_max_epoch):
+            lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
+            m.assign_lr(session, config.learning_rate * lr_decay)
+
+            print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
+            train_perplexity = run_epoch(session, m, names, counts, config.epoch_size, m.train_op,
+                                         verbose=True)
+            print("Epoch: %d Train Perplexity: %.3f" %
+                  (i + 1, train_perplexity))
+
+            m.saver.save(session, checkpoint_path, global_step=i)
+
+
+def namignize(names, checkpoint_path, config):
+    """Recognizes names and prints the Perplexity of the model for each names
+    in the list
+
+    Args:
+        names: a list of names in the model format
+        checkpoint_path: the path to restore the trained model from, should not
+            include the model name, just the path to
+        config: one of the above configs that specify the model and how it
+            should be run and trained
+    Returns:
+        None
+    """
+    with tf.Graph().as_default(), tf.Session() as session:
+
+        with tf.variable_scope("model"):
+            m = NamignizerModel(is_training=False, config=config)
+
+        m.saver.restore(session, checkpoint_path)
+
+        for name in names:
+            x, y = data_utils.name_to_batch(name, m.batch_size, m.num_steps)
+
+            cost, loss, _ = session.run([m.cost, m.loss, tf.no_op()],
+                                  {m.input_data: x,
+                                   m.targets: y,
+                                   m.initial_state: m.initial_state.eval(),
+                                   m.weights: np.concatenate((
+                                       np.ones(len(name)), np.zeros(m.batch_size * m.num_steps - len(name))))})
+
+            print("Name {} gives us a perplexity of {}".format(
+                name, np.exp(cost)))
+
+
+def namignator(checkpoint_path, config):
+    """Generates names randomly according to a given model
+
+    Args:
+        checkpoint_path: the path to restore the trained model from, should not
+            include the model name, just the path to
+        config: one of the above configs that specify the model and how it
+            should be run and trained
+    Returns:
+        None
+    """
+    # mutate the config to become a name generator config
+    config.num_steps = 1
+    config.batch_size = 1
+
+    with tf.Graph().as_default(), tf.Session() as session:
+
+        with tf.variable_scope("model"):
+            m = NamignizerModel(is_training=False, config=config)
+
+        m.saver.restore(session, checkpoint_path)
+
+        activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()],
+                                                  {m.input_data: np.zeros((1, 1)),
+                                                   m.targets: np.zeros((1, 1)),
+                                                   m.initial_state: m.initial_state.eval(),
+                                                   m.weights: np.ones(1)})
+
+        # sample from our softmax activations
+        next_letter = np.random.choice(27, p=activations[0])
+        name = [next_letter]
+        while next_letter != 0:
+            activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()],
+                                                      {m.input_data: [[next_letter]],
+                                                       m.targets: np.zeros((1, 1)),
+                                                       m.initial_state: final_state,
+                                                       m.weights: np.ones(1)})
+
+            next_letter = np.random.choice(27, p=activations[0])
+            name += [next_letter]
+
+        print(map(lambda x: chr(x + 96), name))
+
+
+if __name__ == "__main__":
+    # train("data/SmallNames.txt", "model/namignizer", SmallConfig)
+
+    # namignize(["mary", "ida", "gazorbazorb", "mmmhmm", "bob"],
+    #     tf.train.latest_checkpoint("model"), SmallConfig)
+
+    # namignator(tf.train.latest_checkpoint("model"), SmallConfig)