{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Create and Load TFRecords\n", "\n", "A simple TensorFlow example to parse a dataset into TFRecord format, and then read that dataset.\n", "\n", "In this example, the Titanic Dataset (in CSV format) will be used as a toy dataset, for parsing all the dataset features into TFRecord format, and then building an input pipeline that can be used for training models.\n", "\n", "- Author: Aymeric Damien\n", "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Titanic Dataset\n", "\n", "The titanic dataset is a popular dataset for ML that provides a list of all passengers onboard the Titanic, along with various features such as their age, sex, class (1st, 2nd, 3rd)... And if the passenger survived the disaster or not.\n", "\n", "It can be used to see that even though some luck was involved in surviving the sinking, some groups of people were more likely to survive than others, such as women, children, and the upper-class...\n", "\n", "#### Overview\n", "survived|pclass|name|sex|age|sibsp|parch|ticket|fare\n", "--------|------|----|---|---|-----|-----|------|----\n", "1|1|\"Allen, Miss. Elisabeth Walton\"|female|29|0|0|24160|211.3375\n", "1|1|\"Allison, Master. Hudson Trevor\"|male|0.9167|1|2|113781|151.5500\n", "0|1|\"Allison, Miss. Helen Loraine\"|female|2|1|2|113781|151.5500\n", "0|1|\"Allison, Mr. Hudson Joshua Creighton\"|male|30|1|2|113781|151.5500\n", "...|...|...|...|...|...|...|...|...\n", "\n", "\n", "#### Variable Descriptions\n", "```\n", "survived Survived\n", " (0 = No; 1 = Yes)\n", "pclass Passenger Class\n", " (1 = 1st; 2 = 2nd; 3 = 3rd)\n", "name Name\n", "sex Sex\n", "age Age\n", "sibsp Number of Siblings/Spouses Aboard\n", "parch Number of Parents/Children Aboard\n", "ticket Ticket Number\n", "fare Passenger Fare\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", "import csv\n", "import requests\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download Titanic dataset (in csv format).\n", "d = requests.get(\"https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv\")\n", "with open(\"titanic_dataset.csv\", \"wb\") as f:\n", " f.write(d.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create TFRecords" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate Integer Features.\n", "def build_int64_feature(data):\n", " return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))\n", "\n", "# Generate Float Features.\n", "def build_float_feature(data):\n", " return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))\n", "\n", "# Generate String Features.\n", "def build_string_feature(data):\n", " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))\n", "\n", "# Generate a TF `Example`, parsing all features of the dataset.\n", "def convert_to_tfexample(survived, pclass, name, sex, age, sibsp, parch, ticket, fare):\n", " return tf.train.Example(\n", " features=tf.train.Features(\n", " feature={\n", " 'survived': build_int64_feature(survived),\n", " 'pclass': build_int64_feature(pclass),\n", " 'name': build_string_feature(name),\n", " 'sex': build_string_feature(sex),\n", " 'age': build_float_feature(age),\n", " 'sibsp': build_int64_feature(sibsp),\n", " 'parch': build_int64_feature(parch),\n", " 'ticket': build_string_feature(ticket),\n", " 'fare': build_float_feature(fare),\n", " })\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Open dataset file.\n", "with open(\"titanic_dataset.csv\") as f:\n", " # Output TFRecord file.\n", " with tf.io.TFRecordWriter(\"titanic_dataset.tfrecord\") as w:\n", " # Generate a TF Example for all row in our dataset.\n", " # CSV reader will read and parse all rows.\n", " reader = csv.reader(f, skipinitialspace=True)\n", " for i, record in enumerate(reader):\n", " # Skip header.\n", " if i == 0:\n", " continue\n", " survived, pclass, name, sex, age, sibsp, parch, ticket, fare = record\n", " # Parse each csv row to TF Example using the above functions.\n", " example = convert_to_tfexample(int(survived), int(pclass), name, sex, float(age), int(sibsp), int(parch), ticket, float(fare))\n", " # Serialize each TF Example to string, and write to TFRecord file.\n", " w.write(example.SerializeToString())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load TFRecords" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Build features template, with types.\n", "features = {\n", " 'survived': tf.io.FixedLenFeature([], tf.int64),\n", " 'pclass': tf.io.FixedLenFeature([], tf.int64),\n", " 'name': tf.io.FixedLenFeature([], tf.string),\n", " 'sex': tf.io.FixedLenFeature([], tf.string),\n", " 'age': tf.io.FixedLenFeature([], tf.float32),\n", " 'sibsp': tf.io.FixedLenFeature([], tf.int64),\n", " 'parch': tf.io.FixedLenFeature([], tf.int64),\n", " 'ticket': tf.io.FixedLenFeature([], tf.string),\n", " 'fare': tf.io.FixedLenFeature([], tf.float32),\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create TensorFlow session.\n", "sess = tf.Session()\n", "\n", "# Load TFRecord data.\n", "filenames = [\"titanic_dataset.tfrecord\"]\n", "data = tf.data.TFRecordDataset(filenames)\n", "\n", "# Parse features, using the above template.\n", "def parse_record(record):\n", " return tf.io.parse_single_example(record, features=features)\n", "# Apply the parsing to each record from the dataset.\n", "data = data.map(parse_record)\n", "\n", "# Refill data indefinitely.\n", "data = data.repeat()\n", "# Shuffle data.\n", "data = data.shuffle(buffer_size=1000)\n", "# Batch data (aggregate records together).\n", "data = data.batch(batch_size=4)\n", "# Prefetch batch (pre-load batch for faster consumption).\n", "data = data.prefetch(buffer_size=1)\n", "\n", "# Create an iterator over the dataset.\n", "iterator = data.make_initializable_iterator()\n", "# Initialize the iterator.\n", "sess.run(iterator.initializer)\n", "\n", "# Get next data batch.\n", "x = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'fare': array([ 35.5 , 73.5 , 133.65 , 19.2583], dtype=float32), 'name': array(['Sloper, Mr. William Thompson', 'Davies, Mr. Charles Henry',\n", " 'Frauenthal, Dr. Henry William', 'Baclini, Miss. Marie Catherine'],\n", " dtype=object), 'age': array([28., 18., 50., 5.], dtype=float32), 'parch': array([0, 0, 0, 1]), 'pclass': array([1, 2, 1, 3]), 'sex': array(['male', 'male', 'male', 'female'], dtype=object), 'survived': array([1, 0, 1, 1]), 'sibsp': array([0, 0, 2, 2]), 'ticket': array(['113788', 'S.O.C. 14879', 'PC 17611', '2666'], dtype=object)}\n", "\n", "{'fare': array([ 18.75 , 106.425, 78.85 , 90. ], dtype=float32), 'name': array(['Richards, Mrs. Sidney (Emily Hocking)', 'LeRoy, Miss. Bertha',\n", " 'Cavendish, Mrs. Tyrell William (Julia Florence Siegel)',\n", " 'Hoyt, Mrs. Frederick Maxfield (Jane Anne Forby)'], dtype=object), 'age': array([24., 30., 76., 35.], dtype=float32), 'parch': array([3, 0, 0, 0]), 'pclass': array([2, 1, 1, 1]), 'sex': array(['female', 'female', 'female', 'female'], dtype=object), 'survived': array([1, 1, 1, 1]), 'sibsp': array([2, 0, 1, 1]), 'ticket': array(['29106', 'PC 17761', '19877', '19943'], dtype=object)}\n", "\n", "{'fare': array([19.9667, 15.5 , 15.0458, 66.6 ], dtype=float32), 'name': array(['Hagland, Mr. Konrad Mathias Reiersen', 'Lennon, Miss. Mary',\n", " 'Richard, Mr. Emile', 'Pears, Mr. Thomas Clinton'], dtype=object), 'age': array([ 0., 0., 23., 29.], dtype=float32), 'parch': array([0, 0, 0, 0]), 'pclass': array([3, 3, 2, 1]), 'sex': array(['male', 'female', 'male', 'male'], dtype=object), 'survived': array([0, 0, 0, 0]), 'sibsp': array([1, 1, 0, 1]), 'ticket': array(['65304', '370371', 'SC/PARIS 2133', '113776'], dtype=object)}\n", "\n" ] } ], "source": [ "# Dequeue data and display.\n", "for i in range(3):\n", " print(sess.run(x))\n", " print(\"\")" ] } ], "metadata": { "kernelspec": { "display_name": "tf1", "language": "python", "name": "tf1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15+" } }, "nbformat": 4, "nbformat_minor": 2 }