Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing eq_param_index when loading BundleSolution #211

Open
ptflores1 opened this issue Dec 30, 2023 · 4 comments · May be fixed by #212
Open

Missing eq_param_index when loading BundleSolution #211

ptflores1 opened this issue Dec 30, 2023 · 4 comments · May be fixed by #212

Comments

@ptflores1
Copy link

When calling BundleSolution1D.load() the new instance is initialized without an eq_param_index argument. Since the _diff_eqs_wrapper function uses the local variable eq_param_index there is no way of recovering the original value and the functionality is lost.

@ptflores1
Copy link
Author

The code in

elif load_dict["type_name"] == "BundleSolver1D" or load_dict["parent_type_name"] == "BundleSolver1D":
t_min = load_dict['solver'].r_min[0]
t_max = load_dict['solver'].r_max[0]
solver = cls(ode_system=de_system,
conditions=cond,
metrics=load_dict['metrics'],
nets=nets,
optimizer=optimizer,
train_generator=train_generator,
valid_generator=valid_generator,
t_min=t_min,
t_max=t_max,
theta_min=tuple(load_dict['solver'].r_min[1:]),
theta_max=tuple(load_dict['solver'].r_max[1:]))

can be modified to

elif load_dict["type_name"] == "BundleSolver1D" or load_dict["parent_type_name"] == "BundleSolver1D": 
     t_min = load_dict['solver'].r_min[0] 
     t_max = load_dict['solver'].r_max[0] 
  
     solver = cls(ode_system=de_system, 
                  conditions=cond, 
                  metrics=load_dict['metrics'], 
                  nets=nets, 
                  optimizer=optimizer, 
                  train_generator=train_generator, 
                  valid_generator=valid_generator, 
                  t_min=t_min, 
                  t_max=t_max, 
                  theta_min=tuple(load_dict['solver'].r_min[1:]), 
                  theta_max=tuple(load_dict['solver'].r_max[1:]),
                  eq_param_index=(index - len(cond) - 1 for index in load_dict['solver'].eq_param_index)) # new line
) 

@ptflores1
Copy link
Author

@shuheng-liu would that be okay for a PR?

@sathvikbhagavan
Copy link
Collaborator

sathvikbhagavan commented Dec 31, 2023

Thanks for the report! I don't think we need to recompute as it is stored in https://github.com/NeuroDiffGym/neurodiffeq/blob/master/neurodiffeq/solvers.py#L1354

It should be fixed in #212

@ptflores1
Copy link
Author

Great! Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants