瀏覽代碼

Merge pull request #355 from mfigurnov/refactor-precision

Refactoring of precision calculation.
Jon Shlens 9 年之前
父節點
當前提交
cdaa9b23d2
共有 1 個文件被更改,包括 5 次插入15 次删除
  1. 5 15
      resnet/resnet_main.py

+ 5 - 15
resnet/resnet_main.py

@@ -61,9 +61,6 @@ def train(hps):
   sess = sv.prepare_or_wait_for_session()
 
   step = 0
-  total_prediction = 0
-  correct_prediction = 0
-  precision = 0.0
   lrn_rate = 0.1
 
   while not sv.should_stop():
@@ -81,14 +78,9 @@ def train(hps):
     else:
       lrn_rate = 0.0001
 
-    predictions = np.argmax(predictions, axis=1)
     truth = np.argmax(truth, axis=1)
-    for (t, p) in zip(truth, predictions):
-      if t == p:
-        correct_prediction += 1
-      total_prediction += 1
-    precision = float(correct_prediction) / total_prediction
-    correct_prediction = total_prediction = 0
+    predictions = np.argmax(predictions, axis=1)
+    precision = np.mean(truth == predictions)
 
     step += 1
     if step % 100 == 0:
@@ -135,12 +127,10 @@ def evaluate(hps):
           [model.summaries, model.cost, model.predictions,
            model.labels, model.global_step])
 
-      best_predictions = np.argmax(predictions, axis=1)
       truth = np.argmax(truth, axis=1)
-      for (t, p) in zip(truth, best_predictions):
-        if t == p:
-          correct_prediction += 1
-        total_prediction += 1
+      predictions = np.argmax(predictions, axis=1)
+      correct_prediction += np.sum(truth == predictions)
+      total_prediction += predictions.shape[0]
 
     precision = 1.0 * correct_prediction / total_prediction
     best_precision = max(precision, best_precision)