Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PLA-639][External] Allow properties to be imported for annotations without IDs #803

Merged
merged 3 commits into from Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion darwin/datatypes.py
Expand Up @@ -1468,7 +1468,8 @@ class ObjectStore:
name (str): The alias of the storage connection
prefix (str): The directory that files are written back to in the storage location
readonly (bool): Whether the storage configuration is read-only or not
self.provider (str): The cloud provider (aws, azure, or gcp)
provider (str): The cloud provider (aws, azure, or gcp)
default (bool): Whether the storage connection is the default one
"""

def __init__(
Expand Down
37 changes: 23 additions & 14 deletions darwin/exporter/formats/mask.py
Expand Up @@ -3,7 +3,7 @@
import os
from csv import writer as csv_writer
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, get_args
from typing import Dict, Iterable, List, Set, Tuple, get_args

import numpy as np

Expand Down Expand Up @@ -162,8 +162,10 @@ def get_render_mode(annotations: List[dt.AnnotationLike]) -> dt.MaskTypes.TypeOf
)


def rle_decode(rle: dt.MaskTypes.UndecodedRLE, label_colours: Dict[int, int]) -> List[int]:
"""Decodes a run-length encoded list of integers and substitutes labels by colours.
def rle_decode(
rle: dt.MaskTypes.UndecodedRLE, label_colours: Dict[int, int]
) -> List[int]:
"""Decodes a run-length encoded list of integers and substitutes labels by colours.

Args:
rle (List[int]): A run-length encoded list of integers.
Expand Down Expand Up @@ -352,30 +354,35 @@ def render_raster(
categories.append(new_mask.name)

colour_to_draw = categories.index(new_mask.name)

if new_mask.id not in mask_colours:
mask_colours[new_mask.id] = colour_to_draw

if new_mask.name not in colours:
colours[new_mask.name] = colour_to_draw

colours[new_mask.name] = colour_to_draw

raster_layer_list = [a for a in annotations if a.annotation_class.annotation_type == "raster_layer"]
raster_layer_list = [
a for a in annotations if a.annotation_class.annotation_type == "raster_layer"
]

if len(raster_layer_list) == 0:
errors.append(ValueError(f"File {annotation_file.filename} has no raster layer"))
errors.append(
ValueError(f"File {annotation_file.filename} has no raster layer")
)
return errors, mask, categories, colours

if len(raster_layer_list) > 1:
errors.append(
ValueError(f"File {annotation_file.filename} has more than one raster layer")
ValueError(
f"File {annotation_file.filename} has more than one raster layer"
)
)
return errors, mask, categories, colours

rl = raster_layer_list[0]
if isinstance(rl, dt.VideoAnnotation):
return errors, mask, categories, colours

raster_layer = dt.RasterLayer(
rle=rl.data["dense_rle"],
slot_names=a.slot_names,
Expand All @@ -389,13 +396,15 @@ def render_raster(

if colour_to_draw is None:
errors.append(
ValueError(f"Could not find mask with uuid {uuid} among masks in the file {annotation_file.filename}.")
ValueError(
f"Could not find mask with uuid {uuid} among masks in the file {annotation_file.filename}."
)
)
return errors, mask, categories, colours

label_colours[label] = colour_to_draw

decoded = rle_decode(raster_layer.rle, label_colours)
decoded = rle_decode(raster_layer.rle, label_colours)
mask = np.array(decoded, dtype=np.uint8).reshape(height, width)

return errors, mask, categories, colours
Expand Down
23 changes: 12 additions & 11 deletions darwin/importer/importer.py
@@ -1,3 +1,4 @@
import uuid
from collections import defaultdict
from logging import getLogger
from multiprocessing import cpu_count
Expand Down Expand Up @@ -281,9 +282,9 @@ def _get_team_properties_annotation_lookup(client):
team_properties = client.get_team_properties()

# (property-name, annotation_class_id): FullProperty object
team_properties_annotation_lookup: Dict[
Tuple[str, Optional[int]], FullProperty
] = {}
team_properties_annotation_lookup: Dict[Tuple[str, Optional[int]], FullProperty] = (
{}
)
for prop in team_properties:
team_properties_annotation_lookup[(prop.name, prop.annotation_class_id)] = prop

Expand Down Expand Up @@ -895,9 +896,9 @@ def import_annotations( # noqa: C901

# Need to re parse the files since we didn't save the annotations in memory
for local_path in set(local_file.path for local_file in local_files): # noqa: C401
imported_files: Union[
List[dt.AnnotationFile], dt.AnnotationFile, None
] = importer(local_path)
imported_files: Union[List[dt.AnnotationFile], dt.AnnotationFile, None] = (
importer(local_path)
)
if imported_files is None:
parsed_files = []
elif not isinstance(imported_files, List):
Expand Down Expand Up @@ -1297,17 +1298,17 @@ def _import_annotations(
# Insert the default slot name if not available in the import source
annotation = _handle_slot_names(annotation, dataset.version, default_slot_name)

annotation_class_ids_map[
(annotation_class.name, annotation_type)
] = annotation_class_id
annotation_class_ids_map[(annotation_class.name, annotation_type)] = (
annotation_class_id
)
serial_obj = {
"annotation_class_id": annotation_class_id,
"data": data,
"context_keys": {"slot_names": annotation.slot_names},
}

if annotation.id:
serial_obj["id"] = annotation.id
annotation.id = annotation.id or str(uuid.uuid4())
serial_obj["id"] = annotation.id

if actors:
serial_obj["actors"] = actors # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/darwin/exporter/formats/export_mask_test.py
Expand Up @@ -2,7 +2,7 @@
import platform
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional
from typing import List, Optional
from unittest.mock import patch

import numpy as np
Expand Down
13 changes: 12 additions & 1 deletion tests/darwin/importer/importer_test.py
Expand Up @@ -572,7 +572,18 @@ def test__import_annotations() -> None:
"overwrite": "test_append_out",
}

assert output["annotations"] == assertion["annotations"]
assert (
output["annotations"][0]["annotation_class_id"]
== assertion["annotations"][0]["annotation_class_id"]
)
assert output["annotations"][0]["data"] == assertion["annotations"][0]["data"]
assert (
output["annotations"][0]["actors"] == assertion["annotations"][0]["actors"]
)
assert (
output["annotations"][0]["context_keys"]
== assertion["annotations"][0]["context_keys"]
)
Comment on lines +575 to +586
Copy link
Contributor Author

@JBWilkie JBWilkie Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the importer now assigns a random UUID, comparing the entire annotations field for each no longer works. Instead, we need to compare the individual fields inside annotations

assert output["overwrite"] == assertion["overwrite"]


Expand Down