CompVis/stable-diffusion-v1-4
Stable Diffusion es un modelo de difusión de texto a imagen latente capaz de generar imágenes fotorrealistas a partir de cualquier entrada de texto. Este modelo puede ser utilizado para generar y modificar imágenes según las indicaciones textuales. Es un Modelo de Difusión Latente que utiliza un codificador de texto fijo y preentrenado (CLIP ViT-L/14) como se sugiere en el artículo de Imagen.
Como usar
Recomendamos usar la biblioteca Diffusers de 🤗 para ejecutar Stable Diffusion.
PyTorch
pip install --upgrade diffusers transformers scipy
Ejecutar el pipeline con el programador PNDM predeterminado:
import torch
from diffusers import StableDiffusionPipeline
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
Nota: Si tienes una memoria GPU limitada y tienes menos de 4GB de RAM disponibles, asegúrate de cargar el StableDiffusionPipeline en precisión float16 en lugar de la precisión predeterminada float32 como se hace arriba. Puedes hacerlo diciendo a los difusores que esperen los pesos en precisión float16:
import torch
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
Para intercambiar el programador de ruido, pásalo a from_pretrained
:
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
model_id = "CompVis/stable-diffusion-v1-4"
# Usa el programador Euler aquí en su lugar
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
JAX/Flax
Para usar StableDiffusion en TPUs y GPUs para inferencia más rápida puedes aprovechar JAX/Flax.
Ejecutar el pipeline con el programador PNDM predeterminado:
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)
prompt = "a photo of an astronaut riding a horse on mars"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Nota: Si tienes una memoria TPU limitada, asegúrate de cargar el FlaxStableDiffusionPipeline en precisión bfloat16 en lugar de la precisión predeterminada float32 como se hace arriba. Puedes hacerlo diciendo a los difusores que carguen los pesos desde la rama "bf16":
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)
prompt = "a photo of an astronaut riding a horse on mars"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Funcionalidades
- Generación de imágenes fotorrealistas a partir de indicaciones textuales
- Capacidad de modificar imágenes según indicaciones textuales
- Uso del codificador de texto CLIP ViT-L/14
- Compatible con la biblioteca Diffusers
- Incluye múltiples puntos de control entrenados
Casos de uso
- Despliegue seguro de modelos que tienen el potencial de generar contenido dañino.
- Investigación y comprensión de las limitaciones y sesgos de los modelos generativos.
- Generación de obras de arte y uso en diseño y otros procesos artísticos.
- Aplicaciones en herramientas educativas o creativas.
- Investigación sobre modelos generativos.