123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Class for variational inference."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- sg = tf.contrib.bayesflow.stochastic_graph
- distributions = tf.contrib.distributions
- class VariationalInference(object):
- """VariationalInference class."""
- def __init__(self, model, variational, data):
- """Initializes the VariationalInference class.
- Args:
- model: the probability model. an object with a log_prob and sample method.
- variational: the variational family for the model. an object with
- log_prob and sampling methods.
- data: the observations we use to fit the model.
- """
- self.model = model
- self.variational = variational
- self.data = data
- def build_graph(self):
- """Builds the graph for variational inference."""
- q_samples = self.variational.sample
- log_p = self.model.log_prob(q_samples, self.data['x'])
- log_q = self.variational.log_prob(q_samples)
- elbo = log_p - log_q
- if elbo.get_shape().ndims > 1:
- # first dimension is samples, second is batch_size
- self.scalar_elbo = tf.reduce_mean(tf.reduce_mean(elbo, 0), 0)
- else:
- self.scalar_elbo = tf.reduce_sum(elbo, 0)
- self.elbo = elbo
- self.log_p = log_p
- self.log_q = log_q
|