Skip to content

Commit

Permalink
Do not require dictionary keys to be str (#1886)
Browse files Browse the repository at this point in the history
* Unbreak task packing

When submitting a task group, only attempt to upload task inputs
corresponding to nodes external to the task group since only those
will have been resolved.

* Lock parent job record when persisting sublattices

* Dictionary collector nodes no longer need keys to be `str`

The collector electron now assembles the dictionary from two lists --
one list of keys, one list of corresponding values.

* Fix tests

* Changelog
  • Loading branch information
cjao committed Feb 28, 2024
1 parent e6ad647 commit d7841c7
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@ jobs:
if: env.BUILD_AND_RUN_ALL
id: covalent_start
run: |
export COVALENT_ENABLE_TASK_PACKING=1
covalent db migrate
if [ "${{ matrix.backend }}" = 'dask' ] ; then
COVALENT_ENABLE_TASK_PACKING=1 covalent start -d
covalent start -d
elif [ "${{ matrix.backend }}" = 'local' ] ; then
covalent start --no-cluster -d
else
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Sublattice electron function strings are now parsed correctly
- The keys of dictionary inputs to electrons no longer need be strings.
- Fixed inaccuracies in task packing exposed by no longer uploading null attributes upon dispatch.

### Operations

Expand Down
6 changes: 3 additions & 3 deletions covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,16 +572,16 @@ def _auto_list_node(*args, **kwargs):

elif isinstance(param_value, dict):

def _auto_dict_node(*args, **kwargs):
return dict(kwargs)
def _auto_dict_node(keys, values):
return {keys[i]: values[i] for i in range(len(keys))}

dict_electron = Electron(
function=_auto_dict_node,
metadata=collection_metadata,
task_group_id=self.task_group_id,
packing_tasks=True and active_lattice.task_packing,
) # Group the auto-generated node with the main node.
bound_electron = dict_electron(**param_value)
bound_electron = dict_electron(list(param_value.keys()), list(param_value.values()))
transport_graph.set_node_value(bound_electron.node_id, "name", electron_dict_prefix)
transport_graph.add_edge(
dict_electron.node_id,
Expand Down
15 changes: 13 additions & 2 deletions covalent_dispatcher/_core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro
app_log.debug("8A: Update node success (run_planned_workflow).")

else:
# Nodes whose values have already been resolved
known_nodes = []

# Skip the group if all task outputs can be reused from a
Expand All @@ -196,6 +197,8 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro
# Gather inputs for each task and send the task spec sequence to the runner
task_specs = []

sorted_nodes_set = set(sorted_nodes)

for node_id in sorted_nodes:
app_log.debug(f"Gathering inputs for task {node_id} (run_planned_workflow).")

Expand All @@ -214,8 +217,16 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro
"args_ids": abs_task_input["args"],
"kwargs_ids": abs_task_input["kwargs"],
}
known_nodes += abs_task_input["args"]
known_nodes += list(abs_task_input["kwargs"].values())
# Task inputs that don't belong to the task group have already beeen resolved
external_task_args = filter(
lambda x: x not in sorted_nodes_set, abs_task_input["args"]
)
known_nodes.extend(external_task_args)
external_task_kwargs = filter(
lambda x: x not in sorted_nodes_set, abs_task_input["kwargs"].values()
)
known_nodes.extend(external_task_kwargs)

task_specs.append(task_spec)

app_log.debug(
Expand Down
2 changes: 1 addition & 1 deletion covalent_dispatcher/_dal/importers/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def import_result(
# Main case: insert new lattice, electron, edge, and job records

storage_path = os.path.join(base_path, dispatch_id)
os.makedirs(storage_path)

lattice_record_kwargs = _get_result_meta(res, storage_path, electron_id)
lattice_record_kwargs.update(_get_lattice_meta(res.lattice, storage_path))
Expand Down Expand Up @@ -143,6 +142,7 @@ def _connect_result_to_electron(
fields={"id", "cancel_requested"},
equality_filters={"id": parent_electron_record.job_id},
membership_filters={},
for_update=True,
)[0]
cancel_requested = parent_job_record.cancel_requested

Expand Down
4 changes: 3 additions & 1 deletion covalent_dispatcher/_dal/importers/tg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def import_transport_graph(
# Propagate parent electron id's `cancel_requested` property to the sublattice electrons
if electron_id is not None:
parent_e_record = Electron.meta_type.get_by_primary_key(session, electron_id)
job_record = Job.get_by_primary_key(session=session, primary_key=parent_e_record.job_id)
job_record = Job.get_by_primary_key(
session=session, primary_key=parent_e_record.job_id, for_update=True
)
cancel_requested = job_record.cancel_requested
else:
cancel_requested = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def list_workflow(arg):

@ct.lattice
def dict_workflow(arg):
return dict_task(arg)
return dict_task(arg=arg)

# 1 2
# \ \
Expand Down Expand Up @@ -159,7 +159,7 @@ async def mock_get_incoming_edges(dispatch_id, node_id):

# dict-type inputs

# Nodes 0=task, 1=:electron_dict:, 2=1, 3=2
# Nodes 0=task, 1=:electron_dict:, 2=["a" (3), "b" (4)], 5=[1 (6), 2 (7)]
dict_workflow.build_graph({"a": 1, "b": 2})
abstract_args = {"a": 2, "b": 3}
tg = dict_workflow.transport_graph
Expand All @@ -172,10 +172,31 @@ async def mock_get_incoming_edges(dispatch_id, node_id):
mock_get_incoming_edges,
)

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 0, tg.get_node_value(0, "name")
)
expected_inputs = {"args": [], "kwargs": {"arg": 1}}

assert task_inputs == expected_inputs

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 1, tg.get_node_value(1, "name")
)
expected_inputs = {"args": [], "kwargs": abstract_args}
expected_inputs = {"args": [2, 5], "kwargs": {}}

assert task_inputs == expected_inputs

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 2, tg.get_node_value(2, "name")
)
expected_inputs = {"args": [3, 4], "kwargs": {}}

assert task_inputs == expected_inputs

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 5, tg.get_node_value(5, "name")
)
expected_inputs = {"args": [6, 7], "kwargs": {}}

assert task_inputs == expected_inputs

Expand Down
28 changes: 22 additions & 6 deletions tests/covalent_dispatcher_tests/_core/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def list_workflow(arg):

@ct.lattice
def dict_workflow(arg):
return dict_task(arg)
return dict_task(arg=arg)

# 1 2
# \ \
Expand Down Expand Up @@ -167,20 +167,36 @@ def multivar_workflow(x, y):
# dict-type inputs

dict_workflow.build_graph({"a": 1, "b": 2})
serialized_args = {"a": ct.TransportableObject(1), "b": ct.TransportableObject(2)}

# Nodes 0=task, 1=:electron_dict:, 2=1, 3=2
# Nodes 0=task, 1=:electron_dict:, 2=["a" (3), "b" (4)], 5=[1 (6), 2 (7)]

sdkres = Result(lattice=dict_workflow, dispatch_id="asdf_dict_workflow")
result_object = get_mock_srvresult(sdkres, test_db)
tg = result_object.lattice.transport_graph
tg.set_node_value(2, "output", ct.TransportableObject(1))
tg.set_node_value(3, "output", ct.TransportableObject(2))

tg.set_node_value(1, "output", ct.TransportableObject("node_1_output"))
tg.set_node_value(3, "output", ct.TransportableObject("a"))
tg.set_node_value(4, "output", ct.TransportableObject("b"))
tg.set_node_value(6, "output", ct.TransportableObject(1))
tg.set_node_value(7, "output", ct.TransportableObject(2))

mock_get_result = mocker.patch(
"covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object
)
task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object)
expected_inputs = {"args": [], "kwargs": serialized_args}
serialized_kwargs = {"arg": ct.TransportableObject("node_1_output")}
task_inputs = await _get_task_inputs(0, tg.get_node_value(0, "name"), result_object)
expected_inputs = {"args": [], "kwargs": serialized_kwargs}

serialized_args = [ct.TransportableObject("a"), ct.TransportableObject("b")]
task_inputs = await _get_task_inputs(2, tg.get_node_value(2, "name"), result_object)
expected_inputs = {"args": serialized_args, "kwargs": {}}

assert task_inputs == expected_inputs

serialized_args = [ct.TransportableObject(1), ct.TransportableObject(2)]
task_inputs = await _get_task_inputs(5, tg.get_node_value(5, "name"), result_object)
expected_inputs = {"args": serialized_args, "kwargs": {}}

assert task_inputs == expected_inputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from covalent._shared_files.schemas.result import AssetSchema, ResultSchema
from covalent._shared_files.util_classes import RESULT_STATUS
from covalent_dispatcher._dal.importers.result import SERVER_URL, handle_redispatch, import_result
from covalent_dispatcher._dal.job import Job
from covalent_dispatcher._dal.result import get_result_object
from covalent_dispatcher._db.datastore import DataStore

Expand Down Expand Up @@ -140,6 +141,7 @@ def test_import_previously_imported_result(mocker, test_db):
prefix="covalent-"
) as srv_dir:
sub_res = get_mock_result(sub_dispatch_id, sdk_dir)
sub_res.metadata.root_dispatch_id = dispatch_id
import_result(sub_res, srv_dir, None)
srv_res = get_result_object(dispatch_id, bare=True)
parent_node = srv_res.lattice.transport_graph.get_node(0)
Expand All @@ -152,6 +154,49 @@ def test_import_previously_imported_result(mocker, test_db):
assert sub_srv_res._electron_id == parent_node._electron_id


def test_import_subdispatch_cancel_req(mocker, test_db):
"""Test that Job.cancel_requested is propagated to sublattices"""

dispatch_id = "test_propagate_cancel_requested"
sub_dispatch_id = "test_propagate_cancel_requested_sub"

mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db)

mock_filter_uris = mocker.patch(
"covalent_dispatcher._dal.importers.result._filter_remote_uris"
)

with tempfile.TemporaryDirectory(prefix="covalent-") as sdk_dir, tempfile.TemporaryDirectory(
prefix="covalent-"
) as srv_dir:
res = get_mock_result(dispatch_id, sdk_dir)
import_result(res, srv_dir, None)

with test_db.Session() as session:
Job.update_bulk(
session, values={"cancel_requested": True}, equality_filters={}, membership_filters={}
)
session.commit()

with tempfile.TemporaryDirectory(prefix="covalent-") as sdk_dir, tempfile.TemporaryDirectory(
prefix="covalent-"
) as srv_dir:
sub_res = get_mock_result(sub_dispatch_id, sdk_dir)
sub_res.metadata.root_dispatch_id = dispatch_id
srv_res = get_result_object(dispatch_id, bare=True)
parent_node = srv_res.lattice.transport_graph.get_node(0)
import_result(sub_res, srv_dir, parent_node._electron_id)

with tempfile.TemporaryDirectory(prefix="covalent-") as srv_dir:
import_result(sub_res, srv_dir, parent_node._electron_id)

with test_db.Session() as session:
uncancelled = Job.get(
session, fields=[], equality_filters={"cancel_requested": False}, membership_filters={}
)
assert len(uncancelled) == 0


@pytest.mark.parametrize(
"parent_status,new_status",
[
Expand Down
28 changes: 20 additions & 8 deletions tests/covalent_tests/workflow/electron_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,30 @@ def workflow(x):
g = workflow.transport_graph._graph

# Account for postprocessing node
assert list(g.nodes) == [0, 1, 2, 3, 4]
assert list(g.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8]
fn = g.nodes[1]["function"].get_deserialized()
assert fn(x=2, y=5, z=7) == {"x": 2, "y": 5, "z": 7}
assert g.nodes[2]["value"].get_deserialized() == 5
assert g.nodes[3]["value"].get_deserialized() == 7
assert fn(["x", "y", "z"], [2, 5, 7]) == {"x": 2, "y": 5, "z": 7}
fn = g.nodes[2]["function"].get_deserialized()
assert fn("x", "y") == ["x", "y"]
keys = [g.nodes[3]["value"].get_deserialized(), g.nodes[4]["value"].get_deserialized()]
fn = g.nodes[5]["function"].get_deserialized()
assert fn(2, 3) == [2, 3]
vals = [g.nodes[6]["value"].get_deserialized(), g.nodes[7]["value"].get_deserialized()]
assert keys == ["x", "y"]
assert vals == [5, 7]
assert set(g.edges) == {
(1, 0, 0),
(2, 1, 0),
(3, 1, 0),
(0, 4, 0),
(0, 4, 1),
(1, 4, 0),
(3, 2, 0),
(4, 2, 0),
(5, 1, 0),
(6, 5, 0),
(7, 5, 0),
(0, 8, 0),
(0, 8, 1),
(1, 8, 0),
(2, 8, 0),
(5, 8, 0),
}


Expand Down
3 changes: 2 additions & 1 deletion tests/functional_tests/workflow_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def workflow(x):
res_1 = sum_values(x)
return square(res_1)

dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, "z": 3})
# Check that non-string keys are allowed
dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, 3: 3})

res_obj = rm.get_result(dispatch_id, wait=True)

Expand Down

0 comments on commit d7841c7

Please sign in to comment.