diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 86a9fb7c51..93357eda98 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -31,7 +31,7 @@ @contextmanager def trace_call(name, session, extra_attributes=None): - if not HAS_OPENTELEMETRY_INSTALLED: + if not HAS_OPENTELEMETRY_INSTALLED or not session: # empty context manager. users will have to check if the generated value is None or a span yield None return diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 56e3ebaedc..0b5ee1d894 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -41,10 +41,7 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N """ resume_token = b"" item_buffer = [] - if trace_name and session: - with trace_call(trace_name, session, attributes): - iterator = restart() - else: + with trace_call(trace_name, session, attributes): iterator = restart() while True: try: @@ -55,10 +52,7 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N break except ServiceUnavailable: del item_buffer[:] - if trace_name and session: - with trace_call(trace_name, session, attributes): - iterator = restart(resume_token=resume_token) - else: + with trace_call(trace_name, session, attributes): iterator = restart(resume_token=resume_token) continue diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index bbf676f2ce..80da63a3fd 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -279,7 +279,7 @@ def batch_update(self, statements): trace_attributes = { # Get just the queries from the DML statement batch - "db.statement": [statement[0] for statement in statements] + "db.statement": [statement["sql"] for statement in parsed] } with trace_call("CloudSpanner.DMLTransaction", self._session, trace_attributes): response = api.execute_batch_dml( diff --git a/tests/_helpers.py b/tests/_helpers.py index 2b013d8108..c4f90d8653 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -6,7 +6,9 @@ from opentelemetry.trace.status import StatusCanonicalCode from opentelemetry.sdk.trace import TracerProvider, export - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) HAS_OPENTELEMETRY_INSTALLED = True except ImportError: @@ -14,6 +16,7 @@ StatusCanonicalCode = mock.Mock() + class OpenTelemetryBase(unittest.TestCase): def setUp(self): if HAS_OPENTELEMETRY_INSTALLED: diff --git a/tests/unit/test__opentelemetry_tracing.py b/tests/unit/test__opentelemetry_tracing.py index 85d27a3553..8e26468dfe 100644 --- a/tests/unit/test__opentelemetry_tracing.py +++ b/tests/unit/test__opentelemetry_tracing.py @@ -6,8 +6,6 @@ try: from opentelemetry import trace as trace_api from opentelemetry.trace.status import StatusCanonicalCode - from opentelemetry.sdk.trace import TracerProvider, export - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter except ImportError: pass @@ -16,6 +14,7 @@ from tests._helpers import OpenTelemetryBase, HAS_OPENTELEMETRY_INSTALLED + def _make_rpc_error(error_cls, trailing_metadata=None): import grpc @@ -29,8 +28,10 @@ def _make_session(): return mock.Mock(autospec=Session, instance=True) + # Skip all of these tests if we don't have OpenTelemetry if HAS_OPENTELEMETRY_INSTALLED: + class TestNoTracing(unittest.TestCase): def setUp(self): self._temp_opentelemetry = sys.modules["opentelemetry"] @@ -46,7 +47,6 @@ def test_no_trace_call(self): with _opentelemetry_tracing.trace_call("Test", _make_session()) as no_span: self.assertIsNone(no_span) - class TestTracing(OpenTelemetryBase): def test_trace_call(self): extra_attributes = { @@ -75,9 +75,7 @@ def test_trace_call(self): self.assertEqual(span.kind, trace_api.SpanKind.CLIENT) self.assertEqual(span.attributes, expected_attributes) self.assertEqual(span.name, "CloudSpanner.Test") - self.assertEqual( - span.status.canonical_code, trace_api.status.StatusCanonicalCode.OK - ) + self.assertEqual(span.status.canonical_code, StatusCanonicalCode.OK) def test_trace_error(self): extra_attributes = {"db.instance": "database_name"} diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 17b2ce4688..e95b9e1a06 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -15,7 +15,11 @@ import google.api_core.gapic_v1.method import mock -from tests._helpers import OpenTelemetryBase, StatusCanonicalCode, HAS_OPENTELEMETRY_INSTALLED +from tests._helpers import ( + OpenTelemetryBase, + StatusCanonicalCode, + HAS_OPENTELEMETRY_INSTALLED, +) def _make_rpc_error(error_cls, trailing_metadata=None): @@ -43,7 +47,7 @@ class TestSession(OpenTelemetryBase): BASE_ATTRIBUTES = { "db.type": "spanner", "db.url": "spanner.googleapis.com:443", - "db.instance": "projects/project-id/instances/instance-id/databases/database-id", + "db.instance": DATABASE_NAME, "net.host.name": "spanner.googleapis.com:443", } @@ -1112,7 +1116,9 @@ def _time(_results=[1, 1.5]): with mock.patch("opentelemetry.util.time", _ConstantTime()): with mock.patch("time.sleep") as sleep_mock: with self.assertRaises(Aborted): - session.run_in_transaction(unit_of_work, "abc", timeout_secs=1) + session.run_in_transaction( + unit_of_work, "abc", timeout_secs=1 + ) else: with mock.patch("time.sleep") as sleep_mock: with self.assertRaises(Aborted): diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 17e10c7a9e..5c53ee6a0e 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -15,7 +15,11 @@ import google.api_core.gapic_v1.method import mock -from tests._helpers import OpenTelemetryBase, StatusCanonicalCode, HAS_OPENTELEMETRY_INSTALLED +from tests._helpers import ( + OpenTelemetryBase, + StatusCanonicalCode, + HAS_OPENTELEMETRY_INSTALLED, +) TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"]