Skip to content

Commit

Permalink
fix threshold creation, doc, test
Browse files Browse the repository at this point in the history
Signed-off-by: Pan, Yujie <yujie.pan@intel.com>
  • Loading branch information
yujiepan-work committed Apr 27, 2023
1 parent 2e9599b commit d75524f
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 183 deletions.
44 changes: 40 additions & 4 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from torch import Tensor, nn
from torchmetrics import Metric

from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.post_processing import ThresholdMethod, ADAPTIVE_THRESHOLD_METHOD_MAP
from anomalib.post_processing import ADAPTIVE_THRESHOLD_METHOD_MAP, ThresholdMethod
from anomalib.utils.metrics import (
AnomalibMetricCollection,
AnomalyScoreDistribution,
AnomalyScoreGaussianMixtureThreshold,
AnomalyScoreThreshold,
MinMax,
)
Expand All @@ -45,9 +47,8 @@ def __init__(self) -> None:
self.callbacks: list[Callback]

self.threshold_method: ThresholdMethod
threshold_cls = ADAPTIVE_THRESHOLD_METHOD_MAP.get(self.threshold_method, AnomalyScoreThreshold)
self.image_threshold = threshold_cls().cpu()
self.pixel_threshold = threshold_cls().cpu()
self.image_threshold: AnomalyScoreThreshold
self.pixel_threshold: AnomalyScoreThreshold

self.normalization_metrics: Metric

Expand Down Expand Up @@ -243,3 +244,38 @@ def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = T
# Used to load missing normalization and threshold parameters
self._load_normalization_class(state_dict)
return super().load_state_dict(state_dict, strict=strict)

