Skip to content

Commit 363b50d

Browse files
committed
Open-source data processing code
1 parent 9924ce0 commit 363b50d

22 files changed

+1868
-1
lines changed

.gitignore

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# dataset
2+
/b2d_dynamic_camera
3+
/b2d_fixed_camera
4+
/log
5+
6+
# Byte-compiled / optimized / DLL files
7+
__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
share/python-wheels/
29+
*.egg-info/
30+
.installed.cfg
31+
*.egg
32+
MANIFEST
33+
34+
# Environments
35+
.env
36+
.venv
37+
env/
38+
venv/
39+
ENV/
40+
env.bak/
41+
venv.bak/
42+
43+
# PyInstaller
44+
# Usually these files are written by a python script from a template
45+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
46+
*.manifest
47+
*.spec
48+
49+
# Installer logs
50+
pip-log.txt
51+
pip-delete-this-directory.txt
52+
53+
# Unit test / coverage reports
54+
htmlcov/
55+
.tox/
56+
.nox/
57+
.coverage
58+
.coverage.*
59+
.cache
60+
nosetests.xml
61+
coverage.xml
62+
*.cover
63+
*.py,cover
64+
.hypothesis/
65+
.pytest_cache/
66+
cover/
67+
68+
# Translations
69+
*.mo
70+
*.pot
71+
72+
# Django stuff:
73+
*.log
74+
local_settings.py
75+
db.sqlite3
76+
db.sqlite3-journal
77+
78+
# Flask stuff:
79+
instance/
80+
.webassets-cache
81+
82+
# Scrapy stuff:
83+
.scrapy
84+
85+
# Sphinx documentation
86+
docs/_build/
87+
88+
# PyBuilder
89+
.pybuilder/
90+
target/
91+
92+
# Jupyter Notebook
93+
.ipynb_checkpoints
94+
95+
# IPython
96+
profile_default/
97+
ipython_config.py
98+
99+
# pyenv
100+
# For a library or package, you might want to ignore these files since the code is
101+
# intended to run in multiple environments; otherwise, check them in:
102+
# .python-version
103+
104+
# pipenv
105+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
107+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
108+
# install all needed dependencies.
109+
#Pipfile.lock
110+
111+
# poetry
112+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
113+
# This is especially recommended for binary packages to ensure reproducibility, and is more
114+
# commonly ignored for libraries.
115+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
116+
#poetry.lock
117+
118+
# pdm
119+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
120+
#pdm.lock
121+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
122+
# in version control.
123+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
124+
.pdm.toml
125+
.pdm-python
126+
.pdm-build/
127+
128+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
129+
__pypackages__/
130+
131+
# Celery stuff
132+
celerybeat-schedule
133+
celerybeat.pid
134+
135+
# SageMath parsed files
136+
*.sage.py
137+
138+
# Spyder project settings
139+
.spyderproject
140+
.spyproject
141+
142+
# Rope project settings
143+
.ropeproject
144+
145+
# mkdocs documentation
146+
/site
147+
148+
# mypy
149+
.mypy_cache/
150+
.dmypy.json
151+
dmypy.json
152+
153+
# Pyre type checker
154+
.pyre/
155+
156+
# pytype static type analyzer
157+
.pytype/
158+
159+
# Cython debug symbols
160+
cython_debug/
161+
162+
# PyCharm
163+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165+
# and can be added to the global gitignore or merged into this file. For a more nuclear
166+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
167+
#.idea/

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,32 @@
44
</div>
55

66

7-
We are cleaning and organizing the code, and will open source all the training and inference code. Thanks for your patience.
7+
We are currently cleaning and organizing the code, and the publicly available part now is the data preprocessing section. Thank you for your patience in waiting for the training and inference code.
88

99
[Project Page](https://thinklab-sjtu.github.io/DriveMoE/), [Paper](https://arxiv.org/abs/2505.16278)
1010

11+
12+
## Installation
13+
Before you begin, you need to ensure that your CUDA version is greater than 12.1.
14+
15+
Clone this repository at your directory and run `pip install -e.` to install environment.
16+
17+
Download PaliGemma weights to your directory.
18+
```console
19+
git clone https://huggingface.co/google/paligemma-3b-pt-224
20+
```
21+
22+
If you wish to attempt training DrivePi0 and DriveMoE using the code, or to try open-loop testing with provided checkpoints, you will need to utilize the Bench2Drive dataset and our camera labels. You can download it here (https://huggingface.co/datasets/rethinklab/Bench2Drive)(https://huggingface.co/rethinklab/DriveMoE)
23+
24+
Set environment variables `DATA_DIR` (if downloading datasets for training),`CAMERA_LABEL_DIR`, `LOG_DIR`, and `WANDB_ENTITY` by running `source scripts/set_path.sh`
25+
26+
## Data processing
27+
Run these two scripts to preprocess the training data.
28+
```console
29+
sh script/generate_data.sh && script/window.sh
30+
```
31+
To normalize data during training, we provide dataset statistics. You may also run `sh get_statistics.sh` to generate them.
32+
1133
## Citation <a name="citation"></a>
1234

1335
```bibtex
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
current_dir: "YOUR_CURRENT_WORK_DIR" # You need to set path
2+
3+
data:
4+
val:
5+
statistics_path: ${current_dir}/statistics.json
6+
use_fixed_images: False
7+
work_dir: ${current_dir}
8+
split: val
9+
num_of_action_experts: 6
10+
shuffle_buffer_size: 200000
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
current_dir: "YOUR_CURRENT_WORK_DIR" # You need to set path
2+
3+
data:
4+
train:
5+
statistics_path: ${current_dir}/statistics.json
6+
use_fixed_images: False
7+
work_dir: ${current_dir}
8+
split: train
9+
num_of_action_experts: 13
10+
shuffle_buffer_size: 200000
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
current_dir: "YOUR_CURRENT_WORK_DIR" # You need to set path
2+
3+
data:
4+
train:
5+
statistics_path: ${current_dir}/statistics.json
6+
use_fixed_images: True
7+
work_dir: ${current_dir}
8+
split: None
9+
num_of_action_experts: 44
10+
shuffle_buffer_size: 200000

pyproject.toml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
[project]
2+
name = "DriveMoE"
3+
version = "0.1.0"
4+
description = "Open source codes for Drive-pi0 and DriveMoE"
5+
readme = "README.md"
6+
requires-python = "==3.10.*"
7+
classifiers = [
8+
"Programming Language :: Python :: 3",
9+
]
10+
dependencies = [
11+
"opencv-python",
12+
"joblib",
13+
"bitsandbytes",
14+
"einops",
15+
"gsutil>=5.32",
16+
"hydra-core",
17+
"imageio",
18+
"matplotlib",
19+
"numpy==1.26.4",
20+
"omegaconf",
21+
"pillow",
22+
"pre-commit>=4.0.1",
23+
"pretty_errors",
24+
"protobuf==3.20.3",
25+
"tensorflow==2.15.0",
26+
"tensorflow_datasets==4.9.2",
27+
"torch==2.5.0",
28+
"torchvision==0.20.0",
29+
"transformers",
30+
"tqdm",
31+
"wandb",
32+
]
33+
34+
[build-system]
35+
requires = ["setuptools>=61.0"]
36+
build-backend = "setuptools.build_meta"
37+
38+
[tool.setuptools.packages.find]
39+
exclude = []
40+
41+
[tool.ruff]
42+
line-length = 88
43+
target-version = "py310"
44+
45+
[tool.ruff.lint]
46+
select = ["A", "B", "E", "F", "I", "RUF", "W"]
47+
ignore = ["E203", "E501", "B006", "B026", "B905"]
48+
49+
[tool.ruff.lint.per-file-ignores]
50+
"__init__.py" = ["E402", "F401", "F403"]

script/generate_data.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
DATASET_PATH="$DATA_DIR"
4+
WORK_DIR="${PWD}"
5+
CAM_ID_PATH="$CAMERA_LABEL_DIR"
6+
7+
python "src/data_processing/generate_data.py" --dataset_path "$DATASET_PATH" --cam_id_path "$CAM_ID_PATH" --work_dir "$WORK_DIR"

script/get_statistics.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
DATA_PATH="${PWD}/b2d_dynamic_camera/train"
4+
5+
python "src/data_processing/get_statistics.py" --data_path "$DATA_PATH"

script/set_path.sh

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/bin/bash
2+
3+
##################### Paths #####################
4+
5+
# Set default paths
6+
DEFAULT_DATA_DIR="${PWD}/Bench2Drive-Base"
7+
DEFAULT_LOG_DIR="${PWD}/log"
8+
DEFAULT_CAMERA_LABEL_DIR="${PWD}/camera_labels"
9+
PYTHONPATH="${PWD}"
10+
11+
# Prompt the user for input, allowing overrides
12+
read -p "Enter the desired Bench2Drive Dataset directory [default: ${DEFAULT_DATA_DIR}], leave empty to use default: " DATA_DIR
13+
DATA_DIR=${DATA_DIR:-$DEFAULT_DATA_DIR} # Use user input or default if input is empty
14+
15+
read -p "Enter the desired camera labels directory [default: ${DEFAULT_CAMERA_LABEL_DIR}], leave empty to use default: " CAMERA_LABEL_DIR
16+
CAMERA_LABEL_DIR=${CAMERA_LABEL_DIR:-$DEFAULT_CAMERA_LABEL_DIR} # Use user input or default if input is empty
17+
18+
read -p "Enter the desired logging directory [default: ${DEFAULT_LOG_DIR}], leave empty to use default: " LOG_DIR
19+
LOG_DIR=${LOG_DIR:-$DEFAULT_LOG_DIR} # Use user input or default if input is empty
20+
21+
# Export to current session
22+
export DATA_DIR="$DATA_DIR"
23+
export LOG_DIR="$LOG_DIR"
24+
export CAMERA_LABEL_DIR="$CAMERA_LABEL_DIR"
25+
export PYTHONPATH="$PYTHONPATH"
26+
27+
# Confirm the paths with the user
28+
echo "Data directory set to: $DATA_DIR"
29+
echo "Camera label directory set to: $CAMERA_LABEL_DIR"
30+
echo "Log directory set to: $LOG_DIR"
31+
32+
# Append environment variables to .bashrc
33+
echo "export DATA_DIR=\"$DATA_DIR\"" >> ~/.bashrc
34+
echo "export CAMERA_LABEL_DIR=\"$CAMERA_LABEL_DIR\"" >> ~/.bashrc
35+
echo "export LOG_DIR=\"$LOG_DIR\"" >> ~/.bashrc
36+
37+
echo "Environment variables DATA_DIR, CAMERA_LABEL_DIR and LOG_DIR added to .bashrc and applied to the current session."
38+
39+
##################### WandB #####################
40+
41+
# Prompt the user for input, allowing overrides
42+
read -p "Enter your WandB entity (username or team name), leave empty to skip: " ENTITY
43+
44+
# Check if ENTITY is not empty
45+
if [ -n "$ENTITY" ]; then
46+
# If ENTITY is not empty, set the environment variable
47+
export WANDB_ENTITY="$ENTITY"
48+
49+
# Confirm the entity with the user
50+
echo "WandB entity set to: $WANDB_ENTITY"
51+
52+
# Append environment variable to .bashrc
53+
echo "export WANDB_ENTITY=\"$ENTITY\"" >> ~/.bashrc
54+
55+
echo "Environment variable WANDB_ENTITY added to .bashrc and applied to the current session."
56+
else
57+
# If ENTITY is empty, skip setting the environment variable
58+
echo "No WandB entity provided. Please set wandb=null when running scripts to disable wandb logging and avoid error."
59+
fi

script/window.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
WORK_DIR="${PWD}"
4+
WINDOW_SIZE=5
5+
HORIZON=10
6+
7+
python "src/data_processing/window.py" --work_dir "$WORK_DIR" --window_size "$WINDOW_SIZE" --horizon "$HORIZON"

0 commit comments

Comments
 (0)