Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Feature/replaceable detection helper #690

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions detectron/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@
# (e.g., 'generalized_rcnn', 'mask_rcnn', ...)
__C.MODEL.TYPE = ''

# Detection model helper class to use
#
# Allows to apply custom DetectionModelHelper implementation
__C.MODEL.MODEL_HELPER_CLASS = 'detectron.modeling.detector.DetectionModelHelper'

# The backbone conv body to use
# The string must match a function that is imported in modeling.model_builder
# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN
Expand Down
19 changes: 17 additions & 2 deletions detectron/modeling/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,16 @@ def create(model_type_func, train=False, gpu_id=0):
targeted to a specific GPU by specifying gpu_id. This is used by
optimizer.build_data_parallel_model() during test time.
"""
model = DetectionModelHelper(
parts = cfg.MODEL.MODEL_HELPER_CLASS.split('.')
try:
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
model_helper_class = getattr(module, parts[-1])
except (IndexError, ImportError, AttributeError):
logger.error('Failed to find model helper: %s', model_helper_class)
raise

model = model_helper_class(
name=model_type_func,
train=train,
num_classes=cfg.MODEL.NUM_CLASSES,
Expand Down Expand Up @@ -145,7 +154,13 @@ def get_func(func_name):
return globals()[parts[0]]
# Otherwise, assume we're referencing a module under modeling
module_name = 'detectron.modeling.' + '.'.join(parts[:-1])
module = importlib.import_module(module_name)
try:
module = importlib.import_module(module_name)
except ImportError:
# Finally check if we're referencing a module from the environment
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
logger.debug('Using function %s from the environment', func_name)
return getattr(module, parts[-1])
except Exception:
logger.error('Failed to find function: {}'.format(func_name))
Expand Down