Skip to content

Commit 57889ca

Browse files
committed
Add MIT LICENSE and a pretrained model zoo.
1 parent c19e7fc commit 57889ca

File tree

10 files changed

+201
-36
lines changed

10 files changed

+201
-36
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ ENV/
4747
scripts/test_*
4848

4949
# Data (symlinks) directory, model checkpoints, tensorboard logs etc.
50-
data/
5150
datasets/
5251
checkpoints/
5352
virtex/utils/assets/
5453
!virtex/data/
54+
virtex/model_zoo/configs

LICENSE

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Copyright (c) 2020, Karan Desai.
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4+
associated documentation files (the "Software"), to deal in the Software without restriction,
5+
including without limitation the rights to use, copy, modify, merge, publish, distribute,
6+
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
7+
furnished to do so, subject to the following conditions:
8+
9+
The above copyright notice and this permission notice shall be included in all copies or substantial
10+
portions of the Software.
11+
12+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
13+
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
14+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
15+
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16+
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

configs/_base_bicaptioning_R_50_L1_H1024.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ DATA:
99
ROOT: "datasets/coco"
1010
TOKENIZER_VOCAB: "datasets/vocab/coco_10k.vocab"
1111
TOKENIZER_MODEL: "datasets/vocab/coco_10k.model"
12+
VOCAB_SIZE: 10000
13+
UNK_INDEX: 0
14+
SOS_INDEX: 1
15+
EOS_INDEX: 2
16+
MASK_INDEX: 3
1217

1318
IMAGE_CROP_SIZE: 224
1419
MAX_CAPTION_LENGTH: 30

configs/task_ablations/multilabel_classification_R_50.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
22

3+
DATA:
4+
VOCAB_SIZE: 81
5+
36
MODEL:
47
NAME: "multilabel_classification"
58
TEXTUAL:

docs/virtex/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ Config References
1414
.. literalinclude:: ../../virtex/config.py
1515
:language: python
1616
:linenos:
17-
:lines: 53-171
17+
:lines: 53-189
1818
:dedent: 8

setup.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,51 @@
11
#!/usr/bin/env python
2+
import glob
3+
import os
24
from setuptools import setup
5+
import shutil
6+
from typing import List
7+
8+
9+
def get_model_zoo_configs() -> List[str]:
10+
"""
11+
Return a list of configs to include in package for model zoo. Copy over
12+
these configs inside virtex/model_zoo.
13+
"""
14+
15+
# Use absolute paths while symlinking.
16+
source_configs_dir = os.path.join(
17+
os.path.dirname(os.path.realpath(__file__)), "configs"
18+
)
19+
destination = os.path.join(
20+
os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs"
21+
)
22+
# Symlink the config directory inside package to have a cleaner pip install.
23+
24+
# Remove stale symlink/directory from a previous build.
25+
if os.path.exists(source_configs_dir):
26+
if os.path.islink(destination):
27+
os.unlink(destination)
28+
elif os.path.isdir(destination):
29+
shutil.rmtree(destination)
30+
31+
if not os.path.exists(destination):
32+
try:
33+
os.symlink(source_configs_dir, destination)
34+
except OSError:
35+
# Fall back to copying if symlink fails: ex. on Windows.
36+
shutil.copytree(source_configs_dir, destination)
37+
38+
config_paths = glob.glob("configs/**/*.yaml", recursive=True)
39+
return config_paths
340

441

542
setup(
643
name="virtex",
744
version="0.9",
8-
author="Karan Desai, Justin Johnson",
45+
author="Karan Desai and Justin Johnson",
946
description="VirTex: Learning Visual Representations with Textual Annotations",
47+
package_data={"virtex.model_zoo": get_model_zoo_configs()},
48+
python_requires=">=3.6",
1049
license="Apache 2.0",
1150
zip_safe=True,
1251
)

