Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for giving a single score for the whole object #295

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions norfair/tracker.py
Expand Up @@ -746,7 +746,7 @@ class Detection:
Parameters
----------
points : np.ndarray
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)` where n_dimensions is 2 or 3.
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that it is not actually a restriction to have dimension 2 or 3. You can have any dimension you want. I decided to change that description since I was already modifying the Detection class anyway

scores : np.ndarray, optional
An array of length `n_points` which assigns a score to each of the points defined in `points`.

Expand All @@ -770,12 +770,19 @@ class Detection:
def __init__(
self,
points: np.ndarray,
scores: np.ndarray = None,
scores: Union[float, int, np.ndarray] = None,
data: Any = None,
label: Hashable = None,
embedding=None,
):
self.points = validate_points(points)

if isinstance(scores, np.ndarray):
assert len(scores) == len(
self.points
), "scores should be a np.ndarray with it's length being equal to the amount of points."
elif scores is not None:
scores = np.zeros((len(points),)) + scores
self.scores = scores
self.data = data
self.label = label
Expand Down
2 changes: 1 addition & 1 deletion norfair/utils.py
Expand Up @@ -20,7 +20,7 @@ def validate_points(points: np.ndarray) -> np.array:

def raise_detection_error_message(points):
message = "\n[red]INPUT ERROR:[/red]\n"
message += f"Each `Detection` object should have a property `points` of shape (num_of_points_to_track, 2), not {points.shape}. Check your `Detection` list creation code.\n"
message += f"Each `Detection` object should have a property `points` of shape (n_points, n_dimensions), not {points.shape}. Check your `Detection` list creation code.\n"
message += "You can read the documentation for the `Detection` class here:\n"
message += "https://tryolabs.github.io/norfair/reference/tracker/#norfair.tracker.Detection\n"
raise ValueError(message)
Expand Down