livedemo.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import cv2
  2. import numpy as np
  3. import subprocess
  4. import torch
  5. from stylenet import StyleNetwork
  6. from torchvision import transforms as T
  7. net=StyleNetwork('./models/style_7.pth')
  8. for p in net.parameters():
  9. p.requires_grad=False
  10. device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  11. net=net.eval().to(device) #use eval just for safety
  12. src=cv2.VideoCapture('/dev/video0') #USB camera ID
  13. ffstr='ffmpeg -re -f rawvideo -pix_fmt rgb24 -s 640x480 -i - -f v4l2 -pix_fmt yuv420p /dev/video2'
  14. #ffmpeg pipeline which accepts raw rgb frames from command line and writes to virtul camera in yuv420p format
  15. zoom=subprocess.Popen(ffstr, shell=True, stdin=subprocess.PIPE) #open process with shell so we can write to it
  16. dummyframe=255*np.ones((480,640,3), dtype=np.uint8) #blank frame if camera cannot be read
  17. preprocess=T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  18. #same normalization as that used in training data
  19. ret, frame=src.read()
  20. while True:
  21. try:
  22. if ret:
  23. frame=(frame[:,:,::-1]/255.0).astype(np.float32) #convert BGR to RGB, convert to 0-1 range and cast to float32
  24. frame_tensor=torch.unsqueeze(torch.from_numpy(frame),0).permute(0,3,1,2)
  25. # add batch dimension and convert to NCHW format
  26. tensor_in = preprocess(frame_tensor) #normalize
  27. tensor_in=tensor_in.to(device) #send to GPU
  28. tensor_out = net(tensor_in) #stylized tensor
  29. tensor_out=torch.squeeze(tensor_out).permute(1,2,0) #remove batch dimension and convert to HWC (opencv format)
  30. stylized_frame=(255*(tensor_out.to('cpu').detach().numpy())).astype(np.uint8) #convert to 0-255 range and cast as uint8
  31. #gaussian_blur = cv2.GaussianBlur(stylized_frame, (0, 0), 2.0)
  32. #stylized_frame = cv2.addWeighted(stylized_frame, 1.5, gaussian_blur, -0.5, 0, stylized_frame)
  33. else:
  34. stylized_frame=dummyframe #if camera cannot be read, blank white image will be shown
  35. zoom.stdin.write(stylized_frame.tobytes())
  36. #write to ffmpeg pipeline which in turn writes to virtual camera that can be accessed by zoom/skype/teams
  37. ret,frame=src.read()
  38. except KeyboardInterrupt:
  39. print('Received stop command')
  40. break
  41. zoom.terminate()
  42. src.release() #close ffmpeg pipeline and release camera