Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SVE128 support for GEMMs #873

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

stefan0re
Copy link

@stefan0re stefan0re commented Mar 21, 2024

  • Introduced SVE128 support for GEMM operations
  • Added support for FP32, FP64, BF16, I8
  • Performance is the same compared to NEON implementation (as expected, this will be improved in the future)
  • hash.c test failure, same behavior with LIBXSMM_TARGET=aarch64 flag (just this test: tests/hash.c:84)
  • some xgemm tests fail due to lack of support for some data types

tests/hash.c:84 this test fails
@alheinecke
Copy link
Collaborator

Thanks for the PR, I still see some important work, before we can merge, e.g. auto decection of V2 and fixing all the issues you raised.

Can you please add some functionality the we can unit test on GVT3, then we can create a branch and go from there, e.g. by using: prctl(PR_SVE_SET_VL, 16)

used k unrolling with element access on B
added st1 instruction
with LIBXSMM_TARGET=aarch64 there is fast NEON code for Neoverse V2 (FP32 and FP64)
@stefan0re
Copy link
Author

stefan0re commented Apr 3, 2024

This new implementation is faster than the previous ASIMD/Neon kernel on NVIDIA Grace(FP32, FP64), the main changes are:

  • by-element access for B registers in the fmla instruction
  • different blocking strategy (more registers for B and full vector loads)
  • different load/store: ld1/st1 instead of ldr/str

This is a plot of the FP32 performance, K is fixed to K=48, and N is fixed to N=40 (single core)
old vs new version:

N_40

@breuera
Copy link
Contributor

breuera commented Apr 3, 2024

Since old version blocks N using multiples of 6 and the new one by 5: Can you also share results for N=36?
This would "favor" the old implementation w.r.t. the used C blocking scheme.

@stefan0re
Copy link
Author

stefan0re commented Apr 3, 2024

Above a certain size, (in the old implementation) all values from B are loaded into one register so the blocking doesn't have any effect at that point.
In this plot N is 36.
N_36

Copy link
Collaborator

@alheinecke alheinecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comments

@@ -552,7 +560,7 @@ void libxsmm_aarch64_instruction_asimd_struct_r_move( libxsmm_generated_code*
code[code_head] |= (unsigned int)((0x1 & (unsigned int)i_tupletype) << 30);

/* load/store with offset register */
if ( (i_vmove_instr & 0x3) == 0x3 ) {
if ( (i_vmove_instr & 0x3) == 0x3 && ((i_vmove_instr == LIBXSMM_AARCH64_INSTR_ASIMD_LD1R) || (i_vmove_instr == LIBXSMM_AARCH64_INSTR_ASIMD_LD1R_R_POST))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please rework such that we don't test for full instructions types, but more class of instructions to keep the code generator from be coming to convoluted and performance bottle necked by instruction specific if conditions.

@@ -260,6 +260,14 @@
#define LIBXSMM_AARCH64_INSTR_ASIMD_LD1R_R_POST 0x0dc0c003
#define LIBXSMM_AARCH64_INSTR_ASIMD_LD1_I_POST 0x0ddf8002
#define LIBXSMM_AARCH64_INSTR_ASIMD_LD1_R_POST 0x0dc08003
#define LIBXSMM_AARCH64_INSTR_ASIMD_LD1_4 0x0c402000 // loads 4 values to vector register
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use C and not C++ comments

@@ -1384,14 +1411,36 @@ void libxsmm_generator_store_2dregblock_aarch64_asimd( libxsmm_generated_code* i

/* start register of accumulator */
l_vec_reg_acc_start = i_vec_reg_count - (i_n_blocking * l_m_total_blocks);
/* set store instruction */
if( l_m_blocks[0] == 4 ){
l_a_store_instruction = LIBXSMM_AARCH64_INSTR_ASIMD_ST1_4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still work with v1 gemm kernles

@@ -1113,16 +1246,17 @@ void libxsmm_generator_gemm_aarch64_kloop( libxsmm_generated_code* io
l_k_stride = 4;
}
}

// TODO: implement new neoverse_v2 kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use C comment

@@ -344,7 +344,7 @@ void libxsmm_generator_gemm_vnni_store_C_from_scratch_aarch64( libxsmm_generated
libxsmm_aarch64_instruction_alu_compute_imm12( io_generated_code, LIBXSMM_AARCH64_INSTR_GP_ADD_I, LIBXSMM_AARCH64_GP_REG_XSP, LIBXSMM_AARCH64_GP_REG_X0, 0, 0 );
libxsmm_aarch64_instruction_alu_move( io_generated_code, LIBXSMM_AARCH64_INSTR_GP_STR_I_OFF, LIBXSMM_AARCH64_GP_REG_XSP, LIBXSMM_AARCH64_GP_REG_XZR, 64, i_gp_reg_mapping->gp_reg_c);
libxsmm_aarch64_instruction_alu_move( io_generated_code, LIBXSMM_AARCH64_INSTR_GP_STR_I_OFF, LIBXSMM_AARCH64_GP_REG_XSP, LIBXSMM_AARCH64_GP_REG_XZR, 32, l_gp_reg_in);
if ( libxsmm_cpuid_arm_use_bfdot() == 0 ) {
if ( libxsmm_cpuid_arm_use_bfdot() == 0 ) { // TODO: check for SVE128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use C style comment

@@ -230,7 +232,7 @@ libxsmm_blasint libxsmm_generator_mateltwise_aarch64_valid_arch_precision( libxs
LIBXSMM_DATATYPE_I64 == libxsmm_meltw_getenum_precision(i_mateltwise_desc, LIBXSMM_MELTW_FIELD_COMP) ) {
is_valid_arch_prec = 0;
}
}
} // TODO: check for SVE128 Support!! add -> (&& (io_generated_code->arch != LIBXSMM_AARCH64_SVE128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use C style comment

@@ -416,7 +416,7 @@ libxsmm_blasint libxsmm_generator_matequation_aarch64_valid_arch_precision( libx
/* Binary not supported for fp64 */
libxsmm_meltw_binary_type non_fp64_binary[2] = { LIBXSMM_MELTW_TYPE_BINARY_MUL_AND_REDUCE_TO_SCALAR_OP_ADD,
LIBXSMM_MELTW_TYPE_BINARY_ZIP };

// TODO: check for SVE128!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use C style comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants