Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton dtype mapping for float inputs #210

Open
karen-sy opened this issue Jul 27, 2023 · 1 comment
Open

triton dtype mapping for float inputs #210

karen-sy opened this issue Jul 27, 2023 · 1 comment

Comments

@karen-sy
Copy link

I have a question about this error:

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1069, in ast_to_ttir(fn, signature, specialization, constants, debug)
   1067 all_constants = constants.copy()
   1068 all_constants.update(new_constants)
-> 1069 arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
   1071 prototype = language.function_type([], arg_types)
   1072 generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
   1073                           function_name=function_name, attributes=new_attrs,
   1074                           is_kernel=True, debug=debug)

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1069, in <listcomp>(.0)
   1067 all_constants = constants.copy()
   1068 all_constants.update(new_constants)
-> 1069 arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
   1071 prototype = language.function_type([], arg_types)
   1072 generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
   1073                           function_name=function_name, attributes=new_attrs,
   1074                           is_kernel=True, debug=debug)

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1036, in str_to_ty(name)
   1017     return language.pointer_type(ty)
   1018 tys = {
   1019     "fp8e5": language.float8e5,
   1020     "fp8e4": language.float8e4,
   (...)
   1034     "B": language.int1,
   1035 }
-> 1036 return tys[name]

KeyError: 'f'

Floats are mapped to f here but f is not found in the jax to triton mapping here.

Additionally, adding tys['f'] = language.float32 results in

TypeError: create_scalar_parameter(): incompatible function arguments. The following argument types are supported:
    1. (arg0: bool, arg1: str) -> jaxlib.cuda._triton.TritonParameter
    2. (arg0: int, arg1: str) -> jaxlib.cuda._triton.TritonParameter

Attached is a minimal repro:
add.txt

@sharadmv
Copy link
Collaborator

I think this should be fixed as of #211 but you'll need to install a jaxlib nightly as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants