diff --git a/storage/integration_test.go b/storage/integration_test.go index 500e950b114..137f720476c 100644 --- a/storage/integration_test.go +++ b/storage/integration_test.go @@ -18,7 +18,10 @@ import ( "bytes" "compress/gzip" "context" + "crypto" "crypto/md5" + cryptorand "crypto/rand" + "crypto/rsa" "crypto/sha256" "encoding/base64" "encoding/json" @@ -2018,11 +2021,12 @@ func TestIntegration_SignedURL(t *testing.T) { t.Errorf("%s: SignedURL: %v", test.desc, err) continue } - got, err := getURL(u, test.headers) + + err = verifySignedURL(u, test.headers, contents) if err != nil && !test.fail { - t.Errorf("%s: getURL %q: %v", test.desc, u, err) - } else if err == nil && !bytes.Equal(got, contents) { - t.Errorf("%s: got %q, want %q", test.desc, got, contents) + t.Errorf("%s: wanted success but got error:\n%v", test.desc, err) + } else if err == nil && test.fail { + t.Errorf("%s: wanted failure but test succeeded", test.desc) } } } @@ -2111,13 +2115,9 @@ func TestIntegration_SignedURL_WithEncryptionKeys(t *testing.T) { } if test.opts.Method == "GET" { - got, err := getURL(u, headers) - if err != nil { + if err := verifySignedURL(u, headers, contents); err != nil { t.Fatalf("%s: %v", test.desc, err) } - if !bytes.Equal(got, contents) { - t.Fatalf("%s: got %q, want %q", test.desc, got, contents) - } } } } @@ -4170,78 +4170,18 @@ func TestIntegration_PostPolicyV4(t *testing.T) { }, } - objectName := "my-object.txt" + objectName := uidSpace.New() + object := b.Object(objectName) + defer h.mustDeleteObject(object) + pv4, err := GenerateSignedPostPolicyV4(newBucketName, objectName, opts) if err != nil { t.Fatal(err) } - formBuf := new(bytes.Buffer) - mw := multipart.NewWriter(formBuf) - for fieldName, value := range pv4.Fields { - if err := mw.WriteField(fieldName, value); err != nil { - t.Fatalf("Failed to write form field: %q: %v", fieldName, err) - } - } - - // Now let's perform the upload. - fileBody := bytes.Repeat([]byte("a"), 25) - mf, err := mw.CreateFormFile("file", "myfile.txt") - if err != nil { - t.Fatal(err) - } - if _, err := mf.Write(fileBody); err != nil { + if err := verifyPostPolicy(pv4, object, bytes.Repeat([]byte("a"), 25), statusCodeToRespond); err != nil { t.Fatal(err) } - if err := mw.Close(); err != nil { - t.Fatal(err) - } - - // Compose the HTTP request. - req, err := http.NewRequest("POST", pv4.URL, formBuf) - if err != nil { - t.Fatalf("Failed to compose HTTP request: %v", err) - } - // Ensure the Content-Type is derived from the writer. - req.Header.Set("Content-Type", mw.FormDataContentType()) - res, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if g, w := res.StatusCode, statusCodeToRespond; g != w { - blob, _ := httputil.DumpResponse(res, true) - t.Fatalf("Status code in response mismatch: got %d want %d\nBody: %s", g, w, blob) - } - io.Copy(ioutil.Discard, res.Body) - - // Verify that the file was properly uploaded, by - // reading back its attributes and content! - obj := b.Object(objectName) - defer h.mustDeleteObject(obj) - - attrs, err := obj.Attrs(ctx) - if err != nil { - t.Fatalf("Failed to retrieve attributes: %v", err) - } - if g, w := attrs.Size, int64(len(fileBody)); g != w { - t.Errorf("ContentLength mismatch: got %d want %d", g, w) - } - if g, w := attrs.MD5, md5.Sum(fileBody); !bytes.Equal(g, w[:]) { - t.Errorf("MD5Checksum mismatch\nGot: %x\nWant: %x", g, w) - } - - // Compare the uploaded body with the expected. - rd, err := obj.NewReader(ctx) - if err != nil { - t.Fatalf("Failed to create a reader: %v", err) - } - gotBody, err := ioutil.ReadAll(rd) - if err != nil { - t.Fatalf("Failed to read the body: %v", err) - } - if diff := testutil.Diff(string(gotBody), string(fileBody)); diff != "" { - t.Fatalf("Body mismatch: got - want +\n%s", diff) - } } // Verify that custom scopes passed in by the user are applied correctly. @@ -4276,7 +4216,7 @@ func TestIntegration_Scopes(t *testing.T) { } -func TestBucketSignURL(t *testing.T) { +func TestIntegration_SignedURL_Bucket(t *testing.T) { ctx := context.Background() if testing.Short() && !replaying { @@ -4326,27 +4266,187 @@ func TestBucketSignURL(t *testing.T) { client: clientWithoutPrivateKey, }, } { - bkt := test.client.Bucket(bucketName) - url, err := bkt.SignedURL(obj, &test.opts) - if err != nil { - t.Fatalf("unable to create signed URL: %v", err) - } - resp, err := http.Get(url) - if err != nil { - t.Fatalf("http.Get(%q) errored: %q", url, err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - t.Fatalf("resp.StatusCode = %v, want 200: %v", resp.StatusCode, err) - } - b, err := ioutil.ReadAll(resp.Body) + t.Run(test.desc, func(t *testing.T) { + bkt := test.client.Bucket(bucketName) + url, err := bkt.SignedURL(obj, &test.opts) + if err != nil { + t.Fatalf("unable to create signed URL: %v", err) + } + + if err := verifySignedURL(url, nil, contents); err != nil { + t.Fatalf("problem with the signed URL: %v", err) + } + }) + } +} + +// Tests that the same SignBytes function works for both +// SignRawBytes on GeneratePostPolicyV4 and SignBytes on SignedURL +func TestIntegration_PostPolicyV4_SignedURL_WithSignBytes(t *testing.T) { + ctx := context.Background() + + if testing.Short() && !replaying { + t.Skip("Integration tests skipped in short mode") + } + + client := testConfig(ctx, t) + defer client.Close() + + h := testHelper{t} + projectID := testutil.ProjID() + bucketName := uidSpace.New() + objectName := "my-object.txt" + fileBody := bytes.Repeat([]byte("b"), 25) + bucket := client.Bucket(bucketName) + + h.mustCreate(bucket, projectID, nil) + defer h.mustDeleteBucket(bucket) + + object := bucket.Object(objectName) + defer h.mustDeleteObject(object) + + jwtConf, err := testutil.JWTConfig() + if err != nil { + t.Fatal(err) + } + if jwtConf == nil { + t.Skip("JSON key file is not present") + } + + signingFunc := func(b []byte) ([]byte, error) { + parsedRSAPrivKey, err := parseKey(jwtConf.PrivateKey) if err != nil { - t.Fatalf("unable to read resp.Body: %v", err) + return nil, err } - if !bytes.Equal(b, contents) { - t.Fatalf("got %q, want %q", b, contents) + sum := sha256.Sum256(b) + return rsa.SignPKCS1v15(cryptorand.Reader, parsedRSAPrivKey, crypto.SHA256, sum[:]) + } + + // Test Post Policy + successStatusCode := 200 + ppv4Opts := &PostPolicyV4Options{ + GoogleAccessID: jwtConf.Email, + SignRawBytes: signingFunc, + Expires: time.Now().Add(30 * time.Minute), + Fields: &PolicyV4Fields{ + StatusCodeOnSuccess: successStatusCode, + ContentType: "text/plain", + ACL: "public-read", + }, + } + + pv4, err := GenerateSignedPostPolicyV4(bucketName, objectName, ppv4Opts) + if err != nil { + t.Fatal(err) + } + + if err := verifyPostPolicy(pv4, object, fileBody, successStatusCode); err != nil { + t.Fatal(err) + } + + // Test Signed URL + signURLOpts := &SignedURLOptions{ + GoogleAccessID: jwtConf.Email, + SignBytes: signingFunc, + Method: "GET", + Expires: time.Now().Add(30 * time.Second), + } + + url, err := bucket.SignedURL(objectName, signURLOpts) + if err != nil { + t.Fatalf("unable to create signed URL: %v", err) + } + + if err := verifySignedURL(url, nil, fileBody); err != nil { + t.Fatal(err) + } +} + +// verifySignedURL gets the bytes at the provided url and verifies them against the +// expectedFileBody. Make sure the SignedURLOptions set the method as "GET". +func verifySignedURL(url string, headers map[string][]string, expectedFileBody []byte) error { + got, err := getURL(url, headers) + if err != nil { + return fmt.Errorf("getURL %q: %v", url, err) + } + if !bytes.Equal(got, expectedFileBody) { + return fmt.Errorf("got %q, want %q", got, expectedFileBody) + } + return nil +} + +// verifyPostPolicy uploads a file to the obj using the provided post policy and +// verifies that it was uploaded correctly +func verifyPostPolicy(pv4 *PostPolicyV4, obj *ObjectHandle, bytesToWrite []byte, statusCodeOnSuccess int) error { + ctx := context.Background() + formBuf := new(bytes.Buffer) + mw := multipart.NewWriter(formBuf) + for fieldName, value := range pv4.Fields { + if err := mw.WriteField(fieldName, value); err != nil { + return fmt.Errorf("Failed to write form field: %q: %v", fieldName, err) } } + + // Now let's perform the upload + mf, err := mw.CreateFormFile("file", "myfile.txt") + if err != nil { + return err + } + if _, err := mf.Write(bytesToWrite); err != nil { + return err + } + if err := mw.Close(); err != nil { + return err + } + + // Compose the HTTP request + req, err := http.NewRequest("POST", pv4.URL, formBuf) + if err != nil { + return fmt.Errorf("Failed to compose HTTP request: %v", err) + } + + // Ensure the Content-Type is derived from the writer + req.Header.Set("Content-Type", mw.FormDataContentType()) + + // Send request + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + // Check response + if g, w := res.StatusCode, statusCodeOnSuccess; g != w { + blob, _ := httputil.DumpResponse(res, true) + return fmt.Errorf("Status code in response mismatch: got %d want %d\nBody: %s", g, w, blob) + } + io.Copy(ioutil.Discard, res.Body) + + // Verify that the file was properly uploaded + // by reading back its attributes and content + attrs, err := obj.Attrs(ctx) + if err != nil { + return fmt.Errorf("Failed to retrieve attributes: %v", err) + } + if g, w := attrs.Size, int64(len(bytesToWrite)); g != w { + return fmt.Errorf("ContentLength mismatch: got %d want %d", g, w) + } + if g, w := attrs.MD5, md5.Sum(bytesToWrite); !bytes.Equal(g, w[:]) { + return fmt.Errorf("MD5Checksum mismatch\nGot: %x\nWant: %x", g, w) + } + + // Compare the uploaded body with the expected + rd, err := obj.NewReader(ctx) + if err != nil { + return fmt.Errorf("Failed to create a reader: %v", err) + } + gotBody, err := ioutil.ReadAll(rd) + if err != nil { + return fmt.Errorf("Failed to read the body: %v", err) + } + if diff := testutil.Diff(string(gotBody), string(bytesToWrite)); diff != "" { + return fmt.Errorf("Body mismatch: got - want +\n%s", diff) + } + return nil } func newTestClientWithExplicitCredentials(ctx context.Context, t *testing.T) *Client { diff --git a/storage/post_policy_v4.go b/storage/post_policy_v4.go index db9d1383849..5f418c3246b 100644 --- a/storage/post_policy_v4.go +++ b/storage/post_policy_v4.go @@ -52,22 +52,38 @@ type PostPolicyV4Options struct { // Exactly one of PrivateKey or SignBytes must be non-nil. PrivateKey []byte - // SignBytes is a function for implementing custom signing. For example, if + // SignBytes is a function for implementing custom signing. + // + // Deprecated: Use SignRawBytes. If both SignBytes and SignRawBytes are defined, + // SignBytes will be ignored. + // This SignBytes function expects the bytes it receives to be hashed, while + // SignRawBytes accepts the raw bytes without hashing, allowing more flexibility. + // Add the following to the top of your signing function to hash the bytes + // to use SignRawBytes instead: + // shaSum := sha256.Sum256(bytes) + // bytes = shaSum[:] + // + SignBytes func(hashBytes []byte) (signature []byte, err error) + + // SignRawBytes is a function for implementing custom signing. For example, if // your application is running on Google App Engine, you can use // appengine's internal signing function: - // ctx := appengine.NewContext(request) - // acc, _ := appengine.ServiceAccount(ctx) - // url, err := SignedURL("bucket", "object", &SignedURLOptions{ - // GoogleAccessID: acc, - // SignBytes: func(b []byte) ([]byte, error) { - // _, signedBytes, err := appengine.SignBytes(ctx, b) - // return signedBytes, err - // }, - // // etc. - // }) + // ctx := appengine.NewContext(request) + // acc, _ := appengine.ServiceAccount(ctx) + // &PostPolicyV4Options{ + // GoogleAccessID: acc, + // SignRawBytes: func(b []byte) ([]byte, error) { + // _, signedBytes, err := appengine.SignBytes(ctx, b) + // return signedBytes, err + // }, + // // etc. + // }) // - // Exactly one of PrivateKey or SignBytes must be non-nil. - SignBytes func(hashBytes []byte) (signature []byte, err error) + // SignRawBytes is equivalent to the SignBytes field on SignedURLOptions; + // that is, you may use the same signing function for the two. + // + // Exactly one of PrivateKey or SignRawBytes must be non-nil. + SignRawBytes func(bytes []byte) (signature []byte, err error) // Expires is the expiration time on the signed URL. // It must be a time in the future. @@ -96,6 +112,8 @@ type PostPolicyV4Options struct { // a 4XX status code, back with the message describing the problem. // Optional. Conditions []PostPolicyV4Condition + + shouldHashSignBytes bool } // PolicyV4Fields describes the attributes for a PostPolicyV4 request. @@ -220,20 +238,22 @@ func GenerateSignedPostPolicyV4(bucket, object string, opts *PostPolicyV4Options var signingFn func(hashedBytes []byte) ([]byte, error) switch { - case opts.SignBytes != nil: + case opts.SignRawBytes != nil: + signingFn = opts.SignRawBytes + case opts.shouldHashSignBytes: signingFn = opts.SignBytes - case len(opts.PrivateKey) != 0: parsedRSAPrivKey, err := parseKey(opts.PrivateKey) if err != nil { return nil, err } - signingFn = func(hashedBytes []byte) ([]byte, error) { - return rsa.SignPKCS1v15(rand.Reader, parsedRSAPrivKey, crypto.SHA256, hashedBytes) + signingFn = func(b []byte) ([]byte, error) { + sum := sha256.Sum256(b) + return rsa.SignPKCS1v15(rand.Reader, parsedRSAPrivKey, crypto.SHA256, sum[:]) } default: - return nil, errors.New("storage: exactly one of PrivateKey or SignedBytes must be set") + return nil, errors.New("storage: exactly one of PrivateKey or SignRawBytes must be set") } var descFields PolicyV4Fields @@ -307,10 +327,18 @@ func GenerateSignedPostPolicyV4(bucket, object string, opts *PostPolicyV4Options } b64Policy := base64.StdEncoding.EncodeToString(condsAsJSON) - shaSum := sha256.Sum256([]byte(b64Policy)) - signature, err := signingFn(shaSum[:]) - if err != nil { - return nil, err + var signature []byte + var signErr error + + if opts.shouldHashSignBytes { + // SignBytes expects hashed bytes as input instead of raw bytes, so we hash them + shaSum := sha256.Sum256([]byte(b64Policy)) + signature, signErr = signingFn(shaSum[:]) + } else { + signature, signErr = signingFn([]byte(b64Policy)) + } + if signErr != nil { + return nil, signErr } policyFields["policy"] = b64Policy @@ -348,15 +376,16 @@ func GenerateSignedPostPolicyV4(bucket, object string, opts *PostPolicyV4Options // validatePostPolicyV4Options checks that: // * GoogleAccessID is set -// * either but not both PrivateKey and SignBytes are set or nil, but not both -// * Expires, the deadline is not in the past +// * either PrivateKey or SignRawBytes/SignBytes is set, but not both +// * the deadline set in Expires is not in the past // * if Style is not set, it'll use PathStyle +// * sets shouldHashSignBytes to true if opts.SignBytes should be used func validatePostPolicyV4Options(opts *PostPolicyV4Options, now time.Time) error { if opts == nil || opts.GoogleAccessID == "" { return errors.New("storage: missing required GoogleAccessID") } - if privBlank, signBlank := len(opts.PrivateKey) == 0, opts.SignBytes == nil; privBlank == signBlank { - return errors.New("storage: exactly one of PrivateKey or SignedBytes must be set") + if privBlank, signBlank := len(opts.PrivateKey) == 0, opts.SignBytes == nil && opts.SignRawBytes == nil; privBlank == signBlank { + return errors.New("storage: exactly one of PrivateKey or SignRawBytes must be set") } if opts.Expires.Before(now) { return errors.New("storage: expecting Expires to be in the future") @@ -364,6 +393,9 @@ func validatePostPolicyV4Options(opts *PostPolicyV4Options, now time.Time) error if opts.Style == nil { opts.Style = PathStyle() } + if opts.SignRawBytes == nil && opts.SignBytes != nil { + opts.shouldHashSignBytes = true + } return nil }