浏览代码

add more examples

aymericdamien 7 年之前
父节点
当前提交
cefdb70c32

+ 2 - 0
README.md

@@ -27,6 +27,8 @@ It is suitable for beginners who want to find clear and concise examples about T
 - **Nearest Neighbor** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/2_BasicModels/nearest_neighbor.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py)). Implement Nearest Neighbor algorithm with TensorFlow.
 - **K-Means** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/2_BasicModels/kmeans.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py)). Build a K-Means classifier with TensorFlow.
 - **Random Forest** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/2_BasicModels/random_forest.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/random_forest.py)). Build a Random Forest classifier with TensorFlow.
+- **Gradient Boosted Decision Tree (GBDT)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/2_BasicModels/gradient_boosted_decision_tree.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/gradient_boosted_decision_tree.py)). Build a Gradient Boosted Decision Tree (GBDT) with TensorFlow.
+- **Word2Vec (Word Embedding)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/2_BasicModels/word2vec.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py)). Build a Word Embedding Model (Word2Vec) from Wikipedia data, with TensorFlow.
 
 #### 3 - Neural Networks
 ##### Supervised

+ 41 - 74
examples/2_BasicModels/gradient_boosted_decision_tree.py

@@ -1,6 +1,11 @@
-""" GBDT (Gradient Boosted Decision Tree).
+""" Gradient Boosted Decision Tree (GBDT).
 
+Implement a Gradient Boosted Decision tree with TensorFlow to classify
+handwritten digit images. This example is using the MNIST database of
+handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
 
