Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in evaluate function in Tutorial_Perturbation.ipynb #172

Open
Yonggie opened this issue Mar 27, 2024 · 2 comments
Open

Error in evaluate function in Tutorial_Perturbation.ipynb #172

Yonggie opened this issue Mar 27, 2024 · 2 comments

Comments

@Yonggie
Copy link

Yonggie commented Mar 27, 2024

Loaded the parameter from the full human link and try the tutorial, found error:

def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0

    with torch.no_grad():
        for batch, batch_data in enumerate(val_loader):
            batch_size = len(batch_data.y)
            batch_data.to(device)
            x: torch.Tensor = batch_data.x  # (batch_size * n_genes, 2)
            ori_gene_values = x[:, 0].view(batch_size, n_genes)
            pert_flags = x[:, 1].long().view(batch_size, n_genes) # error here
Exception has occurred: IndexError
index 1 is out of bounds for dimension 1 with size 1
  File "/p", line 224, in evaluate
    pert_flags = x[:, 1].long().view(batch_size, n_genes)
  File "t.py", line 282, in <module>
    val_loss, val_mre = evaluate(
IndexError: index 1 is out of bounds for dimension 1 with size 1
@Yonggie
Copy link
Author

Yonggie commented Mar 27, 2024

I tried pip install cell-gear==0.0.1. And delete the older version of the data, let it downloads it again.

Another error happened:

# code
train_loader = pert_data.dataloader["train_loader"]

# hint
'PertData' object has no attribute 'dataloader'
AttributeError: 'PertData' object has no attribute 'dataloader'

@Yonggie
Copy link
Author

Yonggie commented Mar 27, 2024

Just checked the source code of gears==0.0.1, should do some extra change at the same time:
change pert_data.dataloader[xxx] into pert_data.get_dataloader(**required_para)[xxx].

It worked,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant