Skip to content

Commit

Permalink
fix DiagEmbedInferMeta when x_dims contains -1
Browse files Browse the repository at this point in the history
Signed-off-by: ZelinMa557 <3388706467@qq.com>
  • Loading branch information
ZelinMa557 committed Apr 28, 2024
1 parent 71fd732 commit 242b7b5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,9 @@ void DiagEmbedInferMeta(
dim1,
dim2));

int new_dim_len = static_cast<int>(offset_ + x_dims[x_dims.size() - 1]);
int x_last_dim = x_dims[x_dims.size() - 1];
int new_dim_len =
(x_last_dim == -1) ? static_cast<int>(offset_ + x_last_dim) : -1;
auto sizes = common::vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
Expand Down

0 comments on commit 242b7b5

Please sign in to comment.