app.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. import time
  3. from flask import Flask
  4. from flask import render_template, Response, request, send_from_directory, flash, url_for
  5. from flask import current_app as app
  6. from werkzeug.utils import secure_filename
  7. from src.lstm import ActionClassificationLSTM
  8. from src.video_analyzer import analyse_video, stream_video
  9. # import some common Detectron2 utilities
  10. from detectron2 import model_zoo
  11. from detectron2.engine import DefaultPredictor
  12. from detectron2.config import get_cfg
  13. app = Flask(__name__)
  14. UPLOAD_FOLDER = './'
  15. app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
  16. app.secret_key = "secret key"
  17. start = time.time()
  18. # obtain detectron2's default config
  19. cfg = get_cfg()
  20. # load the pre trained model from Detectron2 model zoo
  21. cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
  22. # set confidence threshold for this model
  23. cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
  24. # load model weights
  25. cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
  26. # create the predictor for pose estimation using the config
  27. pose_detector = DefaultPredictor(cfg)
  28. model_load_done = time.time()
  29. print("Detectron model loaded in ", model_load_done - start)
  30. # Load pretrained LSTM model from checkpoint file
  31. lstm_classifier = ActionClassificationLSTM.load_from_checkpoint("models/saved_model.ckpt")
  32. lstm_classifier.eval()
  33. class DataObject():
  34. pass
  35. def checkFileType(f: str):
  36. return f.split('.')[-1] in ['mp4']
  37. def cleanString(v: str):
  38. out_str = v
  39. delm = ['_', '-', '.']
  40. for d in delm:
  41. out_str = out_str.split(d)
  42. out_str = " ".join(out_str)
  43. return out_str
  44. @app.route('/', methods=['GET'])
  45. def index():
  46. obj = DataObject
  47. obj.video = "sample_video.mp4"
  48. return render_template('/index.html', obj=obj)
  49. @app.route('/upload', methods=['POST'])
  50. def upload():
  51. obj = DataObject
  52. obj.is_video_display = False
  53. obj.video = ""
  54. print("files", request.files)
  55. if request.method == 'POST' and 'video' in request.files:
  56. video_file = request.files['video']
  57. if checkFileType(video_file.filename):
  58. filename = secure_filename(video_file.filename)
  59. print("filename", filename)
  60. # save file to /static/uploads
  61. filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
  62. print("filepath", filepath)
  63. video_file.save(filepath)
  64. obj.video = filename
  65. obj.is_video_display = True
  66. return render_template('/index.html', obj=obj)
  67. else:
  68. if video_file.filename:
  69. msg = f"{video_file.filename} is not a video file"
  70. else:
  71. msg = "Please select a video file"
  72. flash(msg)
  73. return render_template('/index.html', obj=obj)
  74. return render_template('/index.html', obj=obj)
  75. @app.route('/sample', methods=['POST'])
  76. def sample():
  77. obj = DataObject
  78. obj.is_video_display = True
  79. obj.video = "sample_video.mp4"
  80. return render_template('/index.html', obj=obj)
  81. @app.route('/files/<filename>')
  82. def get_file(filename):
  83. return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
  84. @app.route('/analyzed_files/<filename>')
  85. def get_analyzed_file(filename):
  86. return send_from_directory(app.config['UPLOAD_FOLDER'], "res_{}".format(filename), as_attachment=True)
  87. @app.route('/result_video/<filename>')
  88. def get_result_video(filename):
  89. stream = stream_video("{}res_{}".format(
  90. app.config['UPLOAD_FOLDER'], filename))
  91. return Response(stream, mimetype='multipart/x-mixed-replace; boundary=frame')
  92. # route definition for video upload for analysis
  93. @app.route('/analyze/<filename>')
  94. def analyze(filename):
  95. # invokes method analyse_video
  96. return Response(analyse_video(pose_detector, lstm_classifier, filename), mimetype='text/event-stream')
  97. if __name__ == '__main__':
  98. app.run(debug=True, use_reloader=True)