process_images_grp_normalization_wli.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import sys
  2. import time
  3. from matplotlib import cm
  4. from tqdm import tqdm
  5. from skimage.filters import threshold_otsu
  6. from keras.models import load_model
  7. import numpy as np
  8. import pandas as pd
  9. import matplotlib.pyplot as plt
  10. import os.path as osp
  11. import openslide
  12. from pathlib import Path
  13. from skimage.filters import threshold_otsu
  14. import glob
  15. import math
  16. # before importing HDFStore, make sure 'tables' is installed by pip3 install tables
  17. from pandas import HDFStore
  18. from openslide.deepzoom import DeepZoomGenerator
  19. from sklearn.model_selection import StratifiedShuffleSplit
  20. import cv2
  21. from keras.utils.np_utils import to_categorical
  22. import os.path as osp
  23. import os
  24. import openslide
  25. from pathlib import Path
  26. from keras.models import Sequential
  27. from keras.layers import Lambda, Dropout
  28. from keras.layers.convolutional import Convolution2D, Conv2DTranspose
  29. from keras.layers.pooling import MaxPooling2D
  30. from keras.models import model_from_json
  31. import numpy as np
  32. import sys
  33. import skimage.io as io
  34. import skimage.transform as trans
  35. import numpy as np
  36. from keras.models import *
  37. from keras.layers import *
  38. from keras.optimizers import *
  39. from keras.callbacks import ModelCheckpoint, LearningRateScheduler
  40. from keras import backend as keras
  41. import re
  42. import staintools
  43. #############################################
  44. import h5py
  45. from keras.utils import HDF5Matrix
  46. import stain_utils as utils
  47. import stainNorm_Reinhard
  48. import stainNorm_Macenko
  49. import stainNorm_Vahadane
  50. from datetime import datetime
  51. cores = int(os.environ['NSLOTS'])
  52. keras.set_session(keras.tf.Session(config=keras.tf.ConfigProto(intra_op_parallelism_threads=cores, inter_op_parallelism_threads=cores)))
  53. pred_size = int(os.environ['PRED_SIZE'])
  54. stride = 16
  55. mcn = os.environ['MCN']
  56. model = load_model(mcn, compile=False)
  57. # '/home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5')
  58. # '/home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5', compile=False)
  59. # '/home/weizhe.li/Training/HNM_models/no_noise_no_norm/googlenetv1_no_noise_no_norm_0210_hnm_transfer_learn_02.10.20_09:31_original_256_patches-03-0.91.hdf5', compile=False)
  60. # /home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5
  61. def main():
  62. ''' Four command line arguments:
  63. - $DIR Directory where HDF5 is located
  64. - $HDF5_FILE HDF5 file name, like test_001.h5
  65. - $BASENAME subgroup suffix, like 1, 2, ...
  66. - $HEATMAP_DIR heatmap directory name
  67. '''
  68. # print command line arguments
  69. for arg in sys.argv[1:]:
  70. print (arg)
  71. print (os.environ['HOSTNAME'])
  72. print (os.environ['SGE_TASK_ID'])
  73. x = int(os.environ['SGE_TASK_ID'])
  74. print ("x = ", x)
  75. dir = sys.argv[1]
  76. hdf5_file = sys.argv[2]
  77. grp_suffix = sys.argv[3]
  78. heatmap_dir = sys.argv[4]
  79. print ("dir = " + dir)
  80. print ("hdf5_file = " + hdf5_file)
  81. print ("grp_suffix = " + grp_suffix)
  82. print ("heatmap_dir = " + heatmap_dir)
  83. # patches, coords = [], []
  84. start_time = time.time()
  85. # patches, coords = get_patches( dir,
  86. get_patches( dir, hdf5_file, grp_suffix, heatmap_dir, verbose=True)
  87. print("--- %s seconds ---" % (time.time() - start_time))
  88. # end of main ()
  89. ###########################################################################
  90. # HDF5-specific helper functions #
  91. ###########################################################################
  92. def get_patches(db_location, file_name, grp_suffix, heatmap_dir, verbose=False):
  93. """ Loads the numpy patches from HDF5 files.
  94. """
  95. patches, coords = [], []
  96. # Now load the images from H5 file.
  97. file = h5py.File(db_location + "/" + file_name,'r+')
  98. grp='t'+grp_suffix
  99. # dataset = file['/' + ds]
  100. group = file['/' + grp]
  101. for key, value in group.items():
  102. if key == 'img':
  103. dataset=value
  104. if key == 'coord':
  105. dataset2=value
  106. new_patches = np.array(dataset).astype('uint8')
  107. # for patch in new_patches:
  108. # patches.append(patch)
  109. print ("COLOR_NORM on line # 133 is: ", color_norm)
  110. output_preds_final_grp = []
  111. for patch in new_patches:
  112. ################################################color normalization##############################
  113. if color_norm:
  114. patch_normalized = color_norm_pred(patch, fit, log_file, current_time)
  115. else:
  116. patch_normalized = patch
  117. output_preds_final = patch_pred_collect_from_slide_window(pred_size, patch_normalized, model, stride)
  118. output_preds_final_grp.append(output_preds_final)
  119. output_preds_final_grp = np.array(output_preds_final_grp)
  120. np.save(osp.join(heatmap_dir, '%s_%s' % (file_name[:-3], grp)), output_preds_final_grp)
  121. print ("Group " + grp)
  122. # dataset2 = group['/' + "coord"]
  123. new_coords = np.array(dataset2).astype('int64')
  124. for coord in new_coords:
  125. coords.append(coord)
  126. file.close()
  127. # output_preds_final_160 = []
  128. # for i in range(len(patches)):
  129. # output_preds_final = patch_pred_collect_from_slide_window(pred_size, patches[i], model, stride)
  130. # output_preds_final_160.append(output_preds_final)
  131. # output_preds_final_160 = np.array(output_preds_final_160)
  132. # np.save(osp.join(heatmap_dir, '%s_%s' % (file_name[:-3], grp)), output_preds_final_160)
  133. if verbose:
  134. print("[py-wsi] loaded from", file_name, grp)
  135. # return patches, coords
  136. # end of get_patches ()
  137. def patch_pred_collect_from_slide_window(pred_size, fullimage, model, stride):
  138. """
  139. create a nxn matrix that includes all the patches extracted from one big patch by slide window sampling.
  140. :param integer pred_size: the size of patches to be extracted and predicted as tumor or normal patch.
  141. :param nxn matrix fullimage: the image used for slide window prediction, which is larger than the patch to be predicted to avoid side effect.
  142. :param object model: the trained network to predict the patches.
  143. :return a nxn matrix for one patch to be predicted by slide window method
  144. """
  145. output_preds_final = []
  146. for x in tqdm(range(0, pred_size, stride)):
  147. patchforprediction_batch = []
  148. for y in range(0, pred_size, stride):
  149. patchforprediction = fullimage[x:x+pred_size, y:y+pred_size]
  150. patchforprediction_batch.append(patchforprediction)
  151. X_train = np.array(patchforprediction_batch)
  152. preds = predict_batch_from_model(X_train, model)
  153. output_preds_final.append(preds)
  154. output_preds_final = np.array(output_preds_final)
  155. return output_preds_final
  156. # end of patch_pred_collect_from_slide_window
  157. def predict_batch_from_model(patches, model):
  158. """
  159. There are two values for each prediction: one is for the score of normal patches.
  160. ; the other one is for the score of tumor patches. The function is used to select
  161. the score of tumor patches
  162. :param array patches: a list of image patches to be predicted.
  163. :param object model: the trained neural network.
  164. :return lsit predictions: a list of scores for each predicted image patch.
  165. The score here is the probability of the image as a tumor
  166. image.
  167. """
  168. predictions = model.predict(patches)
  169. predictions = predictions[:, 1]
  170. return predictions
  171. # end of predict_batch_from_model
  172. def color_normalization(template_image_path, color_norm_method):
  173. """
  174. The function put all the color normalization methods together.
  175. :param string template_image_path: the path of the image used as a template
  176. :param string color_norm_method: one of the three methods: vahadane, macenko, reinhard.
  177. :return object
  178. """
  179. template_image = staintools.read_image(template_image_path)
  180. standardizer = staintools.LuminosityStandardizer.standardize(
  181. template_image)
  182. if color_norm_method == 'Reinhard':
  183. color_normalizer = stainNorm_Reinhard.Normalizer()
  184. color_normalizer.fit(standardizer)
  185. elif color_norm_method == 'Macenko':
  186. color_normalizer = stainNorm_Macenko.Normalizer()
  187. color_normalizer.fit(standardizer)
  188. elif color_norm_method == 'Vahadane':
  189. color_normalizer = staintools.StainNormalizer(method='vahadane')
  190. color_normalizer.fit(standardizer)
  191. return color_normalizer
  192. def color_norm_pred(image_patch, fit, log_file, current_time):
  193. """
  194. To perform color normalization based on the method used.
  195. :param matrix img: the image to be color normalized
  196. :param object fit: the initialized method for normalization
  197. :return matrix img_norm: the normalized images
  198. :note if the color normalization fails, the original image patches
  199. will be used. But this event will be written in the log file.
  200. """
  201. img = image_patch
  202. img_norm = img
  203. try:
  204. img_standard = staintools.LuminosityStandardizer.standardize(img)
  205. img_norm = fit.transform(img_standard)
  206. except Exception as e:
  207. log_file.write(str(image_patch) + ';' + str(e) + ';' + current_time)
  208. #print(img_norm)
  209. return img_norm
  210. # end of color_norm_pred
  211. if __name__ == "__main__":
  212. current_time = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%p")
  213. color_norm_methods = ['Vahadane', 'Reinhard', 'Macenko']
  214. template_image_path = '/home/weizhe.li/tumor_st.png'
  215. # log_path = '/home/weizhe.li/log_files'
  216. # color_norm = True
  217. cn = os.environ['COLOR_NORM']
  218. if (cn == "True"):
  219. color_norm = True
  220. else:
  221. color_norm = False
  222. if color_norm:
  223. print ("COLOR_NORM on line # 276 is: ", color_norm)
  224. color_norm_method = color_norm_methods[0]
  225. fit = color_normalization(template_image_path, color_norm_method)
  226. else:
  227. print ("COLOR_NORM on Line # 280 is: ", color_norm)
  228. color_norm_method = 'baseline'
  229. fit = None
  230. log_path = os.environ['LOG_DIR']
  231. log_file = open('%s/%s.txt' % (log_path, color_norm_method), 'w')
  232. main()