You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi OpenXLA community, I have question about the code to generate sharding strategies and compute resharding costs for HLO reshape op.
In the code, the reshape sharding strategyoutput_sepc is generated (reshaped) from one of the operand sharding strategy src_strategy_group->strategies[sid]:
It looks like the resharding costs computed are between all operand sharding strategies src_strategy_group and one operand strategy that output_spec reshaped from src_strategy_group->strategies[sid]. I guess this is why I got ZERO resharding cost on a reshape where a all_gather CC OP is actually required.
Hi OpenXLA community, I have question about the code to generate sharding strategies and compute resharding costs for HLO
reshape
op.In the code, the
reshape
sharding strategyoutput_sepc
is generated (reshaped) from one of the operand sharding strategysrc_strategy_group->strategies[sid]
:xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
Lines 1898 to 1901 in 421f4c4
Then compute the communication and memory resharding cost:
xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
Lines 1917 to 1923 in 421f4c4
It looks like the resharding costs computed are between all operand sharding strategies
src_strategy_group
and one operand strategy thatoutput_spec
reshaped fromsrc_strategy_group->strategies[sid]
. I guess this is why I got ZERO resharding cost on areshape
where aall_gather
CC OP is actually required.If my understanding of the Alpa algorithm is correct, the resharding cost should be computed like:
The text was updated successfully, but these errors were encountered: