Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNGDet implementation #11156

Open
wants to merge 58 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
845253b
Init commit
gunho1123 Jul 24, 2023
126d5a7
Add cityscale dataloader
gunho1123 Jul 27, 2023
fdbfa6c
Internal change
gunho1123 Jul 27, 2023
5ebc562
Add prepare_dataset
gunho1123 Jul 31, 2023
1a4da90
Add npz_to_tfrecord
gunho1123 Aug 1, 2023
8456a7a
Add create_tf_record
gunho1123 Aug 1, 2023
cd6ec6b
Internal change
gunho1123 Aug 1, 2023
f7e896d
Init modeling
gunho1123 Aug 2, 2023
f2a7148
Modify dataloader
gunho1123 Aug 3, 2023
6fcdef1
Add modeling
gunho1123 Aug 3, 2023
70578c2
Add losses
gunho1123 Aug 5, 2023
85b313c
Internal change
gunho1123 Aug 6, 2023
e424d1a
Add losses
gunho1123 Aug 10, 2023
eb2272b
Fix modeling
gunho1123 Aug 31, 2023
c07f480
Init evaluation
gunho1123 Sep 7, 2023
9c0c2c8
Internal change
gunho1123 Sep 12, 2023
a7e8cc5
Internal chnage
gunho1123 Sep 17, 2023
f17fe82
Init ckpt load debug
gunho1123 Oct 10, 2023
4ddab3c
Update rngdet.py
gunho1123 Oct 12, 2023
494362b
Internal change
gunho1123 Nov 1, 2023
131843e
Transformer model change
gunho1123 Nov 2, 2023
0382b99
ipynb change
gunho1123 Nov 2, 2023
bfdce70
Merge branch 'tensorflow:master' into master
mjyun01 Feb 7, 2024
88f8834
Update rngdet.py
mjyun01 Feb 7, 2024
e68e93d
Update README.md
mjyun01 Feb 7, 2024
800ffaf
Update rngdet.py
mjyun01 Feb 7, 2024
9fd5dbf
Update rngdet_test.py
mjyun01 Feb 7, 2024
537ecac
Update create_cityscale_tf_record.py
mjyun01 Feb 7, 2024
9f41d8a
Update create_label.py
mjyun01 Feb 7, 2024
6dd74eb
Update sampler.py
mjyun01 Feb 7, 2024
50b089a
Delete official/projects/rngdet/dataloaders/preprocess_ops.py
mjyun01 Feb 7, 2024
9daa439
Delete official/projects/rngdet/dataloaders/sampler.py
mjyun01 Feb 7, 2024
96a89f8
Update rngdet_input.py
mjyun01 Feb 7, 2024
08f5b02
Update rngdet_input_test.py
mjyun01 Feb 7, 2024
af5e730
Update do_train.sh
mjyun01 Feb 7, 2024
cc84a04
Update agent.py
mjyun01 Feb 7, 2024
671ab98
Update rngdet.py
mjyun01 Feb 7, 2024
dc28897
Delete official/projects/rngdet/modeling/fpn.py
mjyun01 Feb 7, 2024
6367850
Update rngdet_test.py
mjyun01 Feb 7, 2024
fe509db
Delete official/projects/rngdet/modeling/transformer.py
mjyun01 Feb 7, 2024
4c1e79e
Delete official/projects/rngdet/modeling/transformer_test.py
mjyun01 Feb 7, 2024
7577510
Delete official/projects/rngdet/rngdet.ipynb
mjyun01 Feb 7, 2024
879c927
Update train.py
mjyun01 Feb 7, 2024
b684f6d
Delete official/projects/rngdet/region_0_sat.png
mjyun01 Feb 7, 2024
2322237
Update rngdet.py
mjyun01 Feb 7, 2024
a07dd62
Update rngdet_test.py
mjyun01 Feb 7, 2024
5ed3946
Merge branch 'tensorflow:master' into master
mjyun01 Feb 8, 2024
63f6054
testing for push
mjyun01 Feb 8, 2024
e84f251
PR cleaning
mjyun01 Feb 8, 2024
6193d7b
missing found
mjyun01 Feb 12, 2024
f9da466
missing found
mjyun01 Feb 12, 2024
5d30c00
Merge branch 'tensorflow:master' into master
mjyun01 Feb 12, 2024
5c8090a
test
mjyun01 Feb 12, 2024
cc07f1e
train name
mjyun01 Feb 12, 2024
7e6843e
Merge branch 'master' of https://github.com/mjyun01/models
mjyun01 Feb 12, 2024
2701e00
rngdet
mjyun01 Feb 12, 2024
1e64358
rngdet
mjyun01 Feb 12, 2024
80d557b
rngdet
mjyun01 Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions official/projects/rngdet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

# Road Network Graph Detection by Transformer

