Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small bug in classfication_inference for csv_data_configuration #224

Open
HFooladi opened this issue Jan 11, 2024 · 1 comment
Open

Small bug in classfication_inference for csv_data_configuration #224

HFooladi opened this issue Jan 11, 2024 · 1 comment

Comments

@HFooladi
Copy link

There is a small bug in the examples/property_prediction/csv_data_configuration/classification_inference.py

On line 37, the output of predict function is logit (so it can change from -inf to inf theoretically).

batch_pred = predict(args, model, bg)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())

So, first it should be converted to a number between [0, 1] with sigmoid function, and then it should be used for hard or soft classification label.

batch_logit = predict(args, model, bg)
batch_pred = torch.sigmoid(batch_logit)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())
@mufeili
Copy link
Contributor

mufeili commented Jan 11, 2024

Nice catch! Thank you for the report. Unfortunately, I've left AWS and cannot update the codebase or approve PR from others. You may modify your own fork if you need to use this functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants