Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

theano.tensor.nnet.bn.batch_normalization verify_grad fail #6803

Open
cheyennee opened this issue Dec 4, 2023 · 0 comments
Open

theano.tensor.nnet.bn.batch_normalization verify_grad fail #6803

cheyennee opened this issue Dec 4, 2023 · 0 comments

Comments

@cheyennee
Copy link

problem:
theano.tensor.nnet.bn.batch_normalization verify_grad fail

repo code:

import theano
import theano.tensor as T
import numpy as np
def custom_activation(inputs, gamma, beta, mean, std):
    return theano.tensor.nnet.bn.batch_normalization(inputs, gamma, beta, mean, std,
                                                       mode="low_mem")
inputs = T.tensor3('inputs')
gamma = T.tensor3('gamma')
beta = T.tensor3('beta')
mean = T.tensor3('mean')
std = T.tensor3('std')
output = custom_activation(inputs, gamma, beta, mean, std)
loss = T.sum(output ** 2)
grad_input, grad_gamma, grad_beta, grad_mean, grad_std = T.grad(loss, [inputs, gamma, beta, mean, std])
input_data = np.random.random((5, 5, 5)).astype('float32')
gamma_data = np.random.random((5, 5, 5)).astype('float32')
beta_data = np.random.random((5, 5, 5)).astype('float32')
mean_data = np.random.random((5, 5, 5)).astype('float32')
std_data = np.random.random((5, 5, 5)).astype('float32')
rng = np.random.RandomState(123)
print(theano.gradient.verify_grad(custom_activation, pt=[input_data, gamma_data, beta_data, mean_data, std_data], rng=rng))

output:

theano.gradient.GradientError: GradientError: numeric gradient and analytic gradient exceed tolerance:
        At position 44 of argument 4 with shape (5, 5, 5),
            val1 = 337197.187500      ,  val2 = 257521.421875
            abs. error = 79675.765625,  abs. tolerance = 0.010000
            rel. error = 0.133972,  rel. tolerance = 0.010000
