diffuser.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import uuid
  2. import torch
  3. import fire
  4. from diffusers import (
  5. StableDiffusionPipeline,
  6. StableDiffusionImg2ImgPipeline
  7. )
  8. import requests
  9. from PIL import Image
  10. from io import BytesIO
  11. class Diffuser(object):
  12. default_backend = "cuda"
  13. def text_to_image(self,
  14. text,
  15. model = "runwayml/stable-diffusion-v1-5",
  16. ftype='f16',
  17. backend=default_backend,
  18. dst_path=None):
  19. if dst_path is None:
  20. dst_path = "image-{}.png".format(str(uuid.uuid4()))
  21. torch_dtype = torch.float16 if ftype == 'f16' else torch.float32
  22. pipeline = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype)
  23. pipeline.to(backend)
  24. image = pipeline(text).images[0]
  25. image.save(dst_path)
  26. def image_to_image(self,
  27. text,
  28. source_image,
  29. model = "runwayml/stable-diffusion-v1-5",
  30. ftype='f16',
  31. backend=default_backend,
  32. dst_path=None):
  33. if dst_path is None:
  34. dst_path = "image-{}.png".format(str(uuid.uuid4()))
  35. torch_dtype = torch.float16 if ftype == 'f16' else torch.float32
  36. pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model, torch_dtype=torch_dtype)
  37. pipeline.to(backend)
  38. # if image is a url, download it
  39. if source_image.startswith("http"):
  40. response = requests.get(source_image)
  41. source_image = Image.open(BytesIO(response.content)).convert("RGB")
  42. source_image = source_image.resize((768, 512))
  43. else:
  44. source_image = Image.open(source_image).convert("RGB")
  45. source_image = source_image.resize((768, 512))
  46. image = pipeline(prompt=text, image=source_image, strength=0.75, guidance_scale=7.5).images[0]
  47. image.save(dst_path)
  48. if __name__ == "__main__":
  49. diffuser = Diffuser()
  50. fire.Fire(diffuser)