heatmap_assembly.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. ####################################################################################################
  2. """
  3. Heatmap assembly
  4. =================
  5. After prediction, each patch from a WSI image has one prediction matrix (for example, 14x14). This script is
  6. used to put all these small matrix into a big map corresponding to a rectangle tissue region of a
  7. WSI image.
  8. How to use
  9. ----------
  10. The following variables needed to be set:
  11. : param Folder_Prediction_Results: the location of the prediction for individual patches
  12. : type Folder_Prediction_results: str
  13. : param slide_category: the category of the slide, for example, 'tumor', 'normal', 'test'
  14. : type slide_category: str
  15. : param Folder_Heatmap: the folder to stoe the stitched heatmap
  16. : type Folder_Heatmap: str
  17. : param Stride: the skipped pixels when prediction, for example, 16, 64
  18. : type Stride: int
  19. Note
  20. ----
  21. The following files are necessary to perform the task:
  22. '/home/weizhe.li/li-code4hpc/pred_dim_0314/training-updated/normal/dimensions',
  23. '/home/weizhe.li/li-code4hpc/pred_dim_0314/training-updated/tumor/dimensions',
  24. '/home/weizhe.li/li-code4hpc/pred_dim_0314/testing/dimensions'
  25. These files store the dimension of the heatmap and location of the heatmap in the WSI image.
  26. """
  27. ####################################################################################################
  28. import os
  29. import os.path as osp
  30. import pandas as pd
  31. import numpy as np
  32. import matplotlib
  33. import matplotlib.pyplot as plt
  34. import glob
  35. #import multiprocess as mlp
  36. import re
  37. ####################################################################################################
  38. def pred_collect(pred_folder):
  39. files = glob.glob(osp.join(pred_folder, '*.npy'))
  40. files.sort(key=lambda f: int(f.split('_t')[1].split('.')[0]))
  41. # create a empty list to store all the small heatmap files (160, 14, 14) list.
  42. heat_map = []
  43. for file in files:
  44. regions = np.load(file)
  45. heat_map.extend(regions)
  46. heat_map_array = np.array(heat_map)
  47. return heat_map_array
  48. def stitch_preded_patches(dim, index, pred_folder, Folder_Heatmap, Stride):
  49. """
  50. stitching the prediction based on each small patches to a big heatmap
  51. :param dimension_files: a list of all the dimension files for one category of slides, foe example, 'tumor'
  52. :type dimension_files: list
  53. :param pred_folder: the folder having all the patch prediction results for a single WSI image.
  54. :type pred_folder: str
  55. :param Folder_Heatmap: the folder to store the big stitched heatmap.
  56. :type Folder_Heatmap: str
  57. :param stride: the stride during prediction
  58. :type stride: int
  59. :return: no return
  60. :note: two files will saved to the Folder_Heatmap:
  61. 1. the stitched heatmap in npy format
  62. 2. the heatmap picture in png format
  63. """
  64. num_of_pred_per_patch = int(224/Stride)
  65. # heat_map_big = np.zeros([dim[7]*num_of_pred_per_patch, dim[8]*num_of_pred_per_patch], dtype=np.float32)
  66. # generate a list of all npy files inside one folder.
  67. heat_map_array = pred_collect(pred_folder)
  68. heat_map_array_iter = iter(heat_map_array)
  69. # construct the full heat_map array
  70. no_tissue_region = np.zeros([int(224/Stride), int(224/Stride)], dtype = np.float32)
  71. # no_tissue_region = np.zeros([12, 12], dtype=np.float32)
  72. heat_map_all = []
  73. for _, item in index.iterrows():
  74. if item.is_tissue:
  75. patch_pred = next(heat_map_array_iter)
  76. heat_map_all.extend(patch_pred)
  77. #print(patch_pred)
  78. else:
  79. heat_map_all.extend(no_tissue_region)
  80. heat_map_all = np.array(heat_map_all)
  81. print(heat_map_all.shape)
  82. # These are critical steps to construct heatmap in a time saving manner.
  83. heat_map_reshape = heat_map_all.reshape(dim[7], dim[8], num_of_pred_per_patch, num_of_pred_per_patch)
  84. b = heat_map_reshape.transpose((0, 2, 1, 3))
  85. # c = b.reshape(heat_map_big.shape[0], heat_map_big.shape[1])
  86. # c = b.reshape(heat_map_big.shape[0], heat_map_big.shape[1])
  87. c = b.reshape(dim[7]*num_of_pred_per_patch, dim[8]*num_of_pred_per_patch)
  88. np.save('%s/%s_oldmodel_test_001n' % (Folder_Heatmap, osp.basename(pred_folder)), c)
  89. matplotlib.image.imsave('%s/%s_oldmodel_test_001n.png' % (Folder_Heatmap, osp.basename(pred_folder)), c)
  90. if __name__ == "__main__":
  91. taskid = int(os.environ['SGE_TASK_ID'])
  92. # Here is the folder for prediction results.The prediction results are organized into folders. Each folder corresponds to a WSI image.
  93. # Folder_Prediction_Results = '/scratch/mikem/UserSupport/weizhe.li/runs_process_cn_True/normal_wnorm_448_400_7690666'
  94. Folder_Prediction_Results = os.environ['Folder_Prediction_Results']
  95. Folder_dimension = os.environ['Folder_dimension']
  96. index_path = os.environ['index_path']
  97. dimension_files = glob.glob(osp.join(Folder_dimension, '*.npy'))
  98. dimension_files.sort()
  99. print(dimension_files)
  100. index_files = glob.glob(osp.join(index_path, '*.pkl'))
  101. index_files.sort()
  102. # Folder_Heatmap = '/scratch/weizhe.li/heat_map/HPC/test'
  103. Folder_Heatmap = os.environ['Folder_Heatmap']
  104. Stride = 16
  105. i = taskid - 1
  106. dir = os.environ['dir']
  107. pred_folder = osp.join(Folder_Prediction_Results, dir, 'preds')
  108. dimension_file = dimension_files[i]
  109. dimension = np.load(dimension_file)
  110. print(dimension[7])
  111. index_file = index_files[i]
  112. index = np.load(index_file)
  113. print(pred_folder)
  114. stitch_preded_patches(dimension, index, pred_folder, Folder_Heatmap, Stride)