Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remora model finetuning #171

Open
palakela opened this issue May 2, 2024 · 8 comments
Open

Remora model finetuning #171

palakela opened this issue May 2, 2024 · 8 comments

Comments

@palakela
Copy link

palakela commented May 2, 2024

Hello,

I am trying to finetune this model to call 6mA (dna_r10.4.1_e8.2_5khz_400bps_sup_v4.2.0_6ma_v2.pt), but independently from how many layers I try to freeze (I've also tried with 0) I run into this error. Any idea on which could be the issue?

Traceback (most recent call last):
  File "./bin/remora", line 8, in <module>
    sys.exit(run())
  File "./lib/python3.10/site-packages/remora/main.py", line 71, in run
    cmd_func(args)
  File "./lib/python3.10/site-packages/remora/parsers.py", line 857, in run_model_train
    train_model(
  File "./lib/python3.10/site-packages/remora/train_model.py", line 180, in train_model
    ckpt, model = model_util.continue_from_checkpoint(
  File "./lib/python3.10/site-packages/remora/model_util.py", line 247, in continue_from_checkpoint
    if ckpt["state_dict"] is None:
  File "./lib/python3.10/site-packages/torch/jit/_script.py", line 862, in __getitem__
    return self.forward_magic_method("__getitem__", idx)
  File "./lib/python3.10/site-packages/torch/jit/_script.py", line 855, in forward_magic_method
    raise NotImplementedError()
NotImplementedError

For reference, I am using remora v. 3.1.0 and this is the command:

remora model train \
	${wd}/data/prepData/train_dataset.jsn \
	--model ${wd}/data/ONT/ConvLSTM_w_ref.py \
	--finetune-path ${wd}/dna_r10.4.1_e8.2_5khz_400bps_sup_v4.2.0_6ma_v2.pt \
	--device 0 \
	--chunk-context 50 50 \
	--output-path ${wd}/data/models/train_results
@marcus1487
Copy link
Collaborator

The resume feature requires a checkpoint model not a torch script model (optimized for inference, and lacking training status). You can recreate the checkpoint with the following snippet. I will try to add this to the core API to make this a bit simpler.

model, model_metadata = model_util._raw_load_torchscript_model(model_path)
model_metadata["epoch"] = 0
state_dict = model.state_dict()
if "total_ops" in state_dict.keys():
    state_dict.pop("total_ops", None)
if "total_params" in state_dict.keys():
    state_dict.pop("total_params", None)
model_metadata["state_dict"] = state_dict
model_metadata["opt"] = None
torch.save(
    model_metadata,
    out_path,
)

@palakela
Copy link
Author

palakela commented May 10, 2024

Thanks for clarifying!

I've managed to produce the checkpoint file to use in the remora training command. Yet, I am running into this other error now. Any idea from where the issue is coming from?

******************** WARNING [11:35:38.198:MainProcess:MainThread:train_model.py:196]: Size mismatch between pretrained model and selected size. Using pretrained model size. ********************
Traceback (most recent call last):
  File "./bin/remora", line 8, in <module>
    sys.exit(run())
             ^^^^^
  File "./lib/python3.12/site-packages/remora/main.py", line 71, in run
    cmd_func(args)
  File "./lib/python3.12/site-packages/remora/parsers.py", line 857, in run_model_train
    train_model(
  File "./lib/python3.12/site-packages/remora/train_model.py", line 202, in train_model
    raise RemoraError(
remora.RemoraError: The chunk context of the pre-trained model and the dataset do not match.

Reference command line:

remora model train \
	${wd}/data/prepData/train_dataset.jsn \
	--model ${wd}/data/ONT/ConvLSTM_w_ref.py \
	--finetune-path ${wd}/model_checkpoint.pth \
	--device 0 \
	--chunk-context 50 50 \
	--output-path ${wd}/data/models/train_results

@marcus1487
Copy link
Collaborator

Training from a checkpoint file requires that the same data input size be used. Setting the --chunk-context to the same value as the pre-trained model should resolve this issue. We will look into setting these parameters automatically in from the pre-trained model in the future. The best way to check this value is to either load the model using the python API or export the .pt file to a dorado model using the remora model export command and viewing the metadata in the config.toml file produced. I will flag up making a remora model inspect command to print out this information a bit more easily in the future.

@marcus1487 marcus1487 changed the title NotImplementedError() when finetuning remora model Remora model finetuning May 21, 2024
@palakela
Copy link
Author

I have double checked, both the pre-trained model and the training dataset have the same chunk_context [100,100]. changing this parameter in the remora model train commant, does not solve the issue.

@marcus1487
Copy link
Collaborator

Could you post the exact command and error message here to help resolve the issue?

@palakela
Copy link
Author

here the full command line I used:

remora dataset prepare \
	--output-path ${wd}data/prepData/MOCK_6mA \
	--refine-kmer-level-table ${wd}data/ONT/9mer_levels_v1.txt \
	--refine-rough-rescale \
	--motif A 0 \
	--mod-base-control \
	--max-chunks-per-read 20 \
	--num-extract-alignment-workers 24 \
	--num-extract-chunks-workers 24 \
	--chunk-context 100 100 \
	--kmer-context-bases 4 4 \
	${wd}data/6mA_unmeth.pod5 \
	${wd}data/6mA_unmeth.pass.bam

remora dataset prepare \
	--output-path ${wd}data/prepData/MOD_6mA \
	--refine-kmer-level-table ${wd}data/ONT/9mer_levels_v1.txt \
	--refine-rough-rescale \
	--motif A 0 \
	--mod-base a 6mA \
	--max-chunks-per-read 20 \
	--num-extract-alignment-workers 24 \
	--num-extract-chunks-workers 24 \
	--chunk-context 100 100 \
	--kmer-context-bases 4 4 \
	${wd}data/7_6mA.pod5 \
	${wd}data/7_6mA.pass.bam

remora dataset make_config \
	${wd}data/prepData/train_dataset.jsn \
	${wd}data/prepData/MOCK_6mA \
	${wd}data/prepData/MOD_6mA \
	--dataset-weights 1 1 \
	--log-filename ${wd}data/prepData/train_dataset.log

python make_checkpoint.py \
	--model dna_r10.4.1_e8.2_5khz_400bps_sup_v4.2.0_6ma_v2.pt \
	--output ${wd}data/models/checkpoint_model.pth

remora model train \
	${wd}data/prepData/train_dataset.jsn \
	--model ${wd}data/ONT/ConvLSTM_w_ref.py \
	--finetune-path ${wd}data/models/checkpoint_model.pth \
	--freeze-num-layers 15 \
	--device 0 \
	--chunk-context 100 100 \
	--output-path ${wd}data/models/train_results_freeze15 \
	--kmer-context-bases 4 4

make_checkpoint.py is the snipped you provided me, taking the .pt file in input and returning the model checkpoint as output

here the full error message arising after the last command:

[11:30:12.563] Seed selected is 1106960644
[11:30:12.637] Loading dataset from Remora dataset config
[11:30:15.316] Dataset summary:
                     size : 33,433,400
     modified_base_labels : True
                mod_bases : ['a']
           mod_long_names : ['6mA']
       kmer_context_bases : (4, 4)
            chunk_context : (100, 100)
                   motifs : [('A', 0)]
           reverse_signal : False
 chunk_extract_base_start : False
     chunk_extract_offset : 0
          sig_map_refiner : Loaded 9-mer table with 7 central position. Rough re-scaling will be executed.

[11:30:15.317] Loading model
[11:30:15.467] Model structure:
network(
  (sig_conv1): Conv1d(1, 4, kernel_size=(5,), stride=(1,))
  (sig_bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sig_conv2): Conv1d(4, 16, kernel_size=(5,), stride=(1,))
  (sig_bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sig_conv3): Conv1d(16, 64, kernel_size=(9,), stride=(3,))
  (sig_bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (seq_conv1): Conv1d(36, 16, kernel_size=(5,), stride=(1,))
  (seq_bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (seq_conv2): Conv1d(16, 64, kernel_size=(13,), stride=(3,))
  (seq_bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (merge_conv1): Conv1d(128, 64, kernel_size=(5,), stride=(1,))
  (merge_bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm1): LSTM(64, 64)
  (lstm2): LSTM(64, 64)
  (fc): Linear(in_features=64, out_features=2, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
******************** WARNING [11:30:15.499:MainProcess:MainThread:train_model.py:196]: Size mismatch between pretrained model and selected size. Using pretrained model size. ********************
Traceback (most recent call last):
  File "./envs/remora_v3.1.0/bin/remora", line 8, in <module>
    sys.exit(run())
             ^^^^^
  File "./envs/remora_v3.1.0/lib/python3.12/site-packages/remora/main.py", line 71, in run
    cmd_func(args)
  File "./envs/remora_v3.1.0/lib/python3.12/site-packages/remora/parsers.py", line 857, in run_model_train
    train_model(
  File "./envs/remora_v3.1.0/lib/python3.12/site-packages/remora/train_model.py", line 202, in train_model
    raise RemoraError(
remora.RemoraError: The chunk context of the pre-trained model and the dataset do not match.

@marcus1487
Copy link
Collaborator

I've made some minor changes around this logic in the latest version. Could you upgrade and report if this is resolved in the latest version?

@palakela
Copy link
Author

palakela commented Jun 5, 2024

Update using the new released remora v3.2.0. I run again all the command mentioned above, but I got a very similar error:

[11:47:09.706] Seed selected is 442297807
[11:47:09.790] Loading dataset from Remora dataset config
[11:47:09.849] Dataset summary:
                     size : 33,432,925
     modified_base_labels : True
                mod_bases : ['a']
           mod_long_names : ['6mA']
       kmer_context_bases : (4, 4)
            chunk_context : (100, 100)
                   motifs : [('A', 0)]
           reverse_signal : False
 chunk_extract_base_start : False
     chunk_extract_offset : 0
               pa_scaling : None
          sig_map_refiner : Loaded 9-mer table with 7 central position. Rough re-scaling will be executed.

[11:47:09.850] Loading model
[11:47:09.949] Model structure:
network(
  (sig_conv1): Conv1d(1, 4, kernel_size=(5,), stride=(1,))
  (sig_bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sig_conv2): Conv1d(4, 16, kernel_size=(5,), stride=(1,))
  (sig_bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sig_conv3): Conv1d(16, 64, kernel_size=(9,), stride=(3,))
  (sig_bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (seq_conv1): Conv1d(36, 16, kernel_size=(5,), stride=(1,))
  (seq_bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (seq_conv2): Conv1d(16, 64, kernel_size=(13,), stride=(3,))
  (seq_bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (merge_conv1): Conv1d(128, 64, kernel_size=(5,), stride=(1,))
  (merge_bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm1): LSTM(64, 64)
  (lstm2): LSTM(64, 64)
  (fc): Linear(in_features=64, out_features=2, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
[11:47:09.949] Gradients will be clipped (by value) at 0.00 MADs above the median of the last 1000 gradient maximums.
******************** WARNING [11:47:10.002:MainProcess:MainThread:train_model.py:289]: Size mismatch between pretrained model and selected size. Using pretrained model size. ********************
Traceback (most recent call last):
  File "./envs/remora_v3.2.0/bin/remora", line 8, in <module>
    sys.exit(run())
             ^^^^^
  File "./envs/remora_v3.2.0/lib/python3.12/site-packages/remora/main.py", line 71, in run
    cmd_func(args)
  File "./envs/remora_v3.2.0/lib/python3.12/site-packages/remora/parsers.py", line 1008, in run_model_train
    train_model(
  File "./envs/remora_v3.2.0/lib/python3.12/site-packages/remora/train_model.py", line 295, in train_model
    raise RemoraError(
remora.RemoraError: The chunk context of the pre-trained model and the dataset do not match.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants