Skip to content

Commit

Permalink
Merge pull request #8651 from OpenMined/aziz/sqlitest
Browse files Browse the repository at this point in the history
reduce sqlite tests flakiness
  • Loading branch information
abyesilyurt committed Apr 29, 2024
2 parents 2c8521d + 94f51c3 commit b1c7bae
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 158 deletions.
4 changes: 3 additions & 1 deletion packages/syft/src/syft/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def __init__(
self.root_verify_key = root_verify_key
self.settings = settings
self.store_config = store_config
self.init_store()
res = self.init_store()
if res.is_err():
raise Exception(f"Something went wrong initializing the store: {res.err()}")

store_config.locking_config.lock_name = f"StorePartition-{settings.name}"
self.lock = SyftLock(store_config.locking_config)
Expand Down
23 changes: 8 additions & 15 deletions packages/syft/tests/syft/stores/kv_document_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def test_kv_store_partition_init_failed(root_verify_key) -> None:
store_config = MockStoreConfig(is_crashed=True)
settings = PartitionSettings(name="test", object_type=MockObjectType)

kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)

res = kv_store_partition.init_store()
assert res.is_err()
with pytest.raises(Exception):
kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)


def test_kv_store_partition_set(
Expand Down Expand Up @@ -81,15 +79,10 @@ def test_kv_store_partition_set_backend_fail(root_verify_key) -> None:
store_config = MockStoreConfig(is_crashed=True)
settings = PartitionSettings(name="test", object_type=MockObjectType)

kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)
kv_store_partition.init_store()

obj = MockSyftObject(data=1)

res = kv_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
assert res.is_err()
with pytest.raises(Exception):
kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)


def test_kv_store_partition_delete(
Expand Down
166 changes: 24 additions & 142 deletions packages/syft/tests/syft/stores/queue_stash_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# stdlib
import threading
from threading import Thread
import time
import traceback
from typing import Any

# third party
Expand Down Expand Up @@ -275,10 +277,12 @@ def helper_queue_set_threading(root_verify_key, create_queue_cbk) -> None:
repeats = 5

execution_err = None
lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue = create_queue_cbk()
with lock:
queue = create_queue_cbk()

for _ in range(repeats):
obj = mock_queue_object()
Expand Down Expand Up @@ -308,69 +312,20 @@ def _kv_cbk(tid: int) -> None:
assert len(queue) == thread_cnt * repeats


# def helper_queue_set_joblib(root_verify_key, create_queue_cbk) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# queue = create_queue_cbk()
# for _ in range(repeats):
# worker_pool_obj = WorkerPool(
# name="mypool",
# image_id=UID(),
# max_count=0,
# worker_list=[],
# )
# linked_worker_pool = LinkedObject.from_obj(
# worker_pool_obj,
# node_uid=UID(),
# service_type=SyftWorkerPoolService,
# )
# obj = QueueItem(
# id=UID(),
# node_uid=UID(),
# method="dummy_method",
# service="dummy_service",
# args=[],
# kwargs={},
# worker_pool=linked_worker_pool,
# )
# for _ in range(10):
# res = queue.set(root_verify_key, obj, ignore_duplicates=False)
# if res.is_ok():
# break

# if res.is_err():
# return res
# return None

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )

# for execution_err in errs:
# assert execution_err is None

# queue = create_queue_cbk()
# assert len(queue) == thread_cnt * repeats


@pytest.mark.parametrize("backend", [helper_queue_set_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_set_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_set_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_set_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_set_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_set_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_threading(root_verify_key, create_queue_cbk)


def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None:
Expand All @@ -383,10 +338,12 @@ def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None:
obj = mock_queue_object()
queue.set(root_verify_key, obj, ignore_duplicates=False)
execution_err = None
lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue_local = create_queue_cbk()
with lock:
queue_local = create_queue_cbk()

for repeat in range(repeats):
obj.args = [repeat]
Expand All @@ -413,53 +370,20 @@ def _kv_cbk(tid: int) -> None:
assert execution_err is None


# def helper_queue_update_joblib(root_verify_key, create_queue_cbk) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# queue_local = create_queue_cbk()

# for repeat in range(repeats):
# obj.args = [repeat]

# for _ in range(10):
# res = queue_local.update(root_verify_key, obj)
# if res.is_ok():
# break

# if res.is_err():
# return res
# return None

# queue = create_queue_cbk()

# obj = mock_queue_object()
# queue.set(root_verify_key, obj, ignore_duplicates=False)

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )
# for execution_err in errs:
# assert execution_err is None


@pytest.mark.parametrize("backend", [helper_queue_update_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_update_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_update_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_update_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_update_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_update_threading(root_verify_key, create_queue_cbk)


def helper_queue_set_delete_threading(
Expand All @@ -480,9 +404,12 @@ def helper_queue_set_delete_threading(

assert res.is_ok()

lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue = create_queue_cbk()
with lock:
queue = create_queue_cbk()
for idx in range(repeats):
item_idx = tid * repeats + idx

Expand All @@ -509,62 +436,17 @@ def _kv_cbk(tid: int) -> None:
assert len(queue) == 0


# def helper_queue_set_delete_joblib(
# root_verify_key,
# create_queue_cbk,
# ) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# nonlocal execution_err
# queue = create_queue_cbk()
# for idx in range(repeats):
# item_idx = tid * repeats + idx

# for _ in range(10):
# res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id)
# if res.is_ok():
# break

# if res.is_err():
# execution_err = res
# assert res.is_ok()

# queue = create_queue_cbk()
# execution_err = None
# objs = []

# for _ in range(repeats * thread_cnt):
# obj = mock_queue_object()
# res = queue.set(root_verify_key, obj, ignore_duplicates=False)
# objs.append(obj)

# assert res.is_ok()

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )

# for execution_err in errs:
# assert execution_err is None

# assert len(queue) == 0


@pytest.mark.parametrize("backend", [helper_queue_set_delete_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_set_delete_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)

0 comments on commit b1c7bae

Please sign in to comment.