+Links:
+    [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
 
 Author: Aymeric Damien
 Project: https://github.com/aymericdamien/TensorFlow-Examples/
@@ -8,27 +13,26 @@ Project: https://github.com/aymericdamien/TensorFlow-Examples/
 
 from __future__ import print_function
 
-import numpy as np
 import tensorflow as tf
 from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
-from tensorflow.contrib.boosted_trees.proto import learner_pb2
-from tensorflow.contrib.learn import learn_runner
+from tensorflow.contrib.boosted_trees.proto import learner_pb2 as gbdt_learner
 
-# Ignore all GPUs, tf random forest does not benefit from it.
-# import os
-# os.environ["CUDA_VISIBLE_DEVICES"] = ""
+# Ignore all GPUs (current TF GBDT does not support GPU).
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 # Import MNIST data
+# Set verbosity to display errors only (Remove this line for showing warnings)
+tf.logging.set_verbosity(tf.logging.ERROR)
 from tensorflow.examples.tutorials.mnist import input_data
 mnist = input_data.read_data_sets("/tmp/data/", one_hot=False,
                                   source_url='http://yann.lecun.com/exdb/mnist/')
 
 # Parameters
-log_dir = "/tmp/tf_gbdt"
-num_steps = 500 # Total steps to train
-batch_size = 1024 # The number of samples per batch
+batch_size = 4096 # The number of samples per batch
 num_classes = 10 # The 10 digits
 num_features = 784 # Each image is 28x28 pixels
+max_steps = 10000
 
 # GBDT Parameters
 learning_rate = 0.1
@@ -36,83 +40,46 @@ l1_regul = 0.
 l2_regul = 1.
 examples_per_layer = 1000
 num_trees = 10
-max_depth = 4
+max_depth = 16
 
-
-def get_input_fn(x, y):
-    """Input function over MNIST data."""
-
-    def input_fn():
-        images_batch, labels_batch = tf.train.shuffle_batch(
-                tensors=[x, y],
-                batch_size=batch_size,
-                capacity=batch_size * 10,
-                min_after_dequeue=batch_size * 2,
-                enqueue_many=True,
-                num_threads=4)
-        features_map = {"images": images_batch}
-        return features_map, labels_batch
-
-    return input_fn
-
-
-learner_config = learner_pb2.LearnerConfig()
+# Fill GBDT parameters into the config proto
+learner_config = gbdt_learner.LearnerConfig()
 learner_config.learning_rate_tuner.fixed.learning_rate = learning_rate
-learner_config.num_classes = num_classes
 learner_config.regularization.l1 = l1_regul
 learner_config.regularization.l2 = l2_regul / examples_per_layer
 learner_config.constraints.max_tree_depth = max_depth
-
-growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
+growing_mode = gbdt_learner.LearnerConfig.LAYER_BY_LAYER
 learner_config.growing_mode = growing_mode
 run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
-
 learner_config.multi_class_strategy = (
-    learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
+    gbdt_learner.LearnerConfig.DIAGONAL_HESSIAN)\
 
-# Create a TF Boosted trees estimator that can take in custom loss.
-estimator = GradientBoostedDecisionTreeClassifier(
+# Create a TensorFlor GBDT Estimator
+gbdt_model = GradientBoostedDecisionTreeClassifier(
+    model_dir=None, # No save directory specified
     learner_config=learner_config,
     n_classes=num_classes,
     examples_per_layer=examples_per_layer,
-    model_dir=log_dir,
     num_trees=num_trees,
     center_bias=False,
     config=run_config)
 
-
-def _make_experiment_fn(output_dir):
-  """Creates experiment for gradient boosted decision trees."""
-  train_input_fn = get_input_fn(mnist.train.images,
-                            mnist.train.labels.astype(np.int32))
-  eval_input_fn = get_input_fn(mnist.test.images,
-                           mnist.test.labels.astype(np.int32))
-
-  return tf.contrib.learn.Experiment(
-      estimator=estimator,
-      train_input_fn=train_input_fn,
-      eval_input_fn=eval_input_fn,
-      train_steps=None,
-      eval_steps=1,
-      eval_metrics=None)
-
-# Training
-learn_runner.run(
-      experiment_fn=_make_experiment_fn,
-      output_dir=log_dir,
-      schedule="train_and_evaluate")
-
-
-# Accuracy
-test_input_fn = get_input_fn(
-    mnist.test.images, mnist.test.labels.astype(np.int32))
-results = estimator.predict(x=mnist.test.images)
-
-acc = 0.
-n = 0
-for i, r in enumerate(results):
-    if np.argmax(r['probabilities']) == int(mnist.test.labels[i]):
-        acc += 1
-    n += 1
-
-print(acc / n)
+# Display TF info logs
+tf.logging.set_verbosity(tf.logging.INFO)
+
+# Define the input function for training
+input_fn = tf.estimator.inputs.numpy_input_fn(
+    x={'images': mnist.train.images}, y=mnist.train.labels,
+    batch_size=batch_size, num_epochs=None, shuffle=True)
+# Train the Model
+gbdt_model.fit(input_fn=input_fn, max_steps=max_steps)
+
+# Evaluate the Model
+# Define the input function for evaluating
+input_fn = tf.estimator.inputs.numpy_input_fn(
+    x={'images': mnist.test.images}, y=mnist.test.labels,
+    batch_size=batch_size, shuffle=False)
+# Use the Estimator 'evaluate' method
+e = gbdt_model.evaluate(input_fn=input_fn)
+
+print("Testing Accuracy:", e['accuracy'])

+ 195 - 0
examples/2_BasicModels/word2vec.py

@@ -0,0 +1,195 @@
+""" Word2Vec.
+
+Implement Word2Vec algorithm to compute vector representations of words.
+This example is using a small chunk of Wikipedia articles to train from.
+
+References:
+    - Mikolov, Tomas et al. "Efficient Estimation of Word Representations
+    in Vector Space.", 2013.
+
+Links:
+    - [Word2Vec] https://arxiv.org/pdf/1301.3781.pdf
+
+Author: Aymeric Damien
+Project: https://github.com/aymericdamien/TensorFlow-Examples/
+"""
+from __future__ import division, print_function, absolute_import
+
+import collections
+import os
+import random
+import urllib
+import zipfile
+
+import numpy as np
+import tensorflow as tf
+
+# Training Parameters
+learning_rate = 0.1
+batch_size = 128
+num_steps = 3000000
+display_step = 10000
+eval_step = 200000
+
+# Evaluation Parameters
+eval_words = ['five', 'of', 'going', 'hardware', 'american', 'britain']
+
+# Word2Vec Parameters
+embedding_size = 200 # Dimension of the embedding vector
+max_vocabulary_size = 50000 # Total number of different words in the vocabulary
+min_occurrence = 10 # Remove all words that does not appears at least n times
+skip_window = 3 # How many words to consider left and right
+num_skips = 2 # How many times to reuse an input to generate a label
+num_sampled = 64 # Number of negative examples to sample
+
+
+# Download a small chunk of Wikipedia articles collection
+url = 'http://mattmahoney.net/dc/text8.zip'
+data_path = 'text8.zip'
+if not os.path.exists(data_path):
+    print("Downloading the dataset... (It may take some time)")
+    filename, _ = urllib.urlretrieve(url, data_path)
+    print("Done!")
+# Unzip the dataset file. Text has already been processed
+with zipfile.ZipFile(data_path) as f:
+    text_words = f.read(f.namelist()[0]).lower().split()
+
+# Build the dictionary and replace rare words with UNK token
+count = [('UNK', -1)]
+# Retrieve the most common words
+count.extend(collections.Counter(text_words).most_common(max_vocabulary_size - 1))
+# Remove samples with less than 'min_occurrence' occurrences
+for i in range(len(count) - 1, -1):
+    if count[i][1] < min_occurrence:
+        count.pop(i)
+    else:
+        # The collection is ordered, so stop when 'min_occurrence' is reached
+        break
+# Compute the vocabulary size
+vocabulary_size = len(count)
+# Assign an id to each word
+word2id = dict()
+for i, (word, _)in enumerate(count):
+    word2id[word] = i
+
+data = list()
+unk_count = 0
+for word in text_words:
+    # Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
+    index = word2id.get(word, 0)
+    if index == 0:
+        unk_count += 1
+    data.append(index)
+count[0] = ('UNK', unk_count)
+id2word = dict(zip(word2id.values(), word2id.keys()))
+
+print("Words count:", len(text_words))
+print("Unique words:", len(set(text_words)))
+print("Vocabulary size:", vocabulary_size)
+print("Most common words:", count[:10])
+
+data_index = 0
+# Generate training batch for the skip-gram model
+def next_batch(batch_size, num_skips, skip_window):
+    global data_index
+    assert batch_size % num_skips == 0
+    assert num_skips <= 2 * skip_window
+    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
+    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
+    # get window size (words left and right + current one)
+    span = 2 * skip_window + 1
+    buffer = collections.deque(maxlen=span)
+    if data_index + span > len(data):
+        data_index = 0
+    buffer.extend(data[data_index:data_index + span])
+    data_index += span
+    for i in range(batch_size // num_skips):
+        context_words = [w for w in range(span) if w != skip_window]
+        words_to_use = random.sample(context_words, num_skips)
+        for j, context_word in enumerate(words_to_use):
+            batch[i * num_skips + j] = buffer[skip_window]
+            labels[i * num_skips + j, 0] = buffer[context_word]
+        if data_index == len(data):
+            buffer.extend(data[0:span])
+            data_index = span
+        else:
+            buffer.append(data[data_index])
+            data_index += 1
+    # Backtrack a little bit to avoid skipping words in the end of a batch
+    data_index = (data_index + len(data) - span) % len(data)
+    return batch, labels
+
+
+# Input data
+X = tf.placeholder(tf.int32, shape=[None])
+# Input label
+Y = tf.placeholder(tf.int32, shape=[None, 1])
+
+# Ensure the following ops & var are assigned on CPU
+# (some ops are not compatible on GPU)
+with tf.device('/cpu:0'):
+    # Create the embedding variable (each row represent a word embedding vector)
+    embedding = tf.Variable(tf.random_normal([vocabulary_size, embedding_size]))
+    # Lookup the corresponding embedding vectors for each sample in X
+    X_embed = tf.nn.embedding_lookup(embedding, X)
+
+    # Construct the variables for the NCE loss
+    nce_weights = tf.Variable(tf.random_normal([vocabulary_size, embedding_size]))
+    nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
+
+# Compute the average NCE loss for the batch
+loss_op = tf.reduce_mean(
+    tf.nn.nce_loss(weights=nce_weights,
+                   biases=nce_biases,
+                   labels=Y,
+                   inputs=X_embed,
+                   num_sampled=num_sampled,
+                   num_classes=vocabulary_size))
+
+# Define the optimizer
+optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+train_op = optimizer.minimize(loss_op)
+
+# Evaluation
+# Compute the cosine similarity between input data embedding and every embedding vectors
+X_embed_norm = X_embed / tf.sqrt(tf.reduce_sum(tf.square(X_embed)))
+embedding_norm = embedding / tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keepdims=True))
+cosine_sim_op = tf.matmul(X_embed_norm, embedding_norm, transpose_b=True)
+
+# Initialize the variables (i.e. assign their default value)
+init = tf.global_variables_initializer()
+
+with tf.Session() as sess:
+
+    # Run the initializer
+    sess.run(init)
+
+    # Testing data
+    x_test = np.array([word2id[w] for w in eval_words])
+
+    average_loss = 0
+    for step in xrange(1, num_steps + 1):
+        # Get a new batch of data
+        batch_x, batch_y = next_batch(batch_size, num_skips, skip_window)
+        # Run training op
+        _, loss = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
+        average_loss += loss
+
+        if step % display_step == 0 or step == 1:
+            if step > 1:
+                average_loss /= display_step
+            print("Step " + str(step) + ", Average Loss= " + \
+                  "{:.4f}".format(average_loss))
+            average_loss = 0
+
+        # Evaluation
+        if step % eval_step == 0 or step == 1:
+            print("Evaluation...")
+            sim = sess.run(cosine_sim_op, feed_dict={X: x_test})
+            for i in xrange(len(eval_words)):
+                top_k = 8  # number of nearest neighbors
+                nearest = (-sim[i, :]).argsort()[1:top_k + 1]
+                log_str = '"%s" nearest neighbors:' % eval_words[i]
+                for k in xrange(top_k):
+                    log_str = '%s %s,' % (log_str, id2word[nearest[k]])
+                print(log_str)

+ 1 - 1
examples/3_NeuralNetworks/neural_network.py

@@ -61,7 +61,7 @@ def model_fn(features, labels, mode):
     if mode == tf.estimator.ModeKeys.PREDICT:
         return tf.estimator.EstimatorSpec(mode, predictions=pred_classes)
 
-        # Define loss and optimizer
+    # Define loss and optimizer
     loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
         logits=logits, labels=tf.cast(labels, dtype=tf.int32)))
     optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)