Skip to content

Commit

Permalink
[PIR save/load]Migrate paddle.save, paddle.load and add program state…
Browse files Browse the repository at this point in the history
…_dict to pir (PaddlePaddle#63957)

* static.save and static.load passed

* refine

* fix pybind stop_gradient

* fix CI bug

* add create_loaded_params

* fix CI bug

* refine

* migrate paddle.save and paddle.load and add program state_dict to pir

* fix
  • Loading branch information
changeyoung98 authored and runzhech committed Apr 30, 2024
1 parent 0a64b8a commit 2d8063e
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 55 deletions.
112 changes: 84 additions & 28 deletions paddle/fluid/pybind/pir.cc
Expand Up @@ -194,6 +194,33 @@ Value GetOutputValueByName(const Program &program, const std::string &name) {
return value;
}

std::string GetValueName(Value value) {
if (auto param_op = value.defining_op<::pir::ParameterOp>()) {
return param_op.param_name();
} else if (auto data_op = value.defining_op<paddle::dialect::DataOp>()) {
return data_op.attribute<pir::StrAttribute>("name").AsString();
} else if (auto block_arg = value.dyn_cast<BlockArgument>()) {
if (block_arg.is_kwarg()) {
return block_arg.keyword();
} else {
return "arg_" + std::to_string(block_arg.index());
}
} else if (value.first_use()) {
auto nextOp = value.first_use().owner();
if (nextOp->isa<::pir::ShadowOutputOp>()) {
return nextOp->attribute<pir::StrAttribute>("output_name").AsString();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value which is "
"shadowoutput "));
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value that "
"is persistable"));
}
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(
*m, "Program", py::dynamic_attr(), R"DOC(
Expand Down Expand Up @@ -336,7 +363,63 @@ void BindProgram(py::module *m) {
[](Program &self, const std::string &name) {
return GetOutputValueByName(self, name);
})
.def("num_ops", [](Program &self) { return self.num_ops(); });
.def("num_ops", [](Program &self) { return self.num_ops(); })
.def(
"state_dict",
[](std::shared_ptr<Program> self,
const std::string &mode = "all",
const framework::Scope &scope = framework::Scope()) {
std::unordered_map<std::string, phi::DenseTensor> state_dict_all;
std::unordered_map<std::string, phi::DenseTensor> state_dict_param;
std::unordered_map<std::string, phi::DenseTensor> state_dict_opt;
for (auto op : self->block()->ops()) {
for (auto var : op->results()) {
auto is_persistable =
var.attribute<BoolAttribute>("persistable");
if (is_persistable && is_persistable.data()) {
if (var.defining_op()->isa<::pir::ParameterOp>()) {
std::string var_name = GetValueName(var);
auto tensor =
scope.FindVar(var_name)->GetMutable<phi::DenseTensor>();
state_dict_param[var_name] = *tensor;
state_dict_all[var_name] = *tensor;
} else if (var.defining_op()
->isa<paddle::dialect::DataOp>()) {
std::string var_name = GetValueName(var);
auto tensor =
scope.FindVar(var_name)->GetMutable<phi::DenseTensor>();
state_dict_opt[var_name] = *tensor;
state_dict_all[var_name] = *tensor;
}
}
}
}
if (mode == "all") {
return state_dict_all;
} else if (mode == "param") {
return state_dict_param;
} else if (mode == "opt") {
return state_dict_opt;
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The mode is not supported."));
}
})
.def("set_state_dict",
[](std::shared_ptr<Program> self,
const std::unordered_map<std::string, phi::DenseTensor>
&state_dict,
const framework::Scope &scope = framework::Scope()) {
for (auto item : state_dict) {
auto var = scope.FindVar(item.first);
if (var == nullptr) {
PADDLE_THROW(phi::errors::NotFound(
"The variable %s is not found.", item.first));
} else {
*var->GetMutable<phi::DenseTensor>() = item.second;
}
}
});
}

std::shared_ptr<Program> ParseProgram(const std::string &program_str) {
Expand Down Expand Up @@ -755,33 +838,6 @@ phi::DataType GetValueDtype(Value value) {
}
}

std::string GetValueName(Value value) {
if (auto param_op = value.defining_op<::pir::ParameterOp>()) {
return param_op.param_name();
} else if (auto data_op = value.defining_op<paddle::dialect::DataOp>()) {
return data_op.attribute<pir::StrAttribute>("name").AsString();
} else if (auto block_arg = value.dyn_cast<BlockArgument>()) {
if (block_arg.is_kwarg()) {
return block_arg.keyword();
} else {
return "arg_" + std::to_string(block_arg.index());
}
} else if (value.first_use()) {
auto nextOp = value.first_use().owner();
if (nextOp->isa<::pir::ShadowOutputOp>()) {
return nextOp->attribute<pir::StrAttribute>("output_name").AsString();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value which is "
"shadowoutput "));
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value that "
"is persistable"));
}
}

const phi::DDim &GetValueDims(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
Expand Down
15 changes: 11 additions & 4 deletions python/paddle/framework/io.py
Expand Up @@ -34,8 +34,10 @@
Variable,
_create_tensor,
_current_expected_place,
_current_expected_place_,
_dygraph_tracer,
in_dygraph_mode,
in_pir_mode,
)

from .io_utils import (
Expand Down Expand Up @@ -519,7 +521,7 @@ def _to_LodTensor(ndarray):
f'Type of `ndarray` should be numpy.ndarray, but received {type(ndarray)}.'
)
t = core.LoDTensor()
place = _current_expected_place()
place = _current_expected_place_()
t.set(ndarray, place)
return t

Expand Down Expand Up @@ -888,9 +890,14 @@ def save(obj, path, protocol=4, **configs):
)

if isinstance(obj, Program):
obj.desc.flush()
with _open_file_buffer(path, "wb") as f:
f.write(obj.desc.serialize_to_string())
if in_pir_mode():
paddle.core.serialize_pir_program(
obj, path, 1, True, False, True
)
else:
obj.desc.flush()
with _open_file_buffer(path, "wb") as f:
f.write(obj.desc.serialize_to_string())

elif _is_state_dict(obj):
if in_dygraph_mode():
Expand Down
54 changes: 53 additions & 1 deletion python/paddle/framework/io_utils.py
Expand Up @@ -18,11 +18,12 @@
import pickle
import sys
from io import BytesIO
from types import FunctionType, MethodType

import numpy as np

import paddle
from paddle.base import core
from paddle.base import core, global_scope
from paddle.base.framework import Parameter, Variable, static_only
from paddle.base.log_helper import get_logger
from paddle.base.wrapped_decorator import signature_safe_contextmanager
Expand Down Expand Up @@ -271,3 +272,54 @@ def _unpack_saved_dict(saved_obj, protocol):
saved_obj[part] = temp_saved_obj[part]
saved_obj['UnpackBigParamInfor@@'] = unpack_infor
return saved_obj


def set_value(var, value, scope=None):
if not (isinstance(value, np.ndarray) or hasattr(value, "__array__")):
raise TypeError(
f"`value` should be `numpy.ndarray` or `LoDTensor`, but received {type(value)}."
)

if scope is not None and not isinstance(scope, core._Scope):
raise TypeError(
f"`scope` should be None or `paddle.static.Scope` type, but received {type(scope)}."
)

if scope is None:
scope = global_scope()

var_temp = scope.find_var(var.name)
if var_temp is None:
raise ValueError(f"Can not find Variable '{var.name}' in the Scope.")

t = var_temp.get_tensor()

if hasattr(value, "shape"):
if isinstance(value.shape, (MethodType, FunctionType)):
value_shape = value.shape()
else:
value_shape = value.shape
if list(t.shape()) != list(value_shape):
raise ValueError(
f"{var.name} expected a shape {list(t.shape())}, but the received shape is {list(value_shape)}."
)

p = t._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
elif p.is_xpu_place():
p = core.Place()
p.set_place(t._place())
place = core.XPUPlace(p.xpu_device_id())
elif p.is_custom_place():
p = core.Place()
p.set_place(t._place())
place = core.CustomPlace(p.custom_device_type(), p.custom_device_id())
else:
p = core.Place()
p.set_place(t._place())
place = core.CUDAPlace(p.gpu_device_id())

t.set(value, place)
7 changes: 6 additions & 1 deletion python/paddle/static/io.py
Expand Up @@ -59,6 +59,7 @@
_safe_load_pickle,
)
from .pir_io import (
get_pir_parameters,
load_pir,
load_pir_inference_model,
load_vars_pir,
Expand Down Expand Up @@ -1731,7 +1732,11 @@ def set_program_state(program, state_dict):
>>> static.set_program_state(prog, program_state)
"""
state_dict = _pack_loaded_dict(state_dict)
parameter_list = list(filter(is_persistable, program.list_vars()))
if in_pir_mode():
params, opts = get_pir_parameters(program)
parameter_list = params + opts
else:
parameter_list = list(filter(is_persistable, program.list_vars()))

used_para_list = {}
for para in parameter_list:
Expand Down
46 changes: 46 additions & 0 deletions test/deprecated/legacy_test/test_static_save_load.py
Expand Up @@ -705,6 +705,7 @@ def set_place(self):
else base.CUDAPlace(0)
)

@test_with_pir_api
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
Expand Down Expand Up @@ -797,6 +798,11 @@ def test_ptb_rnn_cpu_float32(self):
base_map = {}
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -811,6 +817,11 @@ def test_ptb_rnn_cpu_float32(self):
# set var to zero
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
ten = base.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

Expand Down Expand Up @@ -841,6 +852,11 @@ def test_ptb_rnn_cpu_float32(self):

for var in test_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -850,6 +866,11 @@ def test_ptb_rnn_cpu_float32(self):
# check 1
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
ten = base.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

Expand All @@ -863,6 +884,11 @@ def test_ptb_rnn_cpu_float32(self):

for var in test_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -872,6 +898,11 @@ def test_ptb_rnn_cpu_float32(self):
# check 2
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
ten = base.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

Expand All @@ -885,6 +916,11 @@ def test_ptb_rnn_cpu_float32(self):

for var in test_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -894,6 +930,11 @@ def test_ptb_rnn_cpu_float32(self):
# check 3
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
ten = base.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

Expand All @@ -907,6 +948,11 @@ def test_ptb_rnn_cpu_float32(self):

for var in test_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if (
in_pir_mode()
and var.get_defining_op().name() == "pd_op.fetch"
):
continue
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand Down

0 comments on commit 2d8063e

Please sign in to comment.