|
@@ -109,7 +109,7 @@
|
|
|
" # Clip prediction values to avoid log(0) error.\n",
|
|
|
" y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)\n",
|
|
|
" # Compute cross-entropy.\n",
|
|
|
- " return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))\n",
|
|
|
+ " return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred),1))\n",
|
|
|
"\n",
|
|
|
"# Accuracy metric.\n",
|
|
|
"def accuracy(y_pred, y_true):\n",
|