Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OUTPUT_PADDING to ConvTrans2D #889

Open
swfsql opened this issue Nov 15, 2023 · 0 comments · May be fixed by #890
Open

Add OUTPUT_PADDING to ConvTrans2D #889

swfsql opened this issue Nov 15, 2023 · 0 comments · May be fixed by #890

Comments

@swfsql
Copy link
Contributor

swfsql commented Nov 15, 2023

This issue is a request to add OUTPUT_PADDING to Conv2DTranspose.

It appears that dfdx ConvTrans2D behaves as if tensorflow Conv2DTranspose has output_padding=0.
Related code for dfdx and tf.

Code example for tf:

import tensorflow as tf
import numpy as np

x = np.zeros([1, 1, 2, 3], dtype=np.float32)
print(x.shape) # (1, 1, 2, 3)

a = tf.keras.layers.Conv2DTranspose(output_padding=0, filters=1, kernel_size=3, strides=2, padding='same', data_format='channels_first')
b = tf.keras.layers.Conv2DTranspose(output_padding=1, filters=1, kernel_size=3, strides=2, padding='same', data_format='channels_first')

ya = a(x).numpy().shape
yb = b(x).numpy().shape
print(ya) # (1, 1, 3, 5)
print(yb) # (1, 1, 4, 6)

Code example for dfdx:

use dfdx::prelude::*;

const IN_CHAN: usize = 1;
const OUT_CHAN: usize = 1;
const KERNEL_SIZE: usize = 3;
const STRIDE: usize = 2;
// for padding='same'
const PADDING: usize = ((KERNEL_SIZE - 1) * DILATION + 1) / 2; // = 1
const DILATION: usize = 1;
const GROUPS: usize = 1;
type Model = ConvTrans2DConstConfig<
    IN_CHAN,
    OUT_CHAN,
    KERNEL_SIZE,
    STRIDE,
    PADDING,
    DILATION,
    GROUPS,
>;

fn example() {
    let dev = Cpu::default();
    let model = dev.build_module::<f32>(Model::default());
    let x: Tensor<Rank4<1, 1, 2, 3>, _, _> = dev.zeros();
    let _prediction: Tensor<Rank4<1, 1, 3, 5>, _, _> = model.forward(x);
    // note: the shape is the same for `ya` from the tf example.
}
@swfsql swfsql linked a pull request Nov 16, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant