Explorar o código

Add text summarization model to tensorflow/models.

Xin Pan %!s(int64=9) %!d(string=hai) anos
pai
achega
56a05f688e

+ 64 - 0
textsum/BUILD

@@ -0,0 +1,64 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//textsum/...",
+    ],
+)
+
+py_library(
+    name = "seq2seq_attention_model",
+    srcs = ["seq2seq_attention_model.py"],
+    deps = [
+        ":seq2seq_lib",
+    ],
+)
+
+py_library(
+    name = "seq2seq_lib",
+    srcs = ["seq2seq_lib.py"],
+)
+
+py_binary(
+    name = "seq2seq_attention",
+    srcs = ["seq2seq_attention.py"],
+    deps = [
+        ":batch_reader",
+        ":data",
+        ":seq2seq_attention_decode",
+        ":seq2seq_attention_model",
+    ],
+)
+
+py_library(
+    name = "batch_reader",
+    srcs = ["batch_reader.py"],
+    deps = [
+        ":data",
+        ":seq2seq_attention_model",
+    ],
+)
+
+py_library(
+    name = "beam_search",
+    srcs = ["beam_search.py"],
+)
+
+py_library(
+    name = "seq2seq_attention_decode",
+    srcs = ["seq2seq_attention_decode.py"],
+    deps = [
+        ":beam_search",
+        ":data",
+    ],
+)
+
+py_library(
+    name = "data",
+    srcs = ["data.py"],
+)

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 162 - 0
textsum/README.md


+ 263 - 0
textsum/batch_reader.py

@@ -0,0 +1,263 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Batch reader to seq2seq attention model, with bucketing support."""
+
+from collections import namedtuple
+import Queue
+from random import shuffle
+from threading import Thread
+import time
+
+import numpy as np
+import tensorflow as tf
+
+import data
+
+ModelInput = namedtuple('ModelInput',
+                        'enc_input dec_input target enc_len dec_len '
+                        'origin_article origin_abstract')
+
+BUCKET_CACHE_BATCH = 100
+QUEUE_NUM_BATCH = 100
+
+
+class Batcher(object):
+  """Batch reader with shuffling and bucketing support."""
+
+  def __init__(self, data_path, vocab, hps,
+               article_key, abstract_key, max_article_sentences,
+               max_abstract_sentences, bucketing=True, truncate_input=False):
+    """Batcher constructor.
+
+    Args:
+      data_path: tf.Example filepattern.
+      vocab: Vocabulary.
+      hps: Seq2SeqAttention model hyperparameters.
+      article_key: article feature key in tf.Example.
+      abstract_key: abstract feature key in tf.Example.
+      max_article_sentences: Max number of sentences used from article.
+      max_abstract_sentences: Max number of sentences used from abstract.
+      bucketing: Whether bucket articles of similar length into the same batch.
+      truncate_input: Whether to truncate input that is too long. Alternative is
+        to discard such examples.
+    """
+    self._data_path = data_path
+    self._vocab = vocab
+    self._hps = hps
+    self._article_key = article_key
+    self._abstract_key = abstract_key
+    self._max_article_sentences = max_article_sentences
+    self._max_abstract_sentences = max_abstract_sentences
+    self._bucketing = bucketing
+    self._truncate_input = truncate_input
+    self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
+    self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
+    self._input_threads = []
+    for _ in xrange(16):
+      self._input_threads.append(Thread(target=self._FillInputQueue))
+      self._input_threads[-1].daemon = True
+      self._input_threads[-1].start()
+    self._bucketing_threads = []
+    for _ in xrange(4):
+      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
+      self._bucketing_threads[-1].daemon = True
+      self._bucketing_threads[-1].start()
+
+    self._watch_thread = Thread(target=self._WatchThreads)
+    self._watch_thread.daemon = True
+    self._watch_thread.start()
+
+  def NextBatch(self):
+    """Returns a batch of inputs for seq2seq attention model.
+
+    Returns:
+      enc_batch: A batch of encoder inputs [batch_size, hps.enc_timestamps].
+      dec_batch: A batch of decoder inputs [batch_size, hps.dec_timestamps].
+      target_batch: A batch of targets [batch_size, hps.dec_timestamps].
+      enc_input_len: encoder input lengths of the batch.
+      dec_input_len: decoder input lengths of the batch.
+      loss_weights: weights for loss function, 1 if not padded, 0 if padded.
+      origin_articles: original article words.
+      origin_abstracts: original abstract words.
+    """
+    enc_batch = np.zeros(
+        (self._hps.batch_size, self._hps.enc_timesteps), dtype=np.int32)
+    enc_input_lens = np.zeros(
+        (self._hps.batch_size), dtype=np.int32)
+    dec_batch = np.zeros(
+        (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
+    dec_output_lens = np.zeros(
+        (self._hps.batch_size), dtype=np.int32)
+    target_batch = np.zeros(
+        (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
+    loss_weights = np.zeros(
+        (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.float32)
+    origin_articles = ['None'] * self._hps.batch_size
+    origin_abstracts = ['None'] * self._hps.batch_size
+
+    buckets = self._bucket_input_queue.get()
+    for i in xrange(self._hps.batch_size):
+      (enc_inputs, dec_inputs, targets, enc_input_len, dec_output_len,
+       article, abstract) = buckets[i]
+
+      origin_articles[i] = article
+      origin_abstracts[i] = abstract
+      enc_input_lens[i] = enc_input_len
+      dec_output_lens[i] = dec_output_len
+      enc_batch[i, :] = enc_inputs[:]
+      dec_batch[i, :] = dec_inputs[:]
+      target_batch[i, :] = targets[:]
+      for j in xrange(dec_output_len):
+        loss_weights[i][j] = 1
+    return (enc_batch, dec_batch, target_batch, enc_input_lens, dec_output_lens,
+            loss_weights, origin_articles, origin_abstracts)
+
+  def _FillInputQueue(self):
+    """Fill input queue with ModelInput."""
+    start_id = self._vocab.WordToId(data.SENTENCE_START)
+    end_id = self._vocab.WordToId(data.SENTENCE_END)
+    pad_id = self._vocab.WordToId(data.PAD_TOKEN)
+    input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
+    while True:
+      (article, abstract) = input_gen.next()
+      article_sentences = [sent.strip() for sent in
+                           data.ToSentences(article, include_token=False)]
+      abstract_sentences = [sent.strip() for sent in
+                            data.ToSentences(abstract, include_token=False)]
+
+      enc_inputs = []
+      # Use the <s> as the <GO> symbol for decoder inputs.
+      dec_inputs = [start_id]
+
+      # Convert first N sentences to word IDs, stripping existing <s> and </s>.
+      for i in xrange(min(self._max_article_sentences,
+                          len(article_sentences))):
+        enc_inputs += data.GetWordIds(article_sentences[i], self._vocab)
+      for i in xrange(min(self._max_abstract_sentences,
+                          len(abstract_sentences))):
+        dec_inputs += data.GetWordIds(abstract_sentences[i], self._vocab)
+
+      # Filter out too-short input
+      if (len(enc_inputs) < self._hps.min_input_len or
+          len(dec_inputs) < self._hps.min_input_len):
+        tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
+                           len(enc_inputs), len(dec_inputs))
+        continue
+
+      # If we're not truncating input, throw out too-long input
+      if not self._truncate_input:
+        if (len(enc_inputs) > self._hps.enc_timesteps or
+            len(dec_inputs) > self._hps.dec_timesteps):
+          tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
+                             len(enc_inputs), len(dec_inputs))
+          continue
+      # If we are truncating input, do so if necessary
+      else:
+        if len(enc_inputs) > self._hps.enc_timesteps:
+          enc_inputs = enc_inputs[:self._hps.enc_timesteps]
+        if len(dec_inputs) > self._hps.dec_timesteps:
+          dec_inputs = dec_inputs[:self._hps.dec_timesteps]
+
+      # targets is dec_inputs without <s> at beginning, plus </s> at end
+      targets = dec_inputs[1:]
+      targets.append(end_id)
+
+      # Now len(enc_inputs) should be <= enc_timesteps, and
+      # len(targets) = len(dec_inputs) should be <= dec_timesteps
+
+      enc_input_len = len(enc_inputs)
+      dec_output_len = len(targets)
+
+      # Pad if necessary
+      while len(enc_inputs) < self._hps.enc_timesteps:
+        enc_inputs.append(pad_id)
+      while len(dec_inputs) < self._hps.dec_timesteps:
+        dec_inputs.append(end_id)
+      while len(targets) < self._hps.dec_timesteps:
+        targets.append(end_id)
+
+      element = ModelInput(enc_inputs, dec_inputs, targets, enc_input_len,
+                           dec_output_len, ' '.join(article_sentences),
+                           ' '.join(abstract_sentences))
+      self._input_queue.put(element)
+
+  def _FillBucketInputQueue(self):
+    """Fill bucketed batches into the bucket_input_queue."""
+    while True:
+      inputs = []
+      for _ in xrange(self._hps.batch_size * BUCKET_CACHE_BATCH):
+        inputs.append(self._input_queue.get())
+      if self._bucketing:
+        inputs = sorted(inputs, key=lambda inp: inp.enc_len)
+
+      batches = []
+      for i in xrange(0, len(inputs), self._hps.batch_size):
+        batches.append(inputs[i:i+self._hps.batch_size])
+      shuffle(batches)
+      for b in batches:
+        self._bucket_input_queue.put(b)
+
+  def _WatchThreads(self):
+    """Watch the daemon input threads and restart if dead."""
+    while True:
+      time.sleep(60)
+      input_threads = []
+      for t in self._input_threads:
+        if t.is_alive():
+          input_threads.append(t)
+        else:
+          tf.logging.error('Found input thread dead.')
+          new_t = Thread(target=self._FillInputQueue)
+          input_threads.append(new_t)
+          input_threads[-1].daemon = True
+          input_threads[-1].start()
+      self._input_threads = input_threads
+
+      bucketing_threads = []
+      for t in self._bucketing_threads:
+        if t.is_alive():
+          bucketing_threads.append(t)
+        else:
+          tf.logging.error('Found bucketing thread dead.')
+          new_t = Thread(target=self._FillBucketInputQueue)
+          bucketing_threads.append(new_t)
+          bucketing_threads[-1].daemon = True
+          bucketing_threads[-1].start()
+      self._bucketing_threads = bucketing_threads
+
+  def _TextGenerator(self, example_gen):
+    """Generates article and abstract text from tf.Example."""
+    while True:
+      e = example_gen.next()
+      try:
+        article_text = self._GetExFeatureText(e, self._article_key)
+        abstract_text = self._GetExFeatureText(e, self._abstract_key)
+      except ValueError:
+        tf.logging.error('Failed to get article or abstract from example')
+        continue
+
+      yield (article_text, abstract_text)
+
+  def _GetExFeatureText(self, ex, key):
+    """Extract text for a feature from td.Example.
+
+    Args:
+      ex: tf.Example.
+      key: key of the feature to be extracted.
+    Returns:
+      feature: a feature text extracted.
+    """
+    return ex.features.feature[key].bytes_list.value[0]

+ 155 - 0
textsum/beam_search.py

@@ -0,0 +1,155 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Beam search module.
+
+Beam search takes the top K results from the model, predicts the K results for
+each of the previous K result, getting K*K results. Pick the top K results from
+K*K results, and start over again until certain number of results are fully
+decoded.
+"""
+
+import tensorflow as tf
+
+FLAGS = tf.flags.FLAGS
+tf.flags.DEFINE_bool('normalize_by_length', True, 'Whether normalize')
+
+
+class Hypothesis(object):
+  """Defines a hypothesis during beam search."""
+
+  def __init__(self, tokens, log_prob, state):
+    """Hypothesis constructor.
+
+    Args:
+      tokens: start tokens for decoding.
+      log_prob: log prob of the start tokens, usually 1.
+      state: decoder initial states.
+    """
+    self.tokens = tokens
+    self.log_prob = log_prob
+    self.state = state
+
+  def Extend(self, token, log_prob, new_state):
+    """Extend the hypothesis with result from latest step.
+
+    Args:
+      token: latest token from decoding.
+      log_prob: log prob of the latest decoded tokens.
+      new_state: decoder output state. Fed to the decoder for next step.
+    Returns:
+      New Hypothesis with the results from latest step.
+    """
+    return Hypothesis(self.tokens + [token], self.log_prob + log_prob,
+                      new_state)
+
+  @property
+  def latest_token(self):
+    return self.tokens[-1]
+
+  def __str__(self):
+    return ('Hypothesis(log prob = %.4f, tokens = %s)' % (self.log_prob,
+                                                          self.tokens))
+
+
+class BeamSearch(object):
+  """Beam search."""
+
+  def __init__(self, model, beam_size, start_token, end_token, max_steps):
+    """Creates BeamSearch object.
+
+    Args:
+      model: Seq2SeqAttentionModel.
+      beam_size: int.
+      start_token: int, id of the token to start decoding with
+      end_token: int, id of the token that completes an hypothesis
+      max_steps: int, upper limit on the size of the hypothesis
+    """
+    self._model = model
+    self._beam_size = beam_size
+    self._start_token = start_token
+    self._end_token = end_token
+    self._max_steps = max_steps
+
+  def BeamSearch(self, sess, enc_inputs, enc_seqlen):
+    """Performs beam search for decoding.
+
+    Args:
+      sess: tf.Session, session
+      enc_inputs: ndarray of shape (enc_length, 1), the document ids to encode
+      enc_seqlen: ndarray of shape (1), the length of the sequnce
+
+    Returns:
+      hyps: list of Hypothesis, the best hypotheses found by beam search,
+          ordered by score
+    """
+
+    # Run the encoder and extract the outputs and final state.
+    enc_top_states, dec_in_state = self._model.encode_top_state(
+        sess, enc_inputs, enc_seqlen)
+    # Replicate the initial states K times for the first step.
+    hyps = [Hypothesis([self._start_token], 0.0, dec_in_state)
+           ] * self._beam_size
+    results = []
+
+    steps = 0
+    while steps < self._max_steps and len(results) < self._beam_size:
+      latest_tokens = [h.latest_token for h in hyps]
+      states = [h.state for h in hyps]
+
+      topk_ids, topk_log_probs, new_states = self._model.decode_topk(
+          sess, latest_tokens, enc_top_states, states)
+      # Extend each hypothesis.
+      all_hyps = []
+      # The first step takes the best K results from first hyps. Following
+      # steps take the best K results from K*K hyps.
+      num_beam_source = 1 if steps == 0 else len(hyps)
+      for i in xrange(num_beam_source):
+        h, ns = hyps[i], new_states[i]
+        for j in xrange(self._beam_size*2):
+          all_hyps.append(h.Extend(topk_ids[i, j], topk_log_probs[i, j], ns))
+
+      # Filter and collect any hypotheses that have the end token.
+      hyps = []
+      for h in self._BestHyps(all_hyps):
+        if h.latest_token == self._end_token:
+          # Pull the hypothesis off the beam if the end token is reached.
+          results.append(h)
+        else:
+          # Otherwise continue to the extend the hypothesis.
+          hyps.append(h)
+        if len(hyps) == self._beam_size or len(results) == self._beam_size:
+          break
+
+      steps += 1
+
+    if steps == self._max_steps:
+      results.extend(hyps)
+
+    return self._BestHyps(results)
+
+  def _BestHyps(self, hyps):
+    """Sort the hyps based on log probs and length.
+
+    Args:
+      hyps: A list of hypothesis.
+    Returns:
+      hyps: A list of sorted hypothesis in reverse log_prob order.
+    """
+    # This length normalization is only effective for the final results.
+    if FLAGS.normalize_by_length:
+      return sorted(hyps, key=lambda h: h.log_prob/len(h.tokens), reverse=True)
+    else:
+      return sorted(hyps, key=lambda h: h.log_prob, reverse=True)

+ 206 - 0
textsum/data.py

@@ -0,0 +1,206 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Data batchers for data described in ..//data_prep/README.md."""
+
+import glob
+import random
+import struct
+import sys
+
+from tensorflow.core.example import example_pb2
+
+
+# Special tokens
+PARAGRAPH_START = '<p>'
+PARAGRAPH_END = '</p>'
+SENTENCE_START = '<s>'
+SENTENCE_END = '</s>'
+UNKNOWN_TOKEN = '<UNK>'
+PAD_TOKEN = '<PAD>'
+DOCUMENT_START = '<d>'
+DOCUMENT_END = '</d>'
+
+
+class Vocab(object):
+  """Vocabulary class for mapping words and ids."""
+
+  def __init__(self, vocab_file, max_size):
+    self._word_to_id = {}
+    self._id_to_word = {}
+    self._count = 0
+
+    with open(vocab_file, 'r') as vocab_f:
+      for line in vocab_f:
+        pieces = line.split()
+        if len(pieces) != 2:
+          sys.stderr.write('Bad line: %s\n' % line)
+          continue
+        if pieces[0] in self._word_to_id:
+          raise ValueError('Duplicated word: %s.' % pieces[0])
+        self._word_to_id[pieces[0]] = self._count
+        self._id_to_word[self._count] = pieces[0]
+        self._count += 1
+        if self._count > max_size:
+          raise ValueError('Too many words: >%d.' % max_size)
+
+  def WordToId(self, word):
+    if word not in self._word_to_id:
+      return self._word_to_id[UNKNOWN_TOKEN]
+    return self._word_to_id[word]
+
+  def IdToWord(self, word_id):
+    if word_id not in self._id_to_word:
+      raise ValueError('id not found in vocab: %d.' % word_id)
+    return self._id_to_word[word_id]
+
+  def NumIds(self):
+    return self._count
+
+
+def ExampleGen(recordio_path, num_epochs=None):
+  """Generates tf.Examples from path of recordio files.
+
+  Args:
+    recordio_path: CNS path to tf.Example recordio
+    num_epochs: Number of times to go through the data. None means infinite.
+
+  Yields:
+    Deserialized tf.Example.
+
+  If there are multiple files specified, they accessed in a random order.
+  """
+  epoch = 0
+  while True:
+    if num_epochs is not None and epoch >= num_epochs:
+      break
+    filelist = glob.glob(recordio_path)
+    assert filelist, 'Empty filelist.'
+    random.shuffle(filelist)
+    for f in filelist:
+      reader = open(f, 'rb')
+      while True:
+        len_bytes = reader.read(8)
+        if not len_bytes: break
+        str_len = struct.unpack('q', len_bytes)[0]
+        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
+        yield example_pb2.Example.FromString(example_str)
+
+    epoch += 1
+
+
+def Pad(ids, pad_id, length):
+  """Pad or trim list to len length.
+
+  Args:
+    ids: list of ints to pad
+    pad_id: what to pad with
+    length: length to pad or trim to
+
+  Returns:
+    ids trimmed or padded with pad_id
+  """
+  assert pad_id is not None
+  assert length is not None
+
+  if len(ids) < length:
+    a = [pad_id] * (length - len(ids))
+    return ids + a
+  else:
+    return ids[:length]
+
+
+def GetWordIds(text, vocab, pad_len=None, pad_id=None):
+  """Get ids corresponding to words in text.
+
+  Assumes tokens separated by space.
+
+  Args:
+    text: a string
+    vocab: TextVocabularyFile object
+    pad_len: int, length to pad to
+    pad_id: int, word id for pad symbol
+
+  Returns:
+    A list of ints representing word ids.
+  """
+  ids = []
+  for w in text.split():
+    i = vocab.WordToId(w)
+    if i >= 0:
+      ids.append(i)
+    else:
+      ids.append(vocab.WordToId(UNKNOWN_TOKEN))
+  if pad_len is not None:
+    return Pad(ids, pad_id, pad_len)
+  return ids
+
+
+def Ids2Words(ids_list, vocab):
+  """Get words from ids.
+
+  Args:
+    ids_list: list of int32
+    vocab: TextVocabulary object
+
+  Returns:
+    List of words corresponding to ids.
+  """
+  assert isinstance(ids_list, list), '%s  is not a list' % ids_list
+  return [vocab.IdToWord(i) for i in ids_list]
+
+
+def SnippetGen(text, start_tok, end_tok, inclusive=True):
+  """Generates consecutive snippets between start and end tokens.
+
+  Args:
+    text: a string
+    start_tok: a string denoting the start of snippets
+    end_tok: a string denoting the end of snippets
+    inclusive: Whether include the tokens in the returned snippets.
+
+  Yields:
+    String snippets
+  """
+  cur = 0
+  while True:
+    try:
+      start_p = text.index(start_tok, cur)
+      end_p = text.index(end_tok, start_p + 1)
+      cur = end_p + len(end_tok)
+      if inclusive:
+        yield text[start_p:cur]
+      else:
+        yield text[start_p+len(start_tok):end_p]
+    except ValueError as e:
+      raise StopIteration('no more snippets in text: %s' % e)
+
+
+def GetExFeatureText(ex, key):
+  return ex.features.feature[key].bytes_list.value[0]
+
+
+def ToSentences(paragraph, include_token=True):
+  """Takes tokens of a paragraph and returns list of sentences.
+
+  Args:
+    paragraph: string, text of paragraph
+    include_token: Whether include the sentence separation tokens result.
+
+  Returns:
+    List of sentence strings.
+  """
+  s_gen = SnippetGen(paragraph, SENTENCE_START, SENTENCE_END, include_token)
+  return [s for s in s_gen]

BIN=BIN
textsum/data/data


A diferenza do arquivo foi suprimida porque é demasiado grande
+ 10003 - 0
textsum/data/vocab


+ 212 - 0
textsum/seq2seq_attention.py

@@ -0,0 +1,212 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Trains a seq2seq model.
+
+WORK IN PROGRESS.
+
+Implement "Abstractive Text Summarization using Sequence-to-sequence RNNS and
+Beyond."
+
+"""
+import sys
+import time
+
+import tensorflow as tf
+import batch_reader
+import data
+import seq2seq_attention_decode
+import seq2seq_attention_model
+
+FLAGS = tf.app.flags.FLAGS
+tf.app.flags.DEFINE_string('data_path',
+                           '', 'Path expression to tf.Example.')
+tf.app.flags.DEFINE_string('vocab_path',
+                           '', 'Path expression to text vocabulary file.')
+tf.app.flags.DEFINE_string('article_key', 'article',
+                           'tf.Example feature key for article.')
+tf.app.flags.DEFINE_string('abstract_key', 'headline',
+                           'tf.Example feature key for abstract.')
+tf.app.flags.DEFINE_string('log_root', '', 'Directory for model root.')
+tf.app.flags.DEFINE_string('train_dir', '', 'Directory for train.')
+tf.app.flags.DEFINE_string('eval_dir', '', 'Directory for eval.')
+tf.app.flags.DEFINE_string('decode_dir', '', 'Directory for decode summaries.')
+tf.app.flags.DEFINE_string('mode', 'train', 'train/eval/decode mode')
+tf.app.flags.DEFINE_integer('max_run_steps', 10000000,
+                            'Maximum number of run steps.')
+tf.app.flags.DEFINE_integer('max_article_sentences', 2,
+                            'Max number of first sentences to use from the '
+                            'article')
+tf.app.flags.DEFINE_integer('max_abstract_sentences', 100,
+                            'Max number of first sentences to use from the '
+                            'abstract')
+tf.app.flags.DEFINE_integer('beam_size', 4,
+                            'beam size for beam search decoding.')
+tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run eval.')
+tf.app.flags.DEFINE_integer('checkpoint_secs', 60, 'How often to checkpoint.')
+tf.app.flags.DEFINE_bool('use_bucketing', False,
+                         'Whether bucket articles of similar length.')
+tf.app.flags.DEFINE_bool('truncate_input', False,
+                         'Truncate inputs that are too long. If False, '
+                         'examples that are too long are discarded.')
+tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used.')
+tf.app.flags.DEFINE_integer('random_seed', 111, 'A seed value for randomness.')
+
+
+def _RunningAvgLoss(loss, running_avg_loss, summary_writer, step, decay=0.999):
+  """Calculate the running average of losses."""
+  if running_avg_loss == 0:
+    running_avg_loss = loss
+  else:
+    running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
+  running_avg_loss = min(running_avg_loss, 12)
+  loss_sum = tf.Summary()
+  loss_sum.value.add(tag='running_avg_loss', simple_value=running_avg_loss)
+  summary_writer.add_summary(loss_sum, step)
+  sys.stdout.write('running_avg_loss: %f\n' % running_avg_loss)
+  return running_avg_loss
+
+
+def _Train(model, data_batcher):
+  """Runs model training."""
+  with tf.device('/cpu:0'):
+    model.build_graph()
+    saver = tf.train.Saver()
+    # Train dir is different from log_root to avoid summary directory
+    # conflict with Supervisor.
+    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
+    sv = tf.train.Supervisor(logdir=FLAGS.log_root,
+                             is_chief=True,
+                             saver=saver,
+                             summary_op=None,
+                             save_summaries_secs=60,
+                             save_model_secs=FLAGS.checkpoint_secs,
+                             global_step=model.global_step)
+    sess = sv.prepare_or_wait_for_session()
+    running_avg_loss = 0
+    step = 0
+    while not sv.should_stop() and step < FLAGS.max_run_steps:
+      (article_batch, abstract_batch, targets, article_lens, abstract_lens,
+       loss_weights, _, _) = data_batcher.NextBatch()
+      (_, summaries, loss, train_step) = model.run_train_step(
+          sess, article_batch, abstract_batch, targets, article_lens,
+          abstract_lens, loss_weights)
+
+      summary_writer.add_summary(summaries, train_step)
+      running_avg_loss = _RunningAvgLoss(
+          running_avg_loss, loss, summary_writer, train_step)
+      step += 1
+      if step % 100 == 0:
+        summary_writer.flush()
+    sv.Stop()
+    return running_avg_loss
+
+
+def _Eval(model, data_batcher, vocab=None):
+  """Runs model eval."""
+  model.build_graph()
+  saver = tf.train.Saver()
+  summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
+  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+  running_avg_loss = 0
+  step = 0
+  while True:
+    time.sleep(FLAGS.eval_interval_secs)
+    try:
+      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
+    except tf.errors.OutOfRangeError as e:
+      tf.logging.error('Cannot restore checkpoint: %s', e)
+      continue
+
+    if not (ckpt_state and ckpt_state.model_checkpoint_path):
+      tf.logging.info('No model to eval yet at %s', FLAGS.train_dir)
+      continue
+
+    tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
+    saver.restore(sess, ckpt_state.model_checkpoint_path)
+
+    (article_batch, abstract_batch, targets, article_lens, abstract_lens,
+     loss_weights, _, _) = data_batcher.NextBatch()
+    (summaries, loss, train_step) = model.run_eval_step(
+        sess, article_batch, abstract_batch, targets, article_lens,
+        abstract_lens, loss_weights)
+    tf.logging.info(
+        'article:  %s',
+        ' '.join(data.Ids2Words(article_batch[0][:].tolist(), vocab)))
+    tf.logging.info(
+        'abstract: %s',
+        ' '.join(data.Ids2Words(abstract_batch[0][:].tolist(), vocab)))
+
+    summary_writer.add_summary(summaries, train_step)
+    running_avg_loss = _RunningAvgLoss(
+        running_avg_loss, loss, summary_writer, train_step)
+    if step % 100 == 0:
+      summary_writer.flush()
+
+
+def main(unused_argv):
+  vocab = data.Vocab(FLAGS.vocab_path, 1000000)
+  # Check for presence of required special tokens.
+  assert vocab.WordToId(data.PAD_TOKEN) > 0
+  assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
+  assert vocab.WordToId(data.SENTENCE_START) > 0
+  assert vocab.WordToId(data.SENTENCE_END) > 0
+
+  batch_size = 4
+  if FLAGS.mode == 'decode':
+    batch_size = FLAGS.beam_size
+
+  hps = seq2seq_attention_model.HParams(
+      mode=FLAGS.mode,  # train, eval, decode
+      min_lr=0.01,  # min learning rate.
+      lr=0.15,  # learning rate
+      batch_size=batch_size,
+      enc_layers=4,
+      enc_timesteps=120,
+      dec_timesteps=30,
+      min_input_len=2,  # discard articles/summaries < than this
+      num_hidden=256,  # for rnn cell
+      emb_dim=128,  # If 0, don't use embedding
+      max_grad_norm=2,
+      num_softmax_samples=4096)  # If 0, no sampled softmax.
+
+  batcher = batch_reader.Batcher(
+      FLAGS.data_path, vocab, hps, FLAGS.article_key,
+      FLAGS.abstract_key, FLAGS.max_article_sentences,
+      FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
+      truncate_input=FLAGS.truncate_input)
+  tf.set_random_seed(FLAGS.random_seed)
+
+  if hps.mode == 'train':
+    model = seq2seq_attention_model.Seq2SeqAttentionModel(
+        hps, vocab, num_gpus=FLAGS.num_gpus)
+    _Train(model, batcher)
+  elif hps.mode == 'eval':
+    model = seq2seq_attention_model.Seq2SeqAttentionModel(
+        hps, vocab, num_gpus=FLAGS.num_gpus)
+    _Eval(model, batcher, vocab=vocab)
+  elif hps.mode == 'decode':
+    decode_mdl_hps = hps
+    # Only need to restore the 1st step and reuse it since
+    # we keep and feed in state for each step's output.
+    decode_mdl_hps = hps._replace(dec_timesteps=1)
+    model = seq2seq_attention_model.Seq2SeqAttentionModel(
+        decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
+    decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
+    decoder.DecodeLoop()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 161 - 0
textsum/seq2seq_attention_decode.py

@@ -0,0 +1,161 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Module for decoding."""
+
+import os
+import time
+
+import tensorflow as tf
+import beam_search
+import data
+
+FLAGS = tf.app.flags.FLAGS
+tf.app.flags.DEFINE_integer('max_decode_steps', 1000000,
+                            'Number of decoding steps.')
+tf.app.flags.DEFINE_integer('decode_batches_per_ckpt', 8000,
+                            'Number of batches to decode before restoring next '
+                            'checkpoint')
+
+DECODE_LOOP_DELAY_SECS = 60
+DECODE_IO_FLUSH_INTERVAL = 100
+
+
+class DecodeIO(object):
+  """Writes the decoded and references to RKV files for Rouge score.
+
+    See nlp/common/utils/internal/rkv_parser.py for detail about rkv file.
+  """
+
+  def __init__(self, outdir):
+    self._cnt = 0
+    self._outdir = outdir
+    if not os.path.exists(self._outdir):
+      os.mkdir(self._outdir)
+    self._ref_file = None
+    self._decode_file = None
+
+  def Write(self, reference, decode):
+    """Writes the reference and decoded outputs to RKV files.
+
+    Args:
+      reference: The human (correct) result.
+      decode: The machine-generated result
+    """
+    self._ref_file.write('output=%s\n' % reference)
+    self._decode_file.write('output=%s\n' % decode)
+    self._cnt += 1
+    if self._cnt % DECODE_IO_FLUSH_INTERVAL == 0:
+      self._ref_file.flush()
+      self._decode_file.flush()
+
+  def ResetFiles(self):
+    """Resets the output files. Must be called once before Write()."""
+    if self._ref_file: self._ref_file.close()
+    if self._decode_file: self._decode_file.close()
+    timestamp = int(time.time())
+    self._ref_file = open(
+        os.path.join(self._outdir, 'ref%d'%timestamp), 'w')
+    self._decode_file = open(
+        os.path.join(self._outdir, 'decode%d'%timestamp), 'w')
+
+
+class BSDecoder(object):
+  """Beam search decoder."""
+
+  def __init__(self, model, batch_reader, hps, vocab):
+    """Beam search decoding.
+
+    Args:
+      model: The seq2seq attentional model.
+      batch_reader: The batch data reader.
+      hps: Hyperparamters.
+      vocab: Vocabulary
+    """
+    self._model = model
+    self._model.build_graph()
+    self._batch_reader = batch_reader
+    self._hps = hps
+    self._vocab = vocab
+    self._saver = tf.train.Saver()
+    self._decode_io = DecodeIO(FLAGS.decode_dir)
+
+  def DecodeLoop(self):
+    """Decoding loop for long running process."""
+    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+    step = 0
+    while step < FLAGS.max_decode_steps:
+      time.sleep(DECODE_LOOP_DELAY_SECS)
+      if not self._Decode(self._saver, sess):
+        continue
+      step += 1
+
+  def _Decode(self, saver, sess):
+    """Restore a checkpoint and decode it.
+
+    Args:
+      saver: Tensorflow checkpoint saver.
+      sess: Tensorflow session.
+    Returns:
+      If success, returns true, otherwise, false.
+    """
+    ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
+    if not (ckpt_state and ckpt_state.model_checkpoint_path):
+      tf.logging.info('No model to decode yet at %s', FLAGS.log_root)
+      return False
+
+    tf.logging.info('checkpoint path %s', ckpt_state.model_checkpoint_path)
+    ckpt_path = os.path.join(
+        FLAGS.log_root, os.path.basename(ckpt_state.model_checkpoint_path))
+    tf.logging.info('renamed checkpoint path %s', ckpt_path)
+    saver.restore(sess, ckpt_path)
+
+    self._decode_io.ResetFiles()
+    for _ in xrange(FLAGS.decode_batches_per_ckpt):
+      (article_batch, _, _, article_lens, _, _, origin_articles,
+       origin_abstracts) = self._batch_reader.NextBatch()
+      for i in xrange(self._hps.batch_size):
+        bs = beam_search.BeamSearch(
+            self._model, self._hps.batch_size,
+            self._vocab.WordToId(data.SENTENCE_START),
+            self._vocab.WordToId(data.SENTENCE_END),
+            self._hps.dec_timesteps)
+
+        article_batch_cp = article_batch.copy()
+        article_batch_cp[:] = article_batch[i:i+1]
+        article_lens_cp = article_lens.copy()
+        article_lens_cp[:] = article_lens[i:i+1]
+        best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
+        decode_output = [int(t) for t in best_beam.tokens[1:]]
+        self._DecodeBatch(
+            origin_articles[i], origin_abstracts[i], decode_output)
+    return True
+
+  def _DecodeBatch(self, article, abstract, output_ids):
+    """Convert id to words and writing results.
+
+    Args:
+      article: The original article string.
+      abstract: The human (correct) abstract string.
+      output_ids: The abstract word ids output by machine.
+    """
+    decoded_output = ' '.join(data.Ids2Words(output_ids, self._vocab))
+    end_p = decoded_output.find(data.SENTENCE_END, 0)
+    if end_p != -1:
+      decoded_output = decoded_output[:end_p]
+    tf.logging.info('article:  %s', article)
+    tf.logging.info('abstract: %s', abstract)
+    tf.logging.info('decoded:  %s', decoded_output)
+    self._decode_io.Write(abstract, decoded_output.strip())

+ 295 - 0
textsum/seq2seq_attention_model.py

@@ -0,0 +1,295 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Sequence-to-Sequence with attention model for text summarization.
+"""
+from collections import namedtuple
+
+import numpy as np
+import tensorflow as tf
+import seq2seq_lib
+
+
+HParams = namedtuple('HParams',
+                     'mode, min_lr, lr, batch_size, '
+                     'enc_layers, enc_timesteps, dec_timesteps, '
+                     'min_input_len, num_hidden, emb_dim, max_grad_norm, '
+                     'num_softmax_samples')
+
+
+def _extract_argmax_and_embed(embedding, output_projection=None,
+                              update_embedding=True):
+  """Get a loop_function that extracts the previous symbol and embeds it.
+
+  Args:
+    embedding: embedding tensor for symbols.
+    output_projection: None or a pair (W, B). If provided, each fed previous
+      output will first be multiplied by W and added B.
+    update_embedding: Boolean; if False, the gradients will not propagate
+      through the embeddings.
+
+  Returns:
+    A loop function.
+  """
+  def loop_function(prev, _):
+    """function that feed previous model output rather than ground truth."""
+    if output_projection is not None:
+      prev = tf.nn.xw_plus_b(
+          prev, output_projection[0], output_projection[1])
+    prev_symbol = tf.argmax(prev, 1)
+    # Note that gradients will not propagate through the second parameter of
+    # embedding_lookup.
+    emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
+    if not update_embedding:
+      emb_prev = tf.stop_gradient(emb_prev)
+    return emb_prev
+  return loop_function
+
+
+class Seq2SeqAttentionModel(object):
+  """Wrapper for Tensorflow model graph for text sum vectors."""
+
+  def __init__(self, hps, vocab, num_gpus=0):
+    self._hps = hps
+    self._vocab = vocab
+    self._num_gpus = num_gpus
+    self._cur_gpu = 0
+
+  def run_train_step(self, sess, article_batch, abstract_batch, targets,
+                     article_lens, abstract_lens, loss_weights):
+    to_return = [self._train_op, self._summaries, self._loss, self.global_step]
+    return sess.run(to_return,
+                    feed_dict={self._articles: article_batch,
+                               self._abstracts: abstract_batch,
+                               self._targets: targets,
+                               self._article_lens: article_lens,
+                               self._abstract_lens: abstract_lens,
+                               self._loss_weights: loss_weights})
+
+  def run_eval_step(self, sess, article_batch, abstract_batch, targets,
+                    article_lens, abstract_lens, loss_weights):
+    to_return = [self._summaries, self._loss, self.global_step]
+    return sess.run(to_return,
+                    feed_dict={self._articles: article_batch,
+                               self._abstracts: abstract_batch,
+                               self._targets: targets,
+                               self._article_lens: article_lens,
+                               self._abstract_lens: abstract_lens,
+                               self._loss_weights: loss_weights})
+
+  def run_decode_step(self, sess, article_batch, abstract_batch, targets,
+                      article_lens, abstract_lens, loss_weights):
+    to_return = [self._outputs, self.global_step]
+    return sess.run(to_return,
+                    feed_dict={self._articles: article_batch,
+                               self._abstracts: abstract_batch,
+                               self._targets: targets,
+                               self._article_lens: article_lens,
+                               self._abstract_lens: abstract_lens,
+                               self._loss_weights: loss_weights})
+
+  def _next_device(self):
+    """Round robin the gpu device. (Reserve last gpu for expensive op)."""
+    if self._num_gpus == 0:
+      return ''
+    dev = '/gpu:%d' % self._cur_gpu
+    self._cur_gpu = (self._cur_gpu + 1) % (self._num_gpus-1)
+    return dev
+
+  def _get_gpu(self, gpu_id):
+    if self._num_gpus <= 0 or gpu_id >= self._num_gpus:
+      return ''
+    return '/gpu:%d' % gpu_id
+
+  def _add_placeholders(self):
+    """Inputs to be fed to the graph."""
+    hps = self._hps
+    self._articles = tf.placeholder(tf.int32,
+                                    [hps.batch_size, hps.enc_timesteps],
+                                    name='articles')
+    self._abstracts = tf.placeholder(tf.int32,
+                                     [hps.batch_size, hps.dec_timesteps],
+                                     name='abstracts')
+    self._targets = tf.placeholder(tf.int32,
+                                   [hps.batch_size, hps.dec_timesteps],
+                                   name='targets')
+    self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
+                                        name='article_lens')
+    self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
+                                         name='abstract_lens')
+    self._loss_weights = tf.placeholder(tf.float32,
+                                        [hps.batch_size, hps.dec_timesteps],
+                                        name='loss_weights')
+
+  def _add_seq2seq(self):
+    hps = self._hps
+    vsize = self._vocab.NumIds()
+
+    with tf.variable_scope('seq2seq'):
+      encoder_inputs = tf.unpack(tf.transpose(self._articles))
+      decoder_inputs = tf.unpack(tf.transpose(self._abstracts))
+      targets = tf.unpack(tf.transpose(self._targets))
+      loss_weights = tf.unpack(tf.transpose(self._loss_weights))
+      article_lens = self._article_lens
+
+      # Embedding shared by the input and outputs.
+      with tf.variable_scope('embedding'), tf.device('/cpu:0'):
+        embedding = tf.get_variable(
+            'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
+            initializer=tf.truncated_normal_initializer(stddev=1e-4))
+        emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
+                              for x in encoder_inputs]
+        emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
+                              for x in decoder_inputs]
+
+      for layer_i in xrange(hps.enc_layers):
+        with tf.variable_scope('encoder%d'%layer_i), tf.device(
+            self._next_device()):
+          cell_fw = tf.nn.rnn_cell.LSTMCell(
+              hps.num_hidden,
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
+          cell_bw = tf.nn.rnn_cell.LSTMCell(
+              hps.num_hidden,
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+          (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
+              cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
+              sequence_length=article_lens)
+      encoder_outputs = emb_encoder_inputs
+
+      with tf.variable_scope('output_projection'):
+        w = tf.get_variable(
+            'w', [hps.num_hidden, vsize], dtype=tf.float32,
+            initializer=tf.truncated_normal_initializer(stddev=1e-4))
+        w_t = tf.transpose(w)
+        v = tf.get_variable(
+            'v', [vsize], dtype=tf.float32,
+            initializer=tf.truncated_normal_initializer(stddev=1e-4))
+
+      with tf.variable_scope('decoder'), tf.device(self._next_device()):
+        # When decoding, use model output from the previous step
+        # for the next step.
+        loop_function = None
+        if hps.mode == 'decode':
+          loop_function = _extract_argmax_and_embed(
+              embedding, (w, v), update_embedding=False)
+
+        cell = tf.nn.rnn_cell.LSTMCell(
+            hps.num_hidden,
+            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+
+        encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
+                           for x in encoder_outputs]
+        self._enc_top_states = tf.concat(1, encoder_outputs)
+        self._dec_in_state = fw_state
+        # During decoding, follow up _dec_in_state are fed from beam_search.
+        # dec_out_state are stored by beam_search for next step feeding.
+        initial_state_attention = (hps.mode == 'decode')
+        decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
+            emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
+            cell, num_heads=1, loop_function=loop_function,
+            initial_state_attention=initial_state_attention)
+
+      with tf.variable_scope('output'), tf.device(self._next_device()):
+        model_outputs = []
+        for i in xrange(len(decoder_outputs)):
+          if i > 0:
+            tf.get_variable_scope().reuse_variables()
+          model_outputs.append(
+              tf.nn.xw_plus_b(decoder_outputs[i], w, v))
+
+      if hps.mode == 'decode':
+        with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
+          best_outputs = [tf.argmax(x, 1) for x in model_outputs]
+          tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
+          self._outputs = tf.concat(
+              1, [tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
+
+          self._topk_log_probs, self._topk_ids = tf.nn.top_k(
+              tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
+
+      with tf.variable_scope('loss'), tf.device(self._next_device()):
+        def sampled_loss_func(inputs, labels):
+          with tf.device('/cpu:0'):  # Try gpu.
+            labels = tf.reshape(labels, [-1, 1])
+            return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels,
+                                              hps.num_softmax_samples, vsize)
+
+        if hps.num_softmax_samples != 0 and hps.mode == 'train':
+          self._loss = seq2seq_lib.sampled_sequence_loss(
+              decoder_outputs, targets, loss_weights, sampled_loss_func)
+        else:
+          self._loss = tf.nn.seq2seq.sequence_loss(
+              model_outputs, targets, loss_weights)
+        tf.scalar_summary('loss', tf.minimum(12.0, self._loss))
+
+  def _add_train_op(self):
+    """Sets self._train_op, op to run for training."""
+    hps = self._hps
+
+    self._lr_rate = tf.maximum(
+        hps.min_lr,  # min_lr_rate.
+        tf.train.exponential_decay(hps.lr, self.global_step, 30000, 0.98))
+
+    tvars = tf.trainable_variables()
+    with tf.device(self._get_gpu(self._num_gpus-1)):
+      grads, global_norm = tf.clip_by_global_norm(
+          tf.gradients(self._loss, tvars), hps.max_grad_norm)
+    tf.scalar_summary('global_norm', global_norm)
+    optimizer = tf.train.GradientDescentOptimizer(self._lr_rate)
+    tf.scalar_summary('learning rate', self._lr_rate)
+    self._train_op = optimizer.apply_gradients(
+        zip(grads, tvars), global_step=self.global_step, name='train_step')
+
+  def encode_top_state(self, sess, enc_inputs, enc_len):
+    """Return the top states from encoder for decoder.
+
+    Args:
+      sess: tensorflow session.
+      enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
+      enc_len: encoder input length of shape [batch_size]
+    Returns:
+      enc_top_states: The top level encoder states.
+      dec_in_state: The decoder layer initial state.
+    """
+    results = sess.run([self._enc_top_states, self._dec_in_state],
+                       feed_dict={self._articles: enc_inputs,
+                                  self._article_lens: enc_len})
+    return results[0], results[1][0]
+
+  def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
+    """Return the topK results and new decoder states."""
+    feed = {
+        self._enc_top_states: enc_top_states,
+        self._dec_in_state:
+            np.squeeze(np.array(dec_init_states)),
+        self._abstracts:
+            np.transpose(np.array([latest_tokens])),
+        self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
+
+    results = sess.run(
+        [self._topk_ids, self._topk_log_probs, self._dec_out_state],
+        feed_dict=feed)
+
+    ids, probs, states = results[0], results[1], results[2]
+    new_states = [s for s in states]
+    return ids, probs, new_states
+
+  def build_graph(self):
+    self._add_placeholders()
+    self._add_seq2seq()
+    self.global_step = tf.Variable(0, name='global_step', trainable=False)
+    if self._hps.mode == 'train':
+      self._add_train_op()
+    self._summaries = tf.merge_all_summaries()

+ 136 - 0
textsum/seq2seq_lib.py

@@ -0,0 +1,136 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""seq2seq library codes copied from elsewhere for customization."""
+
+import tensorflow as tf
+
+
+# Adapted to support sampled_softmax loss function, which accepts activations
+# instead of logits.
+def sequence_loss_by_example(inputs, targets, weights, loss_function,
+                             average_across_timesteps=True, name=None):
+  """Sampled softmax loss for a sequence of inputs (per example).
+
+  Args:
+    inputs: List of 2D Tensors of shape [batch_size x hid_dim].
+    targets: List of 1D batch-sized int32 Tensors of the same length as logits.
+    weights: List of 1D batch-sized float-Tensors of the same length as logits.
+    loss_function: Sampled softmax function (inputs, labels) -> loss
+    average_across_timesteps: If set, divide the returned cost by the total
+      label weight.
+    name: Optional name for this operation, default: 'sequence_loss_by_example'.
+
+  Returns:
+    1D batch-sized float Tensor: The log-perplexity for each sequence.
+
+  Raises:
+    ValueError: If len(inputs) is different from len(targets) or len(weights).
+  """
+  if len(targets) != len(inputs) or len(weights) != len(inputs):
+    raise ValueError('Lengths of logits, weights, and targets must be the same '
+                     '%d, %d, %d.' % (len(inputs), len(weights), len(targets)))
+  with tf.op_scope(inputs + targets + weights, name,
+                   'sequence_loss_by_example'):
+    log_perp_list = []
+    for inp, target, weight in zip(inputs, targets, weights):
+      crossent = loss_function(inp, target)
+      log_perp_list.append(crossent * weight)
+    log_perps = tf.add_n(log_perp_list)
+    if average_across_timesteps:
+      total_size = tf.add_n(weights)
+      total_size += 1e-12  # Just to avoid division by 0 for all-0 weights.
+      log_perps /= total_size
+  return log_perps
+
+
+def sampled_sequence_loss(inputs, targets, weights, loss_function,
+                          average_across_timesteps=True,
+                          average_across_batch=True, name=None):
+  """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
+
+  Args:
+    inputs: List of 2D Tensors of shape [batch_size x hid_dim].
+    targets: List of 1D batch-sized int32 Tensors of the same length as inputs.
+    weights: List of 1D batch-sized float-Tensors of the same length as inputs.
+    loss_function: Sampled softmax function (inputs, labels) -> loss
+    average_across_timesteps: If set, divide the returned cost by the total
+      label weight.
+    average_across_batch: If set, divide the returned cost by the batch size.
+    name: Optional name for this operation, defaults to 'sequence_loss'.
+
+  Returns:
+    A scalar float Tensor: The average log-perplexity per symbol (weighted).
+
+  Raises:
+    ValueError: If len(inputs) is different from len(targets) or len(weights).
+  """
+  with tf.op_scope(inputs + targets + weights, name, 'sampled_sequence_loss'):
+    cost = tf.reduce_sum(sequence_loss_by_example(
+        inputs, targets, weights, loss_function,
+        average_across_timesteps=average_across_timesteps))
+    if average_across_batch:
+      batch_size = tf.shape(targets[0])[0]
+      return cost / tf.cast(batch_size, tf.float32)
+    else:
+      return cost
+
+
+def linear(args, output_size, bias, bias_start=0.0, scope=None):
+  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+  Args:
+    args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+    output_size: int, second dimension of W[i].
+    bias: boolean, whether to add a bias term or not.
+    bias_start: starting value to initialize the bias; 0 by default.
+    scope: VariableScope for the created subgraph; defaults to "Linear".
+
+  Returns:
+    A 2D Tensor with shape [batch x output_size] equal to
+    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+  Raises:
+    ValueError: if some of the arguments has unspecified or wrong shape.
+  """
+  if args is None or (isinstance(args, (list, tuple)) and not args):
+    raise ValueError('`args` must be specified')
+  if not isinstance(args, (list, tuple)):
+    args = [args]
+
+  # Calculate the total size of arguments on dimension 1.
+  total_arg_size = 0
+  shapes = [a.get_shape().as_list() for a in args]
+  for shape in shapes:
+    if len(shape) != 2:
+      raise ValueError('Linear is expecting 2D arguments: %s' % str(shapes))
+    if not shape[1]:
+      raise ValueError('Linear expects shape[1] of arguments: %s' % str(shapes))
+    else:
+      total_arg_size += shape[1]
+
+  # Now the computation.
+  with tf.variable_scope(scope or 'Linear'):
+    matrix = tf.get_variable('Matrix', [total_arg_size, output_size])
+    if len(args) == 1:
+      res = tf.matmul(args[0], matrix)
+    else:
+      res = tf.matmul(tf.concat(1, args), matrix)
+    if not bias:
+      return res
+    bias_term = tf.get_variable(
+        'Bias', [output_size],
+        initializer=tf.constant_initializer(bias_start))
+  return res + bias_term