Skip to content

Commit

Permalink
Fix benchmark (#1155)
Browse files Browse the repository at this point in the history
* Fixed broken links in readme

* Fixed inference command in readme

* Fix benchmark torch throughput

* Add tests to cover throughput

* Add fixture path to conftest

* Format changed files

* Add implementation to conftest
  • Loading branch information
blaz-r committed Jun 30, 2023
1 parent 22ab4e1 commit 17efdb5
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 1 deletion.
68 changes: 68 additions & 0 deletions tests/pre_merge/utils/sweep/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Fixtures for the sweep tests."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from tempfile import TemporaryDirectory
from typing import Optional, Union

import pytest
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Trainer

from anomalib.config import get_configurable_parameters
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.callbacks import get_callbacks


def get_model_config(
project_path: str,
model_name: str,
dataset_path: str,
category: str,
task: str = "classification",
export_mode: Optional[str] = None,
):
model_config = get_configurable_parameters(model_name=model_name)
model_config.project.path = project_path
model_config.dataset.task = task
model_config.dataset.path = dataset_path
model_config.dataset.category = category
model_config.trainer.fast_dev_run = True
model_config.trainer.max_epochs = 1
model_config.trainer.devices = 1
model_config.trainer.accelerator = "gpu"
model_config.optimization.export_mode = export_mode
return model_config


@pytest.fixture(scope="package")
def generate_results_dir():
with TemporaryDirectory() as project_path:

def make(
model_name: str,
dataset_path: str,
category: str,
task: str = "classification",
export_mode: Optional[str] = None,
) -> Union[DictConfig, ListConfig]:
# then train the model
model_config = get_model_config(
project_path=project_path,
model_name=model_name,
dataset_path=dataset_path,
category=category,
task=task,
export_mode=export_mode,
)
model = get_model(model_config)
datamodule = get_datamodule(model_config)
callbacks = get_callbacks(model_config)
trainer = Trainer(**model_config.trainer, logger=False, callbacks=callbacks)
trainer.fit(model=model, datamodule=datamodule)

return model_config

yield make
62 changes: 62 additions & 0 deletions tests/pre_merge/utils/sweep/test_throughput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Tests for Torch and OpenVINO inferencer throughput used in sweep."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import albumentations as A
from albumentations.pytorch import ToTensorV2

from anomalib.data import TaskType
from anomalib.data.folder import FolderDataset
from anomalib.deploy import ExportMode

from anomalib.utils.sweep.helpers import get_torch_throughput, get_openvino_throughput

from tests.helpers.dataset import TestDataset


transforms = A.Compose([A.ToFloat(max_value=255), ToTensorV2()])


@TestDataset(num_train=20, num_test=10)
def test_torch_throughput(generate_results_dir, path: str = None, category: str = "shapes"):
"""Test get_torch_throughput from utils/sweep/inference.py"""
# generate results with torch model exported
model_config = generate_results_dir(
model_name="padim",
dataset_path=path,
task=TaskType.CLASSIFICATION,
category=category,
export_mode=ExportMode.TORCH,
)

# create Dataset from generated TestDataset
dataset = FolderDataset(
task=TaskType.CLASSIFICATION, transform=transforms, root=path, normal_dir=f"{category}/test/good"
)
dataset.setup()

# run procedure using torch inferencer
get_torch_throughput(model_config.project.path, dataset, device=model_config.trainer.accelerator)


@TestDataset(num_train=20, num_test=10)
def test_openvino_throughput(generate_results_dir, path: str = None, category: str = "shapes"):
"""Test get_openvino_throughput from utils/sweep/inference.py"""
# generate results with torch model exported
model_config = generate_results_dir(
model_name="padim",
dataset_path=path,
task=TaskType.CLASSIFICATION,
category=category,
export_mode=ExportMode.OPENVINO,
)

# create Dataset from generated TestDataset
dataset = FolderDataset(
task=TaskType.CLASSIFICATION, transform=transforms, root=path, normal_dir=f"{category}/test/good"
)
dataset.setup()

# run procedure using openvino inferencer
get_openvino_throughput(model_config.project.path, dataset)
16 changes: 15 additions & 1 deletion tools/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,21 @@ def get_single_model_metrics(model_config: DictConfig | ListConfig, openvino_met
# get testing time
testing_time = time.time() - start_time

throughput = get_torch_throughput(model_config, model, datamodule.test_dataloader().dataset)
# Create dirs for torch export (as default only lighting model is produced)
export(
task=model_config.dataset.task,
transform=trainer.datamodule.test_data.transform.to_dict(),
input_size=model_config.model.input_size,
model=model,
export_mode=ExportMode.TORCH,
export_root=project_path,
)

throughput = get_torch_throughput(
model_path=project_path,
test_dataset=datamodule.test_dataloader().dataset,
device=model_config.trainer.accelerator,
)

# Get OpenVINO metrics
openvino_throughput = float("nan")
Expand Down

0 comments on commit 17efdb5

Please sign in to comment.