Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble with pure inference #4

Open
beyondbeneath opened this issue Apr 4, 2022 · 13 comments
Open

Trouble with pure inference #4

beyondbeneath opened this issue Apr 4, 2022 · 13 comments
Labels
question Further information is requested

Comments

@beyondbeneath
Copy link

beyondbeneath commented Apr 4, 2022

Hello!

Firstly, thanks for this great work!

I managed to modify the AudioSet fine tuning script, and fine tuned a model on a new audio binary classification task. I started with the "Tiny" Patch model and used a batch size of 2. The resulting predictions on the evaluation set looked very promising!.

I'm now trying to write an inference script, to take that saved model to perform inferences, and running into some trouble. Which method do I actually need to call for pure inference? From the documentation it seems to describe only pre-training or fine-tuning, not inference.

More pressing, I can't actually get the model to load. I am trying to load the best_audio_model.pth as follows:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)

however this results in the errors:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[/content/ssast/src/models/ast_models.py](https://localhost:8080/#) in __init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
    146                 p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
--> 147                 p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
    148             except:

KeyError: 'module.p_input_fdim'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
1 frames
[/content/ssast/src/models/ast_models.py](https://localhost:8080/#) in __init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
    147                 p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
    148             except:
--> 149                 raise  ValueError('The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.')
    150 
    151             print('now load a SSL pretrained models from ' + load_pretrained_mdl_path)

ValueError: The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.

Is there anything obvious I'm missing or doing wrong? Would appreciate any guidance on how to load this model, and also perform an inference on a new .wav file. Thanks!

@YuanGongND
Copy link
Owner

YuanGongND commented Apr 4, 2022

Thanks for the kind words.

We use multiple GPU to train the model, so the model is that torch.nn.Dataparallel object. Even though you want to do single GPU inference, you need to do following:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# then do inference as normal
output  = ast_mdl(input)

Another method is to convert torch.nn.Dataparallel models back to normal torch.model objects. You can search online for the solution.

-Yuan

@YuanGongND
Copy link
Owner

YuanGongND commented Apr 4, 2022

Also the model input should be a spectrogram that is processed with the same normalization and feature extraction function

fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
and

ssast/src/dataloader.py

Lines 126 to 127 in 35ae7ab

fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=10)
.

You can also refer to https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py.

@YuanGongND YuanGongND added the question Further information is requested label Apr 4, 2022
@beyondbeneath
Copy link
Author

Thanks Yuan for your suggestions.

To be clear, where do I add the DataParallel wrapper? In your example, you put it after the ASTModel object, however it is in that initial call where it is failing the load, therefore I suspect I need to modify ast_models.py, or are you suggesting alternatively convert the serialised model to a parallel one?

@YuanGongND
Copy link
Owner

You should do something like this:
https://github.com/YuanGongND/ast/blob/7b2fe7084b622e540643b0d7d7ab736b5eb7683b/egs/audioset/inference.py#L82-L89

i.e., audio_model.load_state_dict(checkpoint) after convert it to Dataparallel object.

@YuanGongND
Copy link
Owner

I don't suggest changing ast_models.py.
Somehting like below should work:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model.load_state_dict(checkpoint)
# then do inference as normal
output  = ast_mdl(input)

@beyondbeneath
Copy link
Author

beyondbeneath commented Apr 4, 2022

Sorry I might not have been clear here.

The ast_mdl = ASTModel(...) is the line which is failing. Therefore I cannot convert it after, since that line never runs.

In that call, MODEL is the saved model file from the experiment/models/best_audio_model.pth saved during my previous fine tuning.

Or are you suggesting that I load the pre-trained model provided by this repo (SSAST-Tiny-Patch-400) and then load on top of this my checkpoint (and if so, is the checkpoint the best_audio_model.pth or best_optim_state.pth?

Does that make sense?

@YuanGongND
Copy link
Owner

That's weird, if you use my recipe to fine-tune the model, the saved model should be already a dataparallel object.

@beyondbeneath
Copy link
Author

beyondbeneath commented Apr 4, 2022

Yes... I most certainly used your code.

I essentially used the AudioSet fine tune script - here is the full .sh file with my modifications (not many):


set -x
export TORCH_HOME=../../pretrained_models
mkdir -p ./exp

if [ -e SSAST-Tiny-Patch-400.pth ]
then
    echo "pretrained model already downloaded."
else
    wget https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Tiny-Patch-400.pth?dl=1 -O SSAST-Tiny-Patch-400.pth
fi

pretrain_exp=
pretrain_model=SSAST-Tiny-Patch-400
pretrain_path=./${pretrain_exp}/${pretrain_model}.pth

dataset=testdata

dataset_mean=-4.2677393
dataset_std=4.5689974
target_length=1024
noise=False

task=ft_avgtok
model_size=base
head_lr=1
warmup=True

bal=none
lr=5e-5
epoch=25
tr_data=/content/drive/MyDrive/ssast_train1.json
te_data=/content/drive/MyDrive/ssast_val1.json
freqm=48
timem=192
mixup=0.5
fstride=10
tstride=10
fshape=16
tshape=16
batch_size=2
exp_dir=./exp/test01-${dataset}-f${fstride}-${fshape}-t${tstride}-${tshape}-b${batch_size}-lr${lr}-${task}-${model_size}-${pretrain_exp}-${pretrain_model}-${head_lr}x-noise${noise}-3

CUDA_CACHE_DISABLE=1 python -W ignore ../../run.py --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv ./data/class_labels_indices.csv --n_class 2 \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model False \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --fshape ${fshape} --tshape ${tshape} --warmup False --task ${task} \
--model_size ${model_size} --adaptschedule False \
--pretrained_mdl_path ${pretrain_path} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --target_length ${target_length} \
--num_mel_bins 128 --head_lr ${head_lr} --noise ${noise} \
--lrscheduler_start 10 --lrscheduler_step 5 --lrscheduler_decay 0.5 --wa True --wa_start 6 --wa_end 25 \
--loss BCE --metrics mAP

@YuanGongND
Copy link
Owner

YuanGongND commented Apr 4, 2022

I see. It might be caused by a bug in the code. I didn't consider your use case.

If the model is not too large, can you send the .pth file to me at yuangong@mit.edu?

I can take a look, but not immediately, I will need to find some spare time.

@beyondbeneath
Copy link
Author

FWIW, I got some kind of inference pipeline running - although the results do not match the output originally generated in your recipes for fine tuning, so I'm guessing there's major bugs in what I got working. But I thought it might be relevant anyway. This is all done after I successfully ran the fine-tuning scripts on a new dataset for binary classification:

First, use the same JSON style approach to make a dataloader (using your dataloader AudioDataset):

audio_conf = {
    'num_mel_bins': 128,
    'target_length': 1024,
    'freqm': 48,
    'timem': 192,
    'mixup': 0.5,
    'dataset': 'testdata',
    'mode':'evaluation',
    'mean':-4.2677393,
    'std':4.5689974,
    'noise':False
    }

train_loader = torch.utils.data.DataLoader(
    dataloader.AudioDataset(val_json,
                            label_csv=labels_csv,
                            audio_conf=audio_conf
                            ),
    batch_size=1,
    shuffle=False)

Next, load the original (pre-trained) model from which I fine-tuned from:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path='SSAST-Tiny-Patch-400.pth')

Then, load into this the state checkpoint (no idea if this works as expected, but is the only way I got anything to run without errors):

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load('best_optim_state.pth`, map_location=device)
if not isinstance(ast_mdl, torch.nn.DataParallel):
  ast_mdl = torch.nn.DataParallel(ast_mdl)
ast_mdl.load_state_dict(sd, strict=False)

Then, copying bits and pieces from the supplied traintest.py:

ast_mdl.eval()
with torch.no_grad():
  for i, (audio_input, labels) in enumerate(train_loader):
    prediction = torch.sigmoid(ast_mdl(audio_input, task='ft_avgtok')).to('cpu').detach()
    print(prediction.shape)
    print(np.array(prediction))

For single samples. the prediction probabilities do not sum to 1 nor do they match my expected values from the supplied recipe for fine tuning.

@YuanGongND
Copy link
Owner

The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

For a temporal workaround, you can change these two lines of code:

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

I will find a time to fix it.

@fanOfJava
Copy link

The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

For a temporal workaround, you can change these two lines of code:

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

I will find a time to fix it.

even I change these two lines,I still can not load the finetune-ed model.

@Lindar1994
Copy link

Hi, I am running into exactly the same error and have trouble to load a finetuned model. I am wondering if @beyondbeneath ever found a solution to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants