Skip to content

Commit

Permalink
fix: validate instances in fuse mode (#875)
Browse files Browse the repository at this point in the history
## Change Description

Validate instances in fuse mode to avoid incorrectly picking up file system requests.


## Checklist

- [x] Make sure to open an issue as a 
  [bug/issue](https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/new/choose) 
  before writing your code!  That way we can discuss the change, evaluate 
  designs, and agree on the general idea.
- [x] Ensure the tests and linter pass
- [x] Appropriate documentation is updated (if necessary)

## Relevant issues:

- Fixes #873
  • Loading branch information
monazhn committed Aug 9, 2021
1 parent 247cbad commit 96f8b65
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 12 deletions.
12 changes: 3 additions & 9 deletions cmd/cloud_sql_proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/fuse"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

Expand Down Expand Up @@ -225,17 +224,12 @@ var validNets = func() map[string]bool {

func parseInstanceConfig(dir, instance string, cl *http.Client) (instanceConfig, error) {
var ret instanceConfig
args := strings.Split(instance, "=")
if len(args) > 2 {
return instanceConfig{}, fmt.Errorf("invalid instance argument: must be either form - `<instance_connection_string>` or `<instance_connection_string>=<options>`; invalid arg was %q", instance)
proj, region, name, args, err := proxy.ParseInstanceConnectionName(instance)
if err != nil {
return instanceConfig{}, err
}
// Parse the instance connection name - everything before the "=".
ret.Instance = args[0]
proj, region, name := util.SplitName(ret.Instance)
regionName := fmt.Sprintf("%s~%s", region, name)
if proj == "" || region == "" || name == "" {
return instanceConfig{}, fmt.Errorf("invalid instance connection string: must be in the form `project:region:instance-name`; invalid name was %q", args[0])
}
if len(args) == 1 {
// Default to listening via unix socket in specified directory
ret.Network = "unix"
Expand Down
4 changes: 4 additions & 0 deletions proxy/fuse/fuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ func (r *fsRoot) Lookup(ctx context.Context, req *fuse.LookupRequest, resp *fuse
r.sockLock.Lock()
defer r.sockLock.Unlock()

if _, _, _, _, err := proxy.ParseInstanceConnectionName(instance); err != nil {
return nil, err
}

if ret, ok := r.links[instance]; ok {
return ret, nil
}
Expand Down
6 changes: 3 additions & 3 deletions proxy/fuse/fuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestBadDir(t *testing.T) {
}
defer fuse.Close()

_, err = os.Stat(filepath.Join(dir, "dir1", "dir2"))
_, err = os.Stat(filepath.Join(dir, "dir:ectory:1", "dir:ectory:2"))
if err == nil {
t.Fatal("able to find a directory inside the mount point, expected only regular files")
}
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestSingleInstance(t *testing.T) {
}
defer fuse.Close()

const want = "test:instance"
const want = "test:instance:string"
path := filepath.Join(dir, want)

fi, err := os.Stat(path)
Expand Down Expand Up @@ -155,7 +155,7 @@ func BenchmarkNewConnection(b *testing.B) {
}
}()

const instance = "test:instance"
const instance = "test:instance:string"
path := filepath.Join(dir, instance)

b.ResetTimer()
Expand Down
17 changes: 17 additions & 0 deletions proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
"golang.org/x/net/proxy"
"golang.org/x/time/rate"
)
Expand Down Expand Up @@ -522,6 +524,21 @@ func (c *Client) InstanceVersionContext(ctx context.Context, instance string) (s
return version, nil
}

// ParseInstanceConnectionName verifies that instances are in the expected format and include
// the necessary components.
func ParseInstanceConnectionName(instance string) (string, string, string, []string, error) {
args := strings.Split(instance, "=")
if len(args) > 2 {
return "", "", "", nil, fmt.Errorf("invalid instance argument: must be either form - `<instance_connection_string>` or `<instance_connection_string>=<options>`; invalid arg was %q", instance)
}
// Parse the instance connection name - everything before the "=".
proj, region, name := util.SplitName(args[0])
if proj == "" || region == "" || name == "" {
return "", "", "", nil, fmt.Errorf("invalid instance connection string: must be in the form `project:region:instance-name`; invalid name was %q", args[0])
}
return proj, region, name, args, nil
}

// AvailableConn returns false if MaxConnections has been reached, true otherwise.
// When MaxConnections is 0, there is no limit.
func (c *Client) AvailableConn() bool {
Expand Down
25 changes: 25 additions & 0 deletions proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,28 @@ func TestRemoteCertError(t *testing.T) {
}

}

func TestValidate(t *testing.T) {
// SplitName has its own tests and is not specifically tested here.
table := []struct {
in string
wantValid bool
wantErrorStr string
}{
{"proj:region:my-db", true, ""},
{"proj:region:my-db=options", true, ""},
{"proj=region=my-db", false, "invalid instance argument: must be either form - `<instance_connection_string>` or `<instance_connection_string>=<options>`; invalid arg was \"proj=region=my-db\""},
{"projregionmy-db", false, "invalid instance connection string: must be in the form `project:region:instance-name`; invalid name was \"projregionmy-db\""},
}

for _, test := range table {
_, _, _, _, gotError := ParseInstanceConnectionName(test.in)
var gotErrorStr string
if gotError != nil {
gotErrorStr = gotError.Error()
}
if gotErrorStr != test.wantErrorStr {
t.Errorf("Validate(%q): got \"%v\" for error, want \"%v\"", test.in, gotError, test.wantErrorStr)
}
}
}

0 comments on commit 96f8b65

Please sign in to comment.