Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gaussian Mixture based adaptive threshold #1051

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/anomalib/models/cfa/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[Callback]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Expand Down
45 changes: 41 additions & 4 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from torch import Tensor, nn
from torchmetrics import Metric

from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.post_processing import ThresholdMethod
from anomalib.post_processing import ADAPTIVE_THRESHOLD_METHOD_MAP, ThresholdMethod
from anomalib.utils.metrics import (
AnomalibMetricCollection,
AnomalyScoreDistribution,
AnomalyScoreGaussianMixtureThreshold,
AnomalyScoreThreshold,
MinMax,
)
Expand All @@ -45,8 +47,8 @@ def __init__(self) -> None:
self.callbacks: list[Callback]

self.threshold_method: ThresholdMethod
self.image_threshold = AnomalyScoreThreshold().cpu()
self.pixel_threshold = AnomalyScoreThreshold().cpu()
self.image_threshold: AnomalyScoreThreshold
self.pixel_threshold: AnomalyScoreThreshold

self.normalization_metrics: Metric

Expand Down Expand Up @@ -141,7 +143,7 @@ def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
Args:
outputs: Batch of outputs from the validation step
"""
if self.threshold_method == ThresholdMethod.ADAPTIVE:
if self.threshold_method in ADAPTIVE_THRESHOLD_METHOD_MAP:
self._compute_adaptive_threshold(outputs)
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
Expand Down Expand Up @@ -242,3 +244,38 @@ def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = T
# Used to load missing normalization and threshold parameters
self._load_normalization_class(state_dict)
return super().load_state_dict(state_dict, strict=strict)

def configure_thresholds(self, threshold_config: DictConfig) -> list[AnomalyScoreThreshold]:
"""Configure image and pixel thresholds that determine the prediction labels given scores.

Args:
threshold_config (DictConfig): The configuration of threshold.

Returns:
The image threshold and pixel threshold.
"""
image_threshold: AnomalyScoreThreshold
pixel_threshold: AnomalyScoreThreshold
if threshold_config.method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_anomalous_rate = threshold_config.get(
"image_anomalous_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
pixel_anomalous_rate = threshold_config.get(
"pixel_anomalous_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
image_n_components = threshold_config.get(
"image_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = threshold_config.get(
"pixel_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
image_threshold = AnomalyScoreGaussianMixtureThreshold(
anomalous_rate=image_anomalous_rate, n_components=image_n_components
).cpu()
pixel_threshold = AnomalyScoreGaussianMixtureThreshold(
anomalous_rate=pixel_anomalous_rate, n_components=pixel_n_components
).cpu()
else:
image_threshold = AnomalyScoreThreshold().cpu()
pixel_threshold = AnomalyScoreThreshold().cpu()
return [image_threshold, pixel_threshold]
1 change: 1 addition & 0 deletions src/anomalib/models/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[Callback]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)
1 change: 1 addition & 0 deletions src/anomalib/models/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)
1 change: 1 addition & 0 deletions src/anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/fastflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/ganomaly/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[Callback]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)
1 change: 1 addition & 0 deletions src/anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ def __init__(self, hparams) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Expand Down
1 change: 1 addition & 0 deletions src/anomalib/models/rkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)
1 change: 1 addition & 0 deletions src/anomalib/models/stfpm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
self.image_threshold, self.pixel_threshold = self.configure_thresholds(hparams.metrics.threshold)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .normalization import NormalizationMethod
from .post_process import (
ADAPTIVE_THRESHOLD_METHOD_MAP,
ThresholdMethod,
add_anomalous_label,
add_normal_label,
Expand All @@ -24,4 +25,5 @@
"NormalizationMethod",
"Visualizer",
"ThresholdMethod",
"ADAPTIVE_THRESHOLD_METHOD_MAP",
]
9 changes: 9 additions & 0 deletions src/anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@
import numpy as np
from skimage import morphology

from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold


class ThresholdMethod(str, Enum):
"""Threshold method to apply post-processing to the output predictions."""

ADAPTIVE = "adaptive"
GAUSSIAN_MIXTURE = "gaussian_mixture"
MANUAL = "manual"


ADAPTIVE_THRESHOLD_METHOD_MAP = {
ThresholdMethod.ADAPTIVE: AnomalyScoreThreshold,
ThresholdMethod.GAUSSIAN_MIXTURE: AnomalyScoreGaussianMixtureThreshold,
}


def add_label(
image: np.ndarray,
label_name: str,
Expand Down
26 changes: 26 additions & 0 deletions src/anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from anomalib.deploy import ExportMode
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

from .cdf_normalization import CdfNormalizationCallback
from .graph import GraphLogger
Expand Down Expand Up @@ -81,10 +82,35 @@ def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
pixel_threshold = (
config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
)
# For Gaussian Mixture Estimation of threshold.
image_anomalous_rate = (
config.metrics.threshold.image_anomalous_rate
if "image_anomalous_rate" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
pixel_anomalous_rate = (
config.metrics.threshold.pixel_anomalous_rate
if "pixel_anomalous_rate" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
image_n_components = (
config.metrics.threshold.image_n_components
if "image_n_components" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = (
config.metrics.threshold.pixel_n_components
if "pixel_n_components" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
post_processing_callback = PostProcessingConfigurationCallback(
threshold_method=config.metrics.threshold.method,
manual_image_threshold=image_threshold,
manual_pixel_threshold=pixel_threshold,
image_anomalous_rate=image_anomalous_rate,
pixel_anomalous_rate=pixel_anomalous_rate,
image_n_components=image_n_components,
pixel_n_components=pixel_n_components,
)
callbacks.append(post_processing_callback)

Expand Down
24 changes: 22 additions & 2 deletions src/anomalib/utils/callbacks/post_processing_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from pytorch_lightning import Callback, LightningModule, Trainer

from anomalib.models.components.base.anomaly_module import AnomalyModule
from anomalib.post_processing import NormalizationMethod, ThresholdMethod
from anomalib.post_processing import ADAPTIVE_THRESHOLD_METHOD_MAP, NormalizationMethod, ThresholdMethod
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,10 @@ class PostProcessingConfigurationCallback(Callback):
threshold_method (ThresholdMethod): Flag indicating whether threshold should be manual or adaptive.
manual_image_threshold (float | None): Default manual image threshold value.
manual_pixel_threshold (float | None): Default manual pixel threshold value.
image_anomalous_rate (float | None): Anticipated image anomalous rate for Gaussian Mixture based threshold.
pixel_anomalous_rate (float | None): Anticipated pixel anomalous rate for Gaussian Mixture based threshold.
image_n_components (int): Number of mixture components of Gaussian Mixture based threshold for images.
pixel_n_components (int): Number of mixture components of Gaussian Mixture based threshold for pixels.
"""

def __init__(
Expand All @@ -35,11 +40,15 @@ def __init__(
threshold_method: ThresholdMethod = ThresholdMethod.ADAPTIVE,
manual_image_threshold: float | None = None,
manual_pixel_threshold: float | None = None,
image_anomalous_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
pixel_anomalous_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
image_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
pixel_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
) -> None:
super().__init__()
self.normalization_method = normalization_method

if threshold_method == ThresholdMethod.ADAPTIVE and all(
if threshold_method in ADAPTIVE_THRESHOLD_METHOD_MAP and all(
i is not None for i in (manual_image_threshold, manual_pixel_threshold)
):
raise ValueError(
Expand All @@ -58,6 +67,10 @@ def __init__(
self.threshold_method = threshold_method
self.manual_image_threshold = manual_image_threshold
self.manual_pixel_threshold = manual_pixel_threshold
self.image_anomalous_rate = image_anomalous_rate
self.pixel_anomalous_rate = pixel_anomalous_rate
self.image_n_components = image_n_components
self.pixel_n_components = pixel_n_components

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None = None) -> None:
"""Setup post-processing configuration within Anomalib Model.
Expand All @@ -74,3 +87,10 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None
if pl_module.threshold_method == ThresholdMethod.MANUAL:
pl_module.image_threshold.value = torch.tensor(self.manual_image_threshold).cpu()
pl_module.pixel_threshold.value = torch.tensor(self.manual_pixel_threshold).cpu()
if pl_module.threshold_method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.image_threshold
pixel_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.pixel_threshold
image_threshold.anomalous_rate = self.image_anomalous_rate
pixel_threshold.anomalous_rate = self.pixel_anomalous_rate
image_threshold.n_components = self.image_n_components
pixel_threshold.n_components = self.pixel_n_components
5 changes: 5 additions & 0 deletions src/anomalib/utils/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
add_visualizer_callback,
)
from anomalib.utils.loggers import configure_logger
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger("anomalib.cli")

Expand Down Expand Up @@ -68,6 +69,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"post_processing.threshold_method": "adaptive",
"post_processing.manual_image_threshold": None,
"post_processing.manual_pixel_threshold": None,
"post_processing.image_anomalous_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
"post_processing.pixel_anomalous_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
"post_processing.image_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
"post_processing.pixel_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
}
)

Expand Down
14 changes: 12 additions & 2 deletions src/anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from omegaconf import DictConfig, ListConfig

from .anomaly_score_distribution import AnomalyScoreDistribution
from .anomaly_score_threshold import AnomalyScoreThreshold
from .anomaly_score_threshold import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold
from .aupr import AUPR
from .aupro import AUPRO
from .auroc import AUROC
Expand All @@ -22,7 +22,17 @@
from .optimal_f1 import OptimalF1
from .pro import PRO

__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AnomalyScoreThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]
__all__ = [
"AUROC",
"AUPR",
"AUPRO",
"OptimalF1",
"AnomalyScoreDistribution",
"MinMax",
"PRO",
"AnomalyScoreThreshold",
"AnomalyScoreGaussianMixtureThreshold",
]


def metric_collection_from_names(metric_names: list[str], prefix: str | None) -> AnomalibMetricCollection:
Expand Down