Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COCO mAP metric #2901

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
34a1a3f
Keep only cocomap-related changes
sadra-barikbin May 17, 2023
24fe980
Some improvements
sadra-barikbin May 28, 2023
e2ac8ee
Update docs
sadra-barikbin May 28, 2023
e4683de
Merge branch 'master' into cocomap
sadra-barikbin May 28, 2023
7cf53e1
Fix a bug in docs
sadra-barikbin May 29, 2023
4aa9c5d
Fix a tiny bug related to allgather
sadra-barikbin Jun 15, 2023
950c388
Fix a few bugs
sadra-barikbin Jun 16, 2023
9f5f796
Redesign code:
sadra-barikbin Jun 16, 2023
ffb1ba4
Merge branch 'master' into cocomap
sadra-barikbin Jun 16, 2023
65cdd08
Remove all_gather with different shape
sadra-barikbin Jun 17, 2023
e54af52
Merge branch 'master' into cocomap
sadra-barikbin Jun 21, 2023
aac2e55
Add test for all_gather_with_different_shape func
sadra-barikbin Jun 21, 2023
4cf3972
Merge branch 'master' into cocomap
vfdev-5 Jun 21, 2023
6070e18
A few improvements
sadra-barikbin Aug 23, 2023
aa83e60
Merge remote-tracking branch 'upstream/cocomap' into cocomap
sadra-barikbin Aug 23, 2023
deebbde
Add an output transform
sadra-barikbin Aug 31, 2023
62ca5fb
Add a test for the output_transform
sadra-barikbin Aug 31, 2023
418fcf4
Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycoco…
sadra-barikbin Sep 1, 2023
5fea0cd
Merge branch 'master' into cocomap
sadra-barikbin Sep 1, 2023
79fa1e2
Revert Metric change and a few bug fix
sadra-barikbin Sep 10, 2023
26c96b8
A tiny improvement in local variable names
sadra-barikbin Sep 15, 2023
d18f793
Merge branch 'master' into cocomap
sadra-barikbin Sep 15, 2023
a361ca8
Add max_dep and area_range
sadra-barikbin Dec 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,13 @@ Complete list of metrics
Frequency
Loss
MeanAbsoluteError
MeanAveragePrecision
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
ObjectDetectionMAP
precision.Precision
PSNR
recall.Recall
Expand Down
34 changes: 33 additions & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -43,6 +45,7 @@
"one_rank_only",
"new_group",
"one_rank_first",
"all_gather_tensors_with_shapes",
]

_model = _SerialModel()
Expand Down Expand Up @@ -350,6 +353,35 @@ def all_reduce(
return _model.all_reduce(tensor, op, group=group)


def all_gather_tensors_with_shapes(
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=0)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0)
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
return [
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
]


def all_gather(
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_average_precision import MeanAveragePrecision
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
Expand All @@ -23,6 +24,7 @@
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP

__all__ = [
"Metric",
Expand Down Expand Up @@ -58,4 +60,6 @@
"Rouge",
"RougeN",
"RougeL",
"MeanAveragePrecision",
"ObjectDetectionMAP",
]
441 changes: 441 additions & 0 deletions ignite/metrics/mean_average_precision.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> b
return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x])


def _to_batched_tensor(x: Union[torch.Tensor, float], device: Optional[torch.device] = None) -> torch.Tensor:
def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
return x.unsqueeze(dim=0)
return torch.tensor([x], device=device)
3 changes: 3 additions & 0 deletions ignite/metrics/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP

__all__ = ["ObjectDetectionMAP"]