[![RNGDet](https://img.shields.io/badge/RNGDet-arXiv.2202.07824-B3181B?)](https://arxiv.org/abs/2202.07824)
[![RNGDet++](https://img.shields.io/badge/RNGDet++-arXiv.2209.10150-B3181B?)](https://arxiv.org/abs/2209.10150)

## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.

## Data preparation
To download the dataset and generate labels, try the following command:

```
cd data
./prepare_dataset.bash
```

To generate training samples, try the following command:

```
python create_cityscale_tf_record.py \
--dataroot ./dataset/ \
--roi_size 128 \
--image_size 2048 \
--edge_move_ahead_length 30 \
--num_queries 10 \
--noise 8 \
--max_num_frame 10000 \
--num_shards 32
```
## Training
To edit training options of RNGDet, you can edit following commands in do_train.sh :

```
CUDA_VISIBLE_DEVICES=4 python3 train.py \
--mode=train \
--experiment=rngdet_cityscale \
--model_dir=./CKPT_DIR_NAME \
--config_file=./configs/experiments/cityscale_rngdet_r50_gpu.yaml \
```

To start training, try the following command :
```
sh do_train.sh
```

## Evaluation
To evaluate one image with internal step visualization,

```
python run_test.py -ckpt ./CKPT_DIR_NAME
```

To evaluate all images in the test dataset, and see score(P-P, P-R, R-F) for each images,

```
python run_test_all.py -ckpt ./CKPT_DIR_NAME
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
num_gpus: 1
task:
train_data:
dtype: 'float32'
validation_data:
dtype: 'float32'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
train_data:
dtype: 'float32'
validation_data:
dtype: 'float32'
227 changes: 227 additions & 0 deletions official/projects/rngdet/configs/rngdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DETR configurations."""

import dataclasses
import os
from typing import List, Optional, Union

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import backbones
#from official.projects.rngdet import optimization as optimization_detr


@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
tfds_name: str = ''
tfds_split: str = 'train'
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'float32'
decoder: common.DataDecoder = dataclasses.field(default_factory=common.DataDecoder)
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True


@dataclasses.dataclass
class Losses(hyperparams.Config):
lambda_cls: float = 1.0
lambda_box: float = 5.0
background_cls_weight: float = 0.2

@dataclasses.dataclass
class Rngdet(hyperparams.Config):
"""Rngdet model definations."""
num_queries: int = 10
hidden_size: int = 256
num_classes: int = 2 # 0: vertices, 1: background
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
roi_size: int = 128
backbone: backbones.Backbone = dataclasses.field(default_factory=lambda:backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)))
decoder: decoders.Decoder = dataclasses.field(
default_factory=lambda: decoders.Decoder(type='fpn', fpn=decoders.FPN())
)
min_level: int = 2
max_level: int = 5
norm_activation: common.NormActivation = dataclasses.field(default_factory=common.NormActivation)
backbone_endpoint_name: str = '5'


@dataclasses.dataclass
class RngdetTask(cfg.TaskConfig):
model: Rngdet = dataclasses.field(default_factory=Rngdet)
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
losses: Losses = dataclasses.field(default_factory=Losses)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
per_category_metrics: bool = False


#CITYSCALE_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/cityscale'
CITYSCALE_TRAIN_EXAMPLES = 420140
#CITYSCALE_TRAIN_EXAMPLES = 10140
CITYSCALE_INPUT_PATH_BASE = '/data2/cityscale/tfrecord'
#CITYSCALE_TRAIN_EXAMPLES = 1900
CITYSCALE_VAL_EXAMPLES = 5000

@exp_factory.register_config_factory('rngdet_cityscale')
def rngdet_cityscale() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
train_steps = 50 * steps_per_epoch # 50 epochs
config = cfg.ExperimentConfig(
task=RngdetTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone',
model=Rngdet(
input_size=[128, 128, 3],
roi_size=128,
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=1*steps_per_epoch,
validation_interval=1*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw_experimental',
'adamw_experimental': {
'epsilon': 1.0e-08,
'weight_decay': 1.0e-05,
'global_clipnorm': -1.0,
},
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.0001,
'end_learning_rate': 0.000001,
'offset': 0,
'power': 1.0,
'decay_steps': 50 * steps_per_epoch,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2 * steps_per_epoch,
'warmup_learning_rate': 0,
},
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config



@exp_factory.register_config_factory('rngdet_cityscale_detr')
def rngdet_cityscale() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 16
eval_batch_size = 64
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
train_steps = 50 * steps_per_epoch # 50 epochs
config = cfg.ExperimentConfig(
task=RngdetTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone',
model=Rngdet(
input_size=[128, 128, 3],
roi_size=128,
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train_noise*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=1*steps_per_epoch,
validation_interval=1*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 1e-5,
'epsilon': 1e-08,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [20 * steps_per_epoch,
30 * steps_per_epoch,
40 * steps_per_epoch],
'values': [1.0e-05, 1.0e-05, 1.0e-06, 1.0e-07]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
51 changes: 51 additions & 0 deletions official/projects/rngdet/configs/rngdet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for detr."""

# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr.configs import detr as exp_cfg
from official.projects.detr.dataloaders import coco


class DetrTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.parameters(('detr_coco',))
def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()

@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()


if __name__ == '__main__':
tf.test.main()