Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multiview cameras during training #13

Open
vyeevani opened this issue Sep 19, 2023 · 2 comments
Open

Use multiview cameras during training #13

vyeevani opened this issue Sep 19, 2023 · 2 comments

Comments

@vyeevani
Copy link

vyeevani commented Sep 19, 2023

It appears that right now, multiview cameras aren't being used. I'm planning on doing a slightly janky thing where I create a new top level dataset for each camera view. This isn't an ideal implementation, but it requires the least code changes.

Just to be really clear:
    bridgedata_raw/
        rss/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                        images1/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...

would become:

    bridgedata_raw/
        rss_camera_position_0/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
        rss_camera_position_1/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
@vyeevani
Copy link
Author

vyeevani commented Sep 21, 2023

Sample code for doing this:

import os
import shutil
from absl import app, flags
import tqdm
import glob
import multiprocessing

"""
Converts from the following tree structure to
    bridgedata_raw/
        rss/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                        images1/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...

    bridgedata_raw/
        rss_camera_position_0/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
        rss_camera_position_1/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
    
"""

FLAGS = flags.FLAGS

flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_integer(
    "depth",
    5,
    "Number of directories deep to traverse to the dated directory. Looks for"
    "{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
)
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")

def make_multiview(src_date_path):
    if "lmdb" in src_date_path:
        return
    """
    if not os.path.isdir(src_date_path):
        return
    """
    dest_path = os.path.join(
        FLAGS.output_path, *src_date_path.split(os.sep)[-FLAGS.depth:]
    )

    search_path = os.path.join(src_date_path, "raw", "traj_group*", "traj*", "images*")
    image_paths = glob.glob(search_path)

    def dirname(path, levels=1):
        for _ in range(levels):
            path = os.path.dirname(path)
        return path


    def get_path_level(path, level):
        """
        Retrieve a specific directory level.

        Args:
        - path (str): The input path.
        - level (int): The directory level to retrieve. Positive values count from the top, and negative values count from the bottom.

        Returns:
        - str: Retrieved directory or None if level is out of bounds.
        """
        # Convert to absolute path
        path = os.path.abspath(path)

        # Split into components, ignoring the first element which will be an empty string for absolute paths
        components = path.split(os.sep)[1:]

        # Handle negative indexing
        if level < 0:
            level += len(components)

        # Check bounds
        if level >= len(components) or level < 0:
            return None

        return components[level]


    def set_path_level(path, level, new_dir):
        """
        Set a specific directory level.

        Args:
        - path (str): The input path.
        - level (int): The directory level to set. Positive values count from the top, and negative values count from the bottom.
        - new_dir (str): Directory name to set.

        Returns:
        - str: Modified path or original path if level is out of bounds.
        """
        # Convert to absolute path
        path = os.path.abspath(path)

        # Split into components, ignoring the first element which will be an empty string for absolute paths
        components = path.split(os.sep)[1:]

        # Handle negative indexing
        if level < 0:
            level += len(components)

        # Check bounds
        if level >= len(components) or level < 0:
            return path

        components[level] = new_dir
        return os.sep + os.path.join(*components)
    
    for image_path in image_paths:
        
        
        image_path = os.path.abspath(image_path)
        src_image_path = image_path
        src_traj_path = dirname(image_path)

        dest_image_path = os.path.join(dest_path, *src_image_path.split(os.sep)[-FLAGS.depth + 1:])
        image_number = get_path_level(dest_image_path, -1).split("images")[-1]
        dest_image_path = set_path_level(dest_image_path, -6, get_path_level(dest_image_path, -6) + f"_camera_position_{image_number}")
        dest_image_path = set_path_level(dest_image_path, -1, "images0")
        dest_traj_path = dirname(dest_image_path)

        os.makedirs(dest_traj_path, exist_ok=True)
        os.makedirs(dest_image_path, exist_ok=True)
        for file in os.listdir(src_traj_path):
            src_file = os.path.join(src_traj_path, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_traj_path, file)
                shutil.copy2(src_file, dest_file)
        for file in os.listdir(src_image_path):
            src_file = os.path.join(src_image_path, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_image_path, file)
                shutil.copy2(src_file, dest_file)        
        
    """
    paths = [[(root, dir) for dir in dirs if "images" in dir] for root, dirs, _ in os.walk(input_path)]
    
    for root, dir in tqdm.tqdm(paths):
        src_traj_dir = root
        src_image_dir = os.path.join(root, dir)
        
        dataset_name = root.split(os.sep)[input_path_depth] # get the top level folder under the input path
        camera_number = dir.split("images")[1] # get the number after "images"

        dest_traj_dir = os.path.join(output_path, f"{dataset_name}_camera_position_{camera_number}", *root.split(os.sep)[input_path_depth + 2:])
        dest_image_dir = os.path.join(dest_traj_dir, "images0")

        print(f"{dest_traj_dir}")

        os.makedirs(dest_traj_dir, exist_ok=True)
        os.makedirs(dest_image_dir, exist_ok=True)

        for file in os.listdir(src_traj_dir):
            src_file = os.path.join(src_traj_dir, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_traj_dir, file)
                shutil.copy2(src_file, dest_file)

        for file in os.listdir(src_image_dir):
            src_file = os.path.join(src_image_dir, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_image_dir, file)
                shutil.copy2(src_file, dest_file)
    """

def main(_):
    src_date_paths = glob.glob(os.path.join(os.path.abspath(FLAGS.input_path), *("*" * FLAGS.depth)))
    with multiprocessing.Pool(FLAGS.num_workers) as p:
        list(tqdm.tqdm(p.imap(make_multiview, src_date_paths), total=len(src_date_paths)))
    
if __name__ == "__main__":
    app.run(main)

@vyeevani
Copy link
Author

I validated that training is working. However, for some odd reason, validation is getting stuck. Not sure why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant