taxi_preprocessing_example.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """Download data needed for the examples"""
  2. from __future__ import print_function
  3. if __name__ == "__main__":
  4. from os import path, makedirs, remove
  5. from download_sample_data import bar as progressbar
  6. import pandas as pd
  7. import numpy as np
  8. import sys
  9. try:
  10. import requests
  11. except ImportError:
  12. print('Download script required requests package: conda install requests')
  13. sys.exit(1)
  14. def _download_dataset(url):
  15. r = requests.get(url, stream=True)
  16. output_path = path.split(url)[1]
  17. with open(output_path, 'wb') as f:
  18. total_length = int(r.headers.get('content-length'))
  19. for chunk in progressbar(r.iter_content(chunk_size=1024), expected_size=(total_length/1024) + 1):
  20. if chunk:
  21. f.write(chunk)
  22. f.flush()
  23. examples_dir = path.dirname(path.realpath(__file__))
  24. data_dir = path.join(examples_dir, 'data')
  25. if not path.exists(data_dir):
  26. makedirs(data_dir)
  27. # Taxi data
  28. def latlng_to_meters(df, lat_name, lng_name):
  29. lat = df[lat_name]
  30. lng = df[lng_name]
  31. origin_shift = 2 * np.pi * 6378137 / 2.0
  32. mx = lng * origin_shift / 180.0
  33. my = np.log(np.tan((90 + lat) * np.pi / 360.0)) / (np.pi / 180.0)
  34. my = my * origin_shift / 180.0
  35. df.loc[:, lng_name] = mx
  36. df.loc[:, lat_name] = my
  37. taxi_path = path.join(data_dir, 'nyc_taxi.csv')
  38. if not path.exists(taxi_path):
  39. print("Downloading Taxi Data...")
  40. url = ('https://storage.googleapis.com/tlc-trip-data/2015/'
  41. 'yellow_tripdata_2015-01.csv')
  42. _download_dataset(url)
  43. df = pd.read_csv('yellow_tripdata_2015-01.csv')
  44. print('Filtering Taxi Data')
  45. df = df.loc[(df.pickup_longitude < -73.75) &
  46. (df.pickup_longitude > -74.15) &
  47. (df.dropoff_longitude < -73.75) &
  48. (df.dropoff_longitude > -74.15) &
  49. (df.pickup_latitude > 40.68) &
  50. (df.pickup_latitude < 40.84) &
  51. (df.dropoff_latitude > 40.68) &
  52. (df.dropoff_latitude < 40.84)].copy()
  53. print('Reprojecting Taxi Data')
  54. latlng_to_meters(df, 'pickup_latitude', 'pickup_longitude')
  55. latlng_to_meters(df, 'dropoff_latitude', 'dropoff_longitude')
  56. df.rename(columns={'pickup_longitude': 'pickup_x', 'dropoff_longitude': 'dropoff_x',
  57. 'pickup_latitude': 'pickup_y', 'dropoff_latitude': 'dropoff_y'},
  58. inplace=True)
  59. df.to_csv(taxi_path, index=False)
  60. remove('yellow_tripdata_2015-01.csv')
  61. print("\nAll data downloaded.")