1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import uuid
- import torch
- import fire
- from diffusers import (
- StableDiffusionPipeline,
- StableDiffusionImg2ImgPipeline
- )
- import requests
- from PIL import Image
- from io import BytesIO
- class Diffuser(object):
- default_backend = "cuda"
- def text_to_image(self,
- text,
- model = "runwayml/stable-diffusion-v1-5",
- ftype='f16',
- backend=default_backend,
- dst_path=None):
-
- if dst_path is None:
- dst_path = "image-{}.png".format(str(uuid.uuid4()))
- torch_dtype = torch.float16 if ftype == 'f16' else torch.float32
- pipeline = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype)
- pipeline.to(backend)
- image = pipeline(text).images[0]
- image.save(dst_path)
-
- def image_to_image(self,
- text,
- source_image,
- model = "runwayml/stable-diffusion-v1-5",
- ftype='f16',
- backend=default_backend,
- dst_path=None):
-
- if dst_path is None:
- dst_path = "image-{}.png".format(str(uuid.uuid4()))
- torch_dtype = torch.float16 if ftype == 'f16' else torch.float32
- pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model, torch_dtype=torch_dtype)
- pipeline.to(backend)
-
- # if image is a url, download it
- if source_image.startswith("http"):
- response = requests.get(source_image)
- source_image = Image.open(BytesIO(response.content)).convert("RGB")
- source_image = source_image.resize((768, 512))
- else:
- source_image = Image.open(source_image).convert("RGB")
- source_image = source_image.resize((768, 512))
-
- image = pipeline(prompt=text, image=source_image, strength=0.75, guidance_scale=7.5).images[0]
- image.save(dst_path)
- if __name__ == "__main__":
- diffuser = Diffuser()
- fire.Fire(diffuser)
|