Skip to content

Commit

Permalink
fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 committed May 9, 2024
1 parent 58dcc1c commit 182593f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h"

#define OP_SAME_OPERANDS_AND_RESULT(name) \
bool name##OpInferSymbolicShape( \
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { \
const symbol::ShapeOrDataDimExprs &operand_shape_or_data = \
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); \
shape_analysis->SetShapeOrDataForValue(op->result(0), \
#define OP_SAME_OPERANDS_AND_RESULT(name) \
bool name##OpInferSymbolicShape( \
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \
const symbol::ShapeOrDataDimExprs &operand_shape_or_data = \
infer_context->GetShapeOrDataForValue(op->operand_source(0)); \
infer_context->SetShapeOrDataForValue(op->result(0), \
operand_shape_or_data); \
return true; \
return true; \
}

namespace paddle::dialect {
Expand Down Expand Up @@ -139,20 +139,20 @@ OP_SAME_OPERANDS_AND_RESULT(Sigmoid)
OP_SAME_OPERANDS_AND_RESULT(Sigmoid_)

bool ScaleOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
const symbol::ShapeOrDataDimExprs &operand_shape_or_data =
shape_analysis->GetShapeOrDataForValue(operand_source);
std::vector<symbol::DimExpr> shape(operand_shape_or_data.shape());
pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(0);
const symbol::ShapeOrDataDimExprs &operand_shape_or_data =
infer_context->GetShapeOrDataForValue(operand_source);
std::vector<symbol::DimExpr> shape(operand_shape_or_data.shape());

if (operand_shape_or_data.data()) {
const std::vector<symbol::DimExpr> data = [&] {
const symbol::DimExpr scale = [&]() -> symbol::DimExpr {
if (op->num_operands() == 2) {
return shape_analysis->GetShapeOrDataForValue(op->operand_source(1))
.data()
->at(0);
}
if (operand_shape_or_data.data()) {
const std::vector<symbol::DimExpr> data = [&] {
const symbol::DimExpr scale = [&]() -> symbol::DimExpr {
if (op->num_operands() == 2) {
return infer_context->GetShapeOrDataForValue(op->operand_source(1))
.data()
->at(0);
}
return static_cast<int64_t>(
op->attribute("scale").dyn_cast<pir::FloatAttribute>().data());
}();
Expand All @@ -165,11 +165,10 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op,
return data;
}();

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(shape, data));
} else {
shape_analysis->SetShapeOrDataForValue(op->result(0),
operand_shape_or_data);
infer_context->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(shape, data));
} else {
infer_context->SetShapeOrDataForValue(op->result(0), operand_shape_or_data);
}

return true;
Expand All @@ -180,6 +179,7 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op,
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect

#undef OP_SAME_OPERANDS_AND_RESULT
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid_)
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect

0 comments on commit 182593f

Please sign in to comment.