virtex/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ def __init__(
6262
# Path to .model file generated by ``sentencepiece``.
6363
_C.DATA.TOKENIZER_MODEL = "datasets/vocab/coco_10k.model"
6464

65+
# Handy config params for vocab size and indices of special tokens.
66+
# While these can be picked up from the tokenizer, having these in
67+
# the config makes it easy to create a model without instantiating too
68+
# many tokenizer instances (especially when not needed, e.g. model zoo).
69+
# These must match according to what's present in ``TOKENIZER_VOCAB``
70+
# and ``TOKENIZER_MODEL`` above.
71+
_C.DATA.VOCAB_SIZE = 10000
72+
# Index of out-of-vocabulary (and padding) token.
73+
_C.DATA.UNK_INDEX = 0
74+
# Index of the start-of-sentence [SOS] token.
75+
_C.DATA.SOS_INDEX = 1
76+
# Index of the end-of-sentence [EOS] token.
77+
_C.DATA.EOS_INDEX = 2
78+
# Index of the word masking token. While not used for captioning, having
79+
# this extra token makes it possible to train an MLM model without
80+
# re-creating a new vocab mapping.
81+
_C.DATA.MASK_INDEX = 3
82+
6583
# Size of the image (square) to crop from original input image.
6684
_C.DATA.IMAGE_CROP_SIZE = 224
6785
# Maximum length of input caption (number of tokens).

virtex/factories.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -346,23 +346,17 @@ class TextualHeadFactory(Factory):
346346
# fmt: on
347347

348348
@classmethod
349-
def from_config(
350-
cls, config: Config, tokenizer: Optional[SentencePieceBPETokenizer] = None
351-
) -> nn.Module:
349+
def from_config(cls, config: Config) -> nn.Module:
352350
r"""
353351
Create a textual head directly from config.
354352
355353
Parameters
356354
----------
357355
config: virtex.config.Config
358356
Config object with all the parameters.
359-
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer, optional (default = None)
360-
A tokenizer which has the mapping between word tokens and their
361-
integer IDs.
362357
"""
363358

364359
_C = config
365-
tokenizer = tokenizer or TokenizerFactory.from_config(_C)
366360

367361
# Get architectural hyper-params as per name by matching regex.
368362
name, architecture = _C.MODEL.TEXTUAL.NAME.split("::")
@@ -374,7 +368,7 @@ def from_config(
374368
feedforward_size = int(architecture.group(4))
375369

376370
kwargs = {
377-
"vocab_size": tokenizer.get_vocab_size(),
371+
"vocab_size": _C.DATA.VOCAB_SIZE,
378372
"hidden_size": hidden_size,
379373
}
380374

@@ -384,7 +378,7 @@ def from_config(
384378
attention_heads=attention_heads,
385379
feedforward_size=feedforward_size,
386380
dropout=_C.MODEL.TEXTUAL.DROPOUT,
387-
padding_idx=tokenizer.token_to_id("[UNK]"),
381+
padding_idx=_C.DATA.UNK_INDEX,
388382
max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
389383
)
390384
return cls.create(name, **kwargs)
@@ -406,55 +400,39 @@ class PretrainingModelFactory(Factory):
406400
}
407401

408402
@classmethod
409-
def from_config(
410-
cls, config: Config, tokenizer: Optional[SentencePieceBPETokenizer] = None
411-
) -> nn.Module:
403+
def from_config(cls, config: Config) -> nn.Module:
412404
r"""
413405
Create a model directly from config.
414406
415407
Parameters
416408
----------
417409
config: virtex.config.Config
418410
Config object with all the parameters.
419-
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer, optional (default = None)
420-
A tokenizer which has the mapping between word tokens and their
421-
integer IDs.
422411
"""
423412

424413
_C = config
425-
tokenizer = tokenizer or TokenizerFactory.from_config(_C)
426-
427-
if _C.MODEL.NAME == "multilabel_classification":
428-
# Pass a dummy tokenizer object to TextualHeadFactory for
429-
# `multilabel_classification`, which can return vocab size as `81`
430-
# (80 COCO categories + background).
431-
class DummyTokenizer(object):
432-
def get_vocab_size(self) -> int:
433-
return 81
434-
435-
tokenizer = DummyTokenizer() # type: ignore
436414

437415
# Build visual and textual streams based on config.
438416
visual = VisualBackboneFactory.from_config(_C)
439-
textual = TextualHeadFactory.from_config(_C, tokenizer)
417+
textual = TextualHeadFactory.from_config(_C)
440418

441419
# Add model specific kwargs. Refer call signatures of specific models
442420
# for matching kwargs here.
443421
kwargs = {}
444422
if "captioning" in _C.MODEL.NAME:
445423
kwargs.update(
446424
max_decoding_steps=_C.DATA.MAX_CAPTION_LENGTH,
447-
sos_index=tokenizer.token_to_id("[SOS]"),
448-
eos_index=tokenizer.token_to_id("[EOS]"),
425+
sos_index=_C.DATA.SOS_INDEX,
426+
eos_index=_C.DATA.EOS_INDEX,
449427
)
450428

451429
elif _C.MODEL.NAME == "token_classification":
452430
kwargs.update(
453431
ignore_indices=[
454-
tokenizer.token_to_id("[UNK]"),
455-
tokenizer.token_to_id("[SOS]"),
456-
tokenizer.token_to_id("[EOS]"),
457-
tokenizer.token_to_id("[MASK]"),
432+
_C.DATA.UNK_INDEX,
433+
_C.DATA.SOS_INDEX,
434+
_C.DATA.EOS_INDEX,
435+
_C.DATA.MASK_INDEX
458436
],
459437
)
460438
elif _C.MODEL.NAME == "multilabel_classification":

virtex/model_zoo/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model_zoo import get
2+
3+
__all__ = ["get"]

virtex/model_zoo/model_zoo.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
r"""
2+
A utility module which provides functionality to easily load common VirTex
3+
models (optionally with pretrained weights) using a single line of code.
4+
5+
Get our full best performing VirTex model (with pretrained weights as):
6+
7+
>>> import virtex.model_zoo as mz
8+
>>> model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)
9+
10+
Any config available in ``configs/`` directory under project root can be
11+
specified here, although this command need not be executed from project root.
12+
13+
Part of this code is adapted from Detectron2's model zoo; which was originally
14+
implemented by the developers of this codebase, with reviews and further
15+
changes by Detectron2 developers.
16+
"""
17+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
18+
import os
19+
import pkg_resources
20+
21+
from fvcore.common.download import download
22+
import torch
23+
24+
from virtex.config import Config
25+
from virtex.factories import PretrainingModelFactory
26+
from virtex.utils.checkpointing import CheckpointManager
27+
28+
29+
class _ModelZooUrls(object):
30+
r"""Mapping from config names to URL suffixes of pretrained weights."""
31+
32+
URL_PREFIX = "https://umich.box.com/shared/static"
33+
34+
CONFIG_PATH_TO_URL_SUFFIX = {
35+
36+
# Pretraining Task Ablations
37+
"task_ablations/bicaptioning_R_50_L1_H2048.yaml": "fm1nq819q74vr0kqcd3gkivlzf06xvko.pth",
38+
"task_ablations/captioning_R_50_L1_H2048.yaml": "7fopt8k2eutz9qvth2hh6j00o7z4o7ps.pth",
39+
"task_ablations/token_classification_R_50.yaml": "qwvfnji51g4gvba7i5mrw2ph5z8yfty9.pth",
40+
"task_ablations/multilabel_classification_R_50.yaml": "tk1hlcue9c3268bds3h036ckk7a9btlr.pth",
41+
42+
# Width Ablations
43+
"width_ablations/bicaptioning_R_50_L1_H512.yaml": "qostt3be0pgnd0xf55vdte3wa49x6k99.pth",
44+
"width_ablations/bicaptioning_R_50_L1_H768.yaml": "v0p80tya0wjgsj0liqyvt386903xbwxc.pth",
45+
"width_ablations/bicaptioning_R_50_L1_H1024.yaml": "s2o3tvujcx2djoz1ouvuea27hrys1fbm.pth",
46+
"width_ablations/bicaptioning_R_50_L1_H2048.yaml": "fm1nq819q74vr0kqcd3gkivlzf06xvko.pth",
47+
48+
# Depth Ablations
49+
"depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "s2o3tvujcx2djoz1ouvuea27hrys1fbm.pth",
50+
"depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "5enura2ao2b0iyigcuikfsdd0osun0it.pth",
51+
"depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "xit11ev6h3q7h8wth5qokewxcn6yot2n.pth",
52+
"depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "secpwhjx9oq59mkzsztjaews6n3680bj.pth",
53+
54+
# Backbone Ablations
55+
"backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "s2o3tvujcx2djoz1ouvuea27hrys1fbm.pth",
56+
"backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "0rlu15xq796tz3ebvz7lf5dbpti421le.pth",
57+
"backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "i3p45pr78jdz74r29qkj23v8kzb6gcsq.pth",
58+
}
59+
# Backbone from best model: fotpti1uk6bpoobeazysfc6fdbndvy90.pth
60+
61+
62+
def get(config_path, pretrained: bool = False):
63+
r"""
64+
Get a model specified by relative path under Detectron2's official ``configs/`` directory.
65+
66+
Parameters
67+
----------
68+
config_path: str
69+
Name of config file relative to ``configs/`` directory under project
70+
root. (For example, ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``)
71+
pretrained: bool, optional (default = False)
72+
If ``True``, will initialize the model with the pretrained weights. If
73+
``False``, the weights will be initialized randomly.
74+
"""
75+
76+
# Get the original path to config file (shipped with inside the package).
77+
_pkg_config_path = pkg_resources.resource_filename(
78+
"virtex.model_zoo", os.path.join("configs", config_path)
79+
)
80+
if not os.path.exists(_pkg_config_path):
81+
raise RuntimeError("{} not available in Model Zoo!".format(config_path))
82+
83+
_C = Config(_pkg_config_path)
84+
model = PretrainingModelFactory.from_config(_C)
85+
86+
if pretrained:
87+
# Get URL for the checkpoint for this config path.
88+
if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX:
89+
url_suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path]
90+
checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{url_suffix}"
91+
else:
92+
raise RuntimeError("{} not available in Model Zoo!".format(config_path))
93+
94+
# Download the pretrained model weights and save with a sensible name.
95+
# This will be downloaded only if it does not exist.
96+
checkpoint_path = download(
97+
checkpoint_url,
98+
dir=os.path.expanduser("~/.torch/virtex_cache"),
99+
filename=os.path.basename(config_path).replace(".yaml", ".pth")
100+
)
101+
CheckpointManager(model=model).load(checkpoint_path)
102+
103+
return model

0 commit comments

Comments
 (0)