|
|
@@ -94,7 +94,8 @@ def _Train(model, data_batcher):
|
|
|
save_summaries_secs=60,
|
|
|
save_model_secs=FLAGS.checkpoint_secs,
|
|
|
global_step=model.global_step)
|
|
|
- sess = sv.prepare_or_wait_for_session()
|
|
|
+ sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
|
|
|
+ allow_soft_placement=True))
|
|
|
running_avg_loss = 0
|
|
|
step = 0
|
|
|
while not sv.should_stop() and step < FLAGS.max_run_steps:
|