Skip to content

Commit

Permalink
modify new case
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoguoguo626807 committed May 6, 2024
1 parent 4e5d6e2 commit aa8d7ae
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 109 deletions.
77 changes: 49 additions & 28 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def save(layer, path, input_spec=None, **configs):
]

combine_vars = {}
combine_program = []
property_vals = [] # (value, key)
concrete_program = None
for attr_func in functions:
Expand All @@ -1050,6 +1051,7 @@ def save(layer, path, input_spec=None, **configs):
is_prim_infer=is_prim_infer,
)
)

elif 'forward' == attr_func:
if configs.skip_forward:
# do not jit.save forward function
Expand All @@ -1067,6 +1069,7 @@ def save(layer, path, input_spec=None, **configs):
input_spec=inner_input_spec,
full_graph=True,
)

concrete_program = (
static_forward.concrete_program_specify_input_spec(
with_hook=with_hook, is_prim_infer=is_prim_infer
Expand Down Expand Up @@ -1132,7 +1135,6 @@ def save(layer, path, input_spec=None, **configs):
for structured_name, var in dygraph_state_dict.items():
state_names_dict[var.name] = structured_name
state_var_dict[var.name] = var

# 3. share parameters from Layer to scope & record var info
with dygraph.guard():
if use_pir_api():
Expand Down Expand Up @@ -1216,18 +1218,8 @@ def save(layer, path, input_spec=None, **configs):
)
file_prefix = file_prefix + '.' + attr_func
file_prefix = os.path.join(model_path, file_prefix)

with scope_guard(scope):
if not use_pir_api():
input_vars = [
concrete_program.main_program.global_block().var(name)
for name in input_var_names
]
clone_program = concrete_program.main_program.clone()
clone_input_vars = input_vars
clone_output_vars = output_vars

else:
if use_pir_api():
value_map = paddle.pir.IrMapping()
clone_program = concrete_program.main_program.clone(value_map)
clone_input_vars = []
Expand All @@ -1245,6 +1237,15 @@ def save(layer, path, input_spec=None, **configs):

clone_output_vars = [value_map.look_up(v) for v in output_vars]

else:
input_vars = [
concrete_program.main_program.global_block().var(name)
for name in input_var_names
]
clone_program = concrete_program.main_program.clone()
clone_input_vars = input_vars
clone_output_vars = output_vars

save_inference_model(
path_prefix=file_prefix,
feed_vars=clone_input_vars,
Expand All @@ -1256,12 +1257,23 @@ def save(layer, path, input_spec=None, **configs):
)

if combine_params:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)
if use_pir_api():
# NOTE(Ruting): concrete_program has been pruned when init partialProgramLayer,
# so we do not neet to prune again.

for var in concrete_program.main_program.list_vars():
if var.persistable:
combine_vars[var.name] = var
# NOTE(Ruting): concrete_program will delete after this loop item,
# value delete at the same time, so we use list to Extend its lifecycle
combine_program.append(concrete_program.main_program)
else:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars
)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)

# save shared params
if combine_params:
Expand All @@ -1273,16 +1285,25 @@ def save(layer, path, input_spec=None, **configs):

params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable, ordered_vars
)
),
filename=params_filename,
)
if use_pir_api():
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=ordered_vars,
filename=params_filename,
)
else:
paddle.static.save_vars(
Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(
paddle.framework.io_utils.is_persistable,
ordered_vars,
)
),
filename=params_filename,
)
# save property
property_save_path = os.path.join(
os.path.normpath(model_path), file_prefix + INFER_PROPERTY_SUFFIX
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def __init__(
assert isinstance(self._build_strategy, BuildStrategy)

self._origin_main_program = self._verify_program(main_program)
if parameters is not None:
parameters[0][:] = self._params
parameters[1][:] = self._param_values
with paddle.base.framework._dygraph_guard(paddle.base.dygraph.Tracer()):
self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = ""
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/static/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from paddle.base import (
CompiledProgram,
Variable,
default_main_program,
)


Expand Down Expand Up @@ -66,7 +65,7 @@ def _get_valid_program(program=None):
return default main program if program is None.
"""
if program is None:
program = default_main_program()
program = paddle.static.default_main_program()
elif isinstance(program, CompiledProgram):
program = program._program
if program is None:
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300)
set_tests_properties(test_callback_wandb PROPERTIES TIMEOUT 60)
set_tests_properties(test_jit_save_load_rename PROPERTIES TIMEOUT 50)
set_tests_properties(test_jit_save_load_rename PROPERTIES TIMEOUT 100)
if(WITH_COVERAGE)
set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300)
endif()
Expand Down

0 comments on commit aa8d7ae

Please sign in to comment.