Skip to content

Commit

Permalink
Support SetMaxNamePath API (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy committed Jun 18, 2023
1 parent 6a475dd commit 3664acb
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 85 deletions.
29 changes: 23 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,33 @@ func (c *ZetaSQLiteConn) SetExplainMode(enabled bool) {
c.analyzer.SetExplainMode(enabled)
}

func (c *ZetaSQLiteConn) NamePath() []string {
return c.analyzer.NamePath()
// SetMaxNamePath specifies the maximum value of name path.
// If the name path in the query is the maximum value, the name path set as prefix is not used.
// Effective only when a value greater than zero is specified ( default zero ).
func (c *ZetaSQLiteConn) SetMaxNamePath(num int) {
c.analyzer.SetMaxNamePath(num)
}

// MaxNamePath returns maximum value of name path.
func (c *ZetaSQLiteConn) MaxNamePath() int {
return c.analyzer.MaxNamePath()
}

func (c *ZetaSQLiteConn) SetNamePath(path []string) {
c.analyzer.SetNamePath(path)
// SetNamePath set path to name path to be set as prefix.
// If max name path is specified, an error is returned if the number is exceeded.
func (c *ZetaSQLiteConn) SetNamePath(path []string) error {
return c.analyzer.SetNamePath(path)
}

// NamePath returns path to name path to be set as prefix.
func (c *ZetaSQLiteConn) NamePath() []string {
return c.analyzer.NamePath()
}

func (c *ZetaSQLiteConn) AddNamePath(path string) {
c.analyzer.AddNamePath(path)
// AddNamePath add path to name path to be set as prefix.
// If max name path is specified, an error is returned if the number is exceeded.
func (c *ZetaSQLiteConn) AddNamePath(path string) error {
return c.analyzer.AddNamePath(path)
}

func (s *ZetaSQLiteConn) CheckNamedValue(value *driver.NamedValue) error {
Expand Down
3 changes: 1 addition & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS
func TestRegisterCustomDriver(t *testing.T) {
sql.Register("zetasqlite-custom", &zetasqlite.ZetaSQLiteDriver{
ConnectHook: func(conn *zetasqlite.ZetaSQLiteConn) error {
conn.SetNamePath([]string{"project-id", "datasetID"})
return nil
return conn.SetNamePath([]string{"project-id", "datasetID"})
},
})
db, err := sql.Open("zetasqlite-custom", ":memory:")
Expand Down
31 changes: 21 additions & 10 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type Analyzer struct {
namePath []string
namePath *NamePath
isAutoIndexMode bool
isExplainMode bool
catalog *Catalog
Expand All @@ -22,8 +22,9 @@ type Analyzer struct {

func NewAnalyzer(catalog *Catalog) *Analyzer {
return &Analyzer{
catalog: catalog,
opt: newAnalyzerOptions(),
catalog: catalog,
opt: newAnalyzerOptions(),
namePath: &NamePath{},
}
}

Expand Down Expand Up @@ -103,15 +104,25 @@ func (a *Analyzer) SetExplainMode(enabled bool) {
}

func (a *Analyzer) NamePath() []string {
return a.namePath
return a.namePath.path
}

func (a *Analyzer) SetNamePath(path []string) {
a.namePath = path
func (a *Analyzer) SetNamePath(path []string) error {
return a.namePath.setPath(path)
}

func (a *Analyzer) AddNamePath(path string) {
a.namePath = append(a.namePath, path)
func (a *Analyzer) SetMaxNamePath(num int) {
if num > 0 {
a.namePath.maxNum = num
}
}

func (a *Analyzer) MaxNamePath() int {
return a.namePath.maxNum
}

func (a *Analyzer) AddNamePath(path string) error {
return a.namePath.addPath(path)
}

func (a *Analyzer) parseScript(query string) ([]parsed_ast.StatementNode, error) {
Expand Down Expand Up @@ -430,7 +441,7 @@ func (a *Analyzer) newDropStmtAction(ctx context.Context, query string, args []d
return nil, err
}
objectType := node.ObjectType()
name := FormatName(MergeNamePath(a.namePath, node.NamePath()))
name := a.namePath.format(node.NamePath())
return &DropStmtAction{
name: name,
objectType: objectType,
Expand All @@ -448,7 +459,7 @@ func (a *Analyzer) newDropFunctionStmtAction(ctx context.Context, query string,
if err != nil {
return nil, err
}
name := FormatName(MergeNamePath(a.namePath, node.NamePath()))
name := a.namePath.format(node.NamePath())
return &DropStmtAction{
name: name,
objectType: "FUNCTION",
Expand Down
6 changes: 3 additions & 3 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ func (c *Catalog) formatNamePath(path []string) string {
return strings.Join(path, "_")
}

func (c *Catalog) getFunctions(path []string) []*FunctionSpec {
if len(path) == 0 {
func (c *Catalog) getFunctions(namePath *NamePath) []*FunctionSpec {
if namePath.empty() {
return c.functions
}
key := c.formatNamePath(path)
key := c.formatNamePath(namePath.path)
specs := make([]*FunctionSpec, 0, len(c.functions))
for _, fn := range c.functions {
if len(fn.NamePath) == 1 {
Expand Down
6 changes: 3 additions & 3 deletions internal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ func withAnalyzer(ctx context.Context, analyzer *Analyzer) context.Context {
return context.WithValue(ctx, analyzerKey{}, analyzer)
}

func namePathFromContext(ctx context.Context) []string {
func namePathFromContext(ctx context.Context) *NamePath {
value := ctx.Value(namePathKey{})
if value == nil {
return nil
}
return value.([]string)
return value.(*NamePath)
}

func withNamePath(ctx context.Context, namePath []string) context.Context {
func withNamePath(ctx context.Context, namePath *NamePath) context.Context {
return context.WithValue(ctx, namePathKey{}, namePath)
}

Expand Down
55 changes: 6 additions & 49 deletions internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@ func New(node ast.Node) Formatter {
return newNode(node)
}

func FormatName(namePath []string) string {
namePath = FormatPath(namePath)
return strings.Join(namePath, "_")
}

func FormatPath(path []string) []string {
ret := []string{}
for _, p := range path {
splitted := strings.Split(p, ".")
ret = append(ret, splitted...)
}
return ret
}

func getTableName(ctx context.Context, n ast.Node) (string, error) {
nodeMap := nodeMapFromContext(ctx)
found := nodeMap.FindNodeFromResolvedNode(n)
Expand All @@ -43,12 +29,8 @@ func getTableName(ctx context.Context, n ast.Node) (string, error) {
if err != nil {
return "", fmt.Errorf("failed to find path: %w", err)
}
return FormatName(
MergeNamePath(
namePathFromContext(ctx),
path,
),
), nil
namePath := namePathFromContext(ctx)
return namePath.format(path), nil
}

func getFuncName(ctx context.Context, n ast.Node) (string, error) {
Expand All @@ -73,12 +55,8 @@ func getFuncName(ctx context.Context, n ast.Node) (string, error) {
if err != nil {
return "", fmt.Errorf("failed to find path: %w", err)
}
return FormatName(
MergeNamePath(
namePathFromContext(ctx),
path,
),
), nil
namePath := namePathFromContext(ctx)
return namePath.format(path), nil
}

func getPathFromNode(n parsed_ast.Node) ([]string, error) {
Expand Down Expand Up @@ -152,23 +130,6 @@ func formatInput(input string) (string, error) {
return "", fmt.Errorf("unexpected input pattern: %s", input)
}

func MergeNamePath(namePath []string, queryPath []string) []string {
namePath = FormatPath(namePath)
queryPath = FormatPath(queryPath)
if len(queryPath) == 0 {
return namePath
}

merged := []string{}
for _, path := range namePath {
if queryPath[0] == path {
break
}
merged = append(merged, path)
}
return append(merged, queryPath...)
}

func getFuncNameAndArgs(ctx context.Context, node *ast.BaseFunctionCallNode, isWindowFunc bool) (string, []string, error) {
args := []string{}
for _, a := range node.ArgumentList() {
Expand Down Expand Up @@ -1381,12 +1342,8 @@ func (n *DropStmtNode) FormatSQL(ctx context.Context) (string, error) {
if n.node == nil {
return "", nil
}
tableName := FormatName(
MergeNamePath(
namePathFromContext(ctx),
n.node.NamePath(),
),
)
namePath := namePathFromContext(ctx)
tableName := namePath.format(n.node.NamePath())
if n.node.IsIfExists() {
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName), nil
}
Expand Down
74 changes: 74 additions & 0 deletions internal/name_path.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package internal

import (
"fmt"
"strings"
)

type NamePath struct {
path []string
maxNum int
}

func (p *NamePath) normalizePath(path []string) []string {
ret := []string{}
for _, p := range path {
splitted := strings.Split(p, ".")
ret = append(ret, splitted...)
}
return ret
}

func (p *NamePath) mergePath(path []string) []string {
path = p.normalizePath(path)
if p.maxNum > 0 && len(path) == p.maxNum {
return path
}
if len(path) == 0 {
return p.path
}
merged := []string{}
for _, basePath := range p.path {
if path[0] == basePath {
break
}
merged = append(merged, basePath)
}
return append(merged, path...)
}

func (p *NamePath) format(path []string) string {
return formatPath(p.mergePath(path))
}

func formatPath(path []string) string {
return strings.Join(path, "_")
}

func (p *NamePath) setPath(path []string) error {
normalizedPath := p.normalizePath(path)
if p.maxNum > 0 && len(normalizedPath) > p.maxNum {
return fmt.Errorf("specified too many name paths %v(%d). max name path is %d", path, len(normalizedPath), p.maxNum)
}
p.path = normalizedPath
return nil
}

func (p *NamePath) addPath(path string) error {
normalizedPath := p.normalizePath([]string{path})
totalPath := len(p.path) + len(normalizedPath)
if p.maxNum > 0 && totalPath > p.maxNum {
return fmt.Errorf(
"specified too many name paths %v(%d). max name path is %d",
append(p.path, normalizedPath...),
totalPath,
p.maxNum,
)
}
p.path = append(p.path, normalizedPath...)
return nil
}

func (p *NamePath) empty() bool {
return len(p.path) == 0
}

0 comments on commit 3664acb

Please sign in to comment.