Skip to content

Commit

Permalink
Add some details on how the permutation is done behind the scenes for…
Browse files Browse the repository at this point in the history
… per-channel quantization
  • Loading branch information
OscarSavolainenDR committed Feb 21, 2024
1 parent ab3c5c1 commit 4837b17
Showing 1 changed file with 154 additions and 27 deletions.
181 changes: 154 additions & 27 deletions notebooks/02_quant_activation_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -664,20 +664,9 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"out1_train = quant_hard_tanh(inp1)\n",
"quant_hard_tanh.eval()\n",
Expand All @@ -694,7 +683,7 @@
},
{
"cell_type": "code",
"execution_count": 161,
"execution_count": 32,
"metadata": {},
"outputs": [
{
Expand All @@ -703,7 +692,7 @@
"tensor(2.9998, grad_fn=<MulBackward0>)"
]
},
"execution_count": 161,
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -722,29 +711,136 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the most extreme channel.\n",
"We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the most extreme channel. \n",
"\n",
"Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. To accomplish this, we also need to give it some extra information on the dimensions of the inputted tensor, so that it knows which dimensions to interpret as the output channels. This is done via the `per_channel_broadcastable_shape` and `scaling_stats_permute_dims` attributes. \n",
"We can take a look at the `QuantReLU` object, and in particular look at what the ` scaling_impl` object is composed of. It is responsible for gathering statistics for determining the quantization parameters, and we can see that its ` stats_input_view_shape_impl` attribute is set to be an instance of `OverTensorView`. This is defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L78), and serves to flatten out the observed tensor into a 1D tensor and, in this case, use the `AbsPercentile` observer to calculate the quantization parameters during the gathering statistics stage of QAT."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"QuantReLU(\n",
" (input_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" )\n",
" (act_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
" (activation_impl): ReLU()\n",
" (tensor_quant): RescalingIntQuant(\n",
" (int_quant): IntQuant(\n",
" (float_to_int_impl): RoundSte()\n",
" (tensor_clamp_impl): TensorClamp()\n",
" (delay_wrapper): DelayWrapper(\n",
" (delay_impl): _NoDelay()\n",
" )\n",
" )\n",
" (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
" (stats_input_view_shape_impl): OverTensorView()\n",
" (stats): _Stats(\n",
" (stats_impl): AbsPercentile()\n",
" )\n",
" (restrict_scaling): _RestrictValue(\n",
" (restrict_value_impl): FloatRestrictValue()\n",
" )\n",
" (clamp_scaling): _ClampValue(\n",
" (clamp_min_ste): ScalarClampMinSte()\n",
" )\n",
" (restrict_inplace_preprocess): Identity()\n",
" (restrict_preprocess): Identity()\n",
" )\n",
" (int_scaling_impl): IntScaling()\n",
" (zero_point_impl): ZeroZeroPoint(\n",
" (zero_point): StatelessBuffer()\n",
" )\n",
" (msb_clamp_bit_width_impl): BitWidthConst(\n",
" (bit_width): StatelessBuffer()\n",
" )\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"per_tensor_quant_relu"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. This will implictly call `scaling_stats_input_view_shape_impl`, defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/quant/solver/common.py#L184), and will change the `QuantReLU` from using a per-tensor view when gathering stats to a per output channel view ([`OverOutputChannelView`](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L52)). This simply permutes the tensor into a 2D tensor, with dim 0 equal to the number of output channels.\n",
"\n",
"`per_channel_broadcastable_shape` represents what the dimensions of the quantization parameters will be, and should be laid out to match those of the output channels of the outputted tensor. We also need to specify the permutation dimensions via `scaling_stats_permute_dims` so as to shape the tensor into a standard format of output channels first. This is so that during the statistics gathering stage of QAT the correct stats will be gathered."
"To accomplish this, we also need to give it some extra information: `scaling_stats_permute_dims` and `per_channel_broadcastable_shape`. `scaling_stats_permute_dims` is responsible for defining how we do the permutation. `per_channel_broadcastable_shape` simply represents what the dimensions of the quantization parameters will be, i.e. there should be one parameter per output channel.\n",
"\n",
"Below, we can see that in the per-channel ` QuantReLU` instance, the `stats_input_view_shape_impl` is now ` OverOutputChannelView`, and uses a `PermuteDims` [instance](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L21) to do the permutation of the tensor to, in this case, a 2D tensor. "
]
},
{
"cell_type": "code",
"execution_count": 160,
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[2.9999]],\n",
"\n",
" [[1.0000]],\n",
"\n",
" [[1.0000]]]], grad_fn=<MulBackward0>)"
"QuantReLU(\n",
" (input_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" )\n",
" (act_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
" (activation_impl): ReLU()\n",
" (tensor_quant): RescalingIntQuant(\n",
" (int_quant): IntQuant(\n",
" (float_to_int_impl): RoundSte()\n",
" (tensor_clamp_impl): TensorClamp()\n",
" (delay_wrapper): DelayWrapper(\n",
" (delay_impl): _NoDelay()\n",
" )\n",
" )\n",
" (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
" (stats_input_view_shape_impl): OverOutputChannelView(\n",
" (permute_impl): PermuteDims()\n",
" )\n",
" (stats): _Stats(\n",
" (stats_impl): AbsPercentile()\n",
" )\n",
" (restrict_scaling): _RestrictValue(\n",
" (restrict_value_impl): FloatRestrictValue()\n",
" )\n",
" (clamp_scaling): _ClampValue(\n",
" (clamp_min_ste): ScalarClampMinSte()\n",
" )\n",
" (restrict_inplace_preprocess): Identity()\n",
" (restrict_preprocess): Identity()\n",
" )\n",
" (int_scaling_impl): IntScaling()\n",
" (zero_point_impl): ZeroZeroPoint(\n",
" (zero_point): StatelessBuffer()\n",
" )\n",
" (msb_clamp_bit_width_impl): BitWidthConst(\n",
" (bit_width): StatelessBuffer()\n",
" )\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 160,
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -755,6 +851,37 @@
" per_channel_broadcastable_shape=(1, out_channels, 1 , 1),\n",
" scaling_stats_permute_dims=(1, 0, 2, 3),\n",
" )\n",
"per_chan_quant_relu"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also observe the effect on the quantization parameters:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[2.9999]],\n",
"\n",
" [[1.0000]],\n",
"\n",
" [[1.0000]]]], grad_fn=<MulBackward0>)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_channel = per_chan_quant_relu(inp3)\n",
"out_channel.scale * ((2**8) -1)"
]
Expand All @@ -763,7 +890,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Above, we can see that the number of elements in the quantization scale of the outputted tensor is now 3, matching those of the 3-channel tensor! Furthermore, we see that each channel has an 8-bit quantization range that matches its data distribution, which is much more ideal in terms of reducing quantization mismatch. However, it's important to note that some hardware providers don't efficiently support per-channel quantization in production, so it's best to check if your targetted hardware will allow per-channel quantization."
"We can see that the number of elements in the quantization scale of the outputted tensor is now 3, matching those of the 3-channel tensor! Furthermore, we see that each channel has an 8-bit quantization range that matches its data distribution, which is much more ideal in terms of reducing quantization mismatch. However, it's important to note that some hardware providers don't efficiently support per-channel quantization in production, so it's best to check if your targetted hardware will allow per-channel quantization."
]
},
{
Expand Down

0 comments on commit 4837b17

Please sign in to comment.