Prechádzať zdrojové kódy

Replace old tf.nn modules with 1.0-compatible versions

Neal Wu 8 rokov pred
rodič
commit
73ae53ac28
1 zmenil súbory, kde vykonal 3 pridanie a 3 odobranie
  1. 3 3
      textsum/seq2seq_attention_model.py

+ 3 - 3
textsum/seq2seq_attention_model.py

@@ -166,7 +166,7 @@ class Seq2SeqAttentionModel(object):
               hps.num_hidden,
               initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
               state_is_tuple=False)
-          (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
+          (emb_encoder_inputs, fw_state, _) = tf.contrib.rnn.static_bidirectional_rnn(
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               sequence_length=article_lens)
       encoder_outputs = emb_encoder_inputs
@@ -200,7 +200,7 @@ class Seq2SeqAttentionModel(object):
         # During decoding, follow up _dec_in_state are fed from beam_search.
         # dec_out_state are stored by beam_search for next step feeding.
         initial_state_attention = (hps.mode == 'decode')
-        decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
+        decoder_outputs, self._dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder(
             emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
             cell, num_heads=1, loop_function=loop_function,
             initial_state_attention=initial_state_attention)
@@ -234,7 +234,7 @@ class Seq2SeqAttentionModel(object):
           self._loss = seq2seq_lib.sampled_sequence_loss(
               decoder_outputs, targets, loss_weights, sampled_loss_func)
         else:
-          self._loss = tf.nn.seq2seq.sequence_loss(
+          self._loss = tf.contrib.legacy_seq2seq.sequence_loss(
               model_outputs, targets, loss_weights)
         tf.summary.scalar('loss', tf.minimum(12.0, self._loss))