Skip to content

Commit

Permalink
[pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958)
Browse files Browse the repository at this point in the history
* add jit load.train

* modify backward program lost

* modify

* combine eval and train

* modify 8 case of jit.save.load

* modify jit_save_load case

* rename jit_save_load

* change name all

* modify timeout

* modify new case

* modify TestJitSaveLoadMultiMethods

* modify cpu tensor no holder bug
  • Loading branch information
xiaoguoguo626807 authored and co63oc committed May 10, 2024
1 parent cac9d96 commit 6fcfe6a
Show file tree
Hide file tree
Showing 11 changed files with 695 additions and 253 deletions.
93 changes: 67 additions & 26 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ PyTypeObject *g_ir_value_pytype = nullptr;

void BindOpsAPI(pybind11::module *module);

pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

inline int64_t GetProgramInt64Attr(const std::shared_ptr<Program> &program,
const std::string &attr_name,
int64_t default_value = 0) {
Expand Down Expand Up @@ -195,6 +205,51 @@ Value GetOutputValueByName(const Program &program, const std::string &name) {
return value;
}

void SetValueName(Value value, const std::string name) {
pir::Operation *define_op = value.defining_op();
if (define_op->isa<pir::ParameterOp>()) {
define_op->set_attribute(
"parameter_name",
pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else if (define_op->isa<paddle::dialect::DataOp>()) {
define_op->set_attribute(
"name", pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else if (auto block_arg = value.dyn_cast<BlockArgument>()) {
PADDLE_THROW(
phi::errors::InvalidArgument("Can Not set name for BlockArgument! "));
} else if (value.first_use()) {
auto nextOp = value.first_use().owner();
if (nextOp->isa<::pir::ShadowOutputOp>()) {
nextOp->set_attribute(
"output_name",
pir::StrAttribute::get(pir::IrContext::Instance(), name));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only set name of Value which is "
"shadowoutput "));
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only set name of Value that "
"is persistable"));
}
}

bool HasValueName(const Value &value) {
if (IsFakeValue(value)) {
return false;
}
if (value.defining_op()->isa<::pir::ParameterOp>() ||
value.defining_op()->isa<paddle::dialect::DataOp>() ||
value.isa<BlockArgument>() ||
(value.first_use() &&
(value.first_use().owner()->isa<::pir::ShadowOutputOp>()))) {
return true;
} else {
return false;
}
}

std::string GetValueName(Value value) {
if (auto param_op = value.defining_op<::pir::ParameterOp>()) {
return param_op.param_name();
Expand Down Expand Up @@ -968,21 +1023,12 @@ void BindValue(py::module *m) {
return ss.str();
}
})
.def_property_readonly("name",
[](Value self) { return GetValueName(self); })
.def_property_readonly(
"has_name",
[](Value self) {
if (self.defining_op()->isa<::pir::ParameterOp>() ||
self.defining_op()->isa<paddle::dialect::DataOp>() ||
self.isa<BlockArgument>() ||
(self.first_use() &&
self.first_use().owner()->isa<::pir::ShadowOutputOp>())) {
return true;
} else {
return false;
}
})
.def_property(
"name",
[](Value self) { return GetValueName(self); },
[](Value self, const std::string &name) { SetValueName(self, name); })
.def_property_readonly("has_name",
[](Value self) { return HasValueName(self); })
.def_property(
"shape",
[](Value self) { return phi::vectorize(GetValueDims(self)); },
Expand Down Expand Up @@ -1476,16 +1522,6 @@ using SplitedProgram = std::vector<std::shared_ptr<Program>>;
using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;

pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

static auto GetNoNeedBufferValue(
const ::pir::Block *whole_block,
std::vector<int> range,
Expand Down Expand Up @@ -1594,10 +1630,12 @@ int AppendShadowOutputs(Program *forward_program,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::Value> added_value;

for (const auto &value : outputs) {
if (!added_value.count(value) || IsFakeValue(value)) {
std::string shadow_output_name = name_prefix + std::to_string(counter);
if (HasValueName(value)) {
shadow_output_name = GetValueName(value);
}
AppendShadowOutput(
forward_program, value, shadow_output_name, start_point + counter);
counter += 1;
Expand Down Expand Up @@ -1727,6 +1765,9 @@ SplitedResult SplitForwardBackward(
}
std::string shadow_output_name =
std::string("output_") + std::to_string(counter);
if (HasValueName(v)) {
shadow_output_name = GetValueName(v);
}
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"output_name", pir::StrAttribute::get(ctx, shadow_output_name)},
Expand Down
124 changes: 96 additions & 28 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,45 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
"in configs.output_spec is the output tensor of "
"Layer.forward method."
)
output_spec_is_not_value_error = (
"tensor `%s` is not support in pir mode, "
"because pir value has no name sometimes, especially as ouptut,"
" so we can't check tensor's name with output var name"
)
if output_spec and with_hook:
raise RuntimeError(
"Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
)
result_list = []
if use_pir_api():
from paddle.autograd.backward_utils import ValueSet

for var in paddle.utils.flatten(outputs):
if isinstance(var, paddle.pir.Value):
result_list.append(var)

if output_spec is not None:
if len(output_spec) == len(result_list):
for var in output_spec:
if not isinstance(var, paddle.pir.Value):
warnings.warn(output_spec_is_not_value_error % var.name)
else:
if var not in ValueSet(result_list):
warnings.warn(name_no_exists_error % var.name)
else:
result_set = ValueSet(result_list)
result_list = []
for var in output_spec:
if not isinstance(var, paddle.pir.Value):
raise ValueError(
output_spec_is_not_value_error % var.name
)
else:
if var not in result_set:
raise ValueError(name_no_exists_error % var.name)
else:
result_list.append(var)

else:
output_vars_dict = OrderedDict()
for var in paddle.utils.flatten(outputs):
Expand All @@ -560,6 +590,7 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
raise ValueError(name_no_exists_error % var.name)
else:
result_list.append(output_vars_dict[var.name])

return result_list


Expand Down Expand Up @@ -960,7 +991,9 @@ def save(layer, path, input_spec=None, **configs):
for var in paddle.utils.flatten(input_spec):
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.eager.Tensor, Variable)):
elif isinstance(
var, (core.eager.Tensor, Variable, paddle.pir.Value)
):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var)
)
Expand Down Expand Up @@ -991,6 +1024,7 @@ def save(layer, path, input_spec=None, **configs):
]

combine_vars = {}
combine_program = []
property_vals = [] # (value, key)
concrete_program = None
for attr_func in functions:
Expand All @@ -1017,6 +1051,7 @@ def save(layer, path, input_spec=None, **configs):
is_prim_infer=is_prim_infer,
)
)

elif 'forward' == attr_func:
if configs.skip_forward:
# do not jit.save forward function
Expand All @@ -1034,6 +1069,7 @@ def save(layer, path, input_spec=None, **configs):
input_spec=inner_input_spec,
full_graph=True,
)

concrete_program = (
static_forward.concrete_program_specify_input_spec(
with_hook=with_hook, is_prim_infer=is_prim_infer
Expand Down Expand Up @@ -1099,7 +1135,6 @@ def save(layer, path, input_spec=None, **configs):
for structured_name, var in dygraph_state_dict.items():
state_names_dict[var.name] = structured_name
state_var_dict[var.name] = var

# 3. share parameters from Layer to scope & record var info
with dygraph.guard():
if use_pir_api():
Expand Down Expand Up @@ -1181,25 +1216,38 @@ def save(layer, path, input_spec=None, **configs):
params_filename = (
file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX
)
file_prefix = file_prefix + '.' + attr_func
file_prefix = os.path.join(model_path, file_prefix)

path_prefix = file_prefix + '.' + attr_func
file_path = os.path.join(model_path, path_prefix)
with scope_guard(scope):
if not use_pir_api():
if use_pir_api():
value_map = paddle.pir.IrMapping()
clone_program = concrete_program.main_program.clone(value_map)
clone_input_vars = []
for v in input_vars:
if type(v) is paddle.static.InputSpec:
name = v.name
for op in clone_program.global_block().ops:
if (
op.name() == 'pd_op.data'
and op.attrs()["name"] == name
):
clone_input_vars.append(op.result(0))
else:
clone_input_vars.append(value_map.look_up(v))

clone_output_vars = [value_map.look_up(v) for v in output_vars]

else:
input_vars = [
concrete_program.main_program.global_block().var(name)
for name in input_var_names
]
clone_program = concrete_program.main_program.clone()
clone_input_vars = input_vars
clone_output_vars = output_vars
else:
value_map = paddle.pir.IrMapping()
clone_program = concrete_program.main_program.clone(value_map)
clone_input_vars = [value_map.look_up(v) for v in input_vars]
clone_output_vars = [value_map.look_up(v) for v in output_vars]

save_inference_model(
path_prefix=file_prefix,
path_prefix=file_path,
feed_vars=clone_input_vars,
fetch_vars=clone_output_vars,
executor=Executor(_current_expected_place()),
Expand All @@ -1209,12 +1257,23 @@ def save(layer, path, input_spec=None, **configs):
)

if combine_params:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)
if use_pir_api():
# NOTE(Ruting): concrete_program has been pruned when init partialProgramLayer,
# so we do not neet to prune again.

for var in concrete_program.main_program.list_vars():
if var.persistable:
combine_vars[var.name] = var
# NOTE(Ruting): concrete_program will delete after this loop item,
# value delete at the same time, so we use list to Extend its lifecycle
combine_program.append(concrete_program.main_program)
else:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)

# save shared params
if combine_params:
Expand All @@ -1226,16 +1285,25 @@ def save(layer, path, input_spec=None, **configs):

params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable, ordered_vars
)
),
filename=params_filename,
)
if use_pir_api():
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=ordered_vars,
filename=params_filename,
)
else:
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable,
ordered_vars,
)
),
filename=params_filename,
)
# save property
property_save_path = os.path.join(
os.path.normpath(model_path), file_prefix + INFER_PROPERTY_SUFFIX
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def __init__(
assert isinstance(self._build_strategy, BuildStrategy)

self._origin_main_program = self._verify_program(main_program)
if parameters is not None:
parameters[0][:] = self._params
parameters[1][:] = self._param_values
with paddle.base.framework._dygraph_guard(paddle.base.dygraph.Tracer()):
self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = ""
Expand Down

0 comments on commit 6fcfe6a

Please sign in to comment.