Browse Source

added dynamic RNN

Aymeric Damien 8 years ago
parent
commit
4cce9f6690
2 changed files with 196 additions and 0 deletions
  1. 1 0
      README.md
  2. 195 0
      examples/3_NeuralNetworks/dynamic_rnn.py

+ 1 - 0
README.md

@@ -23,6 +23,7 @@ It is suitable for beginners who want to find clear and concise examples about T
 - Convolutional Neural Network ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/convolutional_network.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/convolutional_network.py))
 - Recurrent Neural Network (LSTM) ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/recurrent_network.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py))
 - Bidirectional Recurrent Neural Network (LSTM) ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/bidirectional_rnn.py))
+- Dynamic Recurrent Neural Network (LSTM) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/dynamic_rnn.py))
 - AutoEncoder ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py))
 
 #### 4 - Utilities

+ 195 - 0
examples/3_NeuralNetworks/dynamic_rnn.py

@@ -0,0 +1,195 @@
+'''
+A Dynamic Reccurent Neural Network (LSTM) implementation example using
+TensorFlow library. This example is using a toy dataset to classify linear
+sequences. The generated sequences have variable length.
+
+Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+
+Author: Aymeric Damien
+Project: https://github.com/aymericdamien/TensorFlow-Examples/
+'''
+
+import tensorflow as tf
+import random
+
+
+# ====================
+#  TOY DATA GENERATOR
+# ====================
+class ToySequenceData(object):
+    """ Generate sequence of data with dynamic length.
+    This class generate samples for training:
+    - Class 0: linear sequences (i.e. [0, 1, 2, 3,...])
+    - Class 1: random sequences (i.e. [1, 3, 10, 7,...])
+
+    NOTICE:
+    We have to pad each sequence to reach 'max_seq_len' for TensorFlow
+    consistency (we cannot feed a numpy array with unconsistent
+    dimensions). The dynamic calculation will then be perform thanks to
+    'seqlen' attribute that records every actual sequence length.
+    """
+    def __init__(self, n_samples=1000, max_seq_len=20, min_seq_len=3,
+                 max_value=1000):
+        self.data = []
+        self.labels = []
+        self.seqlen = []
+        for i in range(n_samples):
+            # Random sequence length
+            len = random.randint(min_seq_len, max_seq_len)
+            # Monitor sequence length for TensorFlow dynamic calculation
+            self.seqlen.append(len)
+            # Add a random or linear int sequence (50% prob)
+            if random.random() < .5:
+                # Generate a linear sequence
+                rand_start = random.randint(0, max_value - len)
+                s = [[float(i)/max_value] for i in
+                     range(rand_start, rand_start + len)]
+                # Pad sequence for dimension consistency
+                s += [[0.] for i in range(max_seq_len - len)]
+                self.data.append(s)
+                self.labels.append([1., 0.])
+            else:
+                # Generate a random sequence
+                s = [[float(random.randint(0, max_value))/max_value]
+                     for i in range(len)]
+                # Pad sequence for dimension consistency
+                s += [[0.] for i in range(max_seq_len - len)]
+                self.data.append(s)
+                self.labels.append([0., 1.])
+        self.batch_id = 0
+
+    def next(self, batch_size):
+        """ Return a batch of data. When dataset end is reached, start over.
+        """
+        if self.batch_id == len(self.data):
+            self.batch_id = 0
+        batch_data = (self.data[self.batch_id:min(self.batch_id +
+                                                  batch_size, len(self.data))])
+        batch_labels = (self.labels[self.batch_id:min(self.batch_id +
+                                                  batch_size, len(self.data))])
+        batch_seqlen = (self.seqlen[self.batch_id:min(self.batch_id +
+                                                  batch_size, len(self.data))])
+        self.batch_id = min(self.batch_id + batch_size, len(self.data))
+        return batch_data, batch_labels, batch_seqlen
+
+
+# ==========
+#   MODEL
+# ==========
+
+# Parameters
+learning_rate = 0.01
+training_iters = 1000000
+batch_size = 128
+display_step = 10
+
+# Network Parameters
+seq_max_len = 20 # Sequence max length
+n_hidden = 64 # hidden layer num of features
+n_classes = 2 # linear sequence or not
+
+trainset = ToySequenceData(n_samples=1000, max_seq_len=seq_max_len)
+testset = ToySequenceData(n_samples=500, max_seq_len=seq_max_len)
+
+# tf Graph input
+x = tf.placeholder("float", [None, seq_max_len, 1])
+y = tf.placeholder("float", [None, n_classes])
+# A placeholder for indicating each sequence length
+seqlen = tf.placeholder(tf.int32, [None])
+
+# Define weights
+weights = {
+    'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
+}
+biases = {
+    'out': tf.Variable(tf.random_normal([n_classes]))
+}
+
+
+def dynamicRNN(x, seqlen, weights, biases):
+
+    # Prepare data shape to match `rnn` function requirements
+    # Current data input shape: (batch_size, n_steps, n_input)
+    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
+
+    # Permuting batch_size and n_steps
+    x = tf.transpose(x, [1, 0, 2])
+    # Reshaping to (n_steps*batch_size, n_input)
+    x = tf.reshape(x, [-1, 1])
+    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
+    x = tf.split(0, seq_max_len, x)
+
+    # Define a lstm cell with tensorflow
+    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
+
+    # Get lstm cell output, providing 'sequence_length' will perform dynamic
+    # calculation.
+    outputs, states = tf.nn.rnn(lstm_cell, x, dtype=tf.float32,
+                                sequence_length=seqlen)
+
+    # When performing dynamic calculation, we must retrieve the last
+    # dynamically computed output, i.e, if a sequence length is 10, we need
+    # to retrieve the 10th output.
+    # However TensorFlow doesn't support advanced indexing yet, so we build
+    # a custom op that for each sample in batch size, get its length and
+    # get the corresponding relevant output.
+
+    # 'outputs' is a list of output at every timestep, we pack them in a Tensor
+    # and change back dimension to [batch_size, n_step, n_input]
+    outputs = tf.pack(outputs)
+    outputs = tf.transpose(outputs, [1, 0, 2])
+
+    # Hack to build the indexing and retrieve the right output.
+    batch_size = tf.shape(outputs)[0]
+    # Start indices for each sample
+    index = tf.range(0, batch_size) * seq_max_len + (seqlen - 1)
+    # Indexing
+    outputs = tf.gather(tf.reshape(outputs, [-1, n_hidden]), index)
+
+    # Linear activation, using rnn inner loop last output
+    return tf.matmul(outputs, weights['out']) + biases['out']
+
+pred = dynamicRNN(x, seqlen, weights, biases)
+
+# Define loss and optimizer
+cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
+optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
+
+# Evaluate model
+correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
+accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
+
+# Initializing the variables
+init = tf.initialize_all_variables()
+
+# Launch the graph
+with tf.Session() as sess:
+    sess.run(init)
+    step = 1
+    # Keep training until reach max iterations
+    while step * batch_size < training_iters:
+        batch_x, batch_y, batch_seqlen = trainset.next(batch_size)
+        # Run optimization op (backprop)
+        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,
+                                       seqlen: batch_seqlen})
+        if step % display_step == 0:
+            # Calculate batch accuracy
+            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y,
+                                                seqlen: batch_seqlen})
+            # Calculate batch loss
+            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y,
+                                             seqlen: batch_seqlen})
+            print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
+                  "{:.6f}".format(loss) + ", Training Accuracy= " + \
+                  "{:.5f}".format(acc)
+        step += 1
+    print "Optimization Finished!"
+
+    # Calculate accuracy for 128 mnist test images
+    test_len = 128
+    test_data = testset.data
+    test_label = testset.labels
+    test_seqlen = testset.seqlen
+    print "Testing Accuracy:", \
+        sess.run(accuracy, feed_dict={x: test_data, y: test_label,
+                                      seqlen: test_seqlen})