Skip to content

Commit

Permalink
Add data structure enum to Python bindings (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
s0l0ist committed Jan 5, 2023
1 parent 5031425 commit 458d101
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 28 deletions.
32 changes: 32 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
# Version 1.0.1

Feat:

- Update the Python bindings to take in an optional `DataStructure` argument for
`CreateSetupMessage`. This allows the user to customize the behavior of the
backing datastructure - i.e. to select between the default (`GCS`) or specify
`BloomFilter`. The previous behavior always selected `GCS` so if the parameter
is omitted, the behavior will remain the same.

Ex:

```python
import private_set_intersection.python as psi

c = psi.client.CreateWithNewKey(...)
s = psi.server.CreateWithNewKey(...)

#...

# Defaults to GCS
s.CreateSetupMessage(fpr, len(client_items), server_items)

# Same as above
s.CreateSetupMessage(fpr, len(client_items), server_items, psi.DataStructure.GCS)

# Specify BloomFilter
s.CreateSetupMessage(fpr, len(client_items), server_items, psi.DataStructure.BloomFilter)

# ...
```

# Version 1.0.0

Breaking:
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@openmined/psi.js",
"version": "1.0.0",
"version": "1.0.1",
"description": "Private Set Intersection for JavaScript",
"repository": {
"type": "git",
Expand Down
12 changes: 10 additions & 2 deletions private_set_intersection/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
Request,
Response,
)
from enum import Enum

__version__ = psi.__version__


class DataStructure(Enum):
GCS = psi.data_structure.GCS
BLOOM_FILTER = psi.data_structure.BloomFilter


class client:
def __init__(self, data: psi.cpp_client):
"""Constructor method for the client object.
Expand Down Expand Up @@ -125,17 +131,18 @@ def CreateFromKey(cls, key_bytes: bytes, reveal_intersection: bool):
return cls(psi.cpp_server.CreateFromKey(key_bytes, reveal_intersection))

def CreateSetupMessage(
self, fpr: float, num_client_inputs: int, inputs: List[str]
self, fpr: float, num_client_inputs: int, inputs: List[str], ds=DataStructure.GCS
) -> ServerSetup:
"""Create a setup message from the server's dataset to be sent to the client.
Args:
fpr: the probability that any query of size `num_client_inputs` will result in a false positive.
num_client_inputs: Client set size.
inputs: Server items.
ds: The underlying data structure to use. Defaults to GCS.
Returns:
A Protobuf with the setup message.
"""
interm_msg = self.data.CreateSetupMessage(fpr, num_client_inputs, inputs).save()
interm_msg = self.data.CreateSetupMessage(fpr, num_client_inputs, inputs, ds.value).save()
msg = ServerSetup()
msg.ParseFromString(interm_msg)
return msg
Expand Down Expand Up @@ -164,6 +171,7 @@ def GetPrivateKeyBytes(self) -> bytes:


__all__ = [
"DataStructure",
"client",
"server",
"ServerSetup",
Expand Down
27 changes: 15 additions & 12 deletions private_set_intersection/python/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def test_client_create_request(cnt, reveal_intersection, benchmark):
benchmark(helper_client_create_request, cnt, reveal_intersection)


def helper_client_process_response(cnt, reveal_intersection):
def helper_client_process_response(cnt, reveal_intersection, ds):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

fpr = 1.0 / 1000000
inputs = ["Element " + str(i) for i in range(cnt)]
req = c.CreateRequest(inputs)

setup = s.CreateSetupMessage(fpr, len(inputs), inputs)
setup = s.CreateSetupMessage(fpr, len(inputs), inputs, ds)
request = c.CreateRequest(inputs)
resp = s.ProcessRequest(request)
if reveal_intersection:
Expand All @@ -34,40 +34,43 @@ def helper_client_process_response(cnt, reveal_intersection):

@pytest.mark.parametrize("cnt", [1, 10, 100, 1000, 10000])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_client_process_response(cnt, reveal_intersection, benchmark):
benchmark(helper_client_process_response, cnt, reveal_intersection)
@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
def test_client_process_response(cnt, reveal_intersection, ds, benchmark):
benchmark(helper_client_process_response, cnt, reveal_intersection, ds)


def helper_server_setup(cnt, fpr, reveal_intersection):
def helper_server_setup(cnt, fpr, reveal_intersection, ds):
s = psi.server.CreateWithNewKey(reveal_intersection)
items = ["Element " + str(2 * i) for i in range(cnt)]
setup = s.CreateSetupMessage(fpr, 10000, items)
setup = s.CreateSetupMessage(fpr, 10000, items, ds)


@pytest.mark.parametrize("cnt", [1, 10, 100, 1000, 10000])
@pytest.mark.parametrize("fpr", [0.001, 0.000001])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_server_setup(cnt, fpr, reveal_intersection, benchmark):
benchmark(helper_server_setup, cnt, fpr, reveal_intersection)
@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
def test_server_setup(cnt, fpr, reveal_intersection, ds, benchmark):
benchmark(helper_server_setup, cnt, fpr, reveal_intersection, ds)


def helper_server_process_request(cnt, reveal_intersection):
def helper_server_process_request(cnt, reveal_intersection, ds):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

fpr = 1.0 / 1000000
inputs = ["Element " + str(i) for i in range(cnt)]
req = c.CreateRequest(inputs)

setup = s.CreateSetupMessage(fpr, len(inputs), inputs)
setup = s.CreateSetupMessage(fpr, len(inputs), inputs, ds)
request = c.CreateRequest(inputs)
resp = s.ProcessRequest(request)


@pytest.mark.parametrize("cnt", [1, 10, 100, 1000, 10000])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_server_process_request(cnt, reveal_intersection, benchmark):
benchmark(helper_server_process_request, cnt, reveal_intersection)
@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
def test_server_process_request(cnt, reveal_intersection, ds, benchmark):
benchmark(helper_server_process_request, cnt, reveal_intersection, ds)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions private_set_intersection/python/psi_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ void bind(pybind11::module& m) {
"Filters";

m.attr("__version__") = ::private_set_intersection::Package::kVersion;

py::enum_<psi::DataStructure>(m, "data_structure", py::arithmetic())
.value("GCS", psi::DataStructure::GCS)
.value("BloomFilter", psi::DataStructure::BloomFilter);

py::class_<psi_proto::ServerSetup>(m, "cpp_proto_server_setup")
.def(py::init<>())
.def("load", [](psi_proto::ServerSetup& obj,
Expand Down Expand Up @@ -147,9 +152,9 @@ void bind(pybind11::module& m) {
.def(
"CreateSetupMessage",
[](const psi::PsiServer& obj, double fpr, int64_t num_client_inputs,
const std::vector<std::string>& inputs) {
const std::vector<std::string>& inputs, psi::DataStructure ds) {
return throwOrReturn(
obj.CreateSetupMessage(fpr, num_client_inputs, inputs));
obj.CreateSetupMessage(fpr, num_client_inputs, inputs, ds));
},
py::call_guard<py::gil_scoped_release>())
.def(
Expand Down
25 changes: 15 additions & 10 deletions private_set_intersection/python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def test_sanity(reveal_intersection):
assert c != None


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
@pytest.mark.parametrize("duplicate", [False, True])
def test_client_server(reveal_intersection, duplicate):
def test_client_server(ds, reveal_intersection, duplicate):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

Expand All @@ -30,7 +31,7 @@ def test_client_server(reveal_intersection, duplicate):

fpr = 1.0 / (1000000000)
setup = dup(
duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items), psi.ServerSetup()
duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items, ds), psi.ServerSetup()
)
request = dup(duplicate, c.CreateRequest(client_items), psi.Request())
resp = dup(duplicate, s.ProcessRequest(request), psi.Response())
Expand Down Expand Up @@ -82,16 +83,17 @@ def test_client_sanity(reveal_intersection):
assert key == newkey


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_server_client(reveal_intersection):
def test_server_client(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
request = c.CreateRequest(client_items)
resp = s.ProcessRequest(request)

Expand All @@ -109,14 +111,15 @@ def test_server_client(reveal_intersection):
assert intersection <= (1.1 * len(client_items) / 2.0)


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_serialization_setup_msg(reveal_intersection):
def test_serialization_setup_msg(ds, reveal_intersection):
s = psi.server.CreateWithNewKey(reveal_intersection)

server_items = ["Element " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
setup = s.CreateSetupMessage(fpr, 1000, server_items)
setup = s.CreateSetupMessage(fpr, 1000, server_items, ds)

buff = setup.SerializeToString()
recreated = psi.ServerSetup()
Expand All @@ -139,16 +142,17 @@ def test_serialization_request(reveal_intersection):
assert request.reveal_intersection == recreated.reveal_intersection


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_serialization_response(reveal_intersection):
def test_serialization_response(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
req = c.CreateRequest(client_items)
resp = s.ProcessRequest(req)

Expand All @@ -160,16 +164,17 @@ def test_serialization_response(reveal_intersection):
assert resp.encrypted_elements == recreated.encrypted_elements


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_empty_intersection(reveal_intersection):
def test_empty_intersection(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Other " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
request = c.CreateRequest(client_items)
resp = s.ProcessRequest(request)

Expand Down
2 changes: 1 addition & 1 deletion tools/package.bzl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
""" Version of the current release """
VERSION_LABEL = "1.0.0"
VERSION_LABEL = "1.0.1"

0 comments on commit 458d101

Please sign in to comment.