def configure_thresholds(self, threshold_config: DictConfig) -> list[AnomalyScoreThreshold]:
"""Configure image and pixel thresholds that determine the prediction labels given scores.
Args:
threshold_config (DictConfig): The configuration of threshold.
Returns:
The image threshold and pixel threshold.
"""
image_threshold: AnomalyScoreThreshold
pixel_threshold: AnomalyScoreThreshold
if threshold_config.method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_anomalous_rate = threshold_config.get(
"image_anomalous_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
pixel_anomalous_rate = threshold_config.get(
"pixel_anomalous_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
image_n_components = threshold_config.get(
"image_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = threshold_config.get(
"pixel_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
image_threshold = AnomalyScoreGaussianMixtureThreshold(
anomalous_rate=image_anomalous_rate, n_components=image_n_components
).cpu()
pixel_threshold = AnomalyScoreGaussianMixtureThreshold(
anomalous_rate=pixel_anomalous_rate, n_components=pixel_n_components
).cpu()
else:
image_threshold = AnomalyScoreThreshold().cpu()
pixel_threshold = AnomalyScoreThreshold().cpu()
return [image_threshold, pixel_threshold]
3 changes: 3 additions & 0 deletions src/anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
input_size: tuple[int, int],
backbone: str,
layers: list[str],
threshold: DictConfig,
pre_trained: bool = True,
coreset_sampling_ratio: float = 0.1,
num_neighbors: int = 9,
Expand All @@ -54,6 +55,7 @@ def __init__(
)
self.coreset_sampling_ratio = coreset_sampling_ratio
self.embeddings: list[Tensor] = []
self.image_threshold, self.pixel_threshold = self.configure_thresholds(threshold)

def configure_optimizers(self) -> None:
"""Configure optimizers.
Expand Down Expand Up @@ -128,6 +130,7 @@ def __init__(self, hparams) -> None:
pre_trained=hparams.model.pre_trained,
coreset_sampling_ratio=hparams.model.coreset_sampling_ratio,
num_neighbors=hparams.model.num_neighbors,
threshold=hparams.metrics.threshold,
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
2 changes: 1 addition & 1 deletion src/anomalib/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from .normalization import NormalizationMethod
from .post_process import (
ADAPTIVE_THRESHOLD_METHOD_MAP,
ThresholdMethod,
add_anomalous_label,
add_normal_label,
anomaly_map_to_color_map,
compute_mask,
superimpose_anomaly_map,
ADAPTIVE_THRESHOLD_METHOD_MAP,
)
from .visualizer import ImageResult, Visualizer

Expand Down
5 changes: 2 additions & 3 deletions src/anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import numpy as np
from skimage import morphology

from anomalib.utils.metrics import (AnomalyScoreThreshold,
GaussianMixtureThresholdEstimator)
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold


class ThresholdMethod(str, Enum):
Expand All @@ -27,7 +26,7 @@ class ThresholdMethod(str, Enum):

ADAPTIVE_THRESHOLD_METHOD_MAP = {
ThresholdMethod.ADAPTIVE: AnomalyScoreThreshold,
ThresholdMethod.GAUSSIAN_MIXTURE: GaussianMixtureThresholdEstimator,
ThresholdMethod.GAUSSIAN_MIXTURE: AnomalyScoreGaussianMixtureThreshold,
}


Expand Down
30 changes: 17 additions & 13 deletions src/anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from anomalib.deploy import ExportMode
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

from .cdf_normalization import CdfNormalizationCallback
from .graph import GraphLogger
Expand Down Expand Up @@ -83,28 +83,32 @@ def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
)
# For Gaussian Mixture Estimation of threshold.
image_positive_rate = (
config.metrics.threshold.image_positive_rate if "image_positive_rate" in config.metrics.threshold.keys() \
else GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
image_anomalous_rate = (
config.metrics.threshold.image_anomalous_rate
if "image_anomalous_rate" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
pixel_positive_rate = (
config.metrics.threshold.pixel_positive_rate if "pixel_positive_rate" in config.metrics.threshold.keys() \
else GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
pixel_anomalous_rate = (
config.metrics.threshold.pixel_anomalous_rate
if "pixel_anomalous_rate" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE
)
image_n_components = (
config.metrics.threshold.image_n_components if "image_n_components" in config.metrics.threshold.keys() \
else GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
config.metrics.threshold.image_n_components
if "image_n_components" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = (
config.metrics.threshold.pixel_n_components if "pixel_n_components" in config.metrics.threshold.keys() \
else GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
config.metrics.threshold.pixel_n_components
if "pixel_n_components" in config.metrics.threshold.keys()
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
post_processing_callback = PostProcessingConfigurationCallback(
threshold_method=config.metrics.threshold.method,
manual_image_threshold=image_threshold,
manual_pixel_threshold=pixel_threshold,
image_positive_rate=image_positive_rate,
pixel_positive_rate=pixel_positive_rate,
image_anomalous_rate=image_anomalous_rate,
pixel_anomalous_rate=pixel_anomalous_rate,
image_n_components=image_n_components,
pixel_n_components=pixel_n_components,
)
Expand Down
32 changes: 16 additions & 16 deletions src/anomalib/utils/callbacks/post_processing_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from pytorch_lightning import Callback, LightningModule, Trainer

from anomalib.models.components.base.anomaly_module import AnomalyModule
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.post_processing import NormalizationMethod, ThresholdMethod, ADAPTIVE_THRESHOLD_METHOD_MAP
from anomalib.post_processing import ADAPTIVE_THRESHOLD_METHOD_MAP, NormalizationMethod, ThresholdMethod
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger(__name__)

Expand All @@ -28,10 +28,10 @@ class PostProcessingConfigurationCallback(Callback):
threshold_method (ThresholdMethod): Flag indicating whether threshold should be manual or adaptive.
manual_image_threshold (float | None): Default manual image threshold value.
manual_pixel_threshold (float | None): Default manual pixel threshold value.
image_positive_rate (float | None): TODO(yujie),
pixel_positive_rate (float | None):
image_n_components (float | None):
pixel_n_components (float | None):
image_anomalous_rate (float | None): Anticipated image anomalous rate for Gaussian Mixture based threshold.
pixel_anomalous_rate (float | None): Anticipated pixel anomalous rate for Gaussian Mixture based threshold.
image_n_components (int): Number of mixture components of Gaussian Mixture based threshold for images.
pixel_n_components (int): Number of mixture components of Gaussian Mixture based threshold for pixels.
"""

def __init__(
Expand All @@ -40,10 +40,10 @@ def __init__(
threshold_method: ThresholdMethod = ThresholdMethod.ADAPTIVE,
manual_image_threshold: float | None = None,
manual_pixel_threshold: float | None = None,
image_positive_rate: float | None = GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
pixel_positive_rate: float | None = GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
image_n_components: int = GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
pixel_n_components: int = GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
image_anomalous_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
pixel_anomalous_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
image_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
pixel_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
) -> None:
super().__init__()
self.normalization_method = normalization_method
Expand All @@ -67,8 +67,8 @@ def __init__(
self.threshold_method = threshold_method
self.manual_image_threshold = manual_image_threshold
self.manual_pixel_threshold = manual_pixel_threshold
self.image_positive_rate = image_positive_rate
self.pixel_positive_rate = pixel_positive_rate
self.image_anomalous_rate = image_anomalous_rate
self.pixel_anomalous_rate = pixel_anomalous_rate
self.image_n_components = image_n_components
self.pixel_n_components = pixel_n_components

Expand All @@ -88,9 +88,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None
pl_module.image_threshold.value = torch.tensor(self.manual_image_threshold).cpu()
pl_module.pixel_threshold.value = torch.tensor(self.manual_pixel_threshold).cpu()
if pl_module.threshold_method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_threshold: GaussianMixtureThresholdEstimator = pl_module.image_threshold
pixel_threshold: GaussianMixtureThresholdEstimator = pl_module.pixel_threshold
image_threshold.positive_rate = self.image_positive_rate
pixel_threshold.positive_rate = self.pixel_positive_rate
image_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.image_threshold
pixel_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.pixel_threshold
image_threshold.anomalous_rate = self.image_anomalous_rate
pixel_threshold.anomalous_rate = self.pixel_anomalous_rate
image_threshold.n_components = self.image_n_components
pixel_threshold.n_components = self.pixel_n_components
10 changes: 5 additions & 5 deletions src/anomalib/utils/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
add_visualizer_callback,
)
from anomalib.utils.loggers import configure_logger
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger("anomalib.cli")

Expand Down Expand Up @@ -69,10 +69,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"post_processing.threshold_method": "adaptive",
"post_processing.manual_image_threshold": None,
"post_processing.manual_pixel_threshold": None,
"post_processing.image_positive_rate": GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
"post_processing.pixel_positive_rate": GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
"post_processing.image_n_components": GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
"post_processing.pixel_n_components": GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
"post_processing.image_anomalous_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
"post_processing.pixel_anomalous_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_ANOMALOUS_RATE,
"post_processing.image_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
"post_processing.pixel_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
}
)

Expand Down
16 changes: 12 additions & 4 deletions src/anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from omegaconf import DictConfig, ListConfig

from .anomaly_score_distribution import AnomalyScoreDistribution
from .anomaly_score_threshold import AnomalyScoreThreshold
from .anomaly_score_threshold_estimator import GaussianMixtureThresholdEstimator
from .anomaly_score_threshold import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold
from .aupr import AUPR
from .aupro import AUPRO
from .auroc import AUROC
Expand All @@ -23,8 +22,17 @@
from .optimal_f1 import OptimalF1
from .pro import PRO

__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AnomalyScoreDistribution", "MinMax", "PRO",
"AnomalyScoreThreshold", "GaussianMixtureThresholdEstimator"]
__all__ = [
"AUROC",
"AUPR",
"AUPRO",
"OptimalF1",
"AnomalyScoreDistribution",
"MinMax",
"PRO",
"AnomalyScoreThreshold",
"AnomalyScoreGaussianMixtureThreshold",
]


def metric_collection_from_names(metric_names: list[str], prefix: str | None) -> AnomalibMetricCollection:
Expand Down

0 comments on commit d75524f

Please sign in to comment.