Skip to content

Commit

Permalink
Fix for remote connections in Go runner (#134)
Browse files Browse the repository at this point in the history
* Fix remote db open

* Update arch

* Fix for remote db connection

* Fix for fmt

* Fix for tests
  • Loading branch information
eatonphil committed Dec 23, 2021
1 parent f8331c8 commit 85e22fd
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 34 deletions.
8 changes: 8 additions & 0 deletions ARCHITECTURE.md
Expand Up @@ -33,6 +33,10 @@ something to install all missing dependencies.

### ./desktop/panel

NOTE: This code is being migrated to Go. All panel types except for a
few database vendors have been ported to Go. A number of Node panel
handlers have been deleted since they are no longer used.

This is where eval handlers for each panel type (program, database,
etc.) are defined.

Expand All @@ -45,6 +49,10 @@ this on desktop. ./server/runner.ts is the equivalent on the server.

This allows easy resource cleanup and easy "kill" panel eval support.

## ./runner

This is where the Go port of the original Node.js panel eval code is.

## ./server

This directory contains the server (Express) app and code that proxies
Expand Down
61 changes: 33 additions & 28 deletions runner/database.go
Expand Up @@ -24,16 +24,22 @@ import (
_ "github.com/snowflakedb/gosnowflake"
)

func getDatabaseHostPort(raw, defaultPort string) (string, string, error) {
beforeQuery := strings.Split(raw, "?")[0]
func getDatabaseHostPortExtra(raw, defaultPort string) (string, string, string, error) {
addressAndArgs := strings.SplitN(raw, "?", 2)
extra := ""
beforeQuery := addressAndArgs[0]
if len(addressAndArgs) > 1 {
extra = addressAndArgs[1]
}
_, _, err := net.SplitHostPort(beforeQuery)
if err != nil && strings.HasSuffix(err.Error(), "missing port in address") {
beforeQuery += ":" + defaultPort
} else if err != nil {
return "", "", edsef("Could not split host-port: %s", err)
return "", "", "", edsef("Could not split host-port: %s", err)
}

return net.SplitHostPort(beforeQuery)
host, port, err := net.SplitHostPort(beforeQuery)
return host, port, extra, err
}

func debugObject(obj interface{}) {
Expand Down Expand Up @@ -97,7 +103,7 @@ var defaultPorts = map[DatabaseConnectorInfoType]string{

func getConnectionString(dbInfo DatabaseConnectorInfoDatabase) (string, string, error) {
address := dbInfo.Address
split := strings.Split(address, "?")
split := strings.SplitN(address, "?", 2)
address = split[0]
extraArgs := ""
if len(split) > 1 {
Expand Down Expand Up @@ -272,16 +278,6 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p

dbInfo := connector.Database

vendor, connStr, err := getConnectionString(dbInfo)
if err != nil {
return err
}

db, err := sqlx.Open(vendor, connStr)
if err != nil {
return err
}

mangleInsert := defaultMangleInsert
qt := ansiSQLQuote
if dbInfo.Type == "postgres" {
Expand All @@ -306,16 +302,11 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p
return err
}

// Require queries end with semicolon primarily for Oracle
// that blows up without this. This will still blow up if
// there's no semicolon and there are comments.
// e.g. `SELECT 1 -- flubber` -> `SELECT 1 -- flubber;`
//qWithoutWs := strings.TrimSpace(query)
//if qWithoutWs[len(qWithoutWs)-1] != ';' {
// query += ";"
//}

server, err := getServer(project, panel.ServerId)
serverId := panel.ServerId
if serverId == "" {
serverId = connector.ServerId
}
server, err := getServer(project, serverId)
if err != nil {
return err
}
Expand All @@ -338,10 +329,9 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p
}

dbInfo.Database = tmp.Name()
tmp.Close()
}

host, port, err := getDatabaseHostPort(dbInfo.Address, defaultPorts[dbInfo.Type])
host, port, extra, err := getDatabaseHostPortExtra(dbInfo.Address, defaultPorts[dbInfo.Type])
if err != nil {
return err
}
Expand All @@ -360,7 +350,21 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p
}
defer w.Close()

return withRemoteConnection(server, host, port, func(host, port string) error {
return withRemoteConnection(server, host, port, func(proxyHost, proxyPort string) error {
dbInfo.Address = proxyHost + ":" + proxyPort
if extra != "" {
dbInfo.Address += "?" + extra
}
vendor, connStr, err := getConnectionString(dbInfo)
if err != nil {
return err
}

db, err := sqlx.Open(vendor, connStr)
if err != nil {
return err
}

wroteFirstRow := false
return withJSONArrayOutWriterFile(w, func(w *JSONArrayWriter) error {
_, err := importAndRun(
Expand All @@ -377,6 +381,7 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p
if err != nil {
return nil, err
}

defer rows.Close()

for rows.Next() {
Expand Down
14 changes: 13 additions & 1 deletion runner/database_test.go
Expand Up @@ -14,6 +14,7 @@ func Test_getConnectionString(t *testing.T) {
expErr error
expHost string
expPort string
expExtra string
}{
{
DatabaseConnectorInfoDatabase{Type: "postgres", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost?sslmode=disable"},
Expand All @@ -22,6 +23,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"5432",
"sslmode=disable",
},
{
DatabaseConnectorInfoDatabase{Type: "postgres", Database: "test", Address: "big.com:8888?sslmode=disable"},
Expand All @@ -30,6 +32,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"big.com",
"8888",
"sslmode=disable",
},
{
DatabaseConnectorInfoDatabase{Type: "mysql", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost:9090"},
Expand All @@ -38,6 +41,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"9090",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "sqlite", Database: "test.sql"},
Expand All @@ -46,6 +50,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"",
"",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "oracle", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"},
Expand All @@ -54,6 +59,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"1521",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "snowflake", Username: "jim", Password: Encrypt{Encrypted: false, Value: ""}, Database: "test", Address: "myid"},
Expand All @@ -62,6 +68,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"",
"",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "snowflake", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "myid?x=y"},
Expand All @@ -70,6 +77,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"",
"",
"x=y",
},
{
DatabaseConnectorInfoDatabase{Type: "sqlserver", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"},
Expand All @@ -78,6 +86,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"1433",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "clickhouse", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"},
Expand All @@ -86,6 +95,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"9000",
"",
},
{
DatabaseConnectorInfoDatabase{Type: "clickhouse", Password: Encrypt{Encrypted: false, Value: ""}, Database: "test", Address: "localhost:9001"},
Expand All @@ -94,6 +104,7 @@ func Test_getConnectionString(t *testing.T) {
nil,
"localhost",
"9001",
"",
},
}
for _, test := range tests {
Expand All @@ -106,9 +117,10 @@ func Test_getConnectionString(t *testing.T) {
continue
}

host, port, err := getDatabaseHostPort(test.conn.Address, defaultPorts[DatabaseConnectorInfoType(test.expVendor)])
host, port, extra, err := getDatabaseHostPortExtra(test.conn.Address, defaultPorts[DatabaseConnectorInfoType(test.expVendor)])
assert.Nil(t, err)
assert.Equal(t, test.expHost, host)
assert.Equal(t, test.expPort, port)
assert.Equal(t, test.expExtra, extra)
}
}
14 changes: 9 additions & 5 deletions runner/ssh.go
Expand Up @@ -254,13 +254,17 @@ func withRemoteConnection(si *ServerInfo, host, port string, cb func(host, port
localPort := localConn.Addr().(*net.TCPAddr).Port
cbErr := cb("localhost", fmt.Sprintf("%d", localPort))
if cbErr != nil {
return err
return cbErr
}

err = <-errC
if err == io.EOF {
select {
case err = <-errC:
if err == io.EOF {
return nil
}

return err
default:
return nil
}

return err
}

0 comments on commit 85e22fd

Please sign in to comment.