|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | +# Modified by Bowen Cheng from https://github.com/facebookresearch/detectron2/blob/main/tools/analyze_model.py |
| 4 | + |
| 5 | +import logging |
| 6 | +import numpy as np |
| 7 | +from collections import Counter |
| 8 | +import tqdm |
| 9 | +from fvcore.nn import flop_count_table # can also try flop_count_str |
| 10 | + |
| 11 | +from detectron2.checkpoint import DetectionCheckpointer |
| 12 | +from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate |
| 13 | +from detectron2.data import build_detection_test_loader |
| 14 | +from detectron2.engine import default_argument_parser |
| 15 | +from detectron2.modeling import build_model |
| 16 | +from detectron2.projects.deeplab import add_deeplab_config |
| 17 | +from detectron2.utils.analysis import ( |
| 18 | + FlopCountAnalysis, |
| 19 | + activation_count_operators, |
| 20 | + parameter_count_table, |
| 21 | +) |
| 22 | +from detectron2.utils.logger import setup_logger |
| 23 | + |
| 24 | +# fmt: off |
| 25 | +import os |
| 26 | +import sys |
| 27 | +sys.path.insert(1, os.path.join(sys.path[0], '..')) |
| 28 | +# fmt: on |
| 29 | + |
| 30 | +from mask2former import add_maskformer2_config |
| 31 | + |
| 32 | +logger = logging.getLogger("detectron2") |
| 33 | + |
| 34 | + |
| 35 | +def setup(args): |
| 36 | + if args.config_file.endswith(".yaml"): |
| 37 | + cfg = get_cfg() |
| 38 | + add_deeplab_config(cfg) |
| 39 | + add_maskformer2_config(cfg) |
| 40 | + cfg.merge_from_file(args.config_file) |
| 41 | + cfg.DATALOADER.NUM_WORKERS = 0 |
| 42 | + cfg.merge_from_list(args.opts) |
| 43 | + cfg.freeze() |
| 44 | + else: |
| 45 | + cfg = LazyConfig.load(args.config_file) |
| 46 | + cfg = LazyConfig.apply_overrides(cfg, args.opts) |
| 47 | + setup_logger(name="fvcore") |
| 48 | + setup_logger() |
| 49 | + return cfg |
| 50 | + |
| 51 | + |
| 52 | +def do_flop(cfg): |
| 53 | + if isinstance(cfg, CfgNode): |
| 54 | + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
| 55 | + model = build_model(cfg) |
| 56 | + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) |
| 57 | + else: |
| 58 | + data_loader = instantiate(cfg.dataloader.test) |
| 59 | + model = instantiate(cfg.model) |
| 60 | + model.to(cfg.train.device) |
| 61 | + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) |
| 62 | + model.eval() |
| 63 | + |
| 64 | + counts = Counter() |
| 65 | + total_flops = [] |
| 66 | + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa |
| 67 | + if args.use_fixed_input_size and isinstance(cfg, CfgNode): |
| 68 | + import torch |
| 69 | + crop_size = cfg.INPUT.CROP.SIZE[0] |
| 70 | + data[0]["image"] = torch.zeros((3, crop_size, crop_size)) |
| 71 | + flops = FlopCountAnalysis(model, data) |
| 72 | + if idx > 0: |
| 73 | + flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) |
| 74 | + counts += flops.by_operator() |
| 75 | + total_flops.append(flops.total()) |
| 76 | + |
| 77 | + logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) |
| 78 | + logger.info( |
| 79 | + "Average GFlops for each type of operators:\n" |
| 80 | + + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) |
| 81 | + ) |
| 82 | + logger.info( |
| 83 | + "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def do_activation(cfg): |
| 88 | + if isinstance(cfg, CfgNode): |
| 89 | + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
| 90 | + model = build_model(cfg) |
| 91 | + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) |
| 92 | + else: |
| 93 | + data_loader = instantiate(cfg.dataloader.test) |
| 94 | + model = instantiate(cfg.model) |
| 95 | + model.to(cfg.train.device) |
| 96 | + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) |
| 97 | + model.eval() |
| 98 | + |
| 99 | + counts = Counter() |
| 100 | + total_activations = [] |
| 101 | + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa |
| 102 | + count = activation_count_operators(model, data) |
| 103 | + counts += count |
| 104 | + total_activations.append(sum(count.values())) |
| 105 | + logger.info( |
| 106 | + "(Million) Activations for Each Type of Operators:\n" |
| 107 | + + str([(k, v / idx) for k, v in counts.items()]) |
| 108 | + ) |
| 109 | + logger.info( |
| 110 | + "Total (Million) Activations: {}±{}".format( |
| 111 | + np.mean(total_activations), np.std(total_activations) |
| 112 | + ) |
| 113 | + ) |
| 114 | + |
| 115 | + |
| 116 | +def do_parameter(cfg): |
| 117 | + if isinstance(cfg, CfgNode): |
| 118 | + model = build_model(cfg) |
| 119 | + else: |
| 120 | + model = instantiate(cfg.model) |
| 121 | + logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) |
| 122 | + |
| 123 | + |
| 124 | +def do_structure(cfg): |
| 125 | + if isinstance(cfg, CfgNode): |
| 126 | + model = build_model(cfg) |
| 127 | + else: |
| 128 | + model = instantiate(cfg.model) |
| 129 | + logger.info("Model Structure:\n" + str(model)) |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + parser = default_argument_parser( |
| 134 | + epilog=""" |
| 135 | +Examples: |
| 136 | +To show parameters of a model: |
| 137 | +$ ./analyze_model.py --tasks parameter \\ |
| 138 | + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml |
| 139 | +Flops and activations are data-dependent, therefore inputs and model weights |
| 140 | +are needed to count them: |
| 141 | +$ ./analyze_model.py --num-inputs 100 --tasks flop \\ |
| 142 | + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ |
| 143 | + MODEL.WEIGHTS /path/to/model.pkl |
| 144 | +""" |
| 145 | + ) |
| 146 | + parser.add_argument( |
| 147 | + "--tasks", |
| 148 | + choices=["flop", "activation", "parameter", "structure"], |
| 149 | + required=True, |
| 150 | + nargs="+", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "-n", |
| 154 | + "--num-inputs", |
| 155 | + default=100, |
| 156 | + type=int, |
| 157 | + help="number of inputs used to compute statistics for flops/activations, " |
| 158 | + "both are data dependent.", |
| 159 | + ) |
| 160 | + parser.add_argument( |
| 161 | + "--use-fixed-input-size", |
| 162 | + action="store_true", |
| 163 | + help="use fixed input size when calculating flops", |
| 164 | + ) |
| 165 | + args = parser.parse_args() |
| 166 | + assert not args.eval_only |
| 167 | + assert args.num_gpus == 1 |
| 168 | + |
| 169 | + cfg = setup(args) |
| 170 | + |
| 171 | + for task in args.tasks: |
| 172 | + { |
| 173 | + "flop": do_flop, |
| 174 | + "activation": do_activation, |
| 175 | + "parameter": do_parameter, |
| 176 | + "structure": do_structure, |
| 177 | + }[task](cfg) |
0 commit comments