compare_time.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. """
  2. Script to compare time for fine-tuned Whisper models.
  3. """
  4. import torch
  5. import time
  6. import os
  7. from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
  8. model_dirs = [
  9. 'whisper_tiny_atco2_v2/best_model',
  10. 'whisper_base_atco2/best_model',
  11. 'whisper_small_atco2/best_model'
  12. ]
  13. input_dir = 'inference_data'
  14. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  15. torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  16. for model_id in model_dirs:
  17. print(f"\nEvaluating model: {model_id}")
  18. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  19. model_id, torch_dtype=torch_dtype,
  20. low_cpu_mem_usage=True,
  21. use_safetensors=True
  22. )
  23. model.to(device)
  24. processor = AutoProcessor.from_pretrained(model_id)
  25. pipe = pipeline(
  26. 'automatic-speech-recognition',
  27. model=model,
  28. tokenizer=processor.tokenizer,
  29. feature_extractor=processor.feature_extractor,
  30. torch_dtype=torch_dtype,
  31. device=device
  32. )
  33. total_time = 0
  34. num_runs = 0
  35. for _ in range(10):
  36. for filename in os.listdir(input_dir):
  37. if filename.endswith('.wav'):
  38. start_time = time.time()
  39. result = pipe(os.path.join(input_dir, filename))
  40. end_time = time.time()
  41. total_time += (end_time - start_time)
  42. num_runs += 1
  43. average_time = total_time / num_runs
  44. print(f"\nAverage time taken for {model_id}: {average_time} seconds")