[NestedTensor] Graph breaks with SDPA + NT constructor #126472
Labels
module: nestedtensor
NestedTensor tag see issue #25032
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
When we use SDPA, we need max_seqlen and min_seqlen. Getting max/min_seqlen normally requires a .item call (which usually graph breaks, I think?).
So this focuses on removing graph breaks where:
General repro - the approach is to call
nested_view_from_values_offsets_lengths
withmax_seqlen
andmin_seqlen
passed in:Failure 1: With #122836 (rebased onto 7f1d5ab)
Failure 2: Based on the failure, I tried with @soulitzer's PR #124624 patched on top:
Failure 3: Based on this, I tried a quick patch: this change
I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).
Failure 4: One other attempt - I figured I'd try #124803 to see if it would fix the issue without unbacked symint issues, but it runs into other issues where we get multiple NestedInts for the same dimension. (So we should probably just go with #124624 and figure out what the unbacked symint issue is about)
Versions
Described above - but these were all built on 7f1d5ab for H100
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer
The text was updated successfully, but these errors were encountered: