inference.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import sys
  3. sys.path.append('RAFT/core')
  4. from argparse import ArgumentParser
  5. from collections import OrderedDict
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from raft import RAFT
  10. from utils import flow_viz
  11. def frame_preprocess(frame, device):
  12. frame = torch.from_numpy(frame).permute(2, 0, 1).float()
  13. frame = frame.unsqueeze(0)
  14. frame = frame.to(device)
  15. return frame
  16. def vizualize_flow(img, flo, save, counter):
  17. # permute the channels and change device is necessary
  18. img = img[0].permute(1, 2, 0).cpu().numpy()
  19. flo = flo[0].permute(1, 2, 0).cpu().numpy()
  20. # map flow to rgb image
  21. flo = flow_viz.flow_to_image(flo)
  22. flo = cv2.cvtColor(flo, cv2.COLOR_RGB2BGR)
  23. # concatenate, save and show images
  24. img_flo = np.concatenate([img, flo], axis=0)
  25. if save:
  26. cv2.imwrite(f"demo_frames/frame_{str(counter)}.jpg", img_flo)
  27. cv2.imshow("Optical Flow", img_flo / 255.0)
  28. k = cv2.waitKey(25) & 0xFF
  29. if k == 27:
  30. return False
  31. return True
  32. def get_cpu_model(model):
  33. new_model = OrderedDict()
  34. # get all layer's names from model
  35. for name in model:
  36. # create new name and update new model
  37. new_name = name[7:]
  38. new_model[new_name] = model[name]
  39. return new_model
  40. def inference(args):
  41. # get the RAFT model
  42. model = RAFT(args)
  43. # load pretrained weights
  44. pretrained_weights = torch.load(args.model)
  45. save = args.save
  46. if save:
  47. if not os.path.exists("demo_frames"):
  48. os.mkdir("demo_frames")
  49. if torch.cuda.is_available():
  50. device = "cuda"
  51. # parallel between available GPUs
  52. model = torch.nn.DataParallel(model)
  53. # load the pretrained weights into model
  54. model.load_state_dict(pretrained_weights)
  55. model.to(device)
  56. else:
  57. device = "cpu"
  58. # change key names for CPU runtime
  59. pretrained_weights = get_cpu_model(pretrained_weights)
  60. # load the pretrained weights into model
  61. model.load_state_dict(pretrained_weights)
  62. # change model's mode to evaluation
  63. model.eval()
  64. video_path = args.video
  65. # capture the video and get the first frame
  66. cap = cv2.VideoCapture(video_path)
  67. ret, frame_1 = cap.read()
  68. # frame preprocessing
  69. frame_1 = frame_preprocess(frame_1, device)
  70. counter = 0
  71. with torch.no_grad():
  72. while True:
  73. # read the next frame
  74. ret, frame_2 = cap.read()
  75. if not ret:
  76. break
  77. # preprocessing
  78. frame_2 = frame_preprocess(frame_2, device)
  79. # predict the flow
  80. flow_low, flow_up = model(frame_1, frame_2, iters=args.iters, test_mode=True)
  81. # transpose the flow output and convert it into numpy array
  82. ret = vizualize_flow(frame_1, flow_up, save, counter)
  83. if not ret:
  84. break
  85. frame_1 = frame_2
  86. counter += 1
  87. def main():
  88. parser = ArgumentParser()
  89. parser.add_argument("--model", help="restore checkpoint")
  90. parser.add_argument("--iters", type=int, default=12)
  91. parser.add_argument("--video", type=str, default="./videos/car.mp4")
  92. parser.add_argument("--save", action="store_true", help="save demo frames")
  93. parser.add_argument("--small", action="store_true", help="use small model")
  94. parser.add_argument(
  95. "--mixed_precision", action="store_true", help="use mixed precision"
  96. )
  97. args = parser.parse_args()
  98. inference(args)
  99. if __name__ == "__main__":
  100. main()