Exception args: 
The error happened with the following inputs:, [array([[[9.7939932e-01, 6.0512030e-01, 5.9105251e-02, 9.6388167e-01,
         4.4261900e-01],
        [7.7799267e-01, 7.0422657e-02, 1.0935184e-01, 3.5948506e-01,
         6.3008845e-02],
        [9.0304464e-01, 7.3716635e-01, 3.5445350e-01, 1.1566298e-01,
         7.6367569e-01],
        [7.1888345e-01, 2.1192497e-01, 7.9529452e-01, 6.6473253e-02,
         6.4979762e-01],
        [3.4635994e-01, 2.4949747e-01, 2.8667966e-01, 2.6805249e-01,
         7.3419249e-01]],

       [[6.8733525e-01, 3.6414587e-01, 9.0999073e-01, 4.5993187e-02,
         1.5592946e-01],
        [8.2265323e-01, 5.4498470e-01, 1.5491189e-02, 6.5347099e-01,
         7.7562116e-02],
        [9.8731863e-01, 4.0907870e-04, 7.2883135e-01, 5.0818402e-01,
         5.4343271e-01],
        [7.3338479e-01, 9.7909147e-01, 9.7011948e-01, 9.2433876e-01,
         1.2736036e-01],
        [9.5852578e-01, 8.6468303e-01, 6.7888975e-01, 2.8732948e-02,
         6.3961381e-01]],

       [[6.6653842e-01, 5.8464664e-01, 8.2846749e-01, 8.5428429e-01,
         3.5124800e-01],
        [7.2151363e-01, 4.2462060e-03, 8.7091082e-01, 3.1344458e-01,
         3.1573704e-01],
        [2.3512026e-02, 9.1719097e-01, 2.2288257e-01, 1.5423688e-01,
         3.7209776e-01],
        [2.0503227e-01, 3.0329087e-01, 6.9673574e-01, 8.5965699e-01,
         9.0892595e-01],
        [7.3415957e-02, 3.9584714e-01, 8.3952808e-01, 1.6504223e-02,
         2.7753446e-01]],

       [[9.8763949e-01, 9.2118192e-01, 2.3881239e-01, 7.7424163e-01,
         4.2708892e-01],
        [2.6487780e-01, 5.9146440e-01, 5.6427348e-01, 1.0342072e-01,
         9.0923107e-01],
        [8.0888015e-01, 9.0396911e-01, 9.7936302e-02, 1.5835674e-01,
         5.5752736e-01],
        [8.8180834e-01, 4.0089506e-01, 3.4582397e-01, 3.7906596e-01,
         5.1031464e-01],
        [2.9850990e-01, 1.2627229e-02, 1.0453793e-01, 3.0912450e-01,
         2.1786679e-01]],

       [[5.7501400e-01, 1.6689327e-01, 8.8347268e-01, 9.2355028e-02,
         3.7855306e-01],
        [5.1518202e-01, 7.0542997e-01, 6.8055171e-01, 2.1280152e-01,
         3.8993579e-01],
        [7.7901608e-01, 9.9249631e-01, 6.8486166e-01, 2.9545444e-01,
         3.9950174e-01],
        [3.0383030e-01, 7.2793978e-01, 8.1536311e-01, 2.3393980e-01,
         2.7590412e-01],
        [3.7806883e-01, 3.3251804e-01, 1.2489691e-01, 3.3686895e-02,
         6.7742860e-01]]], dtype=float32), array([[[0.7150404 , 0.40841398, 0.31652665, 0.620758  , 0.694597  ],
        [0.45993772, 0.81753963, 0.79170954, 0.32062855, 0.8606753 ],
        [0.5857348 , 0.93213105, 0.23503849, 0.5581473 , 0.46650195],
        [0.9251172 , 0.7220746 , 0.48584402, 0.75921583, 0.3891684 ],
        [0.9828218 , 0.42923975, 0.8412026 , 0.5877801 , 0.9493382 ]],

       [[0.85240483, 0.26658458, 0.7024998 , 0.09510583, 0.67056227],
        [0.03600113, 0.6877504 , 0.72064924, 0.6854849 , 0.43024233],
        [0.76781225, 0.20391126, 0.90086097, 0.6127833 , 0.09259505],
        [0.5076855 , 0.36230293, 0.24938737, 0.4928362 , 0.5221543 ],
        [0.8250115 , 0.4711054 , 0.9584736 , 0.01985063, 0.1766647 ]],

       [[0.97382677, 0.46135062, 0.5279493 , 0.22401077, 0.5649083 ],
        [0.9251377 , 0.05980845, 0.19294846, 0.95128477, 0.12647961],
        [0.36235094, 0.06680529, 0.95786756, 0.28337717, 0.46313292],
        [0.10662496, 0.7768377 , 0.6297399 , 0.87481856, 0.7005593 ],
        [0.735211  , 0.42050675, 0.13973123, 0.6589808 , 0.22958525]],

       [[0.927033  , 0.5633377 , 0.71076614, 0.24838483, 0.97341037],
        [0.7256886 , 0.55211055, 0.8026648 , 0.72535825, 0.00234783],
        [0.25180233, 0.8195512 , 0.46333632, 0.6948011 , 0.24597001],
        [0.6305746 , 0.70882875, 0.936255  , 0.29074803, 0.19137461],
        [0.9830756 , 0.62294674, 0.6245399 , 0.5074565 , 0.50367147]],

       [[0.92177594, 0.13298833, 0.65016294, 0.9551079 , 0.6344468 ],
        [0.11023331, 0.61801314, 0.36919546, 0.8083733 , 0.35548696],
        [0.16382198, 0.07467195, 0.9310846 , 0.61180055, 0.665029  ],
        [0.07865816, 0.07841737, 0.98552144, 0.7862453 , 0.30993837],
        [0.34354326, 0.5491426 , 0.5765017 , 0.36303532, 0.7213097 ]]],
      dtype=float32), array([[[9.9878967e-01, 3.1841344e-01, 4.2503473e-01, 4.9110547e-01,
         8.2256037e-01],
        [4.6110809e-01, 6.8224736e-02, 5.0072283e-01, 1.7057499e-04,
         9.3317574e-01],
        [1.7212527e-01, 7.4707109e-01, 5.2101904e-01, 6.7507666e-01,
         1.8209482e-02],
        [5.9743953e-01, 9.0363592e-01, 1.2331127e-02, 1.0966164e-01,
         2.4822460e-01],
        [4.3086007e-01, 6.7671937e-01, 4.0093207e-01, 8.8474207e-02,
         3.5732627e-01]],

       [[5.7635993e-01, 8.7028688e-01, 4.3442804e-01, 4.7951195e-01,
         5.2287483e-01],
        [8.9554888e-01, 2.2197421e-01, 6.8254298e-01, 9.7774011e-01,
         1.9796742e-01],
        [7.3891991e-01, 8.8547754e-01, 5.2624130e-01, 2.3694012e-01,
         9.6865463e-01],
        [4.5104259e-01, 3.5033378e-01, 5.4974027e-02, 7.8019649e-01,
         9.9278307e-01],
        [8.4709150e-01, 9.9882737e-02, 5.0564313e-01, 1.3310011e-01,
         4.2155501e-01]],

       [[9.2953485e-01, 6.5445745e-01, 4.4966930e-01, 9.2948921e-02,
         1.9986467e-01],
        [9.7979516e-02, 7.8997719e-01, 2.9169336e-01, 5.0152874e-01,
         5.6729871e-01],
        [1.5857354e-01, 1.6574441e-01, 6.4803249e-01, 2.2072688e-01,
         1.8944979e-01],
        [7.2125363e-01, 6.2873036e-01, 2.8237939e-01, 5.5678743e-01,
         3.0471690e-02],
        [7.4451041e-01, 3.2813966e-01, 8.6858146e-02, 3.4012830e-01,
         5.1577009e-02]],

       [[4.5679083e-01, 2.2118850e-01, 6.7617589e-01, 3.4349966e-01,
         5.7558262e-01],
        [9.8718780e-01, 2.7051097e-01, 9.0775716e-01, 3.6578760e-01,
         9.1892135e-01],
        [7.1870023e-01, 9.8905450e-01, 4.7226384e-01, 1.6929044e-01,
         3.5370851e-01],
        [4.6522003e-01, 1.8272930e-01, 4.2220739e-01, 7.7767181e-01,
         5.0676256e-01],
        [2.7133554e-01, 6.1185062e-02, 8.9932466e-01, 7.6932591e-01,
         8.8670677e-01]],

       [[5.4032338e-01, 5.9924703e-02, 5.3200132e-01, 9.4982529e-01,
         3.7733430e-01],
        [1.2338323e-01, 4.6847749e-01, 9.7955000e-01, 6.3000694e-02,
         8.6266190e-01],
        [7.0759630e-01, 7.0414644e-01, 2.7299815e-01, 2.7607548e-01,
         5.9052742e-01],
        [7.6836914e-01, 7.3099899e-01, 3.5827240e-01, 6.4539665e-01,
         4.4737288e-01],
        [4.2478621e-01, 5.9447640e-01, 7.3484895e-03, 8.2456648e-02,
         3.9723247e-02]]], dtype=float32), array([[[0.91511077, 0.20812346, 0.24355769, 0.8048399 , 0.24112616],
        [0.4516349 , 0.36915535, 0.11045914, 0.12894228, 0.62363625],
        [0.81521237, 0.6250042 , 0.37365165, 0.74424773, 0.64257514],
        [0.62137175, 0.6198865 , 0.7555557 , 0.10046609, 0.15871146],
        [0.71965295, 0.01198723, 0.5422037 , 0.35671443, 0.98589116]],

       [[0.28980145, 0.81915635, 0.27434763, 0.4565161 , 0.7602793 ],
        [0.70379645, 0.01969267, 0.43138003, 0.8930782 , 0.40145037],
        [0.39907166, 0.87520516, 0.22992425, 0.00271561, 0.16970922],
        [0.40641797, 0.265631  , 0.9606829 , 0.5818419 , 0.57173854],
        [0.3830372 , 0.12285488, 0.15290801, 0.00749229, 0.69158065]],

       [[0.6749767 , 0.36833876, 0.49652702, 0.8544421 , 0.93666184],
        [0.82460165, 0.85327756, 0.36576688, 0.52830255, 0.8034938 ],
        [0.17840292, 0.5484962 , 0.23854317, 0.32817677, 0.03598966],
        [0.09942828, 0.4750383 , 0.77684134, 0.7089314 , 0.6673485 ],
        [0.72133726, 0.38298118, 0.73538774, 0.37344438, 0.0964972 ]],

       [[0.41027513, 0.8572303 , 0.03217429, 0.5555896 , 0.59615463],
        [0.19077149, 0.03644516, 0.35249287, 0.97083557, 0.5467115 ],
        [0.37559906, 0.18422042, 0.4018665 , 0.58514935, 0.81574273],
        [0.64561516, 0.6438398 , 0.4509427 , 0.31985915, 0.4540122 ],
        [0.06431327, 0.37160736, 0.68581146, 0.31691486, 0.8506746 ]],

       [[0.5526377 , 0.2611484 , 0.64599603, 0.8364884 , 0.9726285 ],
        [0.6962134 , 0.20442809, 0.3124854 , 0.00858229, 0.65034956],
        [0.7159022 , 0.6891795 , 0.6671626 , 0.93706894, 0.31457588],
        [0.47984928, 0.85692036, 0.5573992 , 0.8056776 , 0.7622931 ],
        [0.32013267, 0.2771406 , 0.7721897 , 0.21772784, 0.8941111 ]]],
      dtype=float32), array([[[4.69611466e-01, 2.50249326e-01, 8.52982700e-01, 7.61178136e-01,
         9.11713183e-01],
        [2.47502655e-01, 1.47886768e-01, 9.62902665e-01, 8.71025443e-01,
         4.59680080e-01],
        [2.75679290e-01, 5.88561356e-01, 9.85663652e-01, 1.54290691e-01,
         6.76708341e-01],
        [2.59546191e-01, 1.68499857e-01, 4.03166324e-01, 3.55708450e-01,
         8.91494974e-02],
        [4.83964652e-01, 6.17294133e-01, 5.25294960e-01, 5.12484871e-02,
         6.48911297e-01]],

       [[6.21343374e-01, 8.09643924e-01, 2.34679505e-01, 4.33339089e-01,
         8.66907179e-01],
        [5.87542415e-01, 1.23282537e-01, 5.52375615e-01, 1.40067428e-01,
         7.62945354e-01],
        [5.76609850e-01, 6.14476979e-01, 5.43591321e-01, 8.76403689e-01,
         2.20964290e-02],
        [8.83246481e-01, 7.34322190e-01, 6.39633179e-01, 8.90865549e-02,
         9.69635032e-04],
        [4.20043379e-01, 5.34268916e-01, 8.27656269e-01, 9.93158594e-02,
         6.25271678e-01]],

       [[3.30382019e-01, 4.11888808e-01, 6.25758410e-01, 8.98985937e-02,
         7.25209296e-01],
        [6.58063829e-01, 3.65502328e-01, 5.63403130e-01, 8.41521442e-01,
         2.55100459e-01],
        [3.19193870e-01, 3.14122401e-02, 5.40116608e-01, 2.04636857e-01,
         1.61556825e-01],
        [1.88845038e-01, 9.56100523e-01, 3.08048457e-01, 7.40940690e-01,
         1.66164130e-01],
        [1.16847813e-01, 6.80754542e-01, 3.66289318e-01, 2.76375204e-01,
         2.70700037e-01]],

       [[1.72759309e-01, 3.91623467e-01, 2.78773665e-01, 3.99597228e-01,
         2.22708285e-01],
        [9.71614301e-01, 2.07318157e-01, 6.05836809e-01, 3.00344467e-01,
         3.05748552e-01],
        [9.70926583e-01, 8.51833403e-01, 5.76314628e-01, 5.05498052e-01,
         8.85773301e-01],
        [7.42506623e-01, 8.02911103e-01, 8.24976981e-01, 9.01450396e-01,
         4.63226527e-01],
        [8.31520557e-02, 9.53320563e-01, 6.56577229e-01, 3.32093567e-01,
         8.20797384e-01]],

       [[5.94910145e-01, 6.47334158e-01, 3.20279121e-01, 5.95508516e-01,
         2.53212869e-01],
        [3.63010466e-01, 8.13195288e-01, 4.05604511e-01, 3.35060060e-01,
         1.68905750e-01],
        [1.16619848e-01, 7.88851976e-01, 3.65422994e-01, 8.35235059e-01,
         6.18344188e-01],
        [7.58756876e-01, 1.29215822e-01, 7.01661766e-01, 9.95708704e-02,
         2.75885642e-01],
        [8.49941552e-01, 2.16790542e-01, 1.15644135e-01, 8.35533857e-01,
         3.67812991e-01]]], dtype=float32)], 
The value of eps is:, None, 
The out_type is:, None

version:
python: 3.6
theano: 1.0.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant