CompVis/stable-diffusion-v1-4

CompVis
Texto a imagen

Stable Diffusion es un modelo de difusión de texto a imagen latente capaz de generar imágenes fotorrealistas a partir de cualquier entrada de texto. Este modelo puede ser utilizado para generar y modificar imágenes según las indicaciones textuales. Es un Modelo de Difusión Latente que utiliza un codificador de texto fijo y preentrenado (CLIP ViT-L/14) como se sugiere en el artículo de Imagen.

Como usar

Recomendamos usar la biblioteca Diffusers de 🤗 para ejecutar Stable Diffusion.

PyTorch

pip install --upgrade diffusers transformers scipy

Ejecutar el pipeline con el programador PNDM predeterminado:

import torch
from diffusers import StableDiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

Nota: Si tienes una memoria GPU limitada y tienes menos de 4GB de RAM disponibles, asegúrate de cargar el StableDiffusionPipeline en precisión float16 en lugar de la precisión predeterminada float32 como se hace arriba. Puedes hacerlo diciendo a los difusores que esperen los pesos en precisión float16:

import torch

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

Para intercambiar el programador de ruido, pásalo a from_pretrained:

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

model_id = "CompVis/stable-diffusion-v1-4"

# Usa el programador Euler aquí en su lugar
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

JAX/Flax

Para usar StableDiffusion en TPUs y GPUs para inferencia más rápida puedes aprovechar JAX/Flax. Ejecutar el pipeline con el programador PNDM predeterminado:

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

Nota: Si tienes una memoria TPU limitada, asegúrate de cargar el FlaxStableDiffusionPipeline en precisión bfloat16 en lugar de la precisión predeterminada float32 como se hace arriba. Puedes hacerlo diciendo a los difusores que carguen los pesos desde la rama "bf16":

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

Funcionalidades

Generación de imágenes fotorrealistas a partir de indicaciones textuales
Capacidad de modificar imágenes según indicaciones textuales
Uso del codificador de texto CLIP ViT-L/14
Compatible con la biblioteca Diffusers
Incluye múltiples puntos de control entrenados

Casos de uso

Despliegue seguro de modelos que tienen el potencial de generar contenido dañino.
Investigación y comprensión de las limitaciones y sesgos de los modelos generativos.
Generación de obras de arte y uso en diseño y otros procesos artísticos.
Aplicaciones en herramientas educativas o creativas.
Investigación sobre modelos generativos.

Recibe las últimas noticias y actualizaciones sobre el mundo de IA directamente en tu bandeja de entrada.