{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Recall and Precision Diagnostic Example" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "import seaborn as sns\n", "\n", "from IPython.core.pylabtools import figsize\n", "\n", "figsize(10, 8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Confusion Matrix Numbers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "results = pd.DataFrame({'threshold': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],\n", " 'tp': [50, 48, 47, 45, 44, 42, 36, 30, 20, 12, 0],\n", " 'fp': [50, 47, 40, 31, 23, 16, 12, 11, 4, 3, 0],\n", " 'tn': [0, 3, 9, 16, 22, 29, 34, 38, 43, 45, 50],\n", " 'fn': [0, 2, 4, 8, 11, 13, 18, 21, 33, 40, 50]\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Calculate Precision, Recall, F1, TPR, FPR" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def calculate_metrics(results):\n", " roc = pd.DataFrame(index = results['threshold'], columns=['recall', 'precision', 'f1', 'tpr', 'fpr'])\n", " \n", " for i in results.iterrows():\n", " fn, fp, t, tn, tp = i[1]\n", " assert tp + fp + tn + fn == 100, 'Patients must add up to 100'\n", " \n", " recall = tp / (tp + fn)\n", " \n", " if tp == fp == 0:\n", " precision = 0\n", " true_positive_rate = 0\n", " \n", " else:\n", " precision = tp / (tp + fp)\n", " true_positive_rate = tp / (tp + fn)\n", " \n", " if precision == recall == 0:\n", " f1 = 0\n", " else:\n", " f1 = 2 * (precision * recall) / (precision + recall)\n", " \n", " \n", " false_positive_rate = fp / (fp + tn)\n", " \n", " roc.ix[t, 'recall'] = recall\n", " roc.ix[t, 'precision'] = precision\n", " roc.ix[t, 'f1'] = f1\n", " roc.ix[t, 'tpr'] = true_positive_rate\n", " roc.ix[t, 'fpr'] = false_positive_rate\n", " \n", " return roc" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | threshold | \n", "recall | \n", "precision | \n", "f1 | \n", "tpr | \n", "fpr | \n", "
---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "1 | \n", "0.5 | \n", "0.666667 | \n", "1 | \n", "1 | \n", "
1 | \n", "0.1 | \n", "0.96 | \n", "0.505263 | \n", "0.662069 | \n", "0.96 | \n", "0.94 | \n", "
2 | \n", "0.2 | \n", "0.921569 | \n", "0.54023 | \n", "0.681159 | \n", "0.921569 | \n", "0.816327 | \n", "
3 | \n", "0.3 | \n", "0.849057 | \n", "0.592105 | \n", "0.697674 | \n", "0.849057 | \n", "0.659574 | \n", "
4 | \n", "0.4 | \n", "0.8 | \n", "0.656716 | \n", "0.721311 | \n", "0.8 | \n", "0.511111 | \n", "
5 | \n", "0.5 | \n", "0.763636 | \n", "0.724138 | \n", "0.743363 | \n", "0.763636 | \n", "0.355556 | \n", "
6 | \n", "0.6 | \n", "0.666667 | \n", "0.75 | \n", "0.705882 | \n", "0.666667 | \n", "0.26087 | \n", "
7 | \n", "0.7 | \n", "0.588235 | \n", "0.731707 | \n", "0.652174 | \n", "0.588235 | \n", "0.22449 | \n", "
8 | \n", "0.8 | \n", "0.377358 | \n", "0.833333 | \n", "0.519481 | \n", "0.377358 | \n", "0.0851064 | \n", "
9 | \n", "0.9 | \n", "0.230769 | \n", "0.8 | \n", "0.358209 | \n", "0.230769 | \n", "0.0625 | \n", "
10 | \n", "1.0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "