Skip to content

Commit

Permalink
move fmm into AnomalyScoreThreshold file
Browse files Browse the repository at this point in the history
  • Loading branch information
yujiepan-work committed Apr 27, 2023
1 parent 0d31112 commit 6c3b47d
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 147 deletions.
14 changes: 7 additions & 7 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from anomalib.utils.metrics import (
AnomalibMetricCollection,
AnomalyScoreDistribution,
AnomalyScoreGaussianMixtureThreshold,
AnomalyScoreThreshold,
GaussianMixtureThresholdEstimator,
MinMax,
)

Expand Down Expand Up @@ -251,21 +251,21 @@ def configure_thresholds(self, threshold_config: DictConfig) -> list[AnomalyScor
pixel_threshold: AnomalyScoreThreshold
if threshold_config.method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_positive_rate = threshold_config.get(
"image_positive_rate", GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
"image_positive_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE
)
pixel_positive_rate = threshold_config.get(
"pixel_positive_rate", GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
"pixel_positive_rate", AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE
)
image_n_components = threshold_config.get(
"image_n_components", GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
"image_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = threshold_config.get(
"pixel_n_components", GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
"pixel_n_components", AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
image_threshold = GaussianMixtureThresholdEstimator(
image_threshold = AnomalyScoreGaussianMixtureThreshold(
positive_rate=image_positive_rate, n_components=image_n_components
).cpu()
pixel_threshold = GaussianMixtureThresholdEstimator(
pixel_threshold = AnomalyScoreGaussianMixtureThreshold(
positive_rate=pixel_positive_rate, n_components=pixel_n_components
).cpu()
else:
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from skimage import morphology

from anomalib.utils.metrics import AnomalyScoreThreshold, GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold


class ThresholdMethod(str, Enum):
Expand All @@ -26,7 +26,7 @@ class ThresholdMethod(str, Enum):

ADAPTIVE_THRESHOLD_METHOD_MAP = {
ThresholdMethod.ADAPTIVE: AnomalyScoreThreshold,
ThresholdMethod.GAUSSIAN_MIXTURE: GaussianMixtureThresholdEstimator,
ThresholdMethod.GAUSSIAN_MIXTURE: AnomalyScoreGaussianMixtureThreshold,
}


Expand Down
10 changes: 5 additions & 5 deletions src/anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from anomalib.deploy import ExportMode
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

from .cdf_normalization import CdfNormalizationCallback
from .graph import GraphLogger
Expand Down Expand Up @@ -86,22 +86,22 @@ def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
image_positive_rate = (
config.metrics.threshold.image_positive_rate
if "image_positive_rate" in config.metrics.threshold.keys()
else GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE
)
pixel_positive_rate = (
config.metrics.threshold.pixel_positive_rate
if "pixel_positive_rate" in config.metrics.threshold.keys()
else GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE
)
image_n_components = (
config.metrics.threshold.image_n_components
if "image_n_components" in config.metrics.threshold.keys()
else GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
pixel_n_components = (
config.metrics.threshold.pixel_n_components
if "pixel_n_components" in config.metrics.threshold.keys()
else GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS
else AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS
)
post_processing_callback = PostProcessingConfigurationCallback(
threshold_method=config.metrics.threshold.method,
Expand Down
14 changes: 7 additions & 7 deletions src/anomalib/utils/callbacks/post_processing_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anomalib.models.components.base.anomaly_module import AnomalyModule
from anomalib.post_processing import ADAPTIVE_THRESHOLD_METHOD_MAP, NormalizationMethod, ThresholdMethod
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger(__name__)

Expand All @@ -40,10 +40,10 @@ def __init__(
threshold_method: ThresholdMethod = ThresholdMethod.ADAPTIVE,
manual_image_threshold: float | None = None,
manual_pixel_threshold: float | None = None,
image_positive_rate: float | None = GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
pixel_positive_rate: float | None = GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
image_n_components: int = GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
pixel_n_components: int = GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
image_positive_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE,
pixel_positive_rate: float | None = AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE,
image_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
pixel_n_components: int = AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
) -> None:
super().__init__()
self.normalization_method = normalization_method
Expand Down Expand Up @@ -88,8 +88,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None
pl_module.image_threshold.value = torch.tensor(self.manual_image_threshold).cpu()
pl_module.pixel_threshold.value = torch.tensor(self.manual_pixel_threshold).cpu()
if pl_module.threshold_method == ThresholdMethod.GAUSSIAN_MIXTURE:
image_threshold: GaussianMixtureThresholdEstimator = pl_module.image_threshold
pixel_threshold: GaussianMixtureThresholdEstimator = pl_module.pixel_threshold
image_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.image_threshold
pixel_threshold: AnomalyScoreGaussianMixtureThreshold = pl_module.pixel_threshold
image_threshold.positive_rate = self.image_positive_rate
pixel_threshold.positive_rate = self.pixel_positive_rate
image_threshold.n_components = self.image_n_components
Expand Down
10 changes: 5 additions & 5 deletions src/anomalib/utils/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
add_visualizer_callback,
)
from anomalib.utils.loggers import configure_logger
from anomalib.utils.metrics import GaussianMixtureThresholdEstimator
from anomalib.utils.metrics import AnomalyScoreGaussianMixtureThreshold

logger = logging.getLogger("anomalib.cli")

Expand Down Expand Up @@ -69,10 +69,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"post_processing.threshold_method": "adaptive",
"post_processing.manual_image_threshold": None,
"post_processing.manual_pixel_threshold": None,
"post_processing.image_positive_rate": GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
"post_processing.pixel_positive_rate": GaussianMixtureThresholdEstimator.DEFAULT_POSITIVE_RATE,
"post_processing.image_n_components": GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
"post_processing.pixel_n_components": GaussianMixtureThresholdEstimator.DEFAULT_N_COMPONENTS,
"post_processing.image_positive_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE,
"post_processing.pixel_positive_rate": AnomalyScoreGaussianMixtureThreshold.DEFAULT_POSITIVE_RATE,
"post_processing.image_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
"post_processing.pixel_n_components": AnomalyScoreGaussianMixtureThreshold.DEFAULT_N_COMPONENTS,
}
)

Expand Down
5 changes: 2 additions & 3 deletions src/anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from omegaconf import DictConfig, ListConfig

from .anomaly_score_distribution import AnomalyScoreDistribution
from .anomaly_score_threshold import AnomalyScoreThreshold
from .anomaly_score_threshold_estimator import GaussianMixtureThresholdEstimator
from .anomaly_score_threshold import AnomalyScoreGaussianMixtureThreshold, AnomalyScoreThreshold
from .aupr import AUPR
from .aupro import AUPRO
from .auroc import AUROC
Expand All @@ -32,7 +31,7 @@
"MinMax",
"PRO",
"AnomalyScoreThreshold",
"GaussianMixtureThresholdEstimator",
"AnomalyScoreGaussianMixtureThreshold",
]


Expand Down
68 changes: 68 additions & 0 deletions src/anomalib/utils/metrics/anomaly_score_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import warnings

import numpy as np
import scipy.stats
import torch
from sklearn.mixture import GaussianMixture
from torch import Tensor
from torchmetrics import PrecisionRecallCurve
from torchmetrics.utilities.data import dim_zero_cat


class AnomalyScoreThreshold(PrecisionRecallCurve):
Expand Down Expand Up @@ -60,3 +64,67 @@ def compute(self) -> Tensor:
else:
self.value = thresholds[torch.argmax(f1_score)]
return self.value


class AnomalyScoreGaussianMixtureThreshold(AnomalyScoreThreshold):
DEFAULT_POSITIVE_RATE = None
DEFAULT_N_COMPONENTS = 1
N_CANDIDATE_THRESHOLDS = 10**5

def __init__(
self,
default_value: float = 0.5,
positive_rate: float | None = DEFAULT_POSITIVE_RATE,
n_components: int = DEFAULT_N_COMPONENTS,
**kwargs,
) -> None:
super().__init__(default_value=default_value, **kwargs)
assert positive_rate is None or 0.0 < positive_rate < 1.0, "Estimated positive rate should be in range (0, 1)."
self.positive_rate = positive_rate
self.n_components = n_components
self.kwargs = kwargs

def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
if not self._is_sufficient_dataset(target):
warnings.warn(
"The validation set contains too few anomalous or normal images to conduct a Gaussian "
"Mixture estimator. Falling back to Adaptive Threshold without density estimation."
)
return super().compute()

thresholds = self._get_sorted_candidate_thresholds(preds)
positive_cdf = self._compute_estimated_cdf(preds[target == 1], thresholds)
negative_cdf = self._compute_estimated_cdf(preds[target == 0], thresholds)
positive_rate = self.positive_rate or float((target == 1).float().mean())
f1_scores = self._compute_f1_scores(negative_cdf, positive_cdf, 1.0 - positive_rate, positive_rate)
self.value = thresholds[torch.argmax(f1_scores)]
return self.value

def _is_sufficient_dataset(self, target: Tensor) -> bool:
min_samples = max(2, self.n_components)
return bool((target == 0).sum() >= min_samples and (target == 1).sum() >= min_samples)

def _get_sorted_candidate_thresholds(self, preds: Tensor) -> Tensor:
return torch.linspace(preds.min(), preds.max(), self.N_CANDIDATE_THRESHOLDS)

def _compute_estimated_cdf(self, preds: Tensor, sorted_thresholds: Tensor) -> Tensor:
estimator = GaussianMixture(self.n_components, covariance_type="full", **self.kwargs)
estimator.fit(preds.reshape(-1, 1).numpy())
cdf = np.zeros(sorted_thresholds.shape)
for weight, mean, var in zip(estimator.weights_, estimator.means_, estimator.covariances_):
# TODO(yujie): var should be sqrt?
mean = mean.flatten()[0]
var = var.flatten()[0]
cdf += scipy.stats.norm.cdf(sorted_thresholds, loc=mean, scale=var**0.5) * weight
return torch.from_numpy(cdf)

def _compute_f1_scores(
self, negative_cdf: Tensor, positive_cdf: Tensor, negative_weight: float, positive_weight: float
) -> Tensor:
fp = (1.0 - negative_cdf) * negative_weight
tp = (1.0 - positive_cdf) * positive_weight
fn = positive_cdf * positive_weight
f1_scores = (tp * 2.0) / (tp * 2.0 + fp + fn)
return f1_scores
107 changes: 0 additions & 107 deletions src/anomalib/utils/metrics/anomaly_score_threshold_estimator.py

This file was deleted.

0 comments on commit 6c3b47d

Please sign in to comment.