Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pir_save_load] add pir for test_jit_save_load.py #63958

Merged
Merged
93 changes: 67 additions & 26 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ PyTypeObject *g_ir_value_pytype = nullptr;

void BindOpsAPI(pybind11::module *module);

pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

inline int64_t GetProgramInt64Attr(const std::shared_ptr<Program> &program,
const std::string &attr_name,
int64_t default_value = 0) {
Expand Down Expand Up @@ -195,6 +205,51 @@ Value GetOutputValueByName(const Program &program, const std::string &name) {
return value;
}

void SetValueName(Value value, const std::string name) {
pir::Operation *define_op = value.defining_op();
if (define_op->isa<pir::ParameterOp>()) {
define_op->set_attribute(
"parameter_name",
pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else if (define_op->isa<paddle::dialect::DataOp>()) {
define_op->set_attribute(
"name", pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else if (auto block_arg = value.dyn_cast<BlockArgument>()) {
PADDLE_THROW(
phi::errors::InvalidArgument("Can Not set name for BlockArgument! "));
} else if (value.first_use()) {
auto nextOp = value.first_use().owner();
if (nextOp->isa<::pir::ShadowOutputOp>()) {
nextOp->set_attribute(
"output_name",
pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only set name of Value which is "
"shadowoutput "));
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only set name of Value that "
"is persistable"));
}
}

bool HasValueName(const Value &value) {
if (IsFakeValue(value)) {
return false;
}
if (value.defining_op()->isa<::pir::ParameterOp>() ||
value.defining_op()->isa<paddle::dialect::DataOp>() ||
value.isa<BlockArgument>() ||
(value.first_use() &&
(value.first_use().owner()->isa<::pir::ShadowOutputOp>()))) {
return true;
} else {
return false;
}
}

std::string GetValueName(Value value) {
if (auto param_op = value.defining_op<::pir::ParameterOp>()) {
return param_op.param_name();
Expand Down Expand Up @@ -968,21 +1023,12 @@ void BindValue(py::module *m) {
return ss.str();
}
})
.def_property_readonly("name",
[](Value self) { return GetValueName(self); })
.def_property_readonly(
"has_name",
[](Value self) {
if (self.defining_op()->isa<::pir::ParameterOp>() ||
self.defining_op()->isa<paddle::dialect::DataOp>() ||
self.isa<BlockArgument>() ||
(self.first_use() &&
self.first_use().owner()->isa<::pir::ShadowOutputOp>())) {
return true;
} else {
return false;
}
})
.def_property(
"name",
[](Value self) { return GetValueName(self); },
[](Value self, const std::string &name) { SetValueName(self, name); })
.def_property_readonly("has_name",
[](Value self) { return HasValueName(self); })
.def_property(
"shape",
[](Value self) { return phi::vectorize(GetValueDims(self)); },
Expand Down Expand Up @@ -1476,16 +1522,6 @@ using SplitedProgram = std::vector<std::shared_ptr<Program>>;
using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;

pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

static auto GetNoNeedBufferValue(
const ::pir::Block *whole_block,
std::vector<int> range,
Expand Down Expand Up @@ -1594,10 +1630,12 @@ int AppendShadowOutputs(Program *forward_program,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::Value> added_value;

for (const auto &value : outputs) {
if (!added_value.count(value) || IsFakeValue(value)) {
std::string shadow_output_name = name_prefix + std::to_string(counter);
if (HasValueName(value)) {
shadow_output_name = GetValueName(value);
}
AppendShadowOutput(
forward_program, value, shadow_output_name, start_point + counter);
counter += 1;
Expand Down Expand Up @@ -1727,6 +1765,9 @@ SplitedResult SplitForwardBackward(
}
std::string shadow_output_name =
std::string("output_") + std::to_string(counter);
if (HasValueName(v)) {
shadow_output_name = GetValueName(v);
}
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"output_name", pir::StrAttribute::get(ctx, shadow_output_name)},
Expand Down
124 changes: 96 additions & 28 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,45 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
"in configs.output_spec is the output tensor of "
"Layer.forward method."
)
output_spec_is_not_value_error = (
"tensor `%s` is not support in pir mode, "
"because pir value has no name sometimes, especially as ouptut,"
" so we can't check tensor's name with output var name"
)
if output_spec and with_hook:
raise RuntimeError(
"Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
)
result_list = []
if use_pir_api():
from paddle.autograd.backward_utils import ValueSet

for var in paddle.utils.flatten(outputs):
if isinstance(var, paddle.pir.Value):
result_list.append(var)

if output_spec is not None:
if len(output_spec) == len(result_list):
for var in output_spec:
if not isinstance(var, paddle.pir.Value):
warnings.warn(output_spec_is_not_value_error % var.name)
else:
if var not in ValueSet(result_list):
warnings.warn(name_no_exists_error % var.name)
else:
result_set = ValueSet(result_list)
result_list = []
for var in output_spec:
if not isinstance(var, paddle.pir.Value):
raise ValueError(
output_spec_is_not_value_error % var.name
)
else:
if var not in result_set:
raise ValueError(name_no_exists_error % var.name)
else:
result_list.append(var)

else:
output_vars_dict = OrderedDict()
for var in paddle.utils.flatten(outputs):
Expand All @@ -560,6 +590,7 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
raise ValueError(name_no_exists_error % var.name)
else:
result_list.append(output_vars_dict[var.name])

return result_list


Expand Down Expand Up @@ -960,7 +991,9 @@ def save(layer, path, input_spec=None, **configs):
for var in paddle.utils.flatten(input_spec):
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.eager.Tensor, Variable)):
elif isinstance(
var, (core.eager.Tensor, Variable, paddle.pir.Value)
):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var)
)
Expand Down Expand Up @@ -991,6 +1024,7 @@ def save(layer, path, input_spec=None, **configs):
]

combine_vars = {}
combine_program = []
property_vals = [] # (value, key)
concrete_program = None
for attr_func in functions:
Expand All @@ -1017,6 +1051,7 @@ def save(layer, path, input_spec=None, **configs):
is_prim_infer=is_prim_infer,
)
)

elif 'forward' == attr_func:
if configs.skip_forward:
# do not jit.save forward function
Expand All @@ -1034,6 +1069,7 @@ def save(layer, path, input_spec=None, **configs):
input_spec=inner_input_spec,
full_graph=True,
)

concrete_program = (
static_forward.concrete_program_specify_input_spec(
with_hook=with_hook, is_prim_infer=is_prim_infer
Expand Down Expand Up @@ -1099,7 +1135,6 @@ def save(layer, path, input_spec=None, **configs):
for structured_name, var in dygraph_state_dict.items():
state_names_dict[var.name] = structured_name
state_var_dict[var.name] = var

# 3. share parameters from Layer to scope & record var info
with dygraph.guard():
if use_pir_api():
Expand Down Expand Up @@ -1181,25 +1216,38 @@ def save(layer, path, input_spec=None, **configs):
params_filename = (
file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX
)
file_prefix = file_prefix + '.' + attr_func
file_prefix = os.path.join(model_path, file_prefix)

path_prefix = file_prefix + '.' + attr_func
file_path = os.path.join(model_path, path_prefix)
with scope_guard(scope):
if not use_pir_api():
if use_pir_api():
value_map = paddle.pir.IrMapping()
clone_program = concrete_program.main_program.clone(value_map)
clone_input_vars = []
for v in input_vars:
if type(v) is paddle.static.InputSpec:
name = v.name
for op in clone_program.global_block().ops:
if (
op.name() == 'pd_op.data'
and op.attrs()["name"] == name
):
clone_input_vars.append(op.result(0))
else:
clone_input_vars.append(value_map.look_up(v))

clone_output_vars = [value_map.look_up(v) for v in output_vars]

else:
input_vars = [
concrete_program.main_program.global_block().var(name)
for name in input_var_names
]
clone_program = concrete_program.main_program.clone()
clone_input_vars = input_vars
clone_output_vars = output_vars
else:
value_map = paddle.pir.IrMapping()
clone_program = concrete_program.main_program.clone(value_map)
clone_input_vars = [value_map.look_up(v) for v in input_vars]
clone_output_vars = [value_map.look_up(v) for v in output_vars]

save_inference_model(
path_prefix=file_prefix,
path_prefix=file_path,
feed_vars=clone_input_vars,
fetch_vars=clone_output_vars,
executor=Executor(_current_expected_place()),
Expand All @@ -1209,12 +1257,23 @@ def save(layer, path, input_spec=None, **configs):
)

if combine_params:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)
if use_pir_api():
# NOTE(Ruting): concrete_program has been pruned when init partialProgramLayer,
# so we do not neet to prune again.

for var in concrete_program.main_program.list_vars():
if var.persistable:
combine_vars[var.name] = var
# NOTE(Ruting): concrete_program will delete after this loop item,
# value delete at the same time, so we use list to Extend its lifecycle
combine_program.append(concrete_program.main_program)
else:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)

# save shared params
if combine_params:
Expand All @@ -1226,16 +1285,25 @@ def save(layer, path, input_spec=None, **configs):

params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable, ordered_vars
)
),
filename=params_filename,
)
if use_pir_api():
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=ordered_vars,
filename=params_filename,
)
else:
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable,
ordered_vars,
)
),
filename=params_filename,
)
# save property
property_save_path = os.path.join(
os.path.normpath(model_path), file_prefix + INFER_PROPERTY_SUFFIX
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def __init__(
assert isinstance(self._build_strategy, BuildStrategy)

self._origin_main_program = self._verify_program(main_program)
if parameters is not None:
parameters[0][:] = self._params
parameters[1][:] = self._param_values
with paddle.base.framework._dygraph_guard(paddle.base.dygraph.Tracer()):
self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = ""
Expand Down