Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: train() missing 3 required positional arguments: 'data_loader', 'optimizer', and 'epochs' #171

Open
yhliu2022 opened this issue Nov 30, 2022 · 1 comment
Labels
xgraph Interpretability of Graph Neural Networks

Comments

@yhliu2022
Copy link

Test subgraphx example:
explainer = SubgraphX(grace, num_classes=4, device=device,
explain_graph=False, reward_method='nc_mc_l_shapley')

then get this error

TypeError Traceback (most recent call last)
Input In [19], in <cell line: 1>()
----> 1 explainer = SubgraphX(grace, num_classes=4, device=device,
2 explain_graph=False, reward_method='nc_mc_l_shapley')

File ~\DIG\dig\xgraph\method\subgraphx.py:636, in SubgraphX.init(self, model, num_classes, device, num_hops, verbose, explain_graph, rollout, min_atoms, c_puct, expand_atoms, high2low, local_radius, sample_num, reward_method, subgraph_building_method, save_dir, filename, vis)
629 def init(self, model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False,
630 explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0,
631 expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley',
632 subgraph_building_method='zero_filling', save_dir: Optional[str] = None,
633 filename: str = 'example', vis: bool = True):
635 self.model = model
--> 636 self.model.eval()
637 self.device = device
638 self.model.to(self.device)

File ~\anaconda3\envs\tf\lib\site-packages\torch\nn\modules\module.py:1926, in Module.eval(self)
1910 def eval(self: T) -> T:
1911 r"""Sets the module in evaluation mode.
1912
1913 This has any effect only on certain modules. See documentations of
(...)
1924 Module: self
1925 """
-> 1926 return self.train(False)

TypeError: train() missing 3 required positional arguments: 'data_loader', 'optimizer', and 'epochs'

@Oceanusity
Copy link
Collaborator

Hello, would you mind providing more details about the model used here? It seems like the error comes from the command self.model.eval() from the provided Traceback.

@ycremar ycremar added the xgraph Interpretability of Graph Neural Networks label Jan 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
xgraph Interpretability of Graph Neural Networks
Projects
None yet
Development

No branches or pull requests

3 participants