-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename the train method to fit to avoid confusion with PyTorch's built-in train method. #205
base: master
Are you sure you want to change the base?
Conversation
been facing issues loading setting up the GPU as well |
what do you mean? |
tried loading the device file on colab didn't run by when do u think they'll approve on the model.fit change? |
Not sure about this, but you can patch this function as a temporary solution like( from kan import KAN
def patch_kan_train_function():
def patch_train(*args, **kwargs):
pass
KAN.train = patch_train |
and then use it as model.fit moving on right? |
In my specific use case, the class MyNet(nn.Module):
def __init__(self):
...
self.output_module = KAN(...) # Integrate KAN as a component
# Training loop (custom, explicit control)
for train_batch in train_dataloader:
...
loss.backward() # Compute gradients
optim.step() # Update model parameters
... If you prefer to use the standard Keras |
This is me trying to run this code on Google Colab with their device example Their Code Example: |
@KindXiaoming Could you please review this? I think this |
Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play. |
Could u also check the gpu issue @KindXiaoming |
hi @ChrisD-7, can you pull the lastest version and see if gpu issue still persists. Apple GPU (MPS) seems not solved yet, but I don't see more issues regarding cuda. |
This concern is valid, but for plug-and-play scenarios, I believe a warning would be sufficient instead of raising an exception. It might complicate the train function if we check the invoking function to determine the user's intent (whether to switch to the model's training mode or simply train the model). However, if necessary, I can create a slightly more complex train function to handle this issue. |
@KindXiaoming tried pulling it again is it something with the colab env issue? I'm trying to run it for a classification model and faced a RAM issue thought of loading the GPU and faced this issue. |
The
train
method in the KAN class is used to define the main training logic for the model. However, in PyTorch, the name "train" is typically used for switching the model between training and evaluation states, such asmodel.train()
. Also, in this implementation of code, if ther user calls themodel.train()
explicitly, it would raise error. To avoid confusion and improve code clarity, this PR renames thetrain
method tofit
.The
fit
method better conveys the purpose of the function, which is to train the model using the provided dataset and hyperparameters. This change ensures that the method name does not clash with PyTorch's built-in functionality and makes it clear that it is a custom model training function.The renaming is a minor change and does not affect the functionality of the method. All occurrences of
train
within the method have been replaced withfit
, and any calls to the method in other parts of the codebase have been updated accordingly.This change improves the readability and maintainability of the codebase by following a more consistent and intuitive naming convention.