Skip to content

Commit

Permalink
The previous version uses tf.nest.map_structure to apply `add_noise…
Browse files Browse the repository at this point in the history
…` to a `tf.RaggedTensor`. This causes a bug when used in tensorflow federated because `tf.nest.map_structure` will also map `add_noise` to the tensor for shape information in `tf.RaggedTensor`. This causes failure when tff conducts automatic type conversion.

Also use fixed random seed to avoid flaky timeouts and testing failures.

PiperOrigin-RevId: 384573740
  • Loading branch information
tensorflower-gardener committed Jul 13, 2021
1 parent 7f44b02 commit 2cafe28
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
49 changes: 30 additions & 19 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py
Expand Up @@ -25,10 +25,10 @@
"""
import distutils
import math
import attr
from typing import Optional

import attr
import tensorflow as tf

from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation

Expand Down Expand Up @@ -442,16 +442,20 @@ def _loop_body(i, h):
return tree


def _get_add_noise(stddev):
def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):

# The seed should be only used for testing purpose.
if seed is not None:
tf.random.set_seed(seed)

def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev)
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)

def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
Expand All @@ -478,17 +482,16 @@ class GlobalState(object):
"""Class defining global state for `CentralTreeSumQuery`.
Attributes:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
"""
stddev = attr.ib()
arity = attr.ib()
l1_bound = attr.ib()

def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `CentralTreeSumQuery`.
Args:
Expand All @@ -497,15 +500,17 @@ def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed

def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return CentralTreeSumQuery.GlobalState(
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)

def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
Expand Down Expand Up @@ -536,10 +541,9 @@ def get_noised_result(self, sample_state, global_state):
The jth node on the ith layer of the tree can be accessed by tree[i][j]
where tree is the returned value.
"""
add_noise = _get_add_noise(self._stddev)
tree = _build_tree_from_leaf(sample_state, global_state.arity)
return tf.nest.map_structure(
add_noise, tree, expand_composites=True), global_state
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(sample_state, self._arity)
return tf.map_fn(add_noise, tree), global_state


class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
Expand Down Expand Up @@ -577,18 +581,25 @@ class GlobalState(object):
arity = attr.ib()
l1_bound = attr.ib()

def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
def __init__(self,
stddev: float,
arity: int = 2,
l1_bound: int = 10,
seed: Optional[int] = None):
"""Initializes the `DistributedTreeSumQuery`.
Args:
stddev: The stddev of the noise added to each node in the tree.
arity: The branching factor of the tree.
l1_bound: An upper bound on the L1 norm of the input record. This is
needed to bound the sensitivity and deploy differential privacy.
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
test purpose.
"""
self._stddev = stddev
self._arity = arity
self._l1_bound = l1_bound
self._seed = seed

def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
Expand Down Expand Up @@ -628,9 +639,9 @@ def preprocess_record(self, params, record):
use_norm=l1_norm)
preprocessed_record = preprocessed_record[0]

add_noise = _get_add_noise(self._stddev)
add_noise = _get_add_noise(self._stddev, self._seed)
tree = _build_tree_from_leaf(preprocessed_record, arity)
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
noisy_tree = tf.map_fn(add_noise, tree)

# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
Expand Down
34 changes: 10 additions & 24 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py
Expand Up @@ -502,21 +502,15 @@ def test_get_noised_result(self, arity, record, expected_tree):
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state_list = []
for _ in range(1000):
sample_state, _ = query.get_noised_result(preprocessed_record,
global_state)
sample_state_list.append(sample_state.flat_values.numpy())
expectation = np.mean(sample_state_list, axis=0)
variance = np.std(sample_state_list, axis=0)

self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)

sample_state, _ = query.get_noised_result(preprocessed_record, global_state)

self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
sample_state.flat_values, expected_tree, atol=3 * stddev)

@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
Expand Down Expand Up @@ -556,8 +550,7 @@ def test_initial_global_state_type(self):
def test_derive_sample_params(self):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
global_state = query.initial_global_state()
stddev, arity, l1_bound = query.derive_sample_params(
global_state)
stddev, arity, l1_bound = query.derive_sample_params(global_state)
self.assertAllClose(stddev, NOISE_STD)
self.assertAllClose(arity, 2)
self.assertAllClose(l1_bound, 10)
Expand Down Expand Up @@ -587,21 +580,14 @@ def test_preprocess_record(self, arity, record, expected_tree):
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
)
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
query = tree_aggregation_query.DistributedTreeSumQuery(
stddev=stddev, seed=0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)

preprocessed_record_list = []
for _ in range(1000):
preprocessed_record = query.preprocess_record(params, record)
preprocessed_record_list.append(preprocessed_record.numpy())

expectation = np.mean(preprocessed_record_list, axis=0)
variance = np.std(preprocessed_record_list, axis=0)
preprocessed_record = query.preprocess_record(params, record)

self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
self.assertAllClose(
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)

@parameterized.named_parameters(
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
Expand Down

0 comments on commit 2cafe28

Please sign in to comment.