{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Load and parse data with TensorFlow\n", "\n", "A TensorFlow example to build input pipelines for loading data efficiently.\n", "\n", "\n", "- Numpy Arrays\n", "- Images\n", "- CSV file\n", "- Custom data from a Generator\n", "\n", "For more information about creating and loading TensorFlow's `TFRecords` data format, see: [tfrecords.ipynb](tfrecords.ipynb)\n", "\n", "- Author: Aymeric Damien\n", "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", "import numpy as np\n", "import random\n", "import requests\n", "import string\n", "import tarfile\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Numpy Arrays\n", "\n", "Build a data pipeline over numpy arrays." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a toy dataset (even and odd numbers, with respective labels of 0 and 1).\n", "evens = np.arange(0, 100, step=2, dtype=np.int32)\n", "evens_label = np.zeros(50, dtype=np.int32)\n", "odds = np.arange(1, 100, step=2, dtype=np.int32)\n", "odds_label = np.ones(50, dtype=np.int32)\n", "# Concatenate arrays\n", "features = np.concatenate([evens, odds])\n", "labels = np.concatenate([evens_label, odds_label])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Graph().as_default():\n", " # Create TF session.\n", " sess = tf.Session()\n", " \n", " # Slice the numpy arrays (each row becoming a record).\n", " data = tf.data.Dataset.from_tensor_slices((features, labels))\n", " # Refill data indefinitely. \n", " data = data.repeat()\n", " # Shuffle data.\n", " data = data.shuffle(buffer_size=100)\n", " # Batch data (aggregate records together).\n", " data = data.batch(batch_size=4)\n", " # Prefetch batch (pre-load batch for faster consumption).\n", " data = data.prefetch(buffer_size=1)\n", " \n", " # Create an iterator over the dataset.\n", " iterator = data.make_initializable_iterator()\n", " # Initialize the iterator.\n", " sess.run(iterator.initializer)\n", "\n", " # Get next data batch.\n", " d = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[82 58 80 23] [0 0 0 1]\n", "[16 91 74 96] [0 1 0 0]\n", "[ 4 17 32 34] [0 1 0 0]\n", "[16 8 77 21] [0 0 1 1]\n", "[20 99 48 18] [0 1 0 0]\n" ] } ], "source": [ "# Display data.\n", "for i in range(5):\n", " x, y = sess.run(d)\n", " print(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load CSV files\n", "\n", "Build a data pipeline from features stored in a CSV file. For this example, Titanic dataset will be used as a toy dataset stored in CSV format." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Titanic Dataset\n", "\n", "\n", "\n", "survived|pclass|name|sex|age|sibsp|parch|ticket|fare\n", "--------|------|----|---|---|-----|-----|------|----\n", "1|1|\"Allen, Miss. Elisabeth Walton\"|female|29|0|0|24160|211.3375\n", "1|1|\"Allison, Master. Hudson Trevor\"|male|0.9167|1|2|113781|151.5500\n", "0|1|\"Allison, Miss. Helen Loraine\"|female|2|1|2|113781|151.5500\n", "0|1|\"Allison, Mr. Hudson Joshua Creighton\"|male|30|1|2|113781|151.5500\n", "...|...|...|...|...|...|...|...|..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download Titanic dataset (in csv format).\n", "d = requests.get(\"https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv\")\n", "with open(\"titanic_dataset.csv\", \"wb\") as f:\n", " f.write(d.content)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load Titanic dataset.\n", "# Original features: survived,pclass,name,sex,age,sibsp,parch,ticket,fare\n", "# Select specific columns: survived,pclass,name,sex,age,fare\n", "column_to_use = [0, 1, 2, 3, 4, 8]\n", "record_defaults = [tf.int32, tf.int32, tf.string, tf.string, tf.float32, tf.float32]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Graph().as_default():\n", " # Create TF session.\n", " sess = tf.Session()\n", " \n", " # Load the whole dataset file, and slice each line.\n", " data = tf.data.experimental.CsvDataset(\"titanic_dataset.csv\", record_defaults, header=True, select_cols=column_to_use)\n", " # Refill data indefinitely. \n", " data = data.repeat()\n", " # Shuffle data.\n", " data = data.shuffle(buffer_size=1000)\n", " # Batch data (aggregate records together).\n", " data = data.batch(batch_size=2)\n", " # Prefetch batch (pre-load batch for faster consumption).\n", " data = data.prefetch(buffer_size=1)\n", " \n", " # Create an iterator over the dataset.\n", " iterator = data.make_initializable_iterator()\n", " # Initialize the iterator.\n", " sess.run(iterator.initializer)\n", "\n", " # Get next data batch.\n", " d = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 0]\n", "[3 1]\n", "['Lam, Mr. Ali' 'Widener, Mr. Harry Elkins']\n", "['male' 'male']\n", "[ 0. 27.]\n", "[ 56.4958 211.5 ]\n", "\n", "[0 1]\n", "[1 1]\n", "['Baumann, Mr. John D' 'Daly, Mr. Peter Denis ']\n", "['male' 'male']\n", "[ 0. 51.]\n", "[25.925 26.55 ]\n", "\n", "[0 1]\n", "[3 1]\n", "['Assam, Mr. Ali' 'Newell, Miss. Madeleine']\n", "['male' 'female']\n", "[23. 31.]\n", "[ 7.05 113.275]\n", "\n" ] } ], "source": [ "# Display data.\n", "for i in range(3):\n", " survived, pclass, name, sex, age, fare = sess.run(d)\n", " print(survived)\n", " print(pclass)\n", " print(name)\n", " print(sex)\n", " print(age)\n", " print(fare)\n", " print(\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Images\n", "\n", "Build a data pipeline by loading images from disk. For this example, Oxford Flowers dataset will be used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download Oxford 17 flowers dataset.\n", "d = requests.get(\"http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz\")\n", "with open(\"17flowers.tgz\", \"wb\") as f:\n", " f.write(d.content)\n", "# Extract archive.\n", "with tarfile.open(\"17flowers.tgz\") as t:\n", " t.extractall()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a file to list all images path and their corresponding label.\n", "with open('jpg/dataset.csv', 'w') as f:\n", " c = 0\n", " for i in range(1360):\n", " f.write(\"jpg/image_%04i.jpg,%i\\n\" % (i+1, c))\n", " if (i+1) % 80 == 0:\n", " c += 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Graph().as_default():\n", " \n", " # Load Images.\n", " with open(\"jpg/dataset.csv\") as f:\n", " dataset_file = f.read().splitlines()\n", " \n", " # Create TF session.\n", " sess = tf.Session()\n", "\n", " # Load the whole dataset file, and slice each line.\n", " data = tf.data.Dataset.from_tensor_slices(dataset_file)\n", " # Refill data indefinitely.\n", " data = data.repeat()\n", " # Shuffle data.\n", " data = data.shuffle(buffer_size=1000)\n", "\n", " # Load and pre-process images.\n", " def load_image(path):\n", " # Read image from path.\n", " image = tf.io.read_file(path)\n", " # Decode the jpeg image to array [0, 255].\n", " image = tf.image.decode_jpeg(image)\n", " # Resize images to a common size of 256x256.\n", " image = tf.image.resize(image, [256, 256])\n", " # Rescale values to [-1, 1].\n", " image = 1. - image / 127.5\n", " return image\n", " # Decode each line from the dataset file.\n", " def parse_records(line):\n", " # File is in csv format: \"image_path,label_id\".\n", " # TensorFlow requires a default value, but it will never be used.\n", " image_path, image_label = tf.io.decode_csv(line, [\"\", 0])\n", " # Apply the function to load images.\n", " image = load_image(image_path)\n", " return image, image_label\n", " # Use 'map' to apply the above functions in parallel.\n", " data = data.map(parse_records, num_parallel_calls=4)\n", "\n", " # Batch data (aggregate images-array together).\n", " data = data.batch(batch_size=2)\n", " # Prefetch batch (pre-load batch for faster consumption).\n", " data = data.prefetch(buffer_size=1)\n", " \n", " # Create an iterator over the dataset.\n", " iterator = data.make_initializable_iterator()\n", " # Initialize the iterator.\n", " sess.run(iterator.initializer)\n", "\n", " # Get next data batch.\n", " d = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[[ 0.1294117 0.05098033 0.46666664]\n", " [ 0.1368872 0.05098033 0.48909312]\n", " [ 0.0931372 0.0068627 0.46029407]\n", " ...\n", " [ 0.23480386 0.0522058 0.6102941 ]\n", " [ 0.12696075 -0.05416667 0.38063723]\n", " [-0.10024512 -0.28848052 0.10367644]]\n", "\n", " [[ 0.04120708 -0.06118262 0.36256123]\n", " [ 0.08009624 -0.02229345 0.41640145]\n", " [ 0.06797445 -0.04132879 0.41923058]\n", " ...\n", " [ 0.2495715 0.06697345 0.6251221 ]\n", " [ 0.12058818 -0.06094813 0.37577546]\n", " [-0.05184889 -0.24009418 0.16777915]]\n", "\n", " [[-0.09234071 -0.22738981 0.20484066]\n", " [-0.03100491 -0.17312062 0.2811274 ]\n", " [ 0.01051998 -0.13237214 0.3376838 ]\n", " ...\n", " [ 0.27787983 0.07494056 0.64203525]\n", " [ 0.11533964 -0.09005249 0.3869906 ]\n", " [-0.02704227 -0.23958337 0.19454747]]\n", "\n", " ...\n", "\n", " [[ 0.07913595 -0.13069856 0.29874384]\n", " [ 0.10140878 -0.09445572 0.35912937]\n", " [ 0.08869672 -0.08415675 0.41446364]\n", " ...\n", " [ 0.25821072 0.22463232 0.69197303]\n", " [ 0.31636214 0.25750512 0.79362744]\n", " [ 0.09552741 0.01709598 0.57395875]]\n", "\n", " [[ 0.09019601 -0.12156868 0.3098039 ]\n", " [ 0.17446858 -0.02271283 0.43218917]\n", " [ 0.06583172 -0.10818791 0.39230233]\n", " ...\n", " [ 0.27021956 0.23664117 0.70269513]\n", " [ 0.19560927 0.1385014 0.6740407 ]\n", " [ 0.04364848 -0.03478289 0.5220798 ]]\n", "\n", " [[ 0.02830875 -0.18345594 0.24791664]\n", " [ 0.12937105 -0.06781042 0.38709164]\n", " [ 0.01120263 -0.162817 0.33767325]\n", " ...\n", " [ 0.25989532 0.22631687 0.69237083]\n", " [ 0.1200884 0.06298059 0.5985198 ]\n", " [ 0.05961001 -0.01882136 0.53804135]]]\n", "\n", "\n", " [[[ 0.3333333 0.25490195 0.05882347]\n", " [ 0.3333333 0.25490195 0.05882347]\n", " [ 0.3340686 0.24705875 0.03039211]\n", " ...\n", " [-0.5215688 -0.4599266 -0.14632356]\n", " [-0.5100491 -0.47083342 -0.03725493]\n", " [-0.43419123 -0.39497554 0.05992639]]\n", "\n", " [[ 0.34117645 0.26274508 0.0666666 ]\n", " [ 0.35646445 0.2630821 0.0744791 ]\n", " [ 0.3632046 0.2548713 0.04384762]\n", " ...\n", " [-0.9210479 -0.84267783 -0.4540485 ]\n", " [-0.9017464 -0.8390626 -0.3507018 ]\n", " [-0.83339334 -0.7632048 -0.2534927 ]]\n", "\n", " [[ 0.3646446 0.2706495 0.06678915]\n", " [ 0.37248772 0.27837008 0.07445425]\n", " [ 0.38033658 0.27053267 0.05950326]\n", " ...\n", " [-0.94302344 -0.84222686 -0.30278325]\n", " [-0.91017747 -0.8090074 -0.18615782]\n", " [-0.83437514 -0.7402575 -0.08192408]]\n", "\n", " ...\n", "\n", " [[ 0.64705884 0.654902 0.67058825]\n", " [ 0.6318321 0.63967526 0.65536153]\n", " [ 0.63128924 0.6391324 0.65481865]\n", " ...\n", " [ 0.6313726 0.57647055 0.51372546]\n", " [ 0.6078431 0.53725487 0.4823529 ]\n", " [ 0.6078431 0.53725487 0.4823529 ]]\n", "\n", " [[ 0.654902 0.654902 0.6704657 ]\n", " [ 0.654902 0.654902 0.6704657 ]\n", " [ 0.64778835 0.64778835 0.6492474 ]\n", " ...\n", " [ 0.6392157 0.5843137 0.5215686 ]\n", " [ 0.6393325 0.56874424 0.5138422 ]\n", " [ 0.63106614 0.5604779 0.50557595]]\n", "\n", " [[ 0.654902 0.64705884 0.6313726 ]\n", " [ 0.6548728 0.64702964 0.63134336]\n", " [ 0.64705884 0.63210785 0.6377451 ]\n", " ...\n", " [ 0.63244915 0.5775472 0.5148021 ]\n", " [ 0.6698529 0.5992647 0.5443627 ]\n", " [ 0.6545358 0.5839475 0.5290455 ]]]] [5 9]\n" ] } ], "source": [ "# Display data.\n", "for i in range(1):\n", " batch_x, batch_y = sess.run(d)\n", " print(batch_x, batch_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load data from a Generator\n", "\n", "Build a data pipeline from a custom generator. For this example, a toy generator yielding random string, vector and it is used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a dummy generator.\n", "def generate_features():\n", " # Function to generate a random string.\n", " def random_string(length):\n", " return ''.join(random.choice(string.ascii_letters) for m in xrange(length))\n", " # Return a random string, a random vector, and a random int.\n", " yield random_string(4), np.random.uniform(size=4), random.randint(0, 10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Graph().as_default():\n", "\n", " # Create TF session.\n", " sess = tf.Session()\n", "\n", " # Create TF dataset from the generator.\n", " data = tf.data.Dataset.from_generator(generate_features, output_types=(tf.string, tf.float32, tf.int32))\n", " # Refill data indefinitely.\n", " data = data.repeat()\n", " # Shuffle data.\n", " data = data.shuffle(buffer_size=100)\n", " # Batch data (aggregate records together).\n", " data = data.batch(batch_size=4)\n", " # Prefetch batch (pre-load batch for faster consumption).\n", " data = data.prefetch(buffer_size=1)\n", "\n", " # Create an iterator over the dataset.\n", " iterator = data.make_initializable_iterator()\n", " # Initialize the iterator.\n", " sess.run(iterator.initializer)\n", "\n", " # Get next data batch.\n", " d = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['AvCS' 'kAaI' 'QwGX' 'IWOI'] [[0.6096093 0.32192084 0.26622605 0.70250475]\n", " [0.72534287 0.7637426 0.19977213 0.74121326]\n", " [0.6930984 0.09409562 0.4063325 0.5002103 ]\n", " [0.05160935 0.59411395 0.276416 0.98264974]] [1 3 5 6]\n", "['EXjS' 'brvx' 'kwNz' 'eFOb'] [[0.34355283 0.26881003 0.70575935 0.7503411 ]\n", " [0.9584373 0.27466875 0.27802315 0.9563204 ]\n", " [0.19129485 0.07014314 0.0932724 0.20726128]\n", " [0.28744072 0.81736153 0.37507302 0.8984588 ]] [1 9 7 0]\n", "['vpSa' 'UuqW' 'xaTO' 'milw'] [[0.2942028 0.8228986 0.5793326 0.16651365]\n", " [0.28259405 0.599063 0.2922477 0.95071274]\n", " [0.23645316 0.00258607 0.06772221 0.7291911 ]\n", " [0.12861755 0.31435087 0.576638 0.7333119 ]] [3 5 8 4]\n", "['UBBb' 'MUXs' 'nLJB' 'OBGl'] [[0.2677402 0.17931737 0.02607645 0.85898155]\n", " [0.58647937 0.727203 0.13329858 0.8898983 ]\n", " [0.13872191 0.47390288 0.7061665 0.08478573]\n", " [0.3786016 0.22002582 0.91989636 0.45837343]] [ 5 8 0 10]\n", "['kiiz' 'bQYG' 'WpUU' 'AuIY'] [[0.74781317 0.13744462 0.9236441 0.63558507]\n", " [0.23649399 0.35303807 0.0951511 0.03541444]\n", " [0.33599988 0.6906629 0.97166294 0.55850506]\n", " [0.90997607 0.5545979 0.43635726 0.9127501 ]] [8 1 4 4]\n" ] } ], "source": [ "# Display data.\n", "for i in range(5):\n", " batch_str, batch_vector, batch_int = sess.run(d)\n", " print(batch_str, batch_vector, batch_int)" ] } ], "metadata": { "kernelspec": { "display_name": "tf1", "language": "python", "name": "tf1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15+" } }, "nbformat": 4, "nbformat_minor": 2 }