inference.py 950 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. import argparse
  3. from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
  4. parser = argparse.ArgumentParser()
  5. parser.add_argument(
  6. '--model',
  7. required=True
  8. )
  9. parser.add_argument(
  10. '--input',
  11. required=True
  12. )
  13. args = parser.parse_args()
  14. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  15. torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  16. model_id = args.model
  17. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  18. model_id, torch_dtype=torch_dtype,
  19. low_cpu_mem_usage=True,
  20. use_safetensors=True
  21. )
  22. model.to(device)
  23. processor = AutoProcessor.from_pretrained(model_id)
  24. pipe = pipeline(
  25. 'automatic-speech-recognition',
  26. model=model,
  27. tokenizer=processor.tokenizer,
  28. feature_extractor=processor.feature_extractor,
  29. torch_dtype=torch_dtype,
  30. device=device
  31. )
  32. result = pipe(args.input, generate_kwargs={'task': 'transcribe'})
  33. print('\n', result)