Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename the train method to fit to avoid confusion with PyTorch's built-in train method. #205

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

linkedlist771
Copy link

@linkedlist771 linkedlist771 commented May 16, 2024

The train method in the KAN class is used to define the main training logic for the model. However, in PyTorch, the name "train" is typically used for switching the model between training and evaluation states, such as model.train(). Also, in this implementation of code, if ther user calls the model.train() explicitly, it would raise error. To avoid confusion and improve code clarity, this PR renames the train method to fit.

The fit method better conveys the purpose of the function, which is to train the model using the provided dataset and hyperparameters. This change ensures that the method name does not clash with PyTorch's built-in functionality and makes it clear that it is a custom model training function.

The renaming is a minor change and does not affect the functionality of the method. All occurrences of train within the method have been replaced with fit, and any calls to the method in other parts of the codebase have been updated accordingly.

This change improves the readability and maintainability of the codebase by following a more consistent and intuitive naming convention.

@ChrisD-7
Copy link

been facing issues loading setting up the GPU as well

@linkedlist771
Copy link
Author

been facing issues loading setting up the GPU as well

what do you mean?

@ChrisD-7
Copy link

ChrisD-7 commented May 16, 2024

tried loading the device file on colab didn't run

by when do u think they'll approve on the model.fit change?

@linkedlist771
Copy link
Author

Not sure about this, but you can patch this function as a temporary solution like(this function should be called before you use the kan lib):

from kan import KAN
def patch_kan_train_function():
    def patch_train(*args, **kwargs):
        pass

    KAN.train = patch_train

@ChrisD-7
Copy link

and then use it as model.fit moving on right?

@linkedlist771
Copy link
Author

and then use it as model.fit moving on right?

In my specific use case, the KAN model is integrated as a component within my larger neural network model. I don't directly utilize the train method provided by the KAN class. Instead, I opt for a more customized approach, explicitly managing the training process using PyTorch's loss.backward() and optimizer.step() functions.

class MyNet(nn.Module):
    def __init__(self):
        ...
        self.output_module = KAN(...)  # Integrate KAN as a component

# Training loop (custom, explicit control)
for train_batch in train_dataloader:
    ... 
    loss.backward()  # Compute gradients
    optim.step()     # Update model parameters
    ...

If you prefer to use the standard Keras model.fit() API, you would typically rename your custom train function to fit. However, this isn't necessary in my current approach as I'm directly controlling the training loop.

@ChrisD-7
Copy link

This is me trying to run this code on Google Colab with their device example
image

Their Code Example:
https://github.com/KindXiaoming/pykan/blob/master/tutorials/API_10_device.ipynb

@linkedlist771
Copy link
Author

@KindXiaoming Could you please review this? I think this PR is somehow important for those who want to integrate the KAN into their modules.

@KindXiaoming
Copy link
Owner

Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.

@ChrisD-7
Copy link

Could u also check the gpu issue @KindXiaoming

@KindXiaoming
Copy link
Owner

hi @ChrisD-7, can you pull the lastest version and see if gpu issue still persists. Apple GPU (MPS) seems not solved yet, but I don't see more issues regarding cuda.

@linkedlist771
Copy link
Author

Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.

This concern is valid, but for plug-and-play scenarios, I believe a warning would be sufficient instead of raising an exception. It might complicate the train function if we check the invoking function to determine the user's intent (whether to switch to the model's training mode or simply train the model). However, if necessary, I can create a slightly more complex train function to handle this issue.

@ChrisD-7
Copy link

@KindXiaoming tried pulling it again is it something with the colab env issue? I'm trying to run it for a classification model and faced a RAM issue thought of loading the GPU and faced this issue.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants