Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce sqlite tests flakiness #8651

Merged
merged 24 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion packages/syft/src/syft/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def __init__(
self.root_verify_key = root_verify_key
self.settings = settings
self.store_config = store_config
self.init_store()
res = self.init_store()
if res.is_err():
raise Exception(f"Something went wrong initializing the store: {res.err()}")

store_config.locking_config.lock_name = f"StorePartition-{settings.name}"
self.lock = SyftLock(store_config.locking_config)
Expand Down
23 changes: 8 additions & 15 deletions packages/syft/tests/syft/stores/kv_document_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def test_kv_store_partition_init_failed(root_verify_key) -> None:
store_config = MockStoreConfig(is_crashed=True)
settings = PartitionSettings(name="test", object_type=MockObjectType)

kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)

res = kv_store_partition.init_store()
assert res.is_err()
with pytest.raises(Exception):
kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)


def test_kv_store_partition_set(
Expand Down Expand Up @@ -81,15 +79,10 @@ def test_kv_store_partition_set_backend_fail(root_verify_key) -> None:
store_config = MockStoreConfig(is_crashed=True)
settings = PartitionSettings(name="test", object_type=MockObjectType)

kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)
kv_store_partition.init_store()

obj = MockSyftObject(data=1)

res = kv_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
assert res.is_err()
with pytest.raises(Exception):
kv_store_partition = KeyValueStorePartition(
UID(), root_verify_key, settings=settings, store_config=store_config
)


def test_kv_store_partition_delete(
Expand Down
166 changes: 24 additions & 142 deletions packages/syft/tests/syft/stores/queue_stash_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# stdlib
import threading
from threading import Thread
import time
import traceback
from typing import Any

# third party
Expand Down Expand Up @@ -275,10 +277,12 @@ def helper_queue_set_threading(root_verify_key, create_queue_cbk) -> None:
repeats = 5

execution_err = None
lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue = create_queue_cbk()
with lock:
queue = create_queue_cbk()

for _ in range(repeats):
obj = mock_queue_object()
Expand Down Expand Up @@ -308,69 +312,20 @@ def _kv_cbk(tid: int) -> None:
assert len(queue) == thread_cnt * repeats


# def helper_queue_set_joblib(root_verify_key, create_queue_cbk) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# queue = create_queue_cbk()
# for _ in range(repeats):
# worker_pool_obj = WorkerPool(
# name="mypool",
# image_id=UID(),
# max_count=0,
# worker_list=[],
# )
# linked_worker_pool = LinkedObject.from_obj(
# worker_pool_obj,
# node_uid=UID(),
# service_type=SyftWorkerPoolService,
# )
# obj = QueueItem(
# id=UID(),
# node_uid=UID(),
# method="dummy_method",
# service="dummy_service",
# args=[],
# kwargs={},
# worker_pool=linked_worker_pool,
# )
# for _ in range(10):
# res = queue.set(root_verify_key, obj, ignore_duplicates=False)
# if res.is_ok():
# break

# if res.is_err():
# return res
# return None

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )

# for execution_err in errs:
# assert execution_err is None

# queue = create_queue_cbk()
# assert len(queue) == thread_cnt * repeats


@pytest.mark.parametrize("backend", [helper_queue_set_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_set_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_set_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_set_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_set_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_set_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_threading(root_verify_key, create_queue_cbk)


def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None:
Expand All @@ -383,10 +338,12 @@ def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None:
obj = mock_queue_object()
queue.set(root_verify_key, obj, ignore_duplicates=False)
execution_err = None
lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue_local = create_queue_cbk()
with lock:
queue_local = create_queue_cbk()

for repeat in range(repeats):
obj.args = [repeat]
Expand All @@ -413,53 +370,20 @@ def _kv_cbk(tid: int) -> None:
assert execution_err is None


# def helper_queue_update_joblib(root_verify_key, create_queue_cbk) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# queue_local = create_queue_cbk()

# for repeat in range(repeats):
# obj.args = [repeat]

# for _ in range(10):
# res = queue_local.update(root_verify_key, obj)
# if res.is_ok():
# break

# if res.is_err():
# return res
# return None

# queue = create_queue_cbk()

# obj = mock_queue_object()
# queue.set(root_verify_key, obj, ignore_duplicates=False)

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )
# for execution_err in errs:
# assert execution_err is None


@pytest.mark.parametrize("backend", [helper_queue_update_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_update_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_update_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_update_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_update_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_update_threading(root_verify_key, create_queue_cbk)


def helper_queue_set_delete_threading(
Expand All @@ -480,9 +404,12 @@ def helper_queue_set_delete_threading(

assert res.is_ok()

lock = threading.Lock()

def _kv_cbk(tid: int) -> None:
nonlocal execution_err
queue = create_queue_cbk()
with lock:
queue = create_queue_cbk()
for idx in range(repeats):
item_idx = tid * repeats + idx

Expand All @@ -509,62 +436,17 @@ def _kv_cbk(tid: int) -> None:
assert len(queue) == 0


# def helper_queue_set_delete_joblib(
# root_verify_key,
# create_queue_cbk,
# ) -> None:
# thread_cnt = 3
# repeats = 5

# def _kv_cbk(tid: int) -> None:
# nonlocal execution_err
# queue = create_queue_cbk()
# for idx in range(repeats):
# item_idx = tid * repeats + idx

# for _ in range(10):
# res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id)
# if res.is_ok():
# break

# if res.is_err():
# execution_err = res
# assert res.is_ok()

# queue = create_queue_cbk()
# execution_err = None
# objs = []

# for _ in range(repeats * thread_cnt):
# obj = mock_queue_object()
# res = queue.set(root_verify_key, obj, ignore_duplicates=False)
# objs.append(obj)

# assert res.is_ok()

# errs = Parallel(n_jobs=thread_cnt)(
# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
# )

# for execution_err in errs:
# assert execution_err is None

# assert len(queue) == 0


@pytest.mark.parametrize("backend", [helper_queue_set_delete_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace, backend):
def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace):
def create_queue_cbk():
return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)


@pytest.mark.parametrize("backend", [helper_queue_set_delete_threading])
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store, backend):
def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store):
def create_queue_cbk():
return mongo_queue_stash_fn(mongo_document_store)

backend(root_verify_key, create_queue_cbk)
helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)