Skip to content

Commit

Permalink
feat(text-to-image): integrate NSFW safety checker (#76)
Browse files Browse the repository at this point in the history
* feat(pipelines): add optional NSFW safety check to T2I and I2I pipelines

This commit incorporates the CompVis/stable-diffusion-safety-checker
into the text-to-image and image-to-image pipelines. By enabling the
`safety_check` input variable, users get notified of the generation
of NSFW images.

* refactor(runner): enable safety checker by default

This commit enables the safety checker by default. For more information
about this decision see
#76 (comment).
  • Loading branch information
rickstaa committed May 15, 2024
1 parent 8bd32b0 commit 272ac74
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 52 deletions.
22 changes: 18 additions & 4 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker

from diffusers import (
AutoPipelineForImage2Image,
Expand All @@ -11,7 +11,7 @@
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List
from typing import List, Tuple
import logging
import os

Expand Down Expand Up @@ -128,7 +128,14 @@ def __init__(self, model_id: str):

self.ldm = enable_deepcache(self.ldm)

def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[bool]]:
safety_check = kwargs.pop("safety_check", True)

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand Down Expand Up @@ -170,7 +177,14 @@ def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
# Default to 2step
kwargs["num_inference_steps"] = 2

return self.ldm(prompt, image=image, **kwargs).images
output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"
37 changes: 26 additions & 11 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
import logging
import os
from typing import List, Tuple, Optional

import PIL
import torch
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List
import logging
import os
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,7 +148,14 @@ def __init__(self, model_id: str):

self.ldm = enable_deepcache(self.ldm)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
safety_check = kwargs.pop("safety_check", True)

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand Down Expand Up @@ -184,7 +192,14 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
# Default to 2step
kwargs["num_inference_steps"] = 2

return self.ldm(prompt, **kwargs).images
output = self.ldm(prompt, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
88 changes: 88 additions & 0 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import torch
import os
import numpy as np
from torch import dtype as TorchDtype
from pathlib import Path
from PIL import Image
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from typing import Optional
import logging

logger = logging.getLogger(__name__)


def get_model_dir() -> Path:
Expand All @@ -18,3 +27,82 @@ def get_torch_device():
return torch.device("mps")
else:
return torch.device("cpu")


def validate_torch_device(device_name: str) -> bool:
"""Checks if the given PyTorch device name is valid and available.
Args:
device_name: Name of the device ('cuda:0', 'cuda', 'cpu').
Returns:
True if valid and available, False otherwise.
"""
try:
device = torch.device(device_name)
if device.type == "cuda":
# Check if CUDA is available and the specified index is within range
if device.index is None:
return torch.cuda.is_available()
else:
return device.index < torch.cuda.device_count()
return True
except RuntimeError:
return False


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Attributes:
device (str): Device for inference.
"""

def __init__(
self,
device: Optional[str] = "cuda",
dtype: Optional[TorchDtype] = torch.float16,
):
"""Initializes the SafetyChecker.
Args:
device: Device for inference. Defaults to "cuda".
dtype: Data type for inference. Defaults to `torch.float16`.
"""
device = device.lower() if device else device
if not validate_torch_device(device):
default_device = get_torch_device()
logger.warning(
f"Device '{device}' not found. Defaulting to '{default_device}'."
)
device = default_device

self.device = device
self._dtype = dtype
self._safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(self.device)
self._feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)

def check_nsfw_images(
self, images: list[Image.Image]
) -> tuple[list[Image.Image], list[bool]]:
"""Checks images for unsafe content.
Args:
images: Images to check.
Returns:
Tuple of images and corresponding NSFW flags.
"""
safety_checker_input = self._feature_extractor(images, return_tensors="pt").to(
self.device
)
images_np = [np.array(img) for img in images]
_, has_nsfw_concept = self._safety_checker(
images=images_np,
clip_input=safety_checker_input.pixel_values.to(self._dtype),
)
return images, has_nsfw_concept
15 changes: 11 additions & 4 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


# TODO: Make model_id and other properties optional once Go codegen tool supports
# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-image", response_model=ImageResponse, responses=responses)
@router.post(
Expand All @@ -37,6 +37,7 @@ async def image_to_image(
strength: Annotated[float, Form()] = 0.8,
guidance_scale: Annotated[float, Form()] = 7.5,
negative_prompt: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = True,
seed: Annotated[int, Form()] = None,
num_images_per_prompt: Annotated[int, Form()] = 1,
pipeline: Pipeline = Depends(get_pipeline),
Expand Down Expand Up @@ -76,12 +77,13 @@ async def image_to_image(
image = img

try:
images = pipeline(
images, has_nsfw_concept = pipeline(
prompt=prompt,
image=image,
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
num_images_per_prompt=num_images_per_prompt,
)
Expand All @@ -97,7 +99,12 @@ async def image_to_image(
seeds = [seeds]

output_images = []
for img, s in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": s})
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)

return {"images": output_images}
2 changes: 1 addition & 1 deletion runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


# TODO: Make model_id and other properties optional once Go codegen tool supports
# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-video", response_model=VideoResponse, responses=responses)
@router.post(
Expand Down
16 changes: 11 additions & 5 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@


class TextToImageParams(BaseModel):
# TODO: Make model_id and other properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
# TODO: Make model_id and other None properties optional once Go codegen tool
# supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
model_id: str = ""
prompt: str
height: int = None
width: int = None
guidance_scale: float = 7.5
negative_prompt: str = ""
safety_check: bool = True
seed: int = None
num_inference_steps: int = 50 # TODO: Make optional.
num_images_per_prompt: int = 1
Expand Down Expand Up @@ -64,7 +65,7 @@ async def text_to_image(
]

try:
images = pipeline(**params.model_dump())
images, has_nsfw_concept = pipeline(**params.model_dump())
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
Expand All @@ -77,7 +78,12 @@ async def text_to_image(
seeds = [seeds]

output_images = []
for img, sd in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": sd})
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)

return {"images": output_images}
8 changes: 6 additions & 2 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import base64
import io
from typing import List

from PIL import Image
import base64
from pydantic import BaseModel
from typing import List


class Media(BaseModel):
url: str
seed: int
# TODO: Make nsfw property optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
nsfw: bool


class ImageResponse(BaseModel):
Expand Down
11 changes: 7 additions & 4 deletions runner/bench.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import argparse
import os
from time import time
from typing import List

import numpy as np
import torch
from PIL import Image
from app.main import load_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.text_to_image import TextToImagePipeline
from app.pipelines.image_to_image import ImageToImagePipeline
from app.pipelines.image_to_video import ImageToVideoPipeline
from app.pipelines.text_to_image import TextToImagePipeline
from PIL import Image
from pydantic import BaseModel
import os
import numpy as np

PROMPT = "a mountain lion"
IMAGE = "images/test.png"
Expand Down Expand Up @@ -47,6 +48,8 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
for i in range(runs):
start = time()
output = call_pipeline(pipeline, batch_size)
if isinstance(output, tuple):
output = output[0]
assert len(output) == batch_size

inference_time[i] = time() - start
Expand Down

0 comments on commit 272ac74

Please sign in to comment.