Skip to content

Commit

Permalink
fix: request and flattened params are exclusive, surface transport in…
Browse files Browse the repository at this point in the history
… generated layer (#256)

- Restore path helper methods to generated clients.
- Enforce that 'request' argument to generated client methods is exclusive
  to flattened arguments.
- Surface 'transport' property for generated clients.

Closes #251
Closes #252
  • Loading branch information
tseaver committed Nov 13, 2020
1 parent ab19546 commit 386e85e
Show file tree
Hide file tree
Showing 16 changed files with 1,041 additions and 452 deletions.
Expand Up @@ -54,14 +54,58 @@ class FirestoreAdminAsyncClient:
DEFAULT_ENDPOINT = FirestoreAdminClient.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = FirestoreAdminClient.DEFAULT_MTLS_ENDPOINT

collection_group_path = staticmethod(FirestoreAdminClient.collection_group_path)
parse_collection_group_path = staticmethod(
FirestoreAdminClient.parse_collection_group_path
)
database_path = staticmethod(FirestoreAdminClient.database_path)
parse_database_path = staticmethod(FirestoreAdminClient.parse_database_path)
field_path = staticmethod(FirestoreAdminClient.field_path)
parse_field_path = staticmethod(FirestoreAdminClient.parse_field_path)
index_path = staticmethod(FirestoreAdminClient.index_path)
parse_index_path = staticmethod(FirestoreAdminClient.parse_index_path)

common_billing_account_path = staticmethod(
FirestoreAdminClient.common_billing_account_path
)
parse_common_billing_account_path = staticmethod(
FirestoreAdminClient.parse_common_billing_account_path
)

common_folder_path = staticmethod(FirestoreAdminClient.common_folder_path)
parse_common_folder_path = staticmethod(
FirestoreAdminClient.parse_common_folder_path
)

common_organization_path = staticmethod(
FirestoreAdminClient.common_organization_path
)
parse_common_organization_path = staticmethod(
FirestoreAdminClient.parse_common_organization_path
)

common_project_path = staticmethod(FirestoreAdminClient.common_project_path)
parse_common_project_path = staticmethod(
FirestoreAdminClient.parse_common_project_path
)

common_location_path = staticmethod(FirestoreAdminClient.common_location_path)
parse_common_location_path = staticmethod(
FirestoreAdminClient.parse_common_location_path
)

from_service_account_file = FirestoreAdminClient.from_service_account_file
from_service_account_json = from_service_account_file

@property
def transport(self) -> FirestoreAdminTransport:
"""Return the transport used by the client instance.
Returns:
FirestoreAdminTransport: The transport used by the client instance.
"""
return self._client.transport

get_transport_class = functools.partial(
type(FirestoreAdminClient).get_transport_class, type(FirestoreAdminClient)
)
Expand Down Expand Up @@ -166,7 +210,8 @@ async def create_index(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([parent, index]):
has_flattened_params = any([parent, index])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -250,7 +295,8 @@ async def list_indexes(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([parent]):
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -338,7 +384,8 @@ async def get_index(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([name]):
has_flattened_params = any([name])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -413,7 +460,8 @@ async def delete_index(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([name]):
has_flattened_params = any([name])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -496,7 +544,8 @@ async def get_field(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([name]):
has_flattened_params = any([name])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -598,7 +647,8 @@ async def update_field(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([field]):
has_flattened_params = any([field])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -689,7 +739,8 @@ async def list_fields(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([parent]):
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -790,7 +841,8 @@ async def export_documents(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([name]):
has_flattened_params = any([name])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down Expand Up @@ -890,7 +942,8 @@ async def import_documents(
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([name]):
has_flattened_params = any([name])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
Expand Down
105 changes: 101 additions & 4 deletions google/cloud/firestore_admin_v1/services/firestore_admin/client.py
Expand Up @@ -140,6 +140,44 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):

from_service_account_json = from_service_account_file

@property
def transport(self) -> FirestoreAdminTransport:
"""Return the transport used by the client instance.
Returns:
FirestoreAdminTransport: The transport used by the client instance.
"""
return self._transport

@staticmethod
def collection_group_path(project: str, database: str, collection: str,) -> str:
"""Return a fully-qualified collection_group string."""
return "projects/{project}/databases/{database}/collectionGroups/{collection}".format(
project=project, database=database, collection=collection,
)

@staticmethod
def parse_collection_group_path(path: str) -> Dict[str, str]:
"""Parse a collection_group path into its component segments."""
m = re.match(
r"^projects/(?P<project>.+?)/databases/(?P<database>.+?)/collectionGroups/(?P<collection>.+?)$",
path,
)
return m.groupdict() if m else {}

@staticmethod
def database_path(project: str, database: str,) -> str:
"""Return a fully-qualified database string."""
return "projects/{project}/databases/{database}".format(
project=project, database=database,
)

@staticmethod
def parse_database_path(path: str) -> Dict[str, str]:
"""Parse a database path into its component segments."""
m = re.match(r"^projects/(?P<project>.+?)/databases/(?P<database>.+?)$", path)
return m.groupdict() if m else {}

@staticmethod
def field_path(project: str, database: str, collection: str, field: str,) -> str:
"""Return a fully-qualified field string."""
Expand Down Expand Up @@ -172,6 +210,65 @@ def parse_index_path(path: str) -> Dict[str, str]:
)
return m.groupdict() if m else {}

@staticmethod
def common_billing_account_path(billing_account: str,) -> str:
"""Return a fully-qualified billing_account string."""
return "billingAccounts/{billing_account}".format(
billing_account=billing_account,
)

@staticmethod
def parse_common_billing_account_path(path: str) -> Dict[str, str]:
"""Parse a billing_account path into its component segments."""
m = re.match(r"^billingAccounts/(?P<billing_account>.+?)$", path)
return m.groupdict() if m else {}

@staticmethod
def common_folder_path(folder: str,) -> str:
"""Return a fully-qualified folder string."""
return "folders/{folder}".format(folder=folder,)

@staticmethod
def parse_common_folder_path(path: str) -> Dict[str, str]:
"""Parse a folder path into its component segments."""
m = re.match(r"^folders/(?P<folder>.+?)$", path)
return m.groupdict() if m else {}

@staticmethod
def common_organization_path(organization: str,) -> str:
"""Return a fully-qualified organization string."""
return "organizations/{organization}".format(organization=organization,)

@staticmethod
def parse_common_organization_path(path: str) -> Dict[str, str]:
"""Parse a organization path into its component segments."""
m = re.match(r"^organizations/(?P<organization>.+?)$", path)
return m.groupdict() if m else {}

@staticmethod
def common_project_path(project: str,) -> str:
"""Return a fully-qualified project string."""
return "projects/{project}".format(project=project,)

@staticmethod
def parse_common_project_path(path: str) -> Dict[str, str]:
"""Parse a project path into its component segments."""
m = re.match(r"^projects/(?P<project>.+?)$", path)
return m.groupdict() if m else {}

@staticmethod
def common_location_path(project: str, location: str,) -> str:
"""Return a fully-qualified location string."""
return "projects/{project}/locations/{location}".format(
project=project, location=location,
)

@staticmethod
def parse_common_location_path(path: str) -> Dict[str, str]:
"""Parse a location path into its component segments."""
m = re.match(r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)$", path)
return m.groupdict() if m else {}

def __init__(
self,
*,
Expand Down Expand Up @@ -207,10 +304,10 @@ def __init__(
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
Raises:
Expand Down
Expand Up @@ -95,10 +95,10 @@ def __init__(
for grpc channel. It is ignored if ``channel`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
Raises:
Expand All @@ -107,13 +107,16 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
Expand Down Expand Up @@ -150,6 +153,7 @@ def __init__(
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"

Expand Down Expand Up @@ -227,12 +231,8 @@ def create_channel(

@property
def grpc_channel(self) -> grpc.Channel:
"""Create the channel designed to connect to this service.
This property caches on the instance; repeated calls return
the same channel.
"""Return the channel designed to connect to this service.
"""
# Return the channel from cache.
return self._grpc_channel

@property
Expand Down
Expand Up @@ -152,13 +152,16 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
Expand Down Expand Up @@ -195,6 +198,7 @@ def __init__(
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"

Expand Down

0 comments on commit 386e85e

Please sign in to comment.