Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cnnradams committed Jul 15, 2020
1 parent 70a9012 commit bee0e58
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/_opentelemetry_tracing.py
Expand Up @@ -31,7 +31,7 @@

@contextmanager
def trace_call(name, session, extra_attributes=None):
if not HAS_OPENTELEMETRY_INSTALLED:
if not HAS_OPENTELEMETRY_INSTALLED or not session:
# empty context manager. users will have to check if the generated value is None or a span
yield None
return
Expand Down
10 changes: 2 additions & 8 deletions google/cloud/spanner_v1/snapshot.py
Expand Up @@ -41,10 +41,7 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
"""
resume_token = b""
item_buffer = []
if trace_name and session:
with trace_call(trace_name, session, attributes):
iterator = restart()
else:
with trace_call(trace_name, session, attributes):
iterator = restart()
while True:
try:
Expand All @@ -55,10 +52,7 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
break
except ServiceUnavailable:
del item_buffer[:]
if trace_name and session:
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
else:
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
continue

Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/transaction.py
Expand Up @@ -279,7 +279,7 @@ def batch_update(self, statements):

trace_attributes = {
# Get just the queries from the DML statement batch
"db.statement": [statement[0] for statement in statements]
"db.statement": [statement["sql"] for statement in parsed]
}
with trace_call("CloudSpanner.DMLTransaction", self._session, trace_attributes):
response = api.execute_batch_dml(
Expand Down
5 changes: 4 additions & 1 deletion tests/_helpers.py
Expand Up @@ -6,14 +6,17 @@
from opentelemetry.trace.status import StatusCanonicalCode

from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)

HAS_OPENTELEMETRY_INSTALLED = True
except ImportError:
HAS_OPENTELEMETRY_INSTALLED = False

StatusCanonicalCode = mock.Mock()


class OpenTelemetryBase(unittest.TestCase):
def setUp(self):
if HAS_OPENTELEMETRY_INSTALLED:
Expand Down
10 changes: 4 additions & 6 deletions tests/unit/test__opentelemetry_tracing.py
Expand Up @@ -6,8 +6,6 @@
try:
from opentelemetry import trace as trace_api
from opentelemetry.trace.status import StatusCanonicalCode
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
except ImportError:
pass

Expand All @@ -16,6 +14,7 @@

from tests._helpers import OpenTelemetryBase, HAS_OPENTELEMETRY_INSTALLED


def _make_rpc_error(error_cls, trailing_metadata=None):
import grpc

Expand All @@ -29,8 +28,10 @@ def _make_session():

return mock.Mock(autospec=Session, instance=True)


# Skip all of these tests if we don't have OpenTelemetry
if HAS_OPENTELEMETRY_INSTALLED:

class TestNoTracing(unittest.TestCase):
def setUp(self):
self._temp_opentelemetry = sys.modules["opentelemetry"]
Expand All @@ -46,7 +47,6 @@ def test_no_trace_call(self):
with _opentelemetry_tracing.trace_call("Test", _make_session()) as no_span:
self.assertIsNone(no_span)


class TestTracing(OpenTelemetryBase):
def test_trace_call(self):
extra_attributes = {
Expand Down Expand Up @@ -75,9 +75,7 @@ def test_trace_call(self):
self.assertEqual(span.kind, trace_api.SpanKind.CLIENT)
self.assertEqual(span.attributes, expected_attributes)
self.assertEqual(span.name, "CloudSpanner.Test")
self.assertEqual(
span.status.canonical_code, trace_api.status.StatusCanonicalCode.OK
)
self.assertEqual(span.status.canonical_code, StatusCanonicalCode.OK)

def test_trace_error(self):
extra_attributes = {"db.instance": "database_name"}
Expand Down
12 changes: 9 additions & 3 deletions tests/unit/test_session.py
Expand Up @@ -15,7 +15,11 @@

import google.api_core.gapic_v1.method
import mock
from tests._helpers import OpenTelemetryBase, StatusCanonicalCode, HAS_OPENTELEMETRY_INSTALLED
from tests._helpers import (
OpenTelemetryBase,
StatusCanonicalCode,
HAS_OPENTELEMETRY_INSTALLED,
)


def _make_rpc_error(error_cls, trailing_metadata=None):
Expand Down Expand Up @@ -43,7 +47,7 @@ class TestSession(OpenTelemetryBase):
BASE_ATTRIBUTES = {
"db.type": "spanner",
"db.url": "spanner.googleapis.com:443",
"db.instance": "projects/project-id/instances/instance-id/databases/database-id",
"db.instance": DATABASE_NAME,
"net.host.name": "spanner.googleapis.com:443",
}

Expand Down Expand Up @@ -1112,7 +1116,9 @@ def _time(_results=[1, 1.5]):
with mock.patch("opentelemetry.util.time", _ConstantTime()):
with mock.patch("time.sleep") as sleep_mock:
with self.assertRaises(Aborted):
session.run_in_transaction(unit_of_work, "abc", timeout_secs=1)
session.run_in_transaction(
unit_of_work, "abc", timeout_secs=1
)
else:
with mock.patch("time.sleep") as sleep_mock:
with self.assertRaises(Aborted):
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/test_snapshot.py
Expand Up @@ -15,7 +15,11 @@

import google.api_core.gapic_v1.method
import mock
from tests._helpers import OpenTelemetryBase, StatusCanonicalCode, HAS_OPENTELEMETRY_INSTALLED
from tests._helpers import (
OpenTelemetryBase,
StatusCanonicalCode,
HAS_OPENTELEMETRY_INSTALLED,
)

TABLE_NAME = "citizens"
COLUMNS = ["email", "first_name", "last_name", "age"]
Expand Down

0 comments on commit bee0e58

Please sign in to comment.