Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] It is not possible to use "analysis_tools\get_flops.py" to get the number of parameters and calculations #2822

Open
3 tasks done
ya-92626 opened this issue Apr 14, 2024 · 3 comments
Assignees

Comments

@ya-92626
Copy link

Branch

main branch (1.x version, such as v1.0.0, or dev-1.x branch)

Prerequisite

Environment

python 3.9.13
pytorch 1.12.1

Describe the bug

Traceback (most recent call last):
File "F:\opensjj\mmaction2-main\tools\analysis_tools\get_flops.py", line 72, in
main()
File "F:\opensjj\mmaction2-main\tools\analysis_tools\get_flops.py", line 58, in main
analysis_results = get_model_complexity_info(model, input_shape)
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\mmengine\analysis\print_helper.py", line 748, in get_model_complexity_info
flops = flop_handler.total()
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\mmengine\analysis\jit_analysis.py", line 268, in total
stats = self._analyze()
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\mmengine\analysis\jit_analysis.py", line 570, in _analyze
graph = _get_scoped_trace_graph(self._model, self._inputs,
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\mmengine\analysis\jit_analysis.py", line 194, in _get_scoped_trace_graph
graph, _ = _get_trace_graph(module, inputs)
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\jit_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\jit_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\jit_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\YA\anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "F:\opensjj\mmaction2-main\mmaction\models\recognizers\recognizer3d_mm.py", line 37, in extract_feat
for m, m_data in inputs.items():
AttributeError: 'Tensor' object has no attribute 'items'

Reproduces the problem - code sample

import argparse

from mmengine import Config
from mmengine.registry import init_default_scope

from mmaction.registry import MODELS

try:
from mmengine.analysis import get_model_complexity_info
except ImportError:
raise ImportError('Please upgrade mmcv to >0.6.2')

def parse_args():
parser = argparse.ArgumentParser(description='Get model flops and params')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[12,3,8,224,224],
help='input image size')
args = parser.parse_args()
return args

def main():

args = parse_args()

if len(args.shape) == 1:
    input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
    input_shape = (1, 3) + tuple(args.shape)
elif len(args.shape) == 4:
    # n, c, h, w = args.shape for 2D recognizer
    input_shape = tuple(args.shape)
elif len(args.shape) == 5:
    # n, c, t, h, w = args.shape for 3D recognizer or
    # n, m, t, v, c = args.shape for GCN-based recognizer
    input_shape = tuple(args.shape)
else:
    raise ValueError('invalid input shape')

cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmaction'))
model = MODELS.build(cfg.model)
model.eval()

if hasattr(model, 'extract_feat'):
    model.forward = model.extract_feat
else:
    raise NotImplementedError(
        'FLOPs counter is currently not currently supported with {}'.
        format(model.__class__.__name__))

analysis_results = get_model_complexity_info(model, input_shape)
flops = analysis_results['flops_str']
params = analysis_results['params_str']
table = analysis_results['out_table']
print(table)
split_line = '=' * 30
print(f'\n{split_line}\nInput shape: {input_shape}\n'
      f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
      'You may need to check if all ops are supported and verify that the '
      'flops computation is correct.')

if name == 'main':
main()

Reproduces the problem - command or script

No response

Reproduces the problem - error message

No response

Additional information

Here is another question, if we want to analyze the number of parameters and calculations of rgbposec3d, what should be filled here?
“default=[12,3,8,224,224],”
Here is another question, if we want to analyze the number of parameters and calculations of rgb_only, what should be filled here?

@LiYH1234
Copy link

I also have the same problem as you. Can you let me know if you have solved it? Also, there is a problem with my 1.0.0 version loss image and the top 1 image that cannot be produced. Do you know how to solve it

@ya-92626
Copy link
Author

我和你也有同样的问题。如果你已经解决了,你能告诉我吗?另外,我的 1.0.0 版本丢失映像和无法生成的前 1 个映像存在问题。你知道怎么解决吗
I haven't solved it yet,

@ya-92626
Copy link
Author

我和你也有同样的问题。如果你已经解决了,你能告诉我吗?另外,我的 1.0.0 版本丢失映像和无法生成的前 1 个映像存在问题。你知道怎么解决吗
I haven't solved it yet,
The graph can be drawn in terms of log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants