Skip to content

Commit

Permalink
transaction updates
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-rogers committed Apr 26, 2024
1 parent f253434 commit d13bc5b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
43 changes: 42 additions & 1 deletion python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

import whylogs as why
from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.api.writer.whylabs import WhyLabsWriter, WhyLabsTransaction
from whylogs.api.writer.whylabs_client import TransactionAbortedException, WhyLabsClient
from whylogs.api.writer.whylabs_transaction_writer import WhyLabsTransactionWriter
from whylogs.core import DatasetProfileView
Expand Down Expand Up @@ -344,6 +344,47 @@ def test_transaction_context():
assert deserialized_view is not None


@pytest.mark.load
def test_old_transaction_context():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
tids = list()
writer = WhyLabsWriter()
with WhyLabsTransaction(writer):
assert writer._transaction_id is not None
assert writer._whylabs_client._transaction_id == writer._transaction_id
for data in pdfs:
trace_id = str(uuid4())
tids.append(trace_id)
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = writer.write(result.profile())
if not status:
raise Exception() # or retry the profile...
status = writer.transaction_status()
assert len(status["files"]) == len(pdfs)

time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
for trace_id in tids:
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
download_url = response.get("traces")[0]["download_url"]
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view is not None


@pytest.mark.load
def test_transaction_context_aborted():
why.init(force_local=True)
Expand Down
13 changes: 3 additions & 10 deletions python/whylogs/api/writer/whylabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, List, Optional, Tuple, Union

from whylabs_client import ApiClient
from whylabs_client.exceptions import NotFoundException

from whylogs.api.whylabs.session.session_manager import INIT_DOCS
from whylogs.api.writer.whylabs_base import WhyLabsWriterBase
Expand Down Expand Up @@ -260,15 +259,9 @@ def __init__(self, writer: WhyLabsWriter):
self._writer = writer

def __enter__(self) -> str:
self._writer.start_transaction()
if self._writer._transaction_id is None:
self._writer.start_transaction()
return self._writer._transaction_id # type: ignore

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
id = self._writer._transaction_id
try:
self._writer.commit_transaction()
except NotFoundException as e:
if "Transaction has been aborted" in str(e): # TODO: perhaps not the most robust test?
logger.error(f"Transaction {id} was aborted; not committing")
else:
raise e
self._writer.commit_transaction()
5 changes: 5 additions & 0 deletions python/whylogs/api/writer/whylabs_transaction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def write(
return self._write_view(view, **kwargs)

def __enter__(self) -> "WhyLabsTransactionWriter":
if self.transaction_id is None:
self._transaction_id = self._whylabs_client.get_transaction_id()
self._whylabs_client._transaction_id = self._transaction_id

return self

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
id = self.transaction_id
self._transaction_id = None
self._whylabs_client._transaction_id = None # type: ignore
self._whylabs_client.commit_transaction(id) # type: ignore

0 comments on commit d13bc5b

Please sign in to comment.