Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/inference slicer segmentation #1178

Merged
merged 11 commits into from
May 13, 2024
6 changes: 6 additions & 0 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ status: new

:::supervision.detection.utils.move_boxes

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.move_masks">move_masks</a></h2>
</div>

:::supervision.detection.utils.move_masks

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.scale_boxes">scale_boxes</a></h2>
</div>
Expand Down
72 changes: 66 additions & 6 deletions docs/how_to/detect_small_objects.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ size relative to the image resolution.
import torch
import supervision as sv
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import DetrImageProcessor, DetrForSegmentation

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50")

image = Image.open(<SOURCE_IMAGE_PATH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

width, height = image.size
target_size = torch.tensor([[height, width]])
width, height = image_slice.size
target_size = torch.tensor([[width, height]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_size)[0]
detections = sv.Detections.from_transformers(results)
Expand Down Expand Up @@ -239,8 +239,8 @@ objects within each, and aggregating the results.
with torch.no_grad():
outputs = model(**inputs)

width, height = image.size
target_size = torch.tensor([[height, width]])
width, height = image_slice.size
target_size = torch.tensor([[width, height]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_size)[0]
return sv.Detections.from_transformers(results)
Expand All @@ -264,3 +264,63 @@ objects within each, and aggregating the results.
```

![detection-with-inference-slicer](https://media.roboflow.com/supervision_detect_small_objects_example_3.png)

## Small Object Segmentation

[`InferenceSlicer`](/latest/detection/tools/inference_slicer/#supervision.detection.tools.inference_slicer.InferenceSlicer) can perform segmentation tasks too.

=== "Inference"

```{ .py hl_lines="6 16 19-20" }
import cv2
import numpy as np
import supervision as sv
from inference import get_model

model = get_model(model_id="yolov8x-seg-640")
image = cv2.imread(<SOURCE_IMAGE_PATH>)

def callback(image_slice: np.ndarray) -> sv.Detections:
results = model.infer(image_slice)[0]
detections = sv.Detections.from_inference(results)

slicer = sv.InferenceSlicer(callback = callback)
detections = slicer(image)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()

annotated_image = mask_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections)
```

=== "Ultralytics"

```{ .py hl_lines="6 16 19-20" }
import cv2
import numpy as np
import supervision as sv
from ultralytics import YOLO

model = YOLO("yolov8x-seg.pt")
image = cv2.imread(<SOURCE_IMAGE_PATH>)

def callback(image_slice: np.ndarray) -> sv.Detections:
result = model(image_slice)[0]
return sv.Detections.from_ultralytics(result)

slicer = sv.InferenceSlicer(callback = callback)
detections = slicer(image)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()

annotated_image = mask_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections)
```

![detection-with-inference-slicer](https://media.roboflow.com/supervision-docs/inference-slicer-segmentation-example.png)
1 change: 1 addition & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
mask_to_polygons,
mask_to_xyxy,
move_boxes,
move_masks,
pad_boxes,
polygon_to_mask,
polygon_to_xyxy,
Expand Down
27 changes: 23 additions & 4 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.detection.utils import move_boxes
from supervision.detection.utils import move_boxes, move_masks
from supervision.utils.image import crop_image


def move_detections(detections: Detections, offset: np.array) -> Detections:
def move_detections(
detections: Detections,
offset: np.ndarray,
resolution_wh: Optional[Tuple[int, int]] = None,
) -> Detections:
"""
Args:
detections (sv.Detections): Detections object to be moved.
offset (np.array): An array of shape `(2,)` containing offset values in format
offset (np.ndarray): An array of shape `(2,)` containing offset values in format
is `[dx, dy]`.
resolution_wh (Tuple[int, int]): The width and height of the desired mask
resolution. Required for segmentation detections.

Returns:
(sv.Detections) repositioned Detections object.
"""
detections.xyxy = move_boxes(xyxy=detections.xyxy, offset=offset)
if detections.mask is not None:
if resolution_wh is None:
raise ValueError(
"Resolution width and height are required for moving segmentation "
"detections. This should be the same as (width, height) of image shape."
)
detections.mask = move_masks(
masks=detections.mask, offset=offset, resolution_wh=resolution_wh
)
return detections


Expand Down Expand Up @@ -126,7 +142,10 @@ def _run_callback(self, image, offset) -> Detections:
"""
image_slice = crop_image(image=image, xyxy=offset)
detections = self.callback(image_slice)
detections = move_detections(detections=detections, offset=offset[:2])
resolution_wh = (image.shape[1], image.shape[0])
detections = move_detections(
detections=detections, offset=offset[:2], resolution_wh=resolution_wh
)

return detections

Expand Down
34 changes: 34 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,40 @@ def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray:
return xyxy + np.hstack([offset, offset])


def move_masks(
masks: np.ndarray,
offset: np.ndarray,
resolution_wh: Tuple[int, int] = None,
) -> np.ndarray:
"""
Offset the masks in an array by the specified (x, y) amount.

Args:
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each mask.
offset (np.ndarray): An array of shape `(2,)` containing non-negative int values
`[dx, dy]`.
resolution_wh (Tuple[int, int]): The width and height of the desired mask
resolution.

Returns:
(np.ndarray) repositioned masks, optionally padded to the specified shape.
"""

if offset[0] < 0 or offset[1] < 0:
raise ValueError(f"Offset values must be non-negative integers. Got: {offset}")

mask_array = np.full((masks.shape[0], resolution_wh[1], resolution_wh[0]), False)
mask_array[
:,
offset[1] : masks.shape[1] + offset[1],
offset[0] : masks.shape[2] + offset[0],
] = masks

return mask_array


def scale_boxes(xyxy: np.ndarray, factor: float) -> np.ndarray:
"""
Scale the dimensions of bounding boxes.
Expand Down