소스 검색

Merge pull request #1104 from vmarkovtsev/patch-3

Update Swivel code to work with Python 3
Chris Waterson 8 년 전
부모
커밋
89bccc6300
1개의 변경된 파일10개의 추가작업 그리고 9개의 파일을 삭제
  1. 10 9
      swivel/swivel.py

+ 10 - 9
swivel/swivel.py

@@ -51,6 +51,7 @@ embeddings are stored in separate files.
 
 """
 
+from __future__ import print_function
 import argparse
 import glob
 import math
@@ -163,7 +164,7 @@ def write_embeddings_to_disk(config, model, sess):
   # Row Embedding
   row_vocab_path = config.input_base_path + '/row_vocab.txt'
   row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
-  print 'Writing row embeddings to:', row_embedding_output_path
+  print('Writing row embeddings to:', row_embedding_output_path)
   sys.stdout.flush()
   write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
                                  sess, model.row_embedding)
@@ -171,7 +172,7 @@ def write_embeddings_to_disk(config, model, sess):
   # Column Embedding
   col_vocab_path = config.input_base_path + '/col_vocab.txt'
   col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
-  print 'Writing column embeddings to:', col_embedding_output_path
+  print('Writing column embeddings to:', col_embedding_output_path)
   sys.stdout.flush()
   write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
                                  sess, model.col_embedding)
@@ -185,7 +186,7 @@ class SwivelModel(object):
     self._config = config
 
     # Create paths to input data files
-    print 'Reading model from:', config.input_base_path
+    print('Reading model from:', config.input_base_path)
     sys.stdout.flush()
     count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
     row_sums_path = config.input_base_path + '/row_sums.txt'
@@ -197,12 +198,12 @@ class SwivelModel(object):
 
     self.n_rows = len(row_sums)
     self.n_cols = len(col_sums)
-    print 'Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
-        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
+    print('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
+        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols))
     sys.stdout.flush()
     self.n_submatrices = (self.n_rows * self.n_cols /
                           (config.submatrix_rows * config.submatrix_cols))
-    print 'n_submatrices: %d' % (self.n_submatrices)
+    print('n_submatrices: %d' % (self.n_submatrices))
     sys.stdout.flush()
 
     # ===== CREATE VARIABLES ======
@@ -316,15 +317,15 @@ def main(_):
     t0 = [time.time()]
 
     def TrainingFn():
-      for _ in range(n_steps_per_thread):
+      for _ in range(int(n_steps_per_thread)):
         _, global_step = sess.run([model.train_op, model.global_step])
         n_steps_between_status_updates = 100
         if (global_step % n_steps_between_status_updates) == 0:
           elapsed = float(time.time() - t0[0])
-          print '%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
+          print('%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
               global_step, n_submatrices_to_train,
               100.0 * global_step / n_submatrices_to_train,
-              n_steps_between_status_updates / elapsed)
+              n_steps_between_status_updates / elapsed))
           sys.stdout.flush()
           t0[0] = time.time()