You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is the code I changed based on example1. I just increased the number of grids in the list. When the number of grids was updated to 200, the loss suddenly became larger and jitter occurred. I don’t know what went wrong.
import sys
sys.path.append("..")
from kan import *
# initialize KAN with G=3
model = KAN(width=[2,1,1], grid=3, k=3)
# create dataset
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
grids = np.array([3,5,10,20,50,100,200,500,1000])
train_losses = []
test_losses = []
steps = 200
k = 3
for i in range(grids.shape[0]):
if i == 0:
model = KAN(width=[2,1,1], grid=grids[i], k=k)
if i != 0:
model = KAN(width=[2,1,1], grid=grids[i], k=k).initialize_from_another_model(model, dataset['train_input'])
results = model.train(dataset, opt="LBFGS", steps=steps, stop_grid_update_step=30)
train_losses += results['train_loss']
test_losses += results['test_loss']
Hi, at high precision, the results can be quite sensitive to random seeds. At least when I made the plot, noise_scale_base=0.0 is used by default, and the default now becomes noise_scale_base=0.1.
Please try if model = KAN(width=[2,1,1], grid=3, k=3, noise_scale_base=0.0) helps. You may also try different random seeds to see how random seeds may affect results, e.g., using model = KAN(width=[2,1,1], grid=3, k=3, noise_scale_base=0.0, seed=42). Also stop_grid_update_step=50 is used by default, and you are using 30. Overall, my feeling is that since new changes are happening very fast, I think exactly reproducible is hard but you'd get something similar.
This is the code I changed based on example1. I just increased the number of grids in the list. When the number of grids was updated to 200, the loss suddenly became larger and jitter occurred. I don’t know what went wrong.
The following are the training results
The text was updated successfully, but these errors were encountered: