Skip to content

Commit

Permalink
Add a new tutorial on encrypted inference on resnet18 (#4539)
Browse files Browse the repository at this point in the history
* Improve native.encrypt()

* Fix typo

* Improve batchnorm by caching inverse results

* Add support for encrypt / decrypt on nn.Module

* Add early version of tutorial

* Fix AST serialization to send the protocol attribute

* Silence __del__ error at program termination

* Add serialiation of the data_centric_fl_client

* Update the resnet tutorial and rename folder

* Add demo for George and Alex

* Add working version of the tutorial with websockets

* Respond to comments on the tutorial

* Clean def encrypt

* Fix tests for AST serde

* Exlcude the notebook from the CI

* Blend functional.py inverse in FPT.reciprocal

* Update the tutorial

* Update tutorial

* Make N_Process dynamic depending on the machine

* UPdate tuto

* up
  • Loading branch information
LaRiffle committed Sep 11, 2020
1 parent 51ec7ad commit 88c2606
Show file tree
Hide file tree
Showing 25 changed files with 580 additions and 42 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"source": [
"epochs = 10\n",
"# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers\n",
"n_train_items = 1280\n",
"n_test_items = 1280"
"n_train_items = 640\n",
"n_test_items = 640"
]
},
{
Expand Down
File renamed without changes.
34 changes: 31 additions & 3 deletions syft/frameworks/torch/hook/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def module_send_(nn_self, *dest, force_send=False, **kwargs):

def module_move_(nn_self, destination):

params = list(nn_self.parameters())
for p in params:
p.move_(destination)
for element_iter in tensor_iterator(nn_self):
for p in element_iter():
p.move_(destination)

self.torch.nn.Module.move = module_move_

Expand Down Expand Up @@ -707,6 +707,34 @@ def module_get_(nn_self):
self.torch.nn.Module.get_ = module_get_
self.torch.nn.Module.get = module_get_

def module_encrypt_(nn_self, **kwargs):
"""Overloads fix_precision for torch.nn.Module."""
if module_is_missing_grad(nn_self):
create_grad_objects(nn_self)

for element_iter in tensor_iterator(nn_self):
for p in element_iter():
p.encrypt(inplace=True, **kwargs)

return nn_self

self.torch.nn.Module.encrypt_ = module_encrypt_
self.torch.nn.Module.encrypt = module_encrypt_

def module_decrypt_(nn_self):
"""Overloads fix_precision for torch.nn.Module."""
if module_is_missing_grad(nn_self):
create_grad_objects(nn_self)

for element_iter in tensor_iterator(nn_self):
for p in element_iter():
p.decrypt(inplace=True)

return nn_self

self.torch.nn.Module.decrypt_ = module_decrypt_
self.torch.nn.Module.decrypt = module_decrypt_

def module_share_(nn_self, *args, **kwargs):
"""Overloads fix_precision for torch.nn.Module."""
if module_is_missing_grad(nn_self):
Expand Down
2 changes: 1 addition & 1 deletion syft/frameworks/torch/mpc/fss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def full_name(f):
COMP = 1

# number of processes
N_CORES = 8
N_CORES = max(4, multiprocessing.cpu_count())
MULTI_LIMIT = 50_000


Expand Down
12 changes: 10 additions & 2 deletions syft/frameworks/torch/mpc/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get_keys(self, op: str, n_instances: int = 1, remove: bool = True, **kwargs)
f"n_instances={n_instances}"
)
raise EmptyCryptoPrimitiveStoreError(
self, available_instances, n_instances=n_instances, op=op, **kwargs
self,
available_instances=available_instances,
n_instances=n_instances,
op=op,
**kwargs,
)
elif op in {"fss_eq", "fss_comp"}:
available_instances = len(primitive_stack[0]) if len(primitive_stack) > 0 else -1
Expand Down Expand Up @@ -147,7 +151,11 @@ def get_keys(self, op: str, n_instances: int = 1, remove: bool = True, **kwargs)
f"n_instances={n_instances}"
)
raise EmptyCryptoPrimitiveStoreError(
self, available_instances, n_instances=n_instances, op=op, **kwargs
self,
available_instances=available_instances,
n_instances=n_instances,
op=op,
**kwargs,
)

def provide_primitives(
Expand Down
12 changes: 1 addition & 11 deletions syft/frameworks/torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,7 @@ def batch_norm(
mean = running_mean
var = running_var

x = None

C = 20

for i in range(80):
if x is not None:
y = C + 1 - var * (x * x)
x = y * x / C
else:
y = C + 1 - var
x = y / C
x = var.reciprocal(method="newton")

normalized = x * (input - mean)
result = normalized * weight + bias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ def simplify(worker: AbstractWorker, tensor: "AdditiveSharingTensor") -> tuple:
return (
_simplify(tensor.id),
_simplify(tensor.field),
_simplify(tensor.protocol),
tensor.dtype.encode("utf-8"),
_simplify(tensor.crypto_provider.id),
chain,
Expand All @@ -1191,14 +1192,15 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "AdditiveSharingTenso
"""
_detail = lambda x: sy.serde.msgpack.serde._detail(worker, x)

tensor_id, field, dtype, crypto_provider, chain, garbage_collect = tensor_tuple
tensor_id, field, protocol, dtype, crypto_provider, chain, garbage_collect = tensor_tuple

crypto_provider = _detail(crypto_provider)

tensor = AdditiveSharingTensor(
owner=worker,
id=_detail(tensor_id),
field=_detail(field),
protocol=_detail(protocol),
dtype=dtype.decode("utf-8"),
crypto_provider=worker.get_worker(crypto_provider),
)
Expand Down
57 changes: 42 additions & 15 deletions syft/frameworks/torch/tensors/interpreters/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def get(self, *args, inplace: bool = False, user=None, reason: str = "", **kwarg
return tensor

if inplace:
self.set_(tensor)
self.set_(tensor.native_type(self.dtype))
if hasattr(tensor, "child"):
self.child = tensor.child
else:
Expand Down Expand Up @@ -990,15 +990,26 @@ def torch_type(self):
else:
return self.child.torch_type()

def encrypt(self, protocol="mpc", **kwargs):
def encrypt(self, protocol="mpc", inplace=False, **kwargs):
"""
This method will encrypt each value in the tensor using Multi Party
Computation (default) or Paillier Homomorphic Encryption
Args:
protocol (str): Currently supports 'mpc' for Multi Party
Computation and 'paillier' for Paillier Homomorphic Encryption
protocol (str): Currently supports the following crypto protocols:
- 'snn' for SecureNN
- 'fss' for Function Secret Sharing (see AriaNN paper)
- 'mpc' (Multi Party Computation) defaults to most standard protocol,
currently 'snn'
- 'paillier' for Paillier Homomorphic Encryption
inplace (bool): compute the operation inplace (default is False)
**kwargs:
With respect to Fixed Precision accepts:
precision_fractional (int)
dtype (str)
With Respect to MPC accepts:
workers (list): Parties involved in the sharing of the Tensor
crypto_provider (syft.VirtualWorker): Worker responsible for the
Expand All @@ -1019,22 +1030,33 @@ def encrypt(self, protocol="mpc", **kwargs):
NotImplementedError: If protocols other than the ones mentioned above are queried
"""
if protocol.lower() == "mpc":
protocol = protocol.lower()

if protocol in {"mpc", "snn", "fss"}:
if protocol == "mpc":
protocol = "snn"
workers = kwargs.pop("workers")
crypto_provider = kwargs.pop("crypto_provider")
requires_grad = kwargs.pop("requires_grad", False)
no_wrap = kwargs.pop("no_wrap", False)
dtype = kwargs.get("dtype")
kwargs_fix_prec = kwargs # Rest of kwargs for fix_prec method

x_shared = self.fix_prec(**kwargs_fix_prec).share(
*workers,
kwargs_share = dict(
crypto_provider=crypto_provider,
requires_grad=requires_grad,
no_wrap=no_wrap,
protocol=protocol,
dtype=dtype,
)
return x_shared

elif protocol.lower() == "paillier":
if not inplace:
x_shared = self.fix_prec(**kwargs_fix_prec).share(*workers, **kwargs_share)
return x_shared
else:
self.fix_prec_(**kwargs_fix_prec).share_(*workers, **kwargs_share)
return self

elif protocol == "paillier":
public_key = kwargs.get("public_key")

x = self.copy()
Expand All @@ -1046,15 +1068,16 @@ def encrypt(self, protocol="mpc", **kwargs):
else:
raise NotImplementedError(
"Currently the .encrypt() method only supports Paillier Homomorphic "
"Encryption and Secure Multi-Party Computation"
f"Encryption and Secure Multi-Party Computation, but {protocol} was given"
)

def decrypt(self, **kwargs):
def decrypt(self, inplace=False, **kwargs):
"""
This method will decrypt each value in the tensor using Multi Party
Computation (default) or Paillier Homomorphic Encryption
Args:
inplace (bool): compute the operation inplace (default is False)
**kwargs:
With Respect to MPC accepts:
None
Expand All @@ -1075,9 +1098,13 @@ def decrypt(self, **kwargs):
warnings.warn("protocol should no longer be used in decrypt")

if isinstance(self.child, (syft.FixedPrecisionTensor, syft.AutogradTensor)):
x_encrypted = self.copy()
x_decrypted = x_encrypted.get().float_prec()
return x_decrypted
if not inplace:
x_encrypted = self.copy()
x_decrypted = x_encrypted.get().float_prec()
return x_decrypted
else:
self.get_().float_prec_()
return self

elif isinstance(self.child, PaillierTensor):
# self.copy() not required as PaillierTensor's decrypt method is not inplace
Expand Down
20 changes: 17 additions & 3 deletions syft/frameworks/torch/tensors/interpreters/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,30 @@ def reciprocal(self, method="NR", nr_iters=10):
Returns:
Reciprocal of `self`
"""
method = method.lower()

if method.lower() == "nr":
if method == "nr":
new_self = self.modulus()
result = 3 * (0.5 - new_self).exp() + 0.003
for i in range(nr_iters):
result = 2 * result - result * result * new_self
return result * self.signum()
elif method.lower() == "division":
elif method == "newton":
# it is assumed here that input values are taken in [-20, 20]
x = None
C = 20
for i in range(80):
if x is not None:
y = C + 1 - self * (x * x)
x = y * x / C
else:
y = C + 1 - self
x = y / C
return x
elif method == "division":
ones = self * 0 + 1
return ones / self
elif method.lower() == "log":
elif method == "log":
new_self = self.modulus()
return (-new_self.log()).exp() * self.signum()
else:
Expand Down Expand Up @@ -939,6 +952,7 @@ def share_(self, *args, **kwargs):
dtype == self.dtype
), "When sharing a FixedPrecisionTensor, the dtype of the resulting AdditiveSharingTensor \
must be the same as the one of the original tensor"
kwargs.pop("no_wrap", None)
self.child = self.child.share_(*args, no_wrap=True, **kwargs)
return self

Expand Down
6 changes: 5 additions & 1 deletion syft/generic/pointers/object_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union
from typing import TYPE_CHECKING
import weakref
from websocket._exceptions import WebSocketConnectionClosedException

import syft
from syft import exceptions
Expand Down Expand Up @@ -340,7 +341,10 @@ def __del__(self):
if hasattr(self, "owner") and self.garbage_collect_data:
# attribute pointers are not in charge of GC
if self.point_to_attr is None:
self.owner.garbage(self.id_at_location, self.location)
try:
self.owner.garbage(self.id_at_location, self.location)
except (BrokenPipeError, WebSocketConnectionClosedException):
pass

def _create_attr_name_string(self, attr_name):
if self.point_to_attr is not None:
Expand Down
9 changes: 6 additions & 3 deletions syft/generic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ class memorize(dict):
def __init__(self, func):
self.func = func

def __call__(self, *args):
return self[args]
def __call__(self, *args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
return self[key]

def __missing__(self, key):
result = self[key] = self.func(*key)
args, kwargs = key
kwargs = {k: v for k, v in kwargs}
result = self[key] = self.func(*args, **kwargs)
return result


Expand Down
2 changes: 2 additions & 0 deletions test/notebooks/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
exclusion_list_notebooks = [
# Part 10 needs either torch.log2 to be implemented or numpy to be hooked
"Part 10 - Federated Learning with Secure Aggregation.ipynb",
# Part 11 bis needs a lot of RAM and runs for > 300s for sure
"Part 11 bis - Encrypted inference on ResNet-18.ipynb",
# Part 13b and c need fixing of the tensorflow serving with PySyft
"Part 13b - Secure Classification with Syft Keras and TFE - Secure Model Serving.ipynb",
"Part 13c - Secure Classification with Syft Keras and TFE - Private Prediction Client.ipynb",
Expand Down
2 changes: 2 additions & 0 deletions test/serde/serde_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def compare(detailed, original):
)
assert detailed.id == original.id
assert detailed.field == original.field
assert detailed.protocol == original.protocol
assert detailed.child.keys() == original.child.keys()
return True

Expand All @@ -552,6 +553,7 @@ def compare(detailed, original):
(CODE[str], (str(ast.field).encode("utf-8"),))
if ast.field == 2 ** 64
else ast.field, # (int or str) field
(CODE[str], (str(ast.protocol).encode("utf-8"),)), # (str) protocol
ast.dtype.encode("utf-8"),
(CODE[str], (ast.crypto_provider.id.encode("utf-8"),)), # (str) worker_id
msgpack.serde._simplify(
Expand Down

0 comments on commit 88c2606

Please sign in to comment.