utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright (c) 2012, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import cv2
  27. import tensorflow as tf
  28. import numpy as np
  29. def dummy():
  30. pass
  31. def load_image(name,interpolation = cv2.INTER_AREA):
  32. img=cv2.imread(name,1)
  33. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  34. inter_area = cv2.resize(img,(256,256),interpolation=interpolation)
  35. start_pt= np.random.randint(24,size=2)
  36. end_pt = start_pt + [232,232]
  37. img = inter_area[start_pt[0]:end_pt[0],start_pt[1]:end_pt[1]]
  38. return img
  39. def load_dataset(augment_fn = dummy):
  40. import os
  41. import cv2
  42. from datetime import datetime
  43. import numpy as np
  44. import pandas as pd
  45. from scipy import interpolate
  46. import matplotlib.pyplot as plt
  47. #Variables to be used later
  48. filenames = []
  49. labels =[]
  50. i = 0
  51. #Read CSV file using Pandas
  52. df = pd.read_csv('atlantic_storms.csv')
  53. dir ='Dataset/tcdat/'
  54. a = os.listdir(dir)
  55. file_path = "Dataset/Aug/"
  56. directory = os.path.dirname(file_path)
  57. try:
  58. os.stat(directory)
  59. except:
  60. os.mkdir(directory)
  61. aug = 0
  62. for j in a :
  63. c = os.listdir(dir+'/'+j)
  64. for k in c :
  65. d = os.listdir(dir+'/'+j+'/'+k)
  66. for l in d :
  67. print('.',end='')
  68. start_year= '20'+j[2:]+ '-01-01'
  69. end_year= '20'+j[2:]+ '-12-31'
  70. cyc_name = l[4:]
  71. mask = (df['date'] > start_year ) & (df['date'] <= end_year ) & ( df['name'] == cyc_name )
  72. cyc_pd = df.loc[mask]
  73. first = (datetime.strptime(cyc_pd['date'].iloc[0], "%Y-%m-%d %H:%M:%S"))
  74. last = (datetime.strptime(cyc_pd['date'].iloc[-1], "%Y-%m-%d %H:%M:%S"))
  75. text_time=[]
  76. text_vel=[]
  77. for q in range(len(cyc_pd['date'])):
  78. text_vel.append(cyc_pd['maximum_sustained_wind_knots'].iloc[q])
  79. text_time.append((datetime.strptime(cyc_pd['date'].iloc[q],"%Y-%m-%d %H:%M:%S")-first).total_seconds())
  80. func = interpolate.splrep(text_time,text_vel)
  81. e = os.listdir(dir+'/'+j+'/'+k+'/'+l+'/ir/geo/1km')
  82. e.sort()
  83. for m in e :
  84. try :
  85. time=(datetime.strptime(m[:13], "%Y%m%d.%H%M"))
  86. name = dir+j+'/'+k+'/'+l+'/ir/geo/1km/'+m
  87. if(time>first and time < last):
  88. val = int(interpolate.splev((time-first).total_seconds(),func))
  89. filenames.append(name)
  90. if val <=20 :
  91. labels.append(0)
  92. elif val>20 and val <=33 :
  93. labels.append(1)
  94. elif val>33 and val <=63 :
  95. labels.append(2)
  96. elif val>63 and val <=82 :
  97. labels.append(3)
  98. elif val>82 and val <=95 :
  99. labels.append(4)
  100. elif val>95 and val <=112 :
  101. labels.append(5)
  102. elif val>112 and val <=136 :
  103. labels.append(6)
  104. elif val>136 :
  105. labels.append(7)
  106. i = augment_fn(name,labels[-1],filenames,labels,i)
  107. except :
  108. pass
  109. print('')
  110. print(len(filenames))
  111. # Shuffle The Data
  112. import random
  113. # Zip Images with Appropriate Labels before Shuffling
  114. c = list(zip(filenames, labels))
  115. random.shuffle(c)
  116. #Unzip the Data Post Shuffling
  117. filenames, labels = zip(*c)
  118. filenames = list(filenames)
  119. labels = list(labels)
  120. return filenames,labels
  121. # Let's make a Validation Set with 10% of the Original Data with 1.25% contribution of every class
  122. def make_test_set(filenames,labels,val=0.1):
  123. classes = 8
  124. j=0
  125. val_filenames=[]
  126. val_labels=[]
  127. new = [int(val*len(filenames)/classes)]*classes
  128. print(new)
  129. try:
  130. for i in range(len(filenames)):
  131. if(new[labels[i]]>0):
  132. val_filenames.append(filenames[i])
  133. val_labels.append(labels[i])
  134. new[labels[i]] = new[labels[i]]-1
  135. del filenames[i]
  136. del labels[i]
  137. except :
  138. pass
  139. # Shuffle The Data
  140. import random
  141. # Zip Images with Appropriate Labels before Shuffling
  142. c = list(zip(val_filenames, val_labels))
  143. random.shuffle(c)
  144. #Unzip the Data Post Shuffling
  145. val_filenames, val_labels = zip(*c)
  146. val_filenames = list(val_filenames)
  147. val_labels = list(val_labels)
  148. from collections import Counter
  149. print(Counter(labels))
  150. return val_filenames,val_labels
  151. def parse_function(filename, label):
  152. image_string = tf.io.read_file(filename)
  153. #Don't use tf.image.decode_image, or the output shape will be undefined
  154. image = tf.image.decode_jpeg(image_string, channels=3)
  155. #This will convert to float values in [0, 1]
  156. image = tf.image.convert_image_dtype(image, tf.float32)
  157. #Resize Image
  158. image = tf.image.resize(image, [232, 232])
  159. return image, label
  160. def make_dataset(train_in,test_in,val_in):
  161. import tensorflow as tf
  162. train = tf.data.Dataset.from_tensor_slices((train_in[0], train_in[1]))
  163. train = train.shuffle(len(train_in[0]))
  164. train = train.map(parse_function,num_parallel_calls=8)
  165. train = train.batch(train_in[2])
  166. train = train.prefetch(1)
  167. test = tf.data.Dataset.from_tensor_slices((test_in[0], test_in[1]))
  168. test = test.shuffle(len(test_in[0]))
  169. test = test.map(parse_function, num_parallel_calls=8)
  170. test = test.batch(test_in[2])
  171. test = test.prefetch(1)
  172. val = tf.data.Dataset.from_tensor_slices((val_in[0],val_in[1] ))
  173. val = val.map(parse_function, num_parallel_calls=8)
  174. val = val.batch(val_in[2])
  175. val = val.prefetch(1)
  176. return train,test,val