Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COCO mAP metric #2901

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
34a1a3f
Keep only cocomap-related changes
sadra-barikbin May 17, 2023
24fe980
Some improvements
sadra-barikbin May 28, 2023
e2ac8ee
Update docs
sadra-barikbin May 28, 2023
e4683de
Merge branch 'master' into cocomap
sadra-barikbin May 28, 2023
7cf53e1
Fix a bug in docs
sadra-barikbin May 29, 2023
4aa9c5d
Fix a tiny bug related to allgather
sadra-barikbin Jun 15, 2023
950c388
Fix a few bugs
sadra-barikbin Jun 16, 2023
9f5f796
Redesign code:
sadra-barikbin Jun 16, 2023
ffb1ba4
Merge branch 'master' into cocomap
sadra-barikbin Jun 16, 2023
65cdd08
Remove all_gather with different shape
sadra-barikbin Jun 17, 2023
e54af52
Merge branch 'master' into cocomap
sadra-barikbin Jun 21, 2023
aac2e55
Add test for all_gather_with_different_shape func
sadra-barikbin Jun 21, 2023
4cf3972
Merge branch 'master' into cocomap
vfdev-5 Jun 21, 2023
6070e18
A few improvements
sadra-barikbin Aug 23, 2023
aa83e60
Merge remote-tracking branch 'upstream/cocomap' into cocomap
sadra-barikbin Aug 23, 2023
deebbde
Add an output transform
sadra-barikbin Aug 31, 2023
62ca5fb
Add a test for the output_transform
sadra-barikbin Aug 31, 2023
418fcf4
Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycoco…
sadra-barikbin Sep 1, 2023
5fea0cd
Merge branch 'master' into cocomap
sadra-barikbin Sep 1, 2023
79fa1e2
Revert Metric change and a few bug fix
sadra-barikbin Sep 10, 2023
26c96b8
A tiny improvement in local variable names
sadra-barikbin Sep 15, 2023
d18f793
Merge branch 'master' into cocomap
sadra-barikbin Sep 15, 2023
a361ca8
Add max_dep and area_range
sadra-barikbin Dec 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,13 @@ Complete list of metrics
Frequency
Loss
MeanAbsoluteError
MeanAveragePrecision
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
ObjectDetectionMAP
precision.Precision
PSNR
recall.Recall
Expand Down
66 changes: 59 additions & 7 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch

Expand Down Expand Up @@ -350,29 +351,80 @@
return _model.all_reduce(tensor, op, group=group)


def _all_gather_tensors_with_shapes(
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

Check warning on line 361 in ignite/distributed/utils.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/utils.py#L361

Added line #L361 was not covered by tests

if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=1)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(

Check warning on line 368 in ignite/distributed/utils.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/utils.py#L366-L368

Added lines #L366 - L368 were not covered by tests
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0)
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
return [

Check warning on line 372 in ignite/distributed/utils.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/utils.py#L371-L372

Added lines #L371 - L372 were not covered by tests
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
if group is None or rank in group
]


def all_gather(
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor: Union[torch.Tensor, float, str],
group: Optional[Union[Any, List[int]]] = None,
tensor_different_shape: bool = False,
) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]:
"""Helper method to perform all gather operation.

Args:
tensor: tensor or number or str to collect across participating processes.
tensor: tensor or number or str to collect across participating processes. If tensor, it should have
the same number of dimensions across processes.
group: list of integer or the process group for each backend. If None, the default process group will be used.
tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it
induces more collective operations. If False, `tensor` should have the same shape across processes.
Ignored when `tensor` is not a tensor. Default False.


Returns:
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
torch.Tensor of shape ``(world_size, )`` if input is a number or
List of strings if input is a string
If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)``
if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group``
is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only
item is retured.
If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings
is returned if input is a string.

.. versionchanged:: 0.4.11
added ``group``

.. versionchanged:: 0.5.1
added ``tensor_different_shape``
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(tensor, torch.Tensor) and tensor_different_shape:
if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
return [tensor]
all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view(

Check warning on line 423 in ignite/distributed/utils.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/utils.py#L423

Added line #L423 was not covered by tests
-1, len(tensor.shape)
)
return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group)

Check warning on line 426 in ignite/distributed/utils.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/utils.py#L426

Added line #L426 was not covered by tests

return _model.all_gather(tensor, group=group)


Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_average_precision import MeanAveragePrecision
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
Expand All @@ -23,6 +24,7 @@
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP

__all__ = [
"Metric",
Expand Down Expand Up @@ -58,4 +60,6 @@
"Rouge",
"RougeN",
"RougeL",
"MeanAveragePrecision",
"ObjectDetectionMAP",
]