Skip to content

Commit

Permalink
Merge pull request #1178 from roboflow/feat/inference-slicer-segmenta…
Browse files Browse the repository at this point in the history
…tion

Feat/inference slicer segmentation
  • Loading branch information
SkalskiP committed May 13, 2024
2 parents f41adca + 8901192 commit cd8a2be
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ status: new

:::supervision.detection.utils.move_boxes

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.move_masks">move_masks</a></h2>
</div>

:::supervision.detection.utils.move_masks

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.scale_boxes">scale_boxes</a></h2>
</div>
Expand Down
72 changes: 66 additions & 6 deletions docs/how_to/detect_small_objects.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ size relative to the image resolution.
import torch
import supervision as sv
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import DetrImageProcessor, DetrForSegmentation

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50")

image = Image.open(<SOURCE_IMAGE_PATH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

width, height = image.size
target_size = torch.tensor([[height, width]])
width, height = image_slice.size
target_size = torch.tensor([[width, height]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_size)[0]
detections = sv.Detections.from_transformers(results)
Expand Down Expand Up @@ -239,8 +239,8 @@ objects within each, and aggregating the results.
with torch.no_grad():
outputs = model(**inputs)

width, height = image.size
target_size = torch.tensor([[height, width]])
width, height = image_slice.size
target_size = torch.tensor([[width, height]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_size)[0]
return sv.Detections.from_transformers(results)
Expand All @@ -264,3 +264,63 @@ objects within each, and aggregating the results.
```

![detection-with-inference-slicer](https://media.roboflow.com/supervision_detect_small_objects_example_3.png)

## Small Object Segmentation

[`InferenceSlicer`](/latest/detection/tools/inference_slicer/#supervision.detection.tools.inference_slicer.InferenceSlicer) can perform segmentation tasks too.

=== "Inference"

```{ .py hl_lines="6 16 19-20" }
import cv2
import numpy as np
import supervision as sv
from inference import get_model

model = get_model(model_id="yolov8x-seg-640")
image = cv2.imread(<SOURCE_IMAGE_PATH>)

def callback(image_slice: np.ndarray) -> sv.Detections:
results = model.infer(image_slice)[0]
detections = sv.Detections.from_inference(results)

slicer = sv.InferenceSlicer(callback = callback)
detections = slicer(image)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()

annotated_image = mask_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections)
```

=== "Ultralytics"

```{ .py hl_lines="6 16 19-20" }
import cv2
import numpy as np
import supervision as sv
from ultralytics import YOLO

model = YOLO("yolov8x-seg.pt")
image = cv2.imread(<SOURCE_IMAGE_PATH>)

def callback(image_slice: np.ndarray) -> sv.Detections:
result = model(image_slice)[0]
return sv.Detections.from_ultralytics(result)

slicer = sv.InferenceSlicer(callback = callback)
detections = slicer(image)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()

annotated_image = mask_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections)
```

![detection-with-inference-slicer](https://media.roboflow.com/supervision-docs/inference-slicer-segmentation-example.png)
1 change: 1 addition & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
mask_to_polygons,
mask_to_xyxy,
move_boxes,
move_masks,
pad_boxes,
polygon_to_mask,
polygon_to_xyxy,
Expand Down
27 changes: 23 additions & 4 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.detection.utils import move_boxes
from supervision.detection.utils import move_boxes, move_masks
from supervision.utils.image import crop_image


def move_detections(detections: Detections, offset: np.array) -> Detections:
def move_detections(
detections: Detections,
offset: np.ndarray,
resolution_wh: Optional[Tuple[int, int]] = None,
) -> Detections:
"""
Args:
detections (sv.Detections): Detections object to be moved.
offset (np.array): An array of shape `(2,)` containing offset values in format
offset (np.ndarray): An array of shape `(2,)` containing offset values in format
is `[dx, dy]`.
resolution_wh (Tuple[int, int]): The width and height of the desired mask
resolution. Required for segmentation detections.
Returns:
(sv.Detections) repositioned Detections object.
"""
detections.xyxy = move_boxes(xyxy=detections.xyxy, offset=offset)
if detections.mask is not None:
if resolution_wh is None:
raise ValueError(
"Resolution width and height are required for moving segmentation "
"detections. This should be the same as (width, height) of image shape."
)
detections.mask = move_masks(
masks=detections.mask, offset=offset, resolution_wh=resolution_wh
)
return detections


Expand Down Expand Up @@ -126,7 +142,10 @@ def _run_callback(self, image, offset) -> Detections:
"""
image_slice = crop_image(image=image, xyxy=offset)
detections = self.callback(image_slice)
detections = move_detections(detections=detections, offset=offset[:2])
resolution_wh = (image.shape[1], image.shape[0])
detections = move_detections(
detections=detections, offset=offset[:2], resolution_wh=resolution_wh
)

return detections

Expand Down
34 changes: 34 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,40 @@ def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray:
return xyxy + np.hstack([offset, offset])


def move_masks(
masks: np.ndarray,
offset: np.ndarray,
resolution_wh: Tuple[int, int] = None,
) -> np.ndarray:
"""
Offset the masks in an array by the specified (x, y) amount.
Args:
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each mask.
offset (np.ndarray): An array of shape `(2,)` containing non-negative int values
`[dx, dy]`.
resolution_wh (Tuple[int, int]): The width and height of the desired mask
resolution.
Returns:
(np.ndarray) repositioned masks, optionally padded to the specified shape.
"""

if offset[0] < 0 or offset[1] < 0:
raise ValueError(f"Offset values must be non-negative integers. Got: {offset}")

mask_array = np.full((masks.shape[0], resolution_wh[1], resolution_wh[0]), False)
mask_array[
:,
offset[1] : masks.shape[1] + offset[1],
offset[0] : masks.shape[2] + offset[0],
] = masks

return mask_array


def scale_boxes(xyxy: np.ndarray, factor: float) -> np.ndarray:
"""
Scale the dimensions of bounding boxes.
Expand Down

0 comments on commit cd8a2be

Please sign in to comment.