Parcourir la source

Swivel: move the rest of the ops to GPU

Vadim Markovtsev il y a 8 ans
Parent
commit
0d1b00b1e9
1 fichiers modifiés avec 33 ajouts et 36 suppressions
  1. 33 36
      swivel/swivel.py

+ 33 - 36
swivel/swivel.py

@@ -207,46 +207,43 @@ class SwivelModel(object):
     sys.stdout.flush()
 
     # ===== CREATE VARIABLES ======
-
-    with tf.device('/cpu:0'):
-      # embeddings
-      self.row_embedding = embeddings_with_init(
-          embedding_dim=config.embedding_size,
-          vocab_size=self.n_rows,
-          name='row_embedding')
-      self.col_embedding = embeddings_with_init(
-          embedding_dim=config.embedding_size,
-          vocab_size=self.n_cols,
-          name='col_embedding')
-      tf.summary.histogram('row_emb', self.row_embedding)
-      tf.summary.histogram('col_emb', self.col_embedding)
-
-      matrix_log_sum = math.log(np.sum(row_sums) + 1)
-      row_bias_init = [math.log(x + 1) for x in row_sums]
-      col_bias_init = [math.log(x + 1) for x in col_sums]
-      self.row_bias = tf.Variable(row_bias_init,
-                                  trainable=config.trainable_bias)
-      self.col_bias = tf.Variable(col_bias_init,
-                                  trainable=config.trainable_bias)
-      tf.summary.histogram('row_bias', self.row_bias)
-      tf.summary.histogram('col_bias', self.col_bias)
+    # embeddings
+    self.row_embedding = embeddings_with_init(
+      embedding_dim=config.embedding_size,
+      vocab_size=self.n_rows,
+      name='row_embedding')
+    self.col_embedding = embeddings_with_init(
+      embedding_dim=config.embedding_size,
+      vocab_size=self.n_cols,
+      name='col_embedding')
+    tf.summary.histogram('row_emb', self.row_embedding)
+    tf.summary.histogram('col_emb', self.col_embedding)
+
+    matrix_log_sum = math.log(np.sum(row_sums) + 1)
+    row_bias_init = [math.log(x + 1) for x in row_sums]
+    col_bias_init = [math.log(x + 1) for x in col_sums]
+    self.row_bias = tf.Variable(
+        row_bias_init, trainable=config.trainable_bias)
+    self.col_bias = tf.Variable(
+        col_bias_init, trainable=config.trainable_bias)
+    tf.summary.histogram('row_bias', self.row_bias)
+    tf.summary.histogram('col_bias', self.col_bias)
 
     # ===== CREATE GRAPH =====
 
     # Get input
-    with tf.device('/cpu:0'):
-      global_row, global_col, count = count_matrix_input(
-          count_matrix_files, config.submatrix_rows, config.submatrix_cols)
-
-      # Fetch embeddings.
-      selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding,
-                                                      global_row)
-      selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding,
-                                                      global_col)
-
-      # Fetch biases.
-      selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
-      selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
+    global_row, global_col, count = count_matrix_input(
+      count_matrix_files, config.submatrix_rows, config.submatrix_cols)
+
+    # Fetch embeddings.
+    selected_row_embedding = tf.nn.embedding_lookup(
+        self.row_embedding, global_row)
+    selected_col_embedding = tf.nn.embedding_lookup(
+        self.col_embedding, global_col)
+
+    # Fetch biases.
+    selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
+    selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
 
     # Multiply the row and column embeddings to generate predictions.
     predictions = tf.matmul(