import uuid import torch import fire from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline ) import requests from PIL import Image from io import BytesIO class Diffuser(object): default_backend = "cuda" def text_to_image(self, text, model = "runwayml/stable-diffusion-v1-5", ftype='f16', backend=default_backend, dst_path=None): if dst_path is None: dst_path = "image-{}.png".format(str(uuid.uuid4())) torch_dtype = torch.float16 if ftype == 'f16' else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype) pipeline.to(backend) image = pipeline(text).images[0] image.save(dst_path) def image_to_image(self, text, source_image, model = "runwayml/stable-diffusion-v1-5", ftype='f16', backend=default_backend, dst_path=None): if dst_path is None: dst_path = "image-{}.png".format(str(uuid.uuid4())) torch_dtype = torch.float16 if ftype == 'f16' else torch.float32 pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model, torch_dtype=torch_dtype) pipeline.to(backend) # if image is a url, download it if source_image.startswith("http"): response = requests.get(source_image) source_image = Image.open(BytesIO(response.content)).convert("RGB") source_image = source_image.resize((768, 512)) else: source_image = Image.open(source_image).convert("RGB") source_image = source_image.resize((768, 512)) image = pipeline(prompt=text, image=source_image, strength=0.75, guidance_scale=7.5).images[0] image.save(dst_path) if __name__ == "__main__": diffuser = Diffuser() fire.Fire(diffuser)