Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convention for tensor device #31

Open
28 tasks
RemyLau opened this issue Oct 7, 2022 · 0 comments
Open
28 tasks

Convention for tensor device #31

RemyLau opened this issue Oct 7, 2022 · 0 comments

Comments

@RemyLau
Copy link
Collaborator

RemyLau commented Oct 7, 2022

Problem

Currently, there is no convention for which device the data is stored when a function, e.g., fit(), is called. For example, even though the computation device to use is cuda, the input graph may not be on the GPU and will need to be transferred later after subsampling the neighborhoods for mini-batch training. This inconsistency causes many issues for development.

One naive solution is to call tensor.to(device) using the correct device every time a computation is being performed, which is certainly unsatisfactory and makes the code base not as clean.

Solution

  • By default, all data should sit on cpu
  • When a compute function, e.g., fit(), is called, perform any necessary device conversion inside that function.
  • The only exceptions are predict() and score(), which can be more flexible, since they will be used in various places with various configurations, e.g., training on GPU with mini-batch training and evaluating on CPU with full-batch.

The following PR is a corresponding example of fixing the issue: #30

Need to check

grep "\.to(" -r dance | awk -F":" '{print $1}' | sort -u
  • dance/datasets/multimodality.py
  • dance/modules/multi_modality/joint_embedding/dcca.py
  • dance/modules/multi_modality/joint_embedding/jae.py
  • dance/modules/multi_modality/joint_embedding/scmogcn.py
  • dance/modules/multi_modality/joint_embedding/scmogcnv2.py
  • dance/modules/multi_modality/joint_embedding/scmvae.py
  • dance/modules/multi_modality/match_modality/scmm.py
  • dance/modules/multi_modality/match_modality/scmogcn.py
  • dance/modules/multi_modality/predict_modality/babel.py
  • dance/modules/multi_modality/predict_modality/scmm.py
  • dance/modules/multi_modality/predict_modality/scmogcn.py
  • dance/modules/single_modality/cell_type_annotation/actinn.py
  • dance/modules/single_modality/cell_type_annotation/scdeepsort.py
  • dance/modules/single_modality/clustering/graphsc.py
  • dance/modules/single_modality/clustering/scdcc.py
  • dance/modules/single_modality/clustering/scdeepcluster.py
  • dance/modules/single_modality/clustering/scdsc.py
  • dance/modules/single_modality/clustering/sctag.py
  • dance/modules/single_modality/imputation/deepimpute.py
  • dance/modules/single_modality/imputation/graphsci.py
  • dance/modules/single_modality/imputation/scgnn.py
  • dance/modules/spatial/cell_type_deconvo/dstg.py
  • dance/modules/spatial/cell_type_deconvo/spatialdecon.py
  • dance/modules/spatial/cell_type_deconvo/spotlight.py
  • dance/modules/spatial/spatial_domain/spagcn.py
  • dance/modules/spatial/spatial_domain/stagate.py
  • dance/transforms/graph_construct.py
  • dance/transforms/preprocess.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant