Skip to content

Commit

Permalink
Merge pull request #16984 from kronbichler/optimize_matvec_kernel
Browse files Browse the repository at this point in the history
Tensor product operations: Use loop unrolling for slow mat-vec
  • Loading branch information
kronbichler committed May 14, 2024
2 parents 1c172ea + 17fe7b8 commit 7801e2e
Showing 1 changed file with 176 additions and 24 deletions.
200 changes: 176 additions & 24 deletions include/deal.II/matrix_free/tensor_product_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ namespace internal
EvaluatorQuantity quantity,
bool transpose_matrix,
bool add,
bool consider_strides,
typename Number,
typename Number2>
std::enable_if_t<(variant == evaluate_general), void>
Expand All @@ -201,8 +202,8 @@ namespace internal
Number *out,
const int n_rows,
const int n_columns,
const int stride_in,
const int stride_out)
const int stride_in_given,
const int stride_out_given)
{
const int mm = transpose_matrix ? n_rows : n_columns,
nn = transpose_matrix ? n_columns : n_rows;
Expand All @@ -216,6 +217,11 @@ namespace internal
static_assert(quantity == EvaluatorQuantity::value,
"This function should only use EvaluatorQuantity::value");

Assert(consider_strides || (stride_in_given == 1 && stride_out_given == 1),
ExcInternalError());
const int stride_in = consider_strides ? stride_in_given : 1;
const int stride_out = consider_strides ? stride_out_given : 1;

// specialization for n_rows = 2 that manually unrolls the innermost loop
// to make the operation perform better (not completely as good as the
// templated one, but much better than the generic version down below,
Expand Down Expand Up @@ -249,8 +255,10 @@ namespace internal
out[stride_out * col] = result;
}
}
else if (mm <= 128)
else if (std::abs(in - out) < std::min(stride_out * nn, stride_in * mm))
{
Assert(mm <= 128,
ExcNotImplemented("For large sizes, arrays may not overlap"));
std::array<Number, 129> x;
for (int i = 0; i < mm; ++i)
x[i] = in[stride_in * i];
Expand Down Expand Up @@ -278,27 +286,168 @@ namespace internal
}
else
{
Assert(in != out,
ExcNotImplemented("For large sizes, arrays may not overlap"));
for (int col = 0; col < nn; ++col)
int nn_regular = (nn / 4) * 4;
for (int col = 0; col < nn_regular; col += 4)
{
Number res0;
Number res0, res1, res2, res3;
if (transpose_matrix == true)
{
res0 = matrix[col] * in[0];
const Number2 *matrix_ptr = matrix + col;
res0 = matrix_ptr[0] * in[0];
res1 = matrix_ptr[1] * in[0];
res2 = matrix_ptr[2] * in[0];
res3 = matrix_ptr[3] * in[0];
matrix_ptr += n_columns;
for (int i = 1; i < mm; ++i, matrix_ptr += n_columns)
{
res0 += matrix_ptr[0] * in[stride_in * i];
res1 += matrix_ptr[1] * in[stride_in * i];
res2 += matrix_ptr[2] * in[stride_in * i];
res3 += matrix_ptr[3] * in[stride_in * i];
}
}
else
{
const Number2 *matrix_0 = matrix + col * n_columns;
const Number2 *matrix_1 = matrix + (col + 1) * n_columns;
const Number2 *matrix_2 = matrix + (col + 2) * n_columns;
const Number2 *matrix_3 = matrix + (col + 3) * n_columns;

res0 = matrix_0[0] * in[0];
res1 = matrix_1[0] * in[0];
res2 = matrix_2[0] * in[0];
res3 = matrix_3[0] * in[0];
for (int i = 1; i < mm; ++i)
res0 += matrix[i * n_columns + col] * in[stride_in * i];
{
res0 += matrix_0[i] * in[stride_in * i];
res1 += matrix_1[i] * in[stride_in * i];
res2 += matrix_2[i] * in[stride_in * i];
res3 += matrix_3[i] * in[stride_in * i];
}
}
if (add)
{
out[0] += res0;
out[stride_out] += res1;
out[2 * stride_out] += res2;
out[3 * stride_out] += res3;
}
else
{
out[0] = res0;
out[stride_out] = res1;
out[2 * stride_out] = res2;
out[3 * stride_out] = res3;
}
out += 4 * stride_out;
}
if (nn - nn_regular == 3)
{
Number res0, res1, res2;
if (transpose_matrix == true)
{
const Number2 *matrix_ptr = matrix + nn_regular;
res0 = matrix_ptr[0] * in[0];
res1 = matrix_ptr[1] * in[0];
res2 = matrix_ptr[2] * in[0];
matrix_ptr += n_columns;
for (int i = 1; i < mm; ++i, matrix_ptr += n_columns)
{
res0 += matrix_ptr[0] * in[stride_in * i];
res1 += matrix_ptr[1] * in[stride_in * i];
res2 += matrix_ptr[2] * in[stride_in * i];
}
}
else
{
res0 = matrix[col * n_columns] * in[0];
const Number2 *matrix_0 = matrix + nn_regular * n_columns;
const Number2 *matrix_1 = matrix + (nn_regular + 1) * n_columns;
const Number2 *matrix_2 = matrix + (nn_regular + 2) * n_columns;

res0 = matrix_0[0] * in[0];
res1 = matrix_1[0] * in[0];
res2 = matrix_2[0] * in[0];
for (int i = 1; i < mm; ++i)
res0 += matrix[col * n_columns + i] * in[stride_in * i];
{
res0 += matrix_0[i] * in[stride_in * i];
res1 += matrix_1[i] * in[stride_in * i];
res2 += matrix_2[i] * in[stride_in * i];
}
}
if (add)
out[stride_out * col] += res0;
{
out[0] += res0;
out[stride_out] += res1;
out[2 * stride_out] += res2;
}
else
out[stride_out * col] = res0;
{
out[0] = res0;
out[stride_out] = res1;
out[2 * stride_out] = res2;
}
}
else if (nn - nn_regular == 2)
{
Number res0, res1;
if (transpose_matrix == true)
{
const Number2 *matrix_ptr = matrix + nn_regular;
res0 = matrix_ptr[0] * in[0];
res1 = matrix_ptr[1] * in[0];
matrix_ptr += n_columns;
for (int i = 1; i < mm; ++i, matrix_ptr += n_columns)
{
res0 += matrix_ptr[0] * in[stride_in * i];
res1 += matrix_ptr[1] * in[stride_in * i];
}
}
else
{
const Number2 *matrix_0 = matrix + nn_regular * n_columns;
const Number2 *matrix_1 = matrix + (nn_regular + 1) * n_columns;

res0 = matrix_0[0] * in[0];
res1 = matrix_1[0] * in[0];
for (int i = 1; i < mm; ++i)
{
res0 += matrix_0[i] * in[stride_in * i];
res1 += matrix_1[i] * in[stride_in * i];
}
}
if (add)
{
out[0] += res0;
out[stride_out] += res1;
}
else
{
out[0] = res0;
out[stride_out] = res1;
}
}
else if (nn - nn_regular == 1)
{
Number res0;
if (transpose_matrix == true)
{
const Number2 *matrix_ptr = matrix + nn_regular;
res0 = matrix_ptr[0] * in[0];
matrix_ptr += n_columns;
for (int i = 1; i < mm; ++i, matrix_ptr += n_columns)
res0 += matrix_ptr[0] * in[stride_in * i];
}
else
{
const Number2 *matrix_ptr = matrix + nn_regular * n_columns;
res0 = matrix_ptr[0] * in[0];
for (int i = 1; i < mm; ++i)
res0 += matrix_ptr[i] * in[stride_in * i];
}
if (add)
out[0] += res0;
else
out[0] = res0;
}
}
}
Expand Down Expand Up @@ -740,9 +889,9 @@ namespace internal
const int n_columns =
n_rows_static == 0 ? n_columns_runtime : n_columns_static;
const int stride_in =
n_rows_static == 0 ? stride_in_runtime : stride_in_static;
stride_in_static == 0 ? stride_in_runtime : stride_in_static;
const int stride_out =
n_rows_static == 0 ? stride_out_runtime : stride_out_static;
stride_out_static == 0 ? stride_out_runtime : stride_out_static;

Assert(n_rows > 0 && n_columns > 0,
ExcInternalError("The evaluation needs n_rows, n_columns > 0, but " +
Expand Down Expand Up @@ -908,6 +1057,7 @@ namespace internal
EvaluatorQuantity quantity,
bool transpose_matrix,
bool add,
bool consider_strides,
typename Number,
typename Number2>
std::enable_if_t<(variant == evaluate_evenodd), void>
Expand All @@ -923,8 +1073,8 @@ namespace internal
quantity,
0,
0,
0,
0,
consider_strides ? 0 : 1,
consider_strides ? 0 : 1,
transpose_matrix,
add>(
matrix, in, out, n_rows, n_columns, stride_in, stride_out);
Expand Down Expand Up @@ -1680,13 +1830,15 @@ namespace internal
apply_matrix_vector_product<restricted_variant,
quantity,
contract_over_rows,
add>(shape_data,
in,
out,
n_rows,
n_columns,
stride_operation * stride_in,
stride_operation * stride_out);
add,
(direction != 0 || stride != 1)>(
shape_data,
in,
out,
n_rows,
n_columns,
stride_operation * stride_in,
stride_operation * stride_out);

if (one_line == false)
{
Expand Down

0 comments on commit 7801e2e

Please sign in to comment.