From dcfbac267fbf66d189b0cc7e76f4712122a74b7b Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Wed, 9 Sep 2020 22:21:31 +0200 Subject: [PATCH] feat: add custom cell magic parser to handle complex `--params` values (#213) * chore: Move cell magic code into its own directory * Add custom argument parser for cell magic * Add AST node visitor * Use a custom parser for cell magic arguments * Improve cell magic parser test coverage * Generalize valid option values The parser should accept as wide a range of values as possible and let the code that delas with the semantics to decide whether the values are good or not. * Fix recognizing --params option in state 3 The --params option spec must be followed by a non-alphanumeric character, otherwise it's a different option spec (e.g. --paramsX). * Fix typo in comment * Cover missing parser code path with a test * Preserve the cell magic context's import path The context still needs to be importable from the old path * Clarify lexer states * Replace re.scanner with finditer() * Fix typo in docstring * Simplify string literal in a single line Apparently black just places all implicitly concatenated string literals in a single line when short enough without replacing them with a single string literal. * Explain the visitors module. * Pass pos as a positional arg to finditer() This is necessary to retain Python 2 compatibility. * Resolve coverage complaint about a code path The tokens are designed in a way that the scanner *always* returns some match, even if just UNKNOWN or EOL. The "no matches" code path can thus never be taken, but the coverage check can't know that. --- docs/magics.rst | 2 +- google/cloud/bigquery/__init__.py | 2 +- google/cloud/bigquery/magics/__init__.py | 20 + .../magics/line_arg_parser/__init__.py | 34 ++ .../magics/line_arg_parser/exceptions.py | 25 + .../bigquery/magics/line_arg_parser/lexer.py | 268 ++++++++++ .../bigquery/magics/line_arg_parser/parser.py | 484 ++++++++++++++++++ .../magics/line_arg_parser/visitors.py | 159 ++++++ google/cloud/bigquery/{ => magics}/magics.py | 70 ++- tests/unit/line_arg_parser/__init__.py | 13 + tests/unit/line_arg_parser/test_lexer.py | 32 ++ tests/unit/line_arg_parser/test_parser.py | 204 ++++++++ tests/unit/line_arg_parser/test_visitors.py | 34 ++ tests/unit/test_magics.py | 362 +++++++++++-- 14 files changed, 1644 insertions(+), 65 deletions(-) create mode 100644 google/cloud/bigquery/magics/__init__.py create mode 100644 google/cloud/bigquery/magics/line_arg_parser/__init__.py create mode 100644 google/cloud/bigquery/magics/line_arg_parser/exceptions.py create mode 100644 google/cloud/bigquery/magics/line_arg_parser/lexer.py create mode 100644 google/cloud/bigquery/magics/line_arg_parser/parser.py create mode 100644 google/cloud/bigquery/magics/line_arg_parser/visitors.py rename google/cloud/bigquery/{ => magics}/magics.py (91%) create mode 100644 tests/unit/line_arg_parser/__init__.py create mode 100644 tests/unit/line_arg_parser/test_lexer.py create mode 100644 tests/unit/line_arg_parser/test_parser.py create mode 100644 tests/unit/line_arg_parser/test_visitors.py diff --git a/docs/magics.rst b/docs/magics.rst index 732c27af9..bcaad8fa3 100644 --- a/docs/magics.rst +++ b/docs/magics.rst @@ -1,5 +1,5 @@ IPython Magics for BigQuery =========================== -.. automodule:: google.cloud.bigquery.magics +.. automodule:: google.cloud.bigquery.magics.magics :members: diff --git a/google/cloud/bigquery/__init__.py b/google/cloud/bigquery/__init__.py index 63d71694c..89c5a3624 100644 --- a/google/cloud/bigquery/__init__.py +++ b/google/cloud/bigquery/__init__.py @@ -150,7 +150,7 @@ def load_ipython_extension(ipython): """Called by IPython when this module is loaded as an IPython extension.""" - from google.cloud.bigquery.magics import _cell_magic + from google.cloud.bigquery.magics.magics import _cell_magic ipython.register_magic_function( _cell_magic, magic_kind="cell", magic_name="bigquery" diff --git a/google/cloud/bigquery/magics/__init__.py b/google/cloud/bigquery/magics/__init__.py new file mode 100644 index 000000000..d228a35bb --- /dev/null +++ b/google/cloud/bigquery/magics/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigquery.magics.magics import context + + +# For backwards compatibility we need to make the context available in the path +# google.cloud.bigquery.magics.context +__all__ = ("context",) diff --git a/google/cloud/bigquery/magics/line_arg_parser/__init__.py b/google/cloud/bigquery/magics/line_arg_parser/__init__.py new file mode 100644 index 000000000..9471446c5 --- /dev/null +++ b/google/cloud/bigquery/magics/line_arg_parser/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigquery.magics.line_arg_parser.exceptions import ParseError +from google.cloud.bigquery.magics.line_arg_parser.exceptions import ( + DuplicateQueryParamsError, + QueryParamsParseError, +) +from google.cloud.bigquery.magics.line_arg_parser.lexer import Lexer +from google.cloud.bigquery.magics.line_arg_parser.lexer import TokenType +from google.cloud.bigquery.magics.line_arg_parser.parser import Parser +from google.cloud.bigquery.magics.line_arg_parser.visitors import QueryParamsExtractor + + +__all__ = ( + "DuplicateQueryParamsError", + "Lexer", + "Parser", + "ParseError", + "QueryParamsExtractor", + "QueryParamsParseError", + "TokenType", +) diff --git a/google/cloud/bigquery/magics/line_arg_parser/exceptions.py b/google/cloud/bigquery/magics/line_arg_parser/exceptions.py new file mode 100644 index 000000000..6b2081186 --- /dev/null +++ b/google/cloud/bigquery/magics/line_arg_parser/exceptions.py @@ -0,0 +1,25 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ParseError(Exception): + pass + + +class QueryParamsParseError(ParseError): + """Raised when --params option is syntactically incorrect.""" + + +class DuplicateQueryParamsError(ParseError): + pass diff --git a/google/cloud/bigquery/magics/line_arg_parser/lexer.py b/google/cloud/bigquery/magics/line_arg_parser/lexer.py new file mode 100644 index 000000000..17e1ffdae --- /dev/null +++ b/google/cloud/bigquery/magics/line_arg_parser/lexer.py @@ -0,0 +1,268 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from collections import OrderedDict +import itertools +import re + +import enum + + +Token = namedtuple("Token", ("type_", "lexeme", "pos")) +StateTransition = namedtuple("StateTransition", ("new_state", "total_offset")) + +# Pattern matching is done with regexes, and the order in which the token patterns are +# defined is important. +# +# Suppose we had the following token definitions: +# * INT - a token matching integers, +# * FLOAT - a token matching floating point numbers, +# * DOT - a token matching a single literal dot character, i.e. "." +# +# The FLOAT token would have to be defined first, since we would want the input "1.23" +# to be tokenized as a single FLOAT token, and *not* three tokens (INT, DOT, INT). +# +# Sometimes, however, different tokens match too similar patterns, and it is not +# possible to define them in order that would avoid any ambiguity. One such case are +# the OPT_VAL and PY_NUMBER tokens, as both can match an integer literal, say "42". +# +# In order to avoid the dilemmas, the lexer implements a concept of STATES. States are +# used to split token definitions into subgroups, and in each lexer state only a single +# subgroup is used for tokenizing the input. Lexer states can therefore be though of as +# token namespaces. +# +# For example, while parsing the value of the "--params" option, we do not want to +# "recognize" it as a single OPT_VAL token, but instead want to parse it as a Python +# dictionary and verify its syntactial correctness. On the other hand, while parsing +# the value of an option other than "--params", we do not really care about its +# structure, and thus do not want to use any of the "Python tokens" for pattern matching. +# +# Since token definition order is important, an OrderedDict is needed with tightly +# controlled member definitions (i.e. passed as a sequence, and *not* via kwargs). +token_types = OrderedDict( + [ + ( + "state_parse_pos_args", + OrderedDict( + [ + ( + "GOTO_PARSE_NON_PARAMS_OPTIONS", + r"(?P(?=--))", # double dash - starting the options list + ), + ( + "DEST_VAR", + r"(?P[^\d\W]\w*)", # essentially a Python ID + ), + ] + ), + ), + ( + "state_parse_non_params_options", + OrderedDict( + [ + ( + "GOTO_PARSE_PARAMS_OPTION", + r"(?P(?=--params(?:\s|=|--|$)))", # the --params option + ), + ("OPTION_SPEC", r"(?P--\w+)"), + ("OPTION_EQ", r"(?P=)"), + ("OPT_VAL", r"(?P\S+?(?=\s|--|$))"), + ] + ), + ), + ( + "state_parse_params_option", + OrderedDict( + [ + ( + "PY_STRING", + r"(?P(?:{})|(?:{}))".format( + r"'(?:[^'\\]|\.)*'", + r'"(?:[^"\\]|\.)*"', # single and double quoted strings + ), + ), + ("PARAMS_OPT_SPEC", r"(?P--params(?=\s|=|--|$))"), + ("PARAMS_OPT_EQ", r"(?P=)"), + ( + "GOTO_PARSE_NON_PARAMS_OPTIONS", + r"(?P(?=--\w+))", # found another option spec + ), + ("PY_BOOL", r"(?PTrue|False)"), + ("DOLLAR_PY_ID", r"(?P\$[^\d\W]\w*)"), + ( + "PY_NUMBER", + r"(?P-?[1-9]\d*(?:\.\d+)?(:?[e|E][+-]?\d+)?)", + ), + ("SQUOTE", r"(?P')"), + ("DQUOTE", r'(?P")'), + ("COLON", r"(?P:)"), + ("COMMA", r"(?P,)"), + ("LCURL", r"(?P\{)"), + ("RCURL", r"(?P})"), + ("LSQUARE", r"(?P\[)"), + ("RSQUARE", r"(?P])"), + ("LPAREN", r"(?P\()"), + ("RPAREN", r"(?P\))"), + ] + ), + ), + ( + "common", + OrderedDict( + [ + ("WS", r"(?P\s+)"), + ("EOL", r"(?P$)"), + ( + # anything not a whitespace or matched by something else + "UNKNOWN", + r"(?P\S+)", + ), + ] + ), + ), + ] +) + + +# The _generate_next_value_() enum hook is only available in Python 3.6+, thus we +# need to do some acrobatics to implement an "auto str enum" base class. Implementation +# based on the recipe provided by the very author of the Enum library: +# https://stackoverflow.com/a/32313954/5040035 +class StrEnumMeta(enum.EnumMeta): + @classmethod + def __prepare__(metacls, name, bases, **kwargs): + # Having deterministic enum members definition order is nice. + return OrderedDict() + + def __new__(metacls, name, bases, oldclassdict): + # Scan through the declared enum members and convert any value that is a plain + # empty tuple into a `str` of the name instead. + newclassdict = enum._EnumDict() + for key, val in oldclassdict.items(): + if val == (): + val = key + newclassdict[key] = val + return super(StrEnumMeta, metacls).__new__(metacls, name, bases, newclassdict) + + +# The @six.add_metaclass decorator does not work, Enum complains about _sunder_ names, +# and we cannot use class syntax directly, because the Python 3 version would cause +# a syntax error under Python 2. +AutoStrEnum = StrEnumMeta( + "AutoStrEnum", + (str, enum.Enum), + {"__doc__": "Base enum class for for name=value str enums."}, +) + +TokenType = AutoStrEnum( + "TokenType", + [ + (name, name) + for name in itertools.chain.from_iterable(token_types.values()) + if not name.startswith("GOTO_") + ], +) + + +class LexerState(AutoStrEnum): + PARSE_POS_ARGS = () # parsing positional arguments + PARSE_NON_PARAMS_OPTIONS = () # parsing options other than "--params" + PARSE_PARAMS_OPTION = () # parsing the "--params" option + STATE_END = () + + +class Lexer(object): + """Lexical analyzer for tokenizing the cell magic input line.""" + + _GRAND_PATTERNS = { + LexerState.PARSE_POS_ARGS: re.compile( + "|".join( + itertools.chain( + token_types["state_parse_pos_args"].values(), + token_types["common"].values(), + ) + ) + ), + LexerState.PARSE_NON_PARAMS_OPTIONS: re.compile( + "|".join( + itertools.chain( + token_types["state_parse_non_params_options"].values(), + token_types["common"].values(), + ) + ) + ), + LexerState.PARSE_PARAMS_OPTION: re.compile( + "|".join( + itertools.chain( + token_types["state_parse_params_option"].values(), + token_types["common"].values(), + ) + ) + ), + } + + def __init__(self, input_text): + self._text = input_text + + def __iter__(self): + # Since re.scanner does not seem to support manipulating inner scanner states, + # we need to implement lexer state transitions manually using special + # non-capturing lookahead token patterns to signal when a state transition + # should be made. + # Since we don't have "nested" states, we don't really need a stack and + # this simple mechanism is sufficient. + state = LexerState.PARSE_POS_ARGS + offset = 0 # the number of characters processed so far + + while state != LexerState.STATE_END: + token_stream = self._find_state_tokens(state, offset) + + for maybe_token in token_stream: # pragma: NO COVER + if isinstance(maybe_token, StateTransition): + state = maybe_token.new_state + offset = maybe_token.total_offset + break + + if maybe_token.type_ != TokenType.WS: + yield maybe_token + + if maybe_token.type_ == TokenType.EOL: + state = LexerState.STATE_END + break + + def _find_state_tokens(self, state, current_offset): + """Scan the input for current state's tokens starting at ``current_offset``. + + Args: + state (LexerState): The current lexer state. + current_offset (int): The offset in the input text, i.e. the number + of characters already scanned so far. + + Yields: + The next ``Token`` or ``StateTransition`` instance. + """ + pattern = self._GRAND_PATTERNS[state] + scanner = pattern.finditer(self._text, current_offset) + + for match in scanner: # pragma: NO COVER + token_type = match.lastgroup + + if token_type.startswith("GOTO_"): + yield StateTransition( + new_state=getattr(LexerState, token_type[5:]), # w/o "GOTO_" prefix + total_offset=match.start(), + ) + + yield Token(token_type, match.group(), match.start()) diff --git a/google/cloud/bigquery/magics/line_arg_parser/parser.py b/google/cloud/bigquery/magics/line_arg_parser/parser.py new file mode 100644 index 000000000..b9da20cd7 --- /dev/null +++ b/google/cloud/bigquery/magics/line_arg_parser/parser.py @@ -0,0 +1,484 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigquery.magics.line_arg_parser import DuplicateQueryParamsError +from google.cloud.bigquery.magics.line_arg_parser import ParseError +from google.cloud.bigquery.magics.line_arg_parser import QueryParamsParseError +from google.cloud.bigquery.magics.line_arg_parser import TokenType + + +class ParseNode(object): + """A base class for nodes in the input parsed to an abstract syntax tree.""" + + +class InputLine(ParseNode): + def __init__(self, destination_var, option_list): + self.destination_var = destination_var + self.option_list = option_list + + +class DestinationVar(ParseNode): + def __init__(self, token): + # token type is DEST_VAR + self.token = token + self.name = token.lexeme if token is not None else None + + +class CmdOptionList(ParseNode): + def __init__(self, option_nodes): + self.options = [node for node in option_nodes] # shallow copy + + +class CmdOption(ParseNode): + def __init__(self, name, value): + self.name = name # string + self.value = value # CmdOptionValue node + + +class ParamsOption(CmdOption): + def __init__(self, value): + super(ParamsOption, self).__init__("params", value) + + +class CmdOptionValue(ParseNode): + def __init__(self, token): + # token type is OPT_VAL + self.token = token + self.value = token.lexeme + + +class PyVarExpansion(ParseNode): + def __init__(self, token): + self.token = token + self.raw_value = token.lexeme + + +class PyDict(ParseNode): + def __init__(self, dict_items): + self.items = [item for item in dict_items] # shallow copy + + +class PyDictItem(ParseNode): + def __init__(self, key, value): + self.key = key + self.value = value + + +class PyDictKey(ParseNode): + def __init__(self, token): + self.token = token + self.key_value = token.lexeme + + +class PyScalarValue(ParseNode): + def __init__(self, token, raw_value): + self.token = token + self.raw_value = raw_value + + +class PyTuple(ParseNode): + def __init__(self, tuple_items): + self.items = [item for item in tuple_items] # shallow copy + + +class PyList(ParseNode): + def __init__(self, list_items): + self.items = [item for item in list_items] # shallow copy + + +class Parser(object): + """Parser for the tokenized cell magic input line. + + The parser recognizes a simplified subset of Python grammar, specifically + a dictionary representation in typical use cases when the "--params" option + is used with the %%bigquery cell magic. + + The grammar (terminal symbols are CAPITALIZED): + + input_line : destination_var option_list + destination_var : DEST_VAR | EMPTY + option_list : (OPTION_SPEC [OPTION_EQ] option_value)* + (params_option | EMPTY) + (OPTION_SPEC [OPTION_EQ] option_value)* + + option_value : OPT_VAL | EMPTY + + # DOLLAR_PY_ID can occur if a variable passed to --params does not exist + # and is thus not expanded to a dict. + params_option : PARAMS_OPT_SPEC [PARAMS_OPT_EQ] \ + (DOLLAR_PY_ID | PY_STRING | py_dict) + + py_dict : LCURL dict_items RCURL + dict_items : dict_item | (dict_item COMMA dict_items) + dict_item : (dict_key COLON py_value) | EMPTY + + # dict items are actually @parameter names in the cell body (i.e. the query), + # thus restricting them to strings. + dict_key : PY_STRING + + py_value : PY_BOOL + | PY_NUMBER + | PY_STRING + | py_tuple + | py_list + | py_dict + + py_tuple : LPAREN collection_items RPAREN + py_list : LSQUARE collection_items RSQUARE + collection_items : collection_item | (collection_item COMMA collection_items) + collection_item : py_value | EMPTY + + Args: + lexer (line_arg_parser.lexer.Lexer): + An iterable producing a tokenized cell magic argument line. + """ + + def __init__(self, lexer): + self._lexer = lexer + self._tokens_iter = iter(self._lexer) + self.get_next_token() + + def get_next_token(self): + """Obtain the next token from the token stream and store it as current.""" + token = next(self._tokens_iter) + self._current_token = token + + def consume(self, expected_type, exc_type=ParseError): + """Move to the next token in token stream if it matches the expected type. + + Args: + expected_type (lexer.TokenType): The expected token type to be consumed. + exc_type (Optional[ParseError]): The type of the exception to raise. Should be + the ``ParseError`` class or one of its subclasses. Defaults to + ``ParseError``. + + Raises: + ParseError: If the current token does not match the expected type. + """ + if self._current_token.type_ == expected_type: + if expected_type != TokenType.EOL: + self.get_next_token() + else: + if self._current_token.type_ == TokenType.EOL: + msg = "Unexpected end of input, expected {}.".format(expected_type) + else: + msg = "Expected token type {}, but found {} at position {}.".format( + expected_type, self._current_token.lexeme, self._current_token.pos + ) + self.error(message=msg, exc_type=exc_type) + + def error(self, message="Syntax error.", exc_type=ParseError): + """Raise an error with the given message. + + Args: + expected_type (lexer.TokenType): The expected token type to be consumed. + exc_type (Optional[ParseError]): The type of the exception to raise. Should be + the ``ParseError`` class or one of its subclasses. Defaults to + ``ParseError``. + + Raises: + ParseError: If the current token does not match the expected type. + """ + raise exc_type(message) + + def input_line(self): + """The top level method for parsing the cell magic arguments line. + + Implements the following grammar production rule: + + input_line : destination_var option_list + """ + dest_var = self.destination_var() + options = self.option_list() + + token = self._current_token + + if token.type_ != TokenType.EOL: + msg = "Unexpected input at position {}: {}".format(token.pos, token.lexeme) + self.error(msg) + + return InputLine(dest_var, options) + + def destination_var(self): + """Implementation of the ``destination_var`` grammar production rule. + + Production: + + destination_var : DEST_VAR | EMPTY + """ + token = self._current_token + + if token.type_ == TokenType.DEST_VAR: + self.consume(TokenType.DEST_VAR) + result = DestinationVar(token) + elif token.type_ == TokenType.UNKNOWN: + msg = "Unknown input at position {}: {}".format(token.pos, token.lexeme) + self.error(msg) + else: + result = DestinationVar(None) + + return result + + def option_list(self): + """Implementation of the ``option_list`` grammar production rule. + + Production: + + option_list : (OPTION_SPEC [OPTION_EQ] option_value)* + (params_option | EMPTY) + (OPTION_SPEC [OPTION_EQ] option_value)* + """ + all_options = [] + + def parse_nonparams_options(): + while self._current_token.type_ == TokenType.OPTION_SPEC: + token = self._current_token + self.consume(TokenType.OPTION_SPEC) + + opt_name = token.lexeme[2:] # cut off the "--" prefix + + # skip the optional "=" character + if self._current_token.type_ == TokenType.OPTION_EQ: + self.consume(TokenType.OPTION_EQ) + + opt_value = self.option_value() + option = CmdOption(opt_name, opt_value) + all_options.append(option) + + parse_nonparams_options() + + token = self._current_token + + if token.type_ == TokenType.PARAMS_OPT_SPEC: + option = self.params_option() + all_options.append(option) + + parse_nonparams_options() + + if self._current_token.type_ == TokenType.PARAMS_OPT_SPEC: + self.error( + message="Duplicate --params option", exc_type=DuplicateQueryParamsError + ) + + return CmdOptionList(all_options) + + def option_value(self): + """Implementation of the ``option_value`` grammar production rule. + + Production: + + option_value : OPT_VAL | EMPTY + """ + token = self._current_token + + if token.type_ == TokenType.OPT_VAL: + self.consume(TokenType.OPT_VAL) + result = CmdOptionValue(token) + elif token.type_ == TokenType.UNKNOWN: + msg = "Unknown input at position {}: {}".format(token.pos, token.lexeme) + self.error(msg) + else: + result = None + + return result + + def params_option(self): + """Implementation of the ``params_option`` grammar production rule. + + Production: + + params_option : PARAMS_OPT_SPEC [PARAMS_OPT_EQ] \ + (DOLLAR_PY_ID | PY_STRING | py_dict) + """ + self.consume(TokenType.PARAMS_OPT_SPEC) + + # skip the optional "=" character + if self._current_token.type_ == TokenType.PARAMS_OPT_EQ: + self.consume(TokenType.PARAMS_OPT_EQ) + + if self._current_token.type_ == TokenType.DOLLAR_PY_ID: + token = self._current_token + self.consume(TokenType.DOLLAR_PY_ID) + opt_value = PyVarExpansion(token) + elif self._current_token.type_ == TokenType.PY_STRING: + token = self._current_token + self.consume(TokenType.PY_STRING, exc_type=QueryParamsParseError) + opt_value = PyScalarValue(token, token.lexeme) + else: + opt_value = self.py_dict() + + result = ParamsOption(opt_value) + + return result + + def py_dict(self): + """Implementation of the ``py_dict`` grammar production rule. + + Production: + + py_dict : LCURL dict_items RCURL + """ + self.consume(TokenType.LCURL, exc_type=QueryParamsParseError) + dict_items = self.dict_items() + self.consume(TokenType.RCURL, exc_type=QueryParamsParseError) + + return PyDict(dict_items) + + def dict_items(self): + """Implementation of the ``dict_items`` grammar production rule. + + Production: + + dict_items : dict_item | (dict_item COMMA dict_items) + """ + result = [] + + item = self.dict_item() + if item is not None: + result.append(item) + + while self._current_token.type_ == TokenType.COMMA: + self.consume(TokenType.COMMA, exc_type=QueryParamsParseError) + item = self.dict_item() + if item is not None: + result.append(item) + + return result + + def dict_item(self): + """Implementation of the ``dict_item`` grammar production rule. + + Production: + + dict_item : (dict_key COLON py_value) | EMPTY + """ + token = self._current_token + + if token.type_ == TokenType.PY_STRING: + key = self.dict_key() + self.consume(TokenType.COLON, exc_type=QueryParamsParseError) + value = self.py_value() + result = PyDictItem(key, value) + elif token.type_ == TokenType.UNKNOWN: + msg = "Unknown input at position {}: {}".format(token.pos, token.lexeme) + self.error(msg, exc_type=QueryParamsParseError) + else: + result = None + + return result + + def dict_key(self): + """Implementation of the ``dict_key`` grammar production rule. + + Production: + + dict_key : PY_STRING + """ + token = self._current_token + self.consume(TokenType.PY_STRING, exc_type=QueryParamsParseError) + return PyDictKey(token) + + def py_value(self): + """Implementation of the ``py_value`` grammar production rule. + + Production: + + py_value : PY_BOOL | PY_NUMBER | PY_STRING | py_tuple | py_list | py_dict + """ + token = self._current_token + + if token.type_ == TokenType.PY_BOOL: + self.consume(TokenType.PY_BOOL, exc_type=QueryParamsParseError) + return PyScalarValue(token, token.lexeme) + elif token.type_ == TokenType.PY_NUMBER: + self.consume(TokenType.PY_NUMBER, exc_type=QueryParamsParseError) + return PyScalarValue(token, token.lexeme) + elif token.type_ == TokenType.PY_STRING: + self.consume(TokenType.PY_STRING, exc_type=QueryParamsParseError) + return PyScalarValue(token, token.lexeme) + elif token.type_ == TokenType.LPAREN: + tuple_node = self.py_tuple() + return tuple_node + elif token.type_ == TokenType.LSQUARE: + list_node = self.py_list() + return list_node + elif token.type_ == TokenType.LCURL: + dict_node = self.py_dict() + return dict_node + else: + msg = "Unexpected token type {} at position {}.".format( + token.type_, token.pos + ) + self.error(msg, exc_type=QueryParamsParseError) + + def py_tuple(self): + """Implementation of the ``py_tuple`` grammar production rule. + + Production: + + py_tuple : LPAREN collection_items RPAREN + """ + self.consume(TokenType.LPAREN, exc_type=QueryParamsParseError) + items = self.collection_items() + self.consume(TokenType.RPAREN, exc_type=QueryParamsParseError) + + return PyTuple(items) + + def py_list(self): + """Implementation of the ``py_list`` grammar production rule. + + Production: + + py_list : LSQUARE collection_items RSQUARE + """ + self.consume(TokenType.LSQUARE, exc_type=QueryParamsParseError) + items = self.collection_items() + self.consume(TokenType.RSQUARE, exc_type=QueryParamsParseError) + + return PyList(items) + + def collection_items(self): + """Implementation of the ``collection_items`` grammar production rule. + + Production: + + collection_items : collection_item | (collection_item COMMA collection_items) + """ + result = [] + + item = self.collection_item() + if item is not None: + result.append(item) + + while self._current_token.type_ == TokenType.COMMA: + self.consume(TokenType.COMMA, exc_type=QueryParamsParseError) + item = self.collection_item() + if item is not None: + result.append(item) + + return result + + def collection_item(self): + """Implementation of the ``collection_item`` grammar production rule. + + Production: + + collection_item : py_value | EMPTY + """ + if self._current_token.type_ not in {TokenType.RPAREN, TokenType.RSQUARE}: + result = self.py_value() + else: + result = None # end of list/tuple items + + return result diff --git a/google/cloud/bigquery/magics/line_arg_parser/visitors.py b/google/cloud/bigquery/magics/line_arg_parser/visitors.py new file mode 100644 index 000000000..cbe236c06 --- /dev/null +++ b/google/cloud/bigquery/magics/line_arg_parser/visitors.py @@ -0,0 +1,159 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains classes that traverse AST and convert it to something else. + +If the parser successfully accepts a valid input (the bigquery cell magic arguments), +the result is an Abstract Syntax Tree (AST) that represents the input as a tree +with notes containing various useful metadata. + +Node visitors can process such tree and convert it to something else that can +be used for further processing, for example: + + * An optimized version of the tree with redundancy removed/simplified (not used here). + * The same tree, but with semantic errors checked, because an otherwise syntactically + valid input might still contain errors (not used here, semantic errors are detected + elsewhere). + * A form that can be directly handed to the code that operates on the input. The + ``QueryParamsExtractor`` class, for instance, splits the input arguments into + the "--params <...>" part and everything else. + The "everything else" part can be then parsed by the default Jupyter argument parser, + while the --params option is processed separately by the Python evaluator. + +More info on the visitor design pattern: +https://en.wikipedia.org/wiki/Visitor_pattern + +""" + +from __future__ import print_function + + +class NodeVisitor(object): + """Base visitor class implementing the dispatch machinery.""" + + def visit(self, node): + method_name = "visit_{}".format(type(node).__name__) + visitor_method = getattr(self, method_name, self.method_missing) + return visitor_method(node) + + def method_missing(self, node): + raise Exception("No visit_{} method".format(type(node).__name__)) + + +class QueryParamsExtractor(NodeVisitor): + """A visitor that extracts the "--params <...>" part from input line arguments.""" + + def visit_InputLine(self, node): + params_dict_parts = [] + other_parts = [] + + dest_var_parts = self.visit(node.destination_var) + params, other_options = self.visit(node.option_list) + + if dest_var_parts: + other_parts.extend(dest_var_parts) + + if dest_var_parts and other_options: + other_parts.append(" ") + other_parts.extend(other_options) + + params_dict_parts.extend(params) + + return "".join(params_dict_parts), "".join(other_parts) + + def visit_DestinationVar(self, node): + return [node.name] if node.name is not None else [] + + def visit_CmdOptionList(self, node): + params_opt_parts = [] + other_parts = [] + + for i, opt in enumerate(node.options): + option_parts = self.visit(opt) + list_to_extend = params_opt_parts if opt.name == "params" else other_parts + + if list_to_extend: + list_to_extend.append(" ") + list_to_extend.extend(option_parts) + + return params_opt_parts, other_parts + + def visit_CmdOption(self, node): + result = ["--{}".format(node.name)] + + if node.value is not None: + result.append(" ") + value_parts = self.visit(node.value) + result.extend(value_parts) + + return result + + def visit_CmdOptionValue(self, node): + return [node.value] + + def visit_ParamsOption(self, node): + value_parts = self.visit(node.value) + return value_parts + + def visit_PyVarExpansion(self, node): + return [node.raw_value] + + def visit_PyDict(self, node): + result = ["{"] + + for i, item in enumerate(node.items): + if i > 0: + result.append(", ") + item_parts = self.visit(item) + result.extend(item_parts) + + result.append("}") + return result + + def visit_PyDictItem(self, node): + result = self.visit(node.key) # key parts + result.append(": ") + value_parts = self.visit(node.value) + result.extend(value_parts) + return result + + def visit_PyDictKey(self, node): + return [node.key_value] + + def visit_PyScalarValue(self, node): + return [node.raw_value] + + def visit_PyTuple(self, node): + result = ["("] + + for i, item in enumerate(node.items): + if i > 0: + result.append(", ") + item_parts = self.visit(item) + result.extend(item_parts) + + result.append(")") + return result + + def visit_PyList(self, node): + result = ["["] + + for i, item in enumerate(node.items): + if i > 0: + result.append(", ") + item_parts = self.visit(item) + result.extend(item_parts) + + result.append("]") + return result diff --git a/google/cloud/bigquery/magics.py b/google/cloud/bigquery/magics/magics.py similarity index 91% rename from google/cloud/bigquery/magics.py rename to google/cloud/bigquery/magics/magics.py index 7128e32bf..4842c7680 100644 --- a/google/cloud/bigquery/magics.py +++ b/google/cloud/bigquery/magics/magics.py @@ -65,13 +65,6 @@ the variable name (ex. ``$my_dict_var``). See ``In[6]`` and ``In[7]`` in the Examples section below. - .. note:: - - Due to the way IPython argument parser works, negative numbers in - dictionaries are incorrectly "recognized" as additional arguments, - resulting in an error ("unrecognized arguments"). To get around this, - pass such dictionary as a JSON string variable. - * ```` (required, cell argument): SQL query to run. If the query does not contain any whitespace (aside from leading and trailing whitespace), it is assumed to represent a @@ -159,13 +152,15 @@ except ImportError: # pragma: NO COVER raise ImportError("This module can only be loaded in IPython.") +import six + from google.api_core import client_info from google.api_core.exceptions import NotFound import google.auth from google.cloud import bigquery import google.cloud.bigquery.dataset from google.cloud.bigquery.dbapi import _helpers -import six +from google.cloud.bigquery.magics import line_arg_parser as lap IPYTHON_USER_AGENT = "ipython-{}".format(IPython.__version__) @@ -473,7 +468,27 @@ def _cell_magic(line, query): Returns: pandas.DataFrame: the query results. """ - args = magic_arguments.parse_argstring(_cell_magic, line) + # The built-in parser does not recognize Python structures such as dicts, thus + # we extract the "--params" option and inteprpret it separately. + try: + params_option_value, rest_of_args = _split_args_line(line) + except lap.exceptions.QueryParamsParseError as exc: + rebranded_error = SyntaxError( + "--params is not a correctly formatted JSON string or a JSON " + "serializable dictionary" + ) + six.raise_from(rebranded_error, exc) + except lap.exceptions.DuplicateQueryParamsError as exc: + rebranded_error = ValueError("Duplicate --params option.") + six.raise_from(rebranded_error, exc) + except lap.exceptions.ParseError as exc: + rebranded_error = ValueError( + "Unrecognized input, are option values correct? " + "Error details: {}".format(exc.args[0]) + ) + six.raise_from(rebranded_error, exc) + + args = magic_arguments.parse_argstring(_cell_magic, rest_of_args) if args.use_bqstorage_api is not None: warnings.warn( @@ -484,16 +499,16 @@ def _cell_magic(line, query): use_bqstorage_api = not args.use_rest_api params = [] - if args.params is not None: - try: - params = _helpers.to_query_parameters( - ast.literal_eval("".join(args.params)) - ) - except Exception: - raise SyntaxError( - "--params is not a correctly formatted JSON string or a JSON " - "serializable dictionary" + if params_option_value: + # A non-existing params variable is not expanded and ends up in the input + # in its raw form, e.g. "$query_params". + if params_option_value.startswith("$"): + msg = 'Parameter expansion failed, undefined variable "{}".'.format( + params_option_value[1:] ) + raise NameError(msg) + + params = _helpers.to_query_parameters(ast.literal_eval(params_option_value)) project = args.project or context.project client = bigquery.Client( @@ -598,6 +613,25 @@ def _cell_magic(line, query): close_transports() +def _split_args_line(line): + """Split out the --params option value from the input line arguments. + + Args: + line (str): The line arguments passed to the cell magic. + + Returns: + Tuple[str, str] + """ + lexer = lap.Lexer(line) + scanner = lap.Parser(lexer) + tree = scanner.input_line() + + extractor = lap.QueryParamsExtractor() + params_option_value, rest_of_args = extractor.visit(tree) + + return params_option_value, rest_of_args + + def _make_bqstorage_client(use_bqstorage_api, credentials): if not use_bqstorage_api: return None diff --git a/tests/unit/line_arg_parser/__init__.py b/tests/unit/line_arg_parser/__init__.py new file mode 100644 index 000000000..c6334245a --- /dev/null +++ b/tests/unit/line_arg_parser/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/line_arg_parser/test_lexer.py b/tests/unit/line_arg_parser/test_lexer.py new file mode 100644 index 000000000..22fa96f22 --- /dev/null +++ b/tests/unit/line_arg_parser/test_lexer.py @@ -0,0 +1,32 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(scope="session") +def lexer_class(): + from google.cloud.bigquery.magics.line_arg_parser.lexer import Lexer + + return Lexer + + +def test_empy_input(lexer_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + lexer = lexer_class("") + tokens = list(lexer) + + assert tokens == [Token(TokenType.EOL, lexeme="", pos=0)] diff --git a/tests/unit/line_arg_parser/test_parser.py b/tests/unit/line_arg_parser/test_parser.py new file mode 100644 index 000000000..3edff88e9 --- /dev/null +++ b/tests/unit/line_arg_parser/test_parser.py @@ -0,0 +1,204 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(scope="session") +def parser_class(): + from google.cloud.bigquery.magics.line_arg_parser.parser import Parser + + return Parser + + +def test_consume_expected_eol(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.EOL, lexeme="", pos=0)] + parser = parser_class(fake_lexer) + + parser.consume(TokenType.EOL) # no error + + +def test_consume_unexpected_eol(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.EOL, lexeme="", pos=0)] + parser = parser_class(fake_lexer) + + with pytest.raises(ParseError, match=r"Unexpected end of input.*expected COLON.*"): + parser.consume(TokenType.COLON) + + +def test_input_line_unexpected_input(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [ + Token(TokenType.DEST_VAR, lexeme="results", pos=0), + Token(TokenType.UNKNOWN, lexeme="boo!", pos=8), + Token(TokenType.EOL, lexeme="", pos=12), + ] + parser = parser_class(fake_lexer) + + with pytest.raises(ParseError, match=r"Unexpected input.*position 8.*boo!.*"): + parser.input_line() + + +def test_destination_var_unexpected_input(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [ + Token(TokenType.UNKNOWN, lexeme="@!#", pos=2), + Token(TokenType.EOL, lexeme="", pos=5), + ] + parser = parser_class(fake_lexer) + + with pytest.raises(ParseError, match=r"Unknown.*position 2.*@!#.*"): + parser.destination_var() + + +def test_option_value_unexpected_input(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [ + Token(TokenType.UNKNOWN, lexeme="@!#", pos=8), + Token(TokenType.OPTION_SPEC, lexeme="--foo", pos=13), + ] + parser = parser_class(fake_lexer) + + with pytest.raises(ParseError, match=r"Unknown input.*position 8.*@!#.*"): + parser.option_value() + + +def test_dict_items_empty_dict(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.RCURL, lexeme="}", pos=22)] + parser = parser_class(fake_lexer) + + result = parser.dict_items() + + assert result == [] + + +def test_dict_items_trailing_comma(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [ + Token(TokenType.PY_STRING, lexeme="'age'", pos=10), + Token(TokenType.COLON, lexeme=":", pos=17), + Token(TokenType.PY_NUMBER, lexeme="18", pos=19), + Token(TokenType.COMMA, lexeme=",", pos=21), + Token(TokenType.RCURL, lexeme="}", pos=22), + ] + parser = parser_class(fake_lexer) + + result = parser.dict_items() + + assert len(result) == 1 + dict_item = result[0] + assert dict_item.key.key_value == "'age'" + assert dict_item.value.raw_value == "18" + + +def test_dict_item_unknown_input(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.UNKNOWN, lexeme="#/%", pos=35)] + parser = parser_class(fake_lexer) + + with pytest.raises(ParseError, match=r"Unknown.*position 35.*#/%.*"): + parser.dict_item() + + +def test_pyvalue_list_containing_dict(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + from google.cloud.bigquery.magics.line_arg_parser.parser import PyDict + from google.cloud.bigquery.magics.line_arg_parser.parser import PyList + + # A simple iterable of Tokens is sufficient. + fake_lexer = [ + Token(TokenType.LSQUARE, lexeme="[", pos=21), + Token(TokenType.LCURL, lexeme="{", pos=22), + Token(TokenType.PY_STRING, lexeme="'age'", pos=23), + Token(TokenType.COLON, lexeme=":", pos=28), + Token(TokenType.PY_NUMBER, lexeme="18", pos=30), + Token(TokenType.RCURL, lexeme="}", pos=32), + Token(TokenType.COMMA, lexeme=",", pos=33), # trailing comma + Token(TokenType.RSQUARE, lexeme="]", pos=34), + Token(TokenType.EOL, lexeme="", pos=40), + ] + parser = parser_class(fake_lexer) + + result = parser.py_value() + + assert isinstance(result, PyList) + assert len(result.items) == 1 + + element = result.items[0] + assert isinstance(element, PyDict) + assert len(element.items) == 1 + + dict_item = element.items[0] + assert dict_item.key.key_value == "'age'" + assert dict_item.value.raw_value == "18" + + +def test_pyvalue_invalid_token(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import ParseError + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.OPTION_SPEC, lexeme="--verbose", pos=75)] + parser = parser_class(fake_lexer) + + error_pattern = r"Unexpected token.*OPTION_SPEC.*position 75.*" + with pytest.raises(ParseError, match=error_pattern): + parser.py_value() + + +def test_collection_items_empty(parser_class): + from google.cloud.bigquery.magics.line_arg_parser import TokenType + from google.cloud.bigquery.magics.line_arg_parser.lexer import Token + + # A simple iterable of Tokens is sufficient. + fake_lexer = [Token(TokenType.RPAREN, lexeme=")", pos=30)] + parser = parser_class(fake_lexer) + + result = parser.collection_items() + + assert result == [] diff --git a/tests/unit/line_arg_parser/test_visitors.py b/tests/unit/line_arg_parser/test_visitors.py new file mode 100644 index 000000000..51d4f837a --- /dev/null +++ b/tests/unit/line_arg_parser/test_visitors.py @@ -0,0 +1,34 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture +def base_visitor(): + from google.cloud.bigquery.magics.line_arg_parser.visitors import NodeVisitor + + return NodeVisitor() + + +def test_unknown_node(base_visitor): + from google.cloud.bigquery.magics.line_arg_parser.parser import ParseNode + + class UnknownNode(ParseNode): + pass + + node = UnknownNode() + + with pytest.raises(Exception, match=r"No visit_UnknownNode method"): + base_visitor.visit(node) diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index 7b07626ad..73e44f311 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -43,7 +43,7 @@ from google.cloud import bigquery from google.cloud.bigquery import job from google.cloud.bigquery import table -from google.cloud.bigquery import magics +from google.cloud.bigquery.magics import magics from tests.unit.helpers import make_connection from test_utils.imports import maybe_fail_import @@ -69,6 +69,21 @@ def ipython_interactive(request, ipython): yield ipython +@pytest.fixture() +def ipython_ns_cleanup(): + """A helper to clean up user namespace after the test + + for the duration of the test scope. + """ + names_to_clean = [] # pairs (IPython_instance, name_to_clean) + + yield names_to_clean + + for ip, name in names_to_clean: + if name in ip.user_ns: + del ip.user_ns[name] + + @pytest.fixture(scope="session") def missing_bq_storage(): """Provide a patcher that can make the bigquery storage import to fail.""" @@ -256,7 +271,7 @@ def test__run_query(): ] client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) with client_patch as client_mock, io.capture_output() as captured: client_mock().query(sql).result.side_effect = responses @@ -284,7 +299,7 @@ def test__run_query_dry_run_without_errors_is_silent(): sql = "SELECT 17" client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) job_config = job.QueryJobConfig() @@ -350,7 +365,7 @@ def test__create_dataset_if_necessary_exists(): dataset_reference = bigquery.dataset.DatasetReference(project, dataset_id) dataset = bigquery.Dataset(dataset_reference) client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) with client_patch as client_mock: client = client_mock() @@ -364,7 +379,7 @@ def test__create_dataset_if_necessary_not_exist(): project = "project_id" dataset_id = "dataset_id" client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) with client_patch as client_mock: client = client_mock() @@ -382,7 +397,7 @@ def test_extension_load(): # verify that the magic is registered and has the correct source magic = ip.magics_manager.magics["cell"].get("bigquery") - assert magic.__module__ == "google.cloud.bigquery.magics" + assert magic.__module__ == "google.cloud.bigquery.magics.magics" @pytest.mark.usefixtures("ipython_interactive") @@ -415,7 +430,7 @@ def test_bigquery_magic_without_optional_arguments(monkeypatch): sql = "SELECT 17 AS num" result = pandas.DataFrame([17], columns=["num"]) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True @@ -445,7 +460,7 @@ def test_bigquery_magic_default_connection_user_agent(): "google.auth.default", return_value=(credentials_mock, "general-project") ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) conn_patch = mock.patch("google.cloud.bigquery.client.Connection", autospec=True) @@ -466,7 +481,7 @@ def test_bigquery_magic_with_legacy_sql(): ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) with run_query_patch as run_query_mock: ip.run_cell_magic("bigquery", "--use_legacy_sql", "SELECT 17 AS num") @@ -477,19 +492,21 @@ def test_bigquery_magic_with_legacy_sql(): @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_bigquery_magic_with_result_saved_to_variable(): +def test_bigquery_magic_with_result_saved_to_variable(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) + ipython_ns_cleanup.append((ip, "df")) + sql = "SELECT 17 AS num" result = pandas.DataFrame([17], columns=["num"]) assert "df" not in ip.user_ns run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True @@ -516,10 +533,10 @@ def test_bigquery_magic_does_not_clear_display_in_verbose_mode(): ) clear_patch = mock.patch( - "google.cloud.bigquery.magics.display.clear_output", autospec=True + "google.cloud.bigquery.magics.magics.display.clear_output", autospec=True, ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) with clear_patch as clear_mock, run_query_patch: ip.run_cell_magic("bigquery", "--verbose", "SELECT 17 as num") @@ -536,10 +553,10 @@ def test_bigquery_magic_clears_display_in_verbose_mode(): ) clear_patch = mock.patch( - "google.cloud.bigquery.magics.display.clear_output", autospec=True + "google.cloud.bigquery.magics.magics.display.clear_output", autospec=True, ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) with clear_patch as clear_mock, run_query_patch: ip.run_cell_magic("bigquery", "", "SELECT 17 as num") @@ -576,7 +593,7 @@ def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch): sql = "SELECT 17 AS num" result = pandas.DataFrame([17], columns=["num"]) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True @@ -635,7 +652,7 @@ def test_bigquery_magic_with_rest_client_requested(monkeypatch): sql = "SELECT 17 AS num" result = pandas.DataFrame([17], columns=["num"]) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True @@ -719,7 +736,7 @@ def test_bigquery_magic_w_max_results_query_job_results_fails(): "google.cloud.bigquery.client.Client.query", autospec=True ) close_transports_patch = mock.patch( - "google.cloud.bigquery.magics._close_transports", autospec=True, + "google.cloud.bigquery.magics.magics._close_transports", autospec=True, ) sql = "SELECT 17 AS num" @@ -751,7 +768,7 @@ def test_bigquery_magic_w_table_id_invalid(): ) list_rows_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client.list_rows", + "google.cloud.bigquery.magics.magics.bigquery.Client.list_rows", autospec=True, side_effect=exceptions.BadRequest("Not a valid table ID"), ) @@ -792,11 +809,13 @@ def test_bigquery_magic_w_missing_query(): @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_bigquery_magic_w_table_id_and_destination_var(): +def test_bigquery_magic_w_table_id_and_destination_var(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context._project = None + ipython_ns_cleanup.append((ip, "df")) + credentials_mock = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) @@ -809,7 +828,7 @@ def test_bigquery_magic_w_table_id_and_destination_var(): ) client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) table_id = "bigquery-public-data.samples.shakespeare" @@ -849,7 +868,7 @@ def test_bigquery_magic_w_table_id_and_bqstorage_client(): ) client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) bqstorage_mock = mock.create_autospec(bigquery_storage_v1.BigQueryReadClient) @@ -882,7 +901,7 @@ def test_bigquery_magic_dryrun_option_sets_job_config(): ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) sql = "SELECT 17 AS num" @@ -905,7 +924,7 @@ def test_bigquery_magic_dryrun_option_returns_query_job(): google.cloud.bigquery.job.QueryJob, instance=True ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) sql = "SELECT 17 AS num" @@ -919,15 +938,17 @@ def test_bigquery_magic_dryrun_option_returns_query_job(): @pytest.mark.usefixtures("ipython_interactive") -def test_bigquery_magic_dryrun_option_variable_error_message(): +def test_bigquery_magic_dryrun_option_variable_error_message(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) + ipython_ns_cleanup.append((ip, "q_job")) + run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", + "google.cloud.bigquery.magics.magics._run_query", autospec=True, side_effect=exceptions.BadRequest("Syntax error in SQL query"), ) @@ -944,7 +965,7 @@ def test_bigquery_magic_dryrun_option_variable_error_message(): @pytest.mark.usefixtures("ipython_interactive") -def test_bigquery_magic_dryrun_option_saves_query_job_to_variable(): +def test_bigquery_magic_dryrun_option_saves_query_job_to_variable(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( @@ -954,9 +975,11 @@ def test_bigquery_magic_dryrun_option_saves_query_job_to_variable(): google.cloud.bigquery.job.QueryJob, instance=True ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) + ipython_ns_cleanup.append((ip, "q_job")) + sql = "SELECT 17 AS num" assert "q_job" not in ip.user_ns @@ -972,13 +995,15 @@ def test_bigquery_magic_dryrun_option_saves_query_job_to_variable(): @pytest.mark.usefixtures("ipython_interactive") -def test_bigquery_magic_saves_query_job_to_variable_on_error(): +def test_bigquery_magic_saves_query_job_to_variable_on_error(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) + ipython_ns_cleanup.append((ip, "result")) + client_query_patch = mock.patch( "google.cloud.bigquery.client.Client.query", autospec=True ) @@ -1151,7 +1176,7 @@ def test_bigquery_magic_with_project(): "google.auth.default", return_value=(credentials_mock, "general-project") ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) with run_query_patch as run_query_mock, default_patch: ip.run_cell_magic("bigquery", "--project=specific-project", "SELECT 17 as num") @@ -1162,30 +1187,65 @@ def test_bigquery_magic_with_project(): assert magics.context.project == "general-project" +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_with_multiple_options(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + with run_query_patch as run_query_mock, default_patch: + ip.run_cell_magic( + "bigquery", + "--project=specific-project --use_legacy_sql --maximum_bytes_billed 1024", + "SELECT 17 as num", + ) + + args, kwargs = run_query_mock.call_args + client_used = args[0] + assert client_used.project == "specific-project" + + job_config_used = kwargs["job_config"] + assert job_config_used.use_legacy_sql + assert job_config_used.maximum_bytes_billed == 1024 + + @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_bigquery_magic_with_string_params(): +def test_bigquery_magic_with_string_params(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) + ipython_ns_cleanup.append((ip, "params_dict_df")) + sql = "SELECT @num AS num" result = pandas.DataFrame([17], columns=["num"]) - assert "params_string_df" not in ip.user_ns + + assert "params_dict_df" not in ip.user_ns run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True ) query_job_mock.to_dataframe.return_value = result + with run_query_patch as run_query_mock: run_query_mock.return_value = query_job_mock - ip.run_cell_magic("bigquery", 'params_string_df --params {"num":17}', sql) + ip.run_cell_magic("bigquery", "params_string_df --params='{\"num\":17}'", sql) run_query_mock.assert_called_once_with(mock.ANY, sql.format(num=17), mock.ANY) @@ -1197,19 +1257,24 @@ def test_bigquery_magic_with_string_params(): @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_bigquery_magic_with_dict_params(): +def test_bigquery_magic_with_dict_params(ipython_ns_cleanup): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") magics.context.credentials = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) - sql = "SELECT @num AS num" - result = pandas.DataFrame([17], columns=["num"]) + ipython_ns_cleanup.append((ip, "params_dict_df")) + + sql = "SELECT @num AS num, @tricky_value as tricky_value" + result = pandas.DataFrame( + [(False, '--params "value"')], columns=["valid", "tricky_value"] + ) + assert "params_dict_df" not in ip.user_ns run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) query_job_mock = mock.create_autospec( google.cloud.bigquery.job.QueryJob, instance=True @@ -1218,7 +1283,7 @@ def test_bigquery_magic_with_dict_params(): with run_query_patch as run_query_mock: run_query_mock.return_value = query_job_mock - params = {"num": 17} + params = {"valid": False, "tricky_value": '--params "value"'} # Insert dictionary into user namespace so that it can be expanded ip.user_ns["params"] = params ip.run_cell_magic("bigquery", "params_dict_df --params $params", sql) @@ -1230,6 +1295,194 @@ def test_bigquery_magic_with_dict_params(): assert len(df) == len(result) # verify row count assert list(df) == list(result) # verify column names + assert not df["valid"][0] + assert df["tricky_value"][0] == '--params "value"' + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_nonexisting(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + sql = "SELECT @foo AS foo" + + with pytest.raises(NameError, match=r".*undefined variable.*unknown_name.*"): + ip.run_cell_magic("bigquery", "params_dict_df --params $unknown_name", sql) + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_incorrect_syntax(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + sql = "SELECT @foo AS foo" + + with pytest.raises(SyntaxError, match=r".*--params.*"): + cell_magic_args = "params_dict_df --params {'foo': 1; 'bar': 2}" + ip.run_cell_magic("bigquery", cell_magic_args, sql) + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_duplicate(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + sql = "SELECT @foo AS foo" + + with pytest.raises(ValueError, match=r"Duplicate --params option\."): + cell_magic_args = ( + "params_dict_df --params {'foo': 1} --verbose --params {'bar': 2} " + ) + ip.run_cell_magic("bigquery", cell_magic_args, sql) + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_option_value_incorrect(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + sql = "SELECT @foo AS foo" + + with pytest.raises(ValueError, match=r".*invalid literal.*\[PLENTY!\].*"): + cell_magic_args = "params_dict_df --max_results [PLENTY!]" + ip.run_cell_magic("bigquery", cell_magic_args, sql) + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_negative_value(ipython_ns_cleanup): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + ipython_ns_cleanup.append((ip, "params_dict_df")) + + sql = "SELECT @num AS num" + result = pandas.DataFrame([-17], columns=["num"]) + + assert "params_dict_df" not in ip.user_ns + + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + with run_query_patch as run_query_mock: + run_query_mock.return_value = query_job_mock + + params = {"num": -17} + # Insert dictionary into user namespace so that it can be expanded + ip.user_ns["params"] = params + ip.run_cell_magic("bigquery", "params_dict_df --params $params", sql) + + run_query_mock.assert_called_once_with(mock.ANY, sql.format(num=-17), mock.ANY) + + assert "params_dict_df" in ip.user_ns # verify that the variable exists + df = ip.user_ns["params_dict_df"] + assert len(df) == len(result) # verify row count + assert list(df) == list(result) # verify column names + assert df["num"][0] == -17 + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_array_value(ipython_ns_cleanup): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + ipython_ns_cleanup.append((ip, "params_dict_df")) + + sql = "SELECT @num AS num" + result = pandas.DataFrame(["foo bar", "baz quux"], columns=["array_data"]) + + assert "params_dict_df" not in ip.user_ns + + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + with run_query_patch as run_query_mock: + run_query_mock.return_value = query_job_mock + + params = {"array_data": ["foo bar", "baz quux"]} + # Insert dictionary into user namespace so that it can be expanded + ip.user_ns["params"] = params + ip.run_cell_magic("bigquery", "params_dict_df --params $params", sql) + + run_query_mock.assert_called_once_with(mock.ANY, sql.format(num=-17), mock.ANY) + + assert "params_dict_df" in ip.user_ns # verify that the variable exists + df = ip.user_ns["params_dict_df"] + assert len(df) == len(result) # verify row count + assert list(df) == list(result) # verify column names + assert list(df["array_data"]) == ["foo bar", "baz quux"] + + +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_dict_params_tuple_value(ipython_ns_cleanup): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + ipython_ns_cleanup.append((ip, "params_dict_df")) + + sql = "SELECT @num AS num" + result = pandas.DataFrame(["foo bar", "baz quux"], columns=["array_data"]) + + assert "params_dict_df" not in ip.user_ns + + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + with run_query_patch as run_query_mock: + run_query_mock.return_value = query_job_mock + + params = {"array_data": ("foo bar", "baz quux")} + # Insert dictionary into user namespace so that it can be expanded + ip.user_ns["params"] = params + ip.run_cell_magic("bigquery", "params_dict_df --params $params", sql) + + run_query_mock.assert_called_once_with(mock.ANY, sql.format(num=-17), mock.ANY) + + assert "params_dict_df" in ip.user_ns # verify that the variable exists + df = ip.user_ns["params_dict_df"] + assert len(df) == len(result) # verify row count + assert list(df) == list(result) # verify column names + assert list(df["array_data"]) == ["foo bar", "baz quux"] + @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @@ -1246,6 +1499,24 @@ def test_bigquery_magic_with_improperly_formatted_params(): ip.run_cell_magic("bigquery", "--params {17}", sql) +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_with_invalid_multiple_option_values(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + sql = "SELECT @foo AS foo" + + exc_pattern = r".*[Uu]nrecognized input.*option values correct\?.*567.*" + + with pytest.raises(ValueError, match=exc_pattern): + cell_magic_args = "params_dict_df --max_results 10 567" + ip.run_cell_magic("bigquery", cell_magic_args, sql) + + @pytest.mark.usefixtures("ipython_interactive") def test_bigquery_magic_omits_tracebacks_from_error_message(): ip = IPython.get_ipython() @@ -1259,7 +1530,7 @@ def test_bigquery_magic_omits_tracebacks_from_error_message(): ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", + "google.cloud.bigquery.magics.magics._run_query", autospec=True, side_effect=exceptions.BadRequest("Syntax error in SQL query"), ) @@ -1287,7 +1558,7 @@ def test_bigquery_magic_w_destination_table_invalid_format(): ) client_patch = mock.patch( - "google.cloud.bigquery.magics.bigquery.Client", autospec=True + "google.cloud.bigquery.magics.magics.bigquery.Client", autospec=True ) with client_patch, default_patch, pytest.raises(ValueError) as exc_context: @@ -1310,11 +1581,12 @@ def test_bigquery_magic_w_destination_table(): ) create_dataset_if_necessary_patch = mock.patch( - "google.cloud.bigquery.magics._create_dataset_if_necessary", autospec=True + "google.cloud.bigquery.magics.magics._create_dataset_if_necessary", + autospec=True, ) run_query_patch = mock.patch( - "google.cloud.bigquery.magics._run_query", autospec=True + "google.cloud.bigquery.magics.magics._run_query", autospec=True ) with create_dataset_if_necessary_patch, run_query_patch as run_query_mock: @@ -1341,12 +1613,12 @@ def test_bigquery_magic_create_dataset_fails(): ) create_dataset_if_necessary_patch = mock.patch( - "google.cloud.bigquery.magics._create_dataset_if_necessary", + "google.cloud.bigquery.magics.magics._create_dataset_if_necessary", autospec=True, side_effect=OSError, ) close_transports_patch = mock.patch( - "google.cloud.bigquery.magics._close_transports", autospec=True, + "google.cloud.bigquery.magics.magics._close_transports", autospec=True, ) with pytest.raises(