Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use mini-batch with PyGeoTemp #239

Open
gbastiandillo opened this issue Jul 8, 2023 · 0 comments
Open

How to use mini-batch with PyGeoTemp #239

gbastiandillo opened this issue Jul 8, 2023 · 0 comments

Comments

@gbastiandillo
Copy link

gbastiandillo commented Jul 8, 2023

Hi @benedekrozemberczki ,
First of all, thanks for the great library.
I was checking all the closed issues related to mini-batching and I have 2 main questions, hope you have some time to spend on this:

  1. In your experience (besides de CPU/GPU performance), do you identify advantages about using minibatch for space temporal GNN? In your examples you use "snapshots", that could be understood as frames of the same sequence. But using batches (frames of different sequences) using the "diagonal block" trick gives any advantage in the accuracy?

  2. Finally, I want to ask you if you have any example about using mini batches with the "diagonal trick" with pytorch geometric temporal, I searched in internet, but there're not many examples about how to use the "diagonal block" trick in PyGeoTemp. I know that this is a big request, but hope you can share any example that u already have. Or, for example, tell me if this simple is correctly implemented?

I copied the example of MTMDatasetLoader and modified the get_dataset method to use StaticGraphTemporalSignalBatch instead StaticGraphTemporalSignal

import json
import urllib
import numpy as np
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

class MTMDatasetLoader1:

    def __init__(self):
        self._read_web_data()
        self.batch = 16 # Modification 1
    def _read_web_data(self):
        url = 'https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/mtm_1.json'
        self._dataset = json.loads(urllib.request.urlopen(url).read())

    def _get_edges(self):
        self._edges = np.array(self._dataset["edges"]).T

    def _get_edge_weights(self):
        self._edge_weights = np.array([1 for d in self._dataset["edges"]]).T

    def _get_features(self):
        dic = self._dataset
        joints = [str(n) for n in range(21)]
        dataset_length = len(dic["0"].values())
        features = np.zeros((dataset_length, 21, 3))

        for j, joint in enumerate(joints):
            for t, xyz in enumerate(dic[joint].values()):
                xyz_tuple = list(map(float, xyz.strip("()").split(",")))
                features[t, j, :] = xyz_tuple

        self.features = [
            features[i : i + self.frames, :].T
            for i in range(len(features) - self.frames)
        ]

    def _get_targets(self):
        # target eoncoding: {0 : 'Grasp', 1 : 'Move', 2 : 'Negative',
        #                   3 : 'Position', 4 : 'Reach', 5 : 'Release'}
        targets = []
        for _, y in self._dataset["LABEL"].items():
            targets.append(y)

        n_values = np.max(targets) + 1
        targets_ohe = np.eye(n_values)[targets]

        self.targets = [
            targets_ohe[i : i + self.frames, :]
            for i in range(len(targets_ohe) - self.frames)
        ]

    def get_dataset(self, frames: int = 16) -> StaticGraphTemporalSignalBatch:
        """Returning the MTM-1 motion data iterator.

        Args types:
            * **frames** *(int)* - The number of consecutive frames T, default 16.
        Return types:
            * **dataset** *(StaticGraphTemporalSignal)* - The MTM-1 dataset.
        """
        self.frames = frames
        self._get_edges()
        self._get_edge_weights()
        self._get_features()
        self._get_targets()

       #Modification 2
        dataset = StaticGraphTemporalSignalBatch(
            self._edges, self._edge_weights, self.features, self.targets, self.batch
        )
        # End of modification 2

        return dataset

thanks in advance,
Guille

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant