-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: KAN Model Does Not Support Tensor Input with More Than Two Dimensions #204
Comments
I meet the same problem |
I also encountered a similar problem.I did an experiment, ,mlp and kan set the same hidden layers number, if input tensors is [64,28x28],the memory usage of mlp and kan is similar,but if input tensors is [36848,28x28],The memory usage of mlp and kan |
I think this is related with the number of parameters, in from kan import KAN
import torch
model = KAN(width=[256 , 1],
grid=5,
k=3,
seed=0
)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
model = torch.nn.Linear(256, 1)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}") Output: Total parameters: 3585
Total parameters: 257 If you have different findings, keep me updated. |
|
In this case, sophisticated benchmark experiments should be designed to validate this approach. While not particularly challenging, the process can be monotonous. You need to profile the resource consumption of each configuration of the model, analyze the storage requirements for the model's parameters, gradients, optimizer, and input datasets... My previous code was just a simple proof of concept. I might work on this when I have some free time. |
Description
The KAN (Kolmogorov Activation Network) model from the pykan library currently only supports two-dimensional input tensors (batch_size x hid_dim). A
RuntimeError
is raised when attempting to use a three-dimensional tensor (batch_size x atomic_number x hid_dim) as input.Code Snippet
Error Message
Explanation
This limitation prevents the use of the KAN model in scenarios where input tensors exceed two dimensions, such as in certain natural language processing tasks where dimensions might include (batch_size x sequence_length x hid_dim).
Motivation for Using KAN Instead of MLP
The motivation for replacing MLP with KAN as the dimension reduction output module is rooted in KAN’s ability to offer more efficient computation and potentially better performance in capturing non-linear interactions between features. MLP, while versatile, can sometimes be computationally expensive and less effective at handling complex feature interactions in high-dimensional spaces. KAN's structured approach provides a promising alternative that could enhance model efficiency and effectiveness in many applications, particularly where dimensionality reduction is crucial.
Additionally, I have observed that under the same network structure, the parameter count for KAN is significantly higher than that for MLP, indicating a more complex model capacity.
Suggestion
It would be beneficial to update the KAN model implementation to support input tensors with arbitrary dimensions. This adjustment could mimic the functionality of "pointwise feed-forward networks" or "position-wise feed-forward networks" used in architectures like Transformers, which apply the same MLP (multi-layer perceptron) independently to each position in the hidden dimension.
The text was updated successfully, but these errors were encountered: