Parcourir la source

Variational Autoencoder generate() function fixed (z fed in rather than z_mean)

Joshua Howard il y a 8 ans
Parent
commit
dec7c89f5b
1 fichiers modifiés avec 2 ajouts et 3 suppressions
  1. 2 3
      autoencoder/autoencoder_models/VariationalAutoencoder.py

+ 2 - 3
autoencoder/autoencoder_models/VariationalAutoencoder.py

@@ -1,5 +1,4 @@
 import tensorflow as tf
-import numpy as np
 
 class VariationalAutoencoder(object):
 
@@ -57,8 +56,8 @@ class VariationalAutoencoder(object):
 
     def generate(self, hidden = None):
         if hidden is None:
-            hidden = np.random.normal(size=self.weights["b1"])
-        return self.sess.run(self.reconstruction, feed_dict={self.z_mean: hidden})
+            hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
+        return self.sess.run(self.reconstruction, feed_dict={self.z: hidden})
 
     def reconstruct(self, X):
         return self.sess.run(self.reconstruction, feed_dict={self.x: X})