Browse Source

Rename dldp/dldp/heatmap_pred/process_images_grp_normalization_wli.py to dldp/dldp/heatmap_pred/HPC_SC20/process_images_grp_normalization_wli.py

Weizhe Li 5 years ago
parent
commit
094da1632f
1 changed files with 286 additions and 0 deletions
  1. 286 0
      process_images_grp_normalization_wli.py

+ 286 - 0
process_images_grp_normalization_wli.py

@@ -0,0 +1,286 @@
+
+import sys
+import time
+
+from matplotlib import cm
+from tqdm import tqdm
+from skimage.filters import threshold_otsu
+from keras.models import load_model
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import os.path as osp
+import openslide
+from pathlib import Path
+from skimage.filters import threshold_otsu
+import glob
+import math
+# before importing HDFStore, make sure 'tables' is installed by pip3 install tables
+from pandas import HDFStore
+from openslide.deepzoom import DeepZoomGenerator
+from sklearn.model_selection import StratifiedShuffleSplit
+import cv2
+from keras.utils.np_utils import to_categorical
+
+import os.path as osp
+import os 
+import openslide
+from pathlib import Path
+from keras.models import Sequential
+from keras.layers import Lambda, Dropout
+from keras.layers.convolutional import Convolution2D, Conv2DTranspose
+from keras.layers.pooling import MaxPooling2D
+from keras.models import model_from_json
+import numpy as np
+import sys
+import skimage.io as io
+import skimage.transform as trans
+import numpy as np
+from keras.models import *
+from keras.layers import *
+from keras.optimizers import *
+from keras.callbacks import ModelCheckpoint, LearningRateScheduler
+from keras import backend as keras
+import re
+import staintools
+#############################################
+import h5py
+from keras.utils import HDF5Matrix
+import stain_utils as utils
+import stainNorm_Reinhard
+import stainNorm_Macenko
+import stainNorm_Vahadane
+
+from datetime import datetime
+
+cores = int(os.environ['NSLOTS'])
+keras.set_session(keras.tf.Session(config=keras.tf.ConfigProto(intra_op_parallelism_threads=cores, inter_op_parallelism_threads=cores)))
+
+pred_size = int(os.environ['PRED_SIZE'])
+stride = 16
+
+mcn = os.environ['MCN']
+model = load_model(mcn, compile=False)
+        # '/home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5')
+        # '/home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5', compile=False)
+#    '/home/weizhe.li/Training/HNM_models/no_noise_no_norm/googlenetv1_no_noise_no_norm_0210_hnm_transfer_learn_02.10.20_09:31_original_256_patches-03-0.91.hdf5', compile=False)
+
+# /home/weizhe.li/Training/googlenetmainmodel1119HNM-02-0.92.hdf5
+
+def main():
+    ''' Four command line arguments:
+        - $DIR           Directory where HDF5 is located
+        - $HDF5_FILE     HDF5 file name, like test_001.h5
+        - $BASENAME      subgroup suffix, like 1, 2, ... 
+        - $HEATMAP_DIR   heatmap directory name
+    '''
+
+    # print command line arguments
+    for arg in sys.argv[1:]:
+        print (arg)
+
+    print (os.environ['HOSTNAME'])
+    print (os.environ['SGE_TASK_ID'])
+    
+    x = int(os.environ['SGE_TASK_ID'])
+    print ("x = ", x)
+
+    dir = sys.argv[1] 
+    hdf5_file = sys.argv[2] 
+    grp_suffix = sys.argv[3]
+    heatmap_dir = sys.argv[4] 
+
+    print ("dir = " + dir)
+    print ("hdf5_file = " + hdf5_file)
+    print ("grp_suffix = " + grp_suffix)
+    print ("heatmap_dir = " + heatmap_dir)
+     
+    # patches, coords = [], [] 
+     
+    start_time = time.time()
+     
+    # patches, coords = get_patches( dir,
+    get_patches( dir, hdf5_file, grp_suffix, heatmap_dir, verbose=True)
+    
+    print("--- %s seconds ---" % (time.time() - start_time))
+    
+# end of main () 
+
+###########################################################################
+#                HDF5-specific helper functions                           #
+###########################################################################
+
+def get_patches(db_location, file_name, grp_suffix, heatmap_dir, verbose=False):
+    """ Loads the numpy patches from HDF5 files.
+    """
+    patches, coords = [], []
+    
+    # Now load the images from H5 file.
+    file = h5py.File(db_location + "/" + file_name,'r+')
+    grp='t'+grp_suffix
+    # dataset = file['/' + ds]
+    group = file['/' + grp]
+    
+    for key, value in group.items():
+        if key == 'img':
+            dataset=value
+        if key == 'coord':
+            dataset2=value    
+    
+    new_patches = np.array(dataset).astype('uint8')
+    # for patch in new_patches:
+    #    patches.append(patch)
+    print ("COLOR_NORM on line # 133 is: ", color_norm)
+    output_preds_final_grp = []
+    for patch in new_patches:
+        ################################################color normalization##############################
+        if color_norm:
+            patch_normalized = color_norm_pred(patch, fit, log_file, current_time)
+        else:
+            patch_normalized = patch
+        output_preds_final = patch_pred_collect_from_slide_window(pred_size, patch_normalized, model, stride)
+        output_preds_final_grp.append(output_preds_final)
+    output_preds_final_grp = np.array(output_preds_final_grp)
+    np.save(osp.join(heatmap_dir, '%s_%s' % (file_name[:-3], grp)), output_preds_final_grp)
+       
+    print ("Group " + grp)
+        
+    # dataset2 = group['/' + "coord"]
+    new_coords = np.array(dataset2).astype('int64') 
+    for coord in new_coords:
+        coords.append(coord)
+    
+    file.close()
+
+    # output_preds_final_160 = []
+    # for i in range(len(patches)): 
+    #    output_preds_final = patch_pred_collect_from_slide_window(pred_size, patches[i], model, stride)
+    #    output_preds_final_160.append(output_preds_final)
+    # output_preds_final_160 = np.array(output_preds_final_160)
+    # np.save(osp.join(heatmap_dir, '%s_%s' % (file_name[:-3], grp)), output_preds_final_160)
+
+    if verbose:
+        print("[py-wsi] loaded from", file_name, grp)
+
+    # return patches, coords
+
+# end of get_patches ()
+
+
+def patch_pred_collect_from_slide_window(pred_size, fullimage, model, stride):
+    """
+    create a nxn matrix that includes all the patches extracted from one big patch by slide window sampling.
+    
+    :param integer pred_size: the size of patches to be extracted and predicted as tumor or normal patch.
+    :param nxn matrix fullimage: the image used for slide window prediction, which is larger than the patch to be predicted to avoid side effect.
+    :param object model: the trained network to predict the patches.
+    
+    :return a nxn matrix for one patch to be predicted by slide window method
+    """
+    
+    output_preds_final = []
+    
+    for x in tqdm(range(0, pred_size, stride)):
+        patchforprediction_batch = []
+        
+        for y in range(0, pred_size, stride):
+            patchforprediction = fullimage[x:x+pred_size, y:y+pred_size]
+            patchforprediction_batch.append(patchforprediction)
+    
+        X_train = np.array(patchforprediction_batch)
+        preds = predict_batch_from_model(X_train, model)
+        output_preds_final.append(preds)
+    
+    output_preds_final = np.array(output_preds_final)
+    
+    return output_preds_final
+
+# end of patch_pred_collect_from_slide_window
+
+
+def predict_batch_from_model(patches, model):
+    """
+    There are two values for each prediction: one is for the score of normal patches.
+    ; the other one is for the score of tumor patches. The function is used to select
+    the score of tumor patches
+    :param array patches: a list of image patches to be predicted.
+    :param object model: the trained neural network.
+    :return lsit predictions: a list of scores for each predicted image patch. 
+                              The score here is the probability of the image as a tumor
+                              image.
+            
+    """
+    predictions = model.predict(patches)
+    predictions = predictions[:, 1]
+    return predictions
+
+# end of predict_batch_from_model
+
+def color_normalization(template_image_path, color_norm_method):
+    """
+    The function put all the color normalization methods together.
+    :param string template_image_path: the path of the image used as a template
+    :param string color_norm_method: one of the three methods: vahadane, macenko, reinhard.
+    :return object 
+    """
+    template_image = staintools.read_image(template_image_path)
+    standardizer = staintools.LuminosityStandardizer.standardize(
+        template_image)
+    if color_norm_method == 'Reinhard':
+        color_normalizer = stainNorm_Reinhard.Normalizer()
+        color_normalizer.fit(standardizer)
+    elif color_norm_method == 'Macenko':
+        color_normalizer = stainNorm_Macenko.Normalizer()
+        color_normalizer.fit(standardizer)
+    elif color_norm_method == 'Vahadane':
+        color_normalizer = staintools.StainNormalizer(method='vahadane')
+        color_normalizer.fit(standardizer)
+    
+    return color_normalizer
+
+def color_norm_pred(image_patch, fit, log_file, current_time):
+    """
+    To perform color normalization based on the method used.
+    :param matrix img: the image to be color normalized
+    :param object fit: the initialized method for normalization
+    :return matrix img_norm: the normalized images
+    :note if the color normalization fails, the original image patches
+          will be used. But this event will be written in the log file.
+    """
+    img = image_patch
+    img_norm = img
+    try:
+        img_standard = staintools.LuminosityStandardizer.standardize(img)
+        img_norm = fit.transform(img_standard)
+    except Exception as e:
+        log_file.write(str(image_patch) + ';' + str(e) + ';' + current_time)
+    #print(img_norm)
+    return img_norm
+
+# end of color_norm_pred
+
+if __name__ == "__main__":
+    current_time = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%p")
+    color_norm_methods = ['Vahadane', 'Reinhard', 'Macenko']
+    template_image_path = '/home/weizhe.li/tumor_st.png'
+    # log_path = '/home/weizhe.li/log_files'
+    # color_norm = True
+    cn = os.environ['COLOR_NORM']
+    if (cn == "True"): 
+        color_norm = True
+    else:     
+        color_norm = False
+    
+    if color_norm:
+        print ("COLOR_NORM on line # 276 is: ", color_norm)
+        color_norm_method = color_norm_methods[0]
+        fit = color_normalization(template_image_path, color_norm_method)    
+    else:
+        print ("COLOR_NORM on Line # 280 is: ", color_norm)
+        color_norm_method = 'baseline'
+        fit = None
+    
+    log_path = os.environ['LOG_DIR']
+    log_file = open('%s/%s.txt' % (log_path, color_norm_method), 'w')
+
+    main()