Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:auto_sharding] Question about resharding costs of Reshape strategies #12392

Open
Nullkooland opened this issue May 13, 2024 · 0 comments
Open

Comments

@Nullkooland
Copy link

Hi OpenXLA community, I have question about the code to generate sharding strategies and compute resharding costs for HLO reshape op.

In the code, the reshape sharding strategyoutput_sepc is generated (reshaped) from one of the operand sharding strategy src_strategy_group->strategies[sid]:

std::optional<HloSharding> output_spec =
hlo_sharding_util::ReshapeSharding(
operand->shape(), ins->shape(),
src_strategy_group->strategies[sid].output_sharding);

Then compute the communication and memory resharding cost:

std::vector<double> communication_resharding_costs =
CommunicationReshardingCostVector(
src_strategy_group, operand->shape(),
src_strategy_group->strategies[sid].output_sharding, cluster_env);
std::vector<double> memory_resharding_costs = MemoryReshardingCostVector(
src_strategy_group, operand->shape(),
src_strategy_group->strategies[sid].output_sharding, cluster_env);

It looks like the resharding costs computed are between all operand sharding strategies src_strategy_group and one operand strategy that output_spec reshaped from src_strategy_group->strategies[sid]. I guess this is why I got ZERO resharding cost on a reshape where a all_gather CC OP is actually required.

%5 = stablehlo.add %4, %cst_0 {mhlo.sharding = "{devices=[2,4,1,1]0,1,2,3,4,5,6,7}", result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f16[32,16,1000,256]{3,1,2,0}"} : tensor<32x16x1000x256xf16>
%6 = stablehlo.reshape %5 {mhlo.sharding = "{devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"} : (tensor<32x16x1000x256xf16>) -> tensor<512x1000x256xf16>

If my understanding of the Alpa algorithm is correct, the resharding cost should be computed like:

 std::vector<double> communication_resharding_costs =  CommunicationReshardingCostVector( 
     src_strategy_group, operand->shape(), 
     /* required_sharding */ output_spec, cluster_env); 
 std::vector<double> memory_resharding_costs = MemoryReshardingCostVector( 
     src_strategy_group, operand->shape(), 
     /* required_sharding */ output_spec, cluster_env); 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant