|
|
@@ -244,10 +244,10 @@ class ShowAndTellModel(object):
|
|
|
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
|
|
|
# modified LSTM in the "Show and Tell" paper has no biases and outputs
|
|
|
# new_c * sigmoid(o).
|
|
|
- lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(
|
|
|
+ lstm_cell = tf.contrib.rnn.BasicLSTMCell(
|
|
|
num_units=self.config.num_lstm_units, state_is_tuple=True)
|
|
|
if self.mode == "train":
|
|
|
- lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
|
|
|
+ lstm_cell = tf.contrib.rnn.DropoutWrapper(
|
|
|
lstm_cell,
|
|
|
input_keep_prob=self.config.lstm_dropout_keep_prob,
|
|
|
output_keep_prob=self.config.lstm_dropout_keep_prob)
|
|
|
@@ -264,13 +264,13 @@ class ShowAndTellModel(object):
|
|
|
if self.mode == "inference":
|
|
|
# In inference mode, use concatenated states for convenient feeding and
|
|
|
# fetching.
|
|
|
- tf.concat(1, initial_state, name="initial_state")
|
|
|
+ tf.concat_v2(initial_state, 1, name="initial_state")
|
|
|
|
|
|
# Placeholder for feeding a batch of concatenated states.
|
|
|
state_feed = tf.placeholder(dtype=tf.float32,
|
|
|
shape=[None, sum(lstm_cell.state_size)],
|
|
|
name="state_feed")
|
|
|
- state_tuple = tf.split(1, 2, state_feed)
|
|
|
+ state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
|
|
|
|
|
|
# Run a single LSTM step.
|
|
|
lstm_outputs, state_tuple = lstm_cell(
|
|
|
@@ -278,7 +278,7 @@ class ShowAndTellModel(object):
|
|
|
state=state_tuple)
|
|
|
|
|
|
# Concatentate the resulting state.
|
|
|
- tf.concat(1, state_tuple, name="state")
|
|
|
+ tf.concat_v2(state_tuple, 1, name="state")
|
|
|
else:
|
|
|
# Run the batch of sequence embeddings through the LSTM.
|
|
|
sequence_length = tf.reduce_sum(self.input_mask, 1)
|
|
|
@@ -307,18 +307,19 @@ class ShowAndTellModel(object):
|
|
|
weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
|
|
|
|
|
|
# Compute losses.
|
|
|
- losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
|
|
|
- batch_loss = tf.div(tf.reduce_sum(tf.mul(losses, weights)),
|
|
|
+ losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
|
|
|
+ logits=logits)
|
|
|
+ batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)),
|
|
|
tf.reduce_sum(weights),
|
|
|
name="batch_loss")
|
|
|
- tf.contrib.losses.add_loss(batch_loss)
|
|
|
- total_loss = tf.contrib.losses.get_total_loss()
|
|
|
+ tf.losses.add_loss(batch_loss)
|
|
|
+ total_loss = tf.losses.get_total_loss()
|
|
|
|
|
|
# Add summaries.
|
|
|
- tf.scalar_summary("batch_loss", batch_loss)
|
|
|
- tf.scalar_summary("total_loss", total_loss)
|
|
|
+ tf.summary.scalar("losses/batch_loss", batch_loss)
|
|
|
+ tf.summary.scalar("losses/total_loss", total_loss)
|
|
|
for var in tf.trainable_variables():
|
|
|
- tf.histogram_summary(var.op.name, var)
|
|
|
+ tf.summary.histogram("parameters/" + var.op.name, var)
|
|
|
|
|
|
self.total_loss = total_loss
|
|
|
self.target_cross_entropy_losses = losses # Used in evaluation.
|