Skip to content

Commit

Permalink
fix DiagEmbedInferMeta when x_dims contains -1 (#63961)
Browse files Browse the repository at this point in the history
Signed-off-by: ZelinMa557 <3388706467@qq.com>
  • Loading branch information
ZelinMa557 committed May 8, 2024
1 parent dae0e23 commit 6a58ff4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,9 @@ void DiagEmbedInferMeta(
dim1,
dim2));

int new_dim_len = static_cast<int>(offset_ + x_dims[x_dims.size() - 1]);
int x_last_dim = x_dims[x_dims.size() - 1];
int new_dim_len =
(x_last_dim == -1) ? -1 : static_cast<int>(offset_ + x_last_dim);
auto sizes = common::vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
Expand Down Expand Up @@ -3653,7 +3655,7 @@ void RepeatInterleaveInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument(
"repeat_interleave's output tensor can't be nullptr"));

output_dim[n_dim] = input_dim[n_dim] * repeats;
if (input_dim[n_dim] != -1) output_dim[n_dim] = input_dim[n_dim] * repeats;
out->set_dims(common::make_ddim(output_dim));
out->share_lod(x);
out->set_dtype(x.dtype());
Expand Down

0 comments on commit 6a58ff4

Please sign in to comment.