Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype: Revert cl/445467631. #4550

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def create_null_federated_secure_modular_sum():
building_blocks.Struct([]), placements.CLIENTS
),
building_blocks.Struct([]),
preapply_modulus=False,
)


Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ py_library(
deps = [
":building_blocks",
":intrinsic_defs",
":tensorflow_computation_factory",
":transformation_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
import string
from typing import Optional, Union

import tensorflow as tf

from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory
from tensorflow_federated.python.core.impl.compiler import transformation_utils
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
Expand Down Expand Up @@ -759,49 +756,23 @@ def create_federated_max(
return building_blocks.Call(intrinsic, value)


def _cast(
comp: building_blocks.ComputationBuildingBlock,
type_signature: computation_types.TensorType,
) -> building_blocks.Call:
"""Casts `comp` to the provided type."""

def cast_fn(value):

def cast_element(element, type_signature: computation_types.TensorType):
return tf.cast(element, type_signature.dtype)

if isinstance(comp.type_signature, computation_types.StructType):
return structure.map_structure(cast_element, value, type_signature)
return cast_element(value, type_signature)

cast_proto, cast_type = tensorflow_computation_factory.create_unary_operator(
cast_fn, comp.type_signature
)
cast_comp = building_blocks.CompiledComputation(
cast_proto, type_signature=cast_type
)
return building_blocks.Call(cast_comp, comp)


def create_federated_secure_modular_sum(
value: building_blocks.ComputationBuildingBlock,
modulus: building_blocks.ComputationBuildingBlock,
preapply_modulus: bool = True,
) -> building_blocks.ComputationBuildingBlock:
r"""Creates a called secure modular sum.

Call
/ \
Intrinsic [Comp, Comp]

Args:
value: A `building_blocks.ComputationBuildingBlock` to use as the value.
modulus: A `building_blocks.ComputationBuildingBlock` to use as the
`modulus` value.
preapply_modulus: Whether or not to preapply `modulus` to the input `value`.
This can be `False` if `value` is guaranteed to already be in range.

Returns:
A computation building block which invokes `federated_secure_modular_sum`.

Raises:
TypeError: If any of the types do not match.
A `building_blocks.Call`.
"""
py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock)
py_typecheck.check_type(modulus, building_blocks.ComputationBuildingBlock)
Expand All @@ -820,50 +791,8 @@ def create_federated_secure_modular_sum(
intrinsic_defs.FEDERATED_SECURE_MODULAR_SUM.uri, intrinsic_type
)

if not preapply_modulus:
values = building_blocks.Struct([value, modulus])
return building_blocks.Call(intrinsic, values)

# Pre-insert a modulus to ensure the the input values are within range.
mod_ref = building_blocks.Reference('mod', modulus.type_signature)

# In order to run `tf.math.floormod`, our modulus and value must be the same
# type.
casted_mod = _cast(
mod_ref,
value.type_signature.member, # pytype: disable=attribute-error
)
# Since in the preapply_modulus case the modulus is expected to be available
# at the client as well as at the server for aggregation, we need to broadcast
# the modulus to be able to avoid repeating the modulus value (which could
# cause accuracy issues if the modulus is non-deterministic).
casted_mod_at_server = create_federated_value(casted_mod, placements.SERVER)
value_with_mod = create_federated_zip(
building_blocks.Struct(
[value, create_federated_broadcast(casted_mod_at_server)]
)
)

def structural_modulus(value, mod):
return structure.map_structure(tf.math.floormod, value, mod)

structural_modulus_proto, structural_modulus_type = (
tensorflow_computation_factory.create_binary_operator(
structural_modulus,
value.type_signature.member, # pytype: disable=attribute-error
casted_mod.type_signature,
)
)
structural_modulus_tf = building_blocks.CompiledComputation(
structural_modulus_proto, type_signature=structural_modulus_type
)
value_modded = create_federated_map_or_apply(
structural_modulus_tf, value_with_mod
)
values = building_blocks.Struct([value_modded, mod_ref])
return building_blocks.Block(
[('mod', modulus)], building_blocks.Call(intrinsic, values)
)
values = building_blocks.Struct([value, modulus])
return building_blocks.Call(intrinsic, values)


def create_federated_secure_sum(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,6 @@ def test_splits_on_intrinsic_with_args_from_original_arg(self):
),
index=0,
),
preapply_modulus=False,
)
comp = building_blocks.Lambda('arg', arg_type, intrinsic_call)

Expand Down Expand Up @@ -1679,7 +1678,6 @@ def test_splits_on_multiple_intrinsics(self):
),
index=0,
),
preapply_modulus=False,
)
)
block_locals = [
Expand Down