瀏覽代碼

Add sys.stdout.flush()

蠍ヤ缘 9 年之前
父節點
當前提交
f3144eb061
共有 1 個文件被更改,包括 6 次插入0 次删除
  1. 6 0
      swivel/swivel.py

+ 6 - 0
swivel/swivel.py

@@ -164,6 +164,7 @@ def write_embeddings_to_disk(config, model, sess):
   row_vocab_path = config.input_base_path + '/row_vocab.txt'
   row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
   print 'Writing row embeddings to:', row_embedding_output_path
+  sys.stdout.flush()
   write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
                                  sess, model.row_embedding)
 
@@ -171,6 +172,7 @@ def write_embeddings_to_disk(config, model, sess):
   col_vocab_path = config.input_base_path + '/col_vocab.txt'
   col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
   print 'Writing column embeddings to:', col_embedding_output_path
+  sys.stdout.flush()
   write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
                                  sess, model.col_embedding)
 
@@ -184,6 +186,7 @@ class SwivelModel(object):
 
     # Create paths to input data files
     print 'Reading model from:', config.input_base_path
+    sys.stdout.flush()
     count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
     row_sums_path = config.input_base_path + '/row_sums.txt'
     col_sums_path = config.input_base_path + '/col_sums.txt'
@@ -196,9 +199,11 @@ class SwivelModel(object):
     self.n_cols = len(col_sums)
     print 'Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
         self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
+    sys.stdout.flush()
     self.n_submatrices = (self.n_rows * self.n_cols /
                           (config.submatrix_rows * config.submatrix_cols))
     print 'n_submatrices: %d' % (self.n_submatrices)
+    sys.stdout.flush()
 
     # ===== CREATE VARIABLES ======
 
@@ -320,6 +325,7 @@ def main(_):
               global_step, n_submatrices_to_train,
               100.0 * global_step / n_submatrices_to_train,
               n_steps_between_status_updates / elapsed)
+          sys.stdout.flush()
           t0[0] = time.time()
 
     # Start training threads