New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add spmd rule for amp ops #64202
add spmd rule for amp ops #64202
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… dev/add_spmd_for_amp_ops
} | ||
auto dims_mapping = dist_attr.dims_mapping(); | ||
for (auto& m : dims_mapping) { | ||
if (m != -1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to record split on which mesh and which dimension of mesh
TensorDistAttr found_infinite_attr = | ||
CopyTensorDistAttrForOutput(scale.dist_attr()); | ||
if (splited) { | ||
found_infinite_attr.set_partial(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be according to the mesh info recorded above, specific the partial dimension.
in which case, the synchronization would be limited in the specific group where the tensor are splited.
BUT here is also OK, since the found_infinite is a BOOL variable, synchronize it among the global world size.
bool splited = false; | ||
for (auto& x : xs) { | ||
auto dist_attr = x.dist_attr(); | ||
dist_attr.clean_partial_status(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could allow the partial status pass through to have better performance:
if not allow partial to get through, allreduce will be conducted before this operation, no matter checkfinite return False or True.
BUT if we allow the partial to get through, allreduce will be triggered by the later operation(e.g. adam) and if the checkfinite return True, not adam would be call and therefore not allreduce is conduct
std::vector<TensorDistAttr> xs_attrs; | ||
for (auto& x : xs) { | ||
auto dist_attr = x.dist_attr(); | ||
dist_attr.clean_partial_status(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same as above
return { | ||
{xs_attrs, | ||
found_infinite_attr, | ||
prev_loss_scaling.dist_attr(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all these scalar tensor should be placemented on a merged mesh of all gradient tensors' mesh.
Sorry to inform you that e4026d0's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… dev/add_spmd_for_amp_ops
3c3e757
to
8b31e5d
Compare
8b31e5d
to
970f561
Compare
b461623
to
c0a45da
Compare
c0a45da
to
200bf8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
* add spmd for check_finite_and_unscale * add spmd for update_loss_scaling * fix partial * fix ut * fix spmd * fix custom_relu test * fix hybrid test * fix ut
PR Category
Auto Parallel
PR Types
Devs
Description
add spmd rule for amp ops
Fix spmd rule for layer_norm_grad
Pcard-67164