Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug with input_dim and pure inference #17

Open
fanOfJava opened this issue Mar 21, 2023 · 11 comments
Open

Bug with input_dim and pure inference #17

fanOfJava opened this issue Mar 21, 2023 · 11 comments
Labels
bug Something isn't working

Comments

@fanOfJava
Copy link

在finetune阶段,无论run.sh阶段设置的input_dtim是多少,最终都会是1024.

@YuanGongND
Copy link
Owner

Hi,

Can you elaborate on which argument you are referring to, is that

target_length=512

Thanks!

-Yuan

@fanOfJava
Copy link
Author

yes

@YuanGongND
Copy link
Owner

Can you explain why the value would be 1024?

It seems to me that it changes

ssast/src/run.py

Lines 97 to 101 in a1a3eec

audio_conf = {'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset,
'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':args.noise}
val_audio_conf = {'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset,
'mode': 'evaluation', 'mean': args.dataset_mean, 'std': args.dataset_std, 'noise': False}

and

ssast/src/run.py

Lines 132 to 138 in a1a3eec

audio_model = ASTModel(fshape=args.fshape, tshape=args.tshape, fstride=args.fshape, tstride=args.tshape,
input_fdim=args.num_mel_bins, input_tdim=args.target_length, model_size=args.model_size, pretrain_stage=True)
# in the fine-tuning stage
else:
audio_model = ASTModel(label_dim=args.n_class, fshape=args.fshape, tshape=args.tshape, fstride=args.fstride, tstride=args.tstride,
input_fdim=args.num_mel_bins, input_tdim=args.target_length, model_size=args.model_size, pretrain_stage=False,
load_pretrained_mdl_path=args.pretrained_mdl_path)

for both dataloading and model instantiation.

@fanOfJava
Copy link
Author

because the process of loading the model file ssast-base-patch-400.pth changes the target_length, the code is shown as below
try:
p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2],sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

@fanOfJava
Copy link
Author

我猜测这也是为什么finetune完之后,做纯推理时load model file会报错的原因。不知道我理解的是否对

@YuanGongND
Copy link
Owner

can you paste the error code here?

@fanOfJava
Copy link
Author

can you paste the error code here?

you can print the p_input_tdim before 156 line of ast_model,you will find the error

@YuanGongND
Copy link
Owner

I don't have enough time to run it again. The code is a cleaned up version from the development version. It went through a brief test and I guess I did take care of this. So if you already have a error message, that would be very helpful. It might due to something else.

@fanOfJava
Copy link
Author

我相信很多人都有同样的问题。因为finetune之后保存的模型,根本没法load进来做纯推理,我也不知道该如何测试训练好的模型的真实性能

@YuanGongND
Copy link
Owner

Oh I see, yes, that is a known problem. It should be fine if you finetune a pretrained model that has different target_length, but if you want to take the finetuned model for deployment, you will get an error.

For checking the performance, once you finetune a pretrained model, the script will print out the accuracy (or mAP) and also save the result on disk.

For deploy the model for inference, you will need to fix the bug.

@YuanGongND
Copy link
Owner

Can you check this: #4

@YuanGongND YuanGongND added the bug Something isn't working label Mar 21, 2023
@YuanGongND YuanGongND changed the title 这个代码貌似有问题 Bug with input_dim and pure inference Mar 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants