Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial examples of per-channel quantization #867

Merged
226 changes: 223 additions & 3 deletions notebooks/02_quant_activation_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"source": [
"import torch\n",
"from brevitas.nn import QuantConv2d, QuantIdentity\n",
"from brevitas.quant.scaled_int import Int8ActPerTensorFloat \n",
"from brevitas.quant.scaled_int import Int8ActPerTensorFloat\n",
"\n",
"torch.manual_seed(0) # set a seed to make sure the random weight init is reproducible\n",
"output_quant_conv = QuantConv2d(\n",
Expand Down Expand Up @@ -100,7 +100,7 @@
"source": [
"torch.manual_seed(0)\n",
"input_output_quant_conv = QuantConv2d(\n",
" in_channels=2, out_channels=3, kernel_size=(3,3), \n",
" in_channels=2, out_channels=3, kernel_size=(3,3),\n",
" input_quant=Int8ActPerTensorFloat, output_quant=Int8ActPerTensorFloat)\n",
"\n",
"torch.manual_seed(0)\n",
Expand Down Expand Up @@ -594,7 +594,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, the only layer that is an exception to this is `QuantHardTanh`. That is because the interface to `torch.nn.HardTanh` already requires users to manually specify `min_val` and `max_val`, so Brevitas preserves that both when quantization is enabled or disabled. With quantization enabled, by default those values are used for initialization, but then the range is learned. Let's look at an example:"
"By default, the only layer that is an exception to this is `QuantHardTanh`. That is because the interface to `torch.nn.HardTanh` already requires users to manually specify `min_val` and `max_val`, so Brevitas preserves that both when quantization is enabled or disabled. With quantization enabled, by default those values are used for initialization, but then the range is learned. Let's look at an example. Run the cell below, and we expect it to throw an error because of missing attributes:"
]
},
{
Expand Down Expand Up @@ -676,6 +676,226 @@
"assert_with_message(out1_train.scale.isclose(out2_eval.scale).item())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In all of the examples that have currently been looked at in this tutorial, we have used per-tensor quantization. I.e., the output tensor of the activation, if quantized, was always quantized on a per-tensor level, with a single scale and zero-point quantization parameter per output tensor. However, one can also do per-channel quantization, where each output channel of the tensor has its own quantization parameters. In the example below, we look at per-tensor quantization of an input tensor that has 3 channels and 256 elements in the height and width dimensions. We purposely mutate the 1st channel to have its dynamic range be 3 times larger than the other 2 channels. We then feed it through a `QuantReLU`, whose default behavior is to quantize at a per-tensor granularity."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.9998, grad_fn=<MulBackward0>)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_channels = 3\n",
"inp3 = torch.rand(1, out_channels, 256, 256) # (B, C, H, W)\n",
"inp3[:, 0, :, :] *= 3\n",
"\n",
"per_tensor_quant_relu = QuantReLU(return_quant_tensor=True)\n",
"out_tensor = per_tensor_quant_relu(inp3)\n",
"out_tensor.scale * ((2**8) -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the channel with the largest dynamic range. \n",
"\n",
"We can take a look at the `QuantReLU` object, and in particular look at what the `scaling_impl` object is composed of. It is responsible for gathering statistics for determining the quantization parameters, and we can see that its `stats_input_view_shape_impl` attribute is set to be an instance of `OverTensorView`. This is defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L78), and serves to flatten out the observed tensor into a 1D tensor and, in this case, use the `AbsPercentile` observer to calculate the quantization parameters during the gathering statistics stage of QAT."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"QuantReLU(\n",
" (input_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" )\n",
" (act_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
" (activation_impl): ReLU()\n",
" (tensor_quant): RescalingIntQuant(\n",
" (int_quant): IntQuant(\n",
" (float_to_int_impl): RoundSte()\n",
" (tensor_clamp_impl): TensorClamp()\n",
" (delay_wrapper): DelayWrapper(\n",
" (delay_impl): _NoDelay()\n",
" )\n",
" )\n",
" (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
" (stats_input_view_shape_impl): OverTensorView()\n",
" (stats): _Stats(\n",
" (stats_impl): AbsPercentile()\n",
" )\n",
" (restrict_scaling): _RestrictValue(\n",
" (restrict_value_impl): FloatRestrictValue()\n",
" )\n",
" (clamp_scaling): _ClampValue(\n",
" (clamp_min_ste): ScalarClampMinSte()\n",
" )\n",
" (restrict_inplace_preprocess): Identity()\n",
" (restrict_preprocess): Identity()\n",
" )\n",
" (int_scaling_impl): IntScaling()\n",
" (zero_point_impl): ZeroZeroPoint(\n",
" (zero_point): StatelessBuffer()\n",
" )\n",
" (msb_clamp_bit_width_impl): BitWidthConst(\n",
" (bit_width): StatelessBuffer()\n",
" )\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"per_tensor_quant_relu"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. This will implictly call `scaling_stats_input_view_shape_impl`, defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/quant/solver/common.py#L184), and will change the `QuantReLU` from using a per-tensor view when gathering stats to a per output channel view ([`OverOutputChannelView`](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L52)). This simply permutes the tensor into a 2D tensor, with dim 0 equal to the number of output channels.\n",
"\n",
"To accomplish this, we also need to give it some extra information: `scaling_stats_permute_dims` and `per_channel_broadcastable_shape`. `scaling_stats_permute_dims` is responsible for defining how we do the permutation. `per_channel_broadcastable_shape` is necessary to understand along which dimensions the scale factor has to be broadcasted, so that the scale factor values are applied along the channel dimensions of the input.\n",
"By default, PyTorch will broadcast along the first rightmost dimension for which the shapes of the two tensors match. To make sure that we apply the scale factor in our desired output channel dimension, we need to tell PyTorch how to correctly broadcast the scale factors. Therefore the scale factor will have as many dimensions as the input tensors, with all the shapes equal to 1 apart from the channel dimension.\n",
"\n",
"Below, we can see that in the per-channel ` QuantReLU` instance, the `stats_input_view_shape_impl` is now ` OverOutputChannelView`, and uses a `PermuteDims` [instance](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L21) to do the permutation of the tensor to, in this case, a 2D tensor. "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"QuantReLU(\n",
" (input_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" )\n",
" (act_quant): ActQuantProxyFromInjector(\n",
" (_zero_hw_sentinel): StatelessBuffer()\n",
" (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
" (activation_impl): ReLU()\n",
" (tensor_quant): RescalingIntQuant(\n",
" (int_quant): IntQuant(\n",
" (float_to_int_impl): RoundSte()\n",
" (tensor_clamp_impl): TensorClamp()\n",
" (delay_wrapper): DelayWrapper(\n",
" (delay_impl): _NoDelay()\n",
" )\n",
" )\n",
" (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
" (stats_input_view_shape_impl): OverOutputChannelView(\n",
" (permute_impl): PermuteDims()\n",
" )\n",
" (stats): _Stats(\n",
" (stats_impl): AbsPercentile()\n",
" )\n",
" (restrict_scaling): _RestrictValue(\n",
" (restrict_value_impl): FloatRestrictValue()\n",
" )\n",
" (clamp_scaling): _ClampValue(\n",
" (clamp_min_ste): ScalarClampMinSte()\n",
" )\n",
" (restrict_inplace_preprocess): Identity()\n",
" (restrict_preprocess): Identity()\n",
" )\n",
" (int_scaling_impl): IntScaling()\n",
" (zero_point_impl): ZeroZeroPoint(\n",
" (zero_point): StatelessBuffer()\n",
" )\n",
" (msb_clamp_bit_width_impl): BitWidthConst(\n",
" (bit_width): StatelessBuffer()\n",
" )\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"per_chan_quant_relu = QuantReLU(return_quant_tensor=True,\n",
" scaling_per_output_channel=True,\n",
" per_channel_broadcastable_shape=(1, out_channels, 1 , 1),\n",
" scaling_stats_permute_dims=(1, 0, 2, 3),\n",
" )\n",
"per_chan_quant_relu"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also observe the effect on the quantization parameters:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[2.9999]],\n",
"\n",
" [[1.0000]],\n",
"\n",
" [[1.0000]]]], grad_fn=<MulBackward0>)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_channel = per_chan_quant_relu(inp3)\n",
"out_channel.scale * ((2**8) -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the number of elements in the quantization scale of the outputted tensor is now 3, matching those of the 3-channel tensor! Furthermore, we see that each channel has an 8-bit quantization range that matches its data distribution, which is much more ideal in terms of reducing quantization mismatch. However, it's important to note that some hardware providers don't efficiently support per-channel quantization in production, so it's best to check if your targetted hardware will allow per-channel quantization."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down