Skip to content

Commit 6b6afcd

Browse files
author
ADMIN
committed
update
1 parent 8f14aad commit 6b6afcd

File tree

4 files changed

+309
-0
lines changed

4 files changed

+309
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ To use the model, follow these steps:
5353
### 🔍 Demo
5454
I have attached a .ipynb [file](demo.ipynb) in the repository. You can refer to it to know how to use the model.
5555

56+
Additionally, I have provided another .ipynb [file](cls_embeddings.ipynb) that illustrates the process of learning class embeddings for the model.
57+
5658
**Note**: You may need to adjust the threshold value to achieve the best results.
5759

5860
### 💡 Conclusion

SA_1B/sa_11027.jpg

668 KB
Loading

cls_embeddings.ipynb

Lines changed: 270 additions & 0 deletions
Large diffs are not rendered by default.

src/model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,40 @@ def forward_for_training_model_with_ref_points_lst_lst(
453453
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
454454
out["dn_meta"] = dn_meta
455455
return out, visual_embeds_list
456+
457+
def forward_for_cls_embeddings(
458+
self,
459+
tensor_list: NestedTensor,
460+
embeds,
461+
):
462+
# memory
463+
memory, mask_flatten, spatial_shapes, level_start_index, valid_ratios = (
464+
self.image_encoder(tensor_list)
465+
)
466+
467+
# box decoder
468+
out_bboxes, out_logits, dn_out_bboxes, dn_out_logits, dn_meta = (
469+
self.box_decoder(
470+
memory=memory,
471+
mask_flatten=mask_flatten,
472+
spatial_shapes=spatial_shapes,
473+
level_start_index=level_start_index,
474+
valid_ratios=valid_ratios,
475+
visual_embed=embeds,
476+
targets=None,
477+
)
478+
)
479+
out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]}
480+
out["aux_outputs"] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
481+
if dn_meta is not None:
482+
483+
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
484+
out["dn_meta"] = dn_meta
485+
return out, embeds
486+
487+
@torch.jit.unused
488+
def _set_aux_loss(self, outputs_class, outputs_coord):
489+
return [
490+
{"pred_logits": a, "pred_boxes": b}
491+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
492+
]

0 commit comments

Comments
 (0)