Skip to content

Commit

Permalink
Handle OCI-Chunk-Max-Length header field in blob pushes
Browse files Browse the repository at this point in the history
This header comes from the proposal of adding a way for registries to
limit chunk sizes from the client.

The implementation is simple enough; we're just adding a io.LimitReader
and don't let it read more than the bytes that were specified by the
registry.

Signed-off-by: Gabi Villalonga <gvillalongasimon@cloudflare.com>
  • Loading branch information
gabivlj committed Oct 31, 2023
1 parent 6c694cb commit 66f6603
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
9 changes: 9 additions & 0 deletions internal/client/blob_writer.go
Expand Up @@ -23,6 +23,7 @@ type httpBlobUpload struct {
location string // always the last value of the location header.
offset int64
closed bool
maxRange int64
}

func (hbu *httpBlobUpload) Reader() (io.ReadCloser, error) {
Expand All @@ -37,6 +38,10 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
}

func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
if hbu.maxRange != 0 {
r = io.LimitReader(r, hbu.maxRange)
}

req, err := http.NewRequestWithContext(hbu.ctx, http.MethodPatch, hbu.location, io.NopCloser(r))
if err != nil {
return 0, err
Expand Down Expand Up @@ -73,6 +78,10 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
}

func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
if hbu.maxRange != 0 && hbu.maxRange < int64(len(p)) {
p = p[:hbu.maxRange]
}

req, err := http.NewRequestWithContext(hbu.ctx, http.MethodPatch, hbu.location, bytes.NewReader(p))
if err != nil {
return 0, err
Expand Down
144 changes: 144 additions & 0 deletions internal/client/blob_writer_test.go
Expand Up @@ -3,7 +3,9 @@ package client
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"testing"

Expand Down Expand Up @@ -500,3 +502,145 @@ func TestUploadWrite(t *testing.T) {
t.Fatalf("Unexpected response status: %s, expected %s", uploadErr.Status, expected)
}
}

// tests the case of sending only the bytes that we're limiting on
func TestUploadLimitRange(t *testing.T) {
const numberOfBlobs = 10
const blobSize = 5
const lastBlobOffset = 2

_, b := newRandomBlob(numberOfBlobs*5 + 2)
repo := "test/upload/write"
locationPath := fmt.Sprintf("/v2/%s/uploads/testid", repo)
requests := []testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: http.MethodGet,
Route: "/v2/",
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Headers: http.Header(map[string][]string{
"Docker-Distribution-API-Version": {"registry/2.0"},
}),
},
},
}

for blob := 0; blob < numberOfBlobs; blob++ {
start := blob * blobSize
end := ((blob + 1) * blobSize) - 1

requests = append(requests, testutil.RequestResponseMapping{
Request: testutil.Request{
Method: http.MethodPatch,
Route: locationPath,
Body: b[start : end+1],
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
Headers: http.Header(map[string][]string{
"Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"},
"Location": {locationPath},
"Range": {fmt.Sprintf("%d-%d", start, end)},
}),
},
})
}

requests = append(requests, testutil.RequestResponseMapping{
Request: testutil.Request{
Method: http.MethodPatch,
Route: locationPath,
Body: b[numberOfBlobs*blobSize:],
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
Headers: http.Header(map[string][]string{
"Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"},
"Location": {locationPath},
"Range": {fmt.Sprintf("%d-%d", numberOfBlobs*blobSize, len(b)-1)},
}),
},
})

t.Run("reader chunked upload", func(t *testing.T) {
m := testutil.RequestResponseMap(requests)
e, c := testServer(m)
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
maxRange: int64(blobSize),
}

reader := bytes.NewBuffer(b)
for i := 0; i < numberOfBlobs; i++ {
blobUpload.location = e + locationPath
n, err := blobUpload.ReadFrom(reader)
if err != nil {
t.Fatalf("Error calling Write: %s", err)
}

if n != blobSize {
t.Fatalf("Unexpected n %v != %v blobSize", n, blobSize)
}
}

n, err := blobUpload.ReadFrom(reader)
if err != nil {
t.Fatalf("Error calling Write: %s", err)
}

if n != lastBlobOffset {
t.Fatalf("Expected last write to have written %v but wrote %v", lastBlobOffset, n)
}

_, err = reader.Read([]byte{0, 0, 0})
if !errors.Is(err, io.EOF) {
t.Fatalf("Expected io.EOF when reading blob as the test should've read the whole thing")
}
})

t.Run("buffer chunked upload", func(t *testing.T) {
buff := b
m := testutil.RequestResponseMap(requests)
e, c := testServer(m)
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
maxRange: int64(blobSize),
}

for i := 0; i < numberOfBlobs; i++ {
blobUpload.location = e + locationPath
n, err := blobUpload.Write(buff)
if err != nil {
t.Fatalf("Error calling Write: %s", err)
}

if n != blobSize {
t.Fatalf("Unexpected n %v != %v blobSize", n, blobSize)
}

buff = buff[n:]
}

n, err := blobUpload.Write(buff)
if err != nil {
t.Fatalf("Error calling Write: %s", err)
}

if n != lastBlobOffset {
t.Fatalf("Expected last write to have written %v but wrote %v", lastBlobOffset, n)
}

buff = buff[n:]
if len(buff) != 0 {
t.Fatalf("Expected length 0 on the buffer body as we should've read the whole thing, but got %v", len(buff))
}
})
}
6 changes: 6 additions & 0 deletions internal/client/repository.go
Expand Up @@ -809,13 +809,19 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err
}

maxRange, err := v2.GetOCIMaxRange(resp)
if err != nil {
return nil, err
}

return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: uuid,
startedAt: time.Now(),
location: location,
maxRange: maxRange,
}, nil
default:
return nil, HandleHTTPResponseError(resp)
Expand Down
18 changes: 18 additions & 0 deletions registry/api/v2/headerparser.go
Expand Up @@ -2,7 +2,9 @@ package v2

import (
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"unicode"
)
Expand Down Expand Up @@ -159,3 +161,19 @@ Loop:

return res, parse, nil
}

// GetOCIMaxRang gets the OCI-Chunk-Max-Length from the headers. If not set it will return 0
func GetOCIMaxRange(resp *http.Response) (int64, error) {
maxRangeStr := resp.Header.Get("OCI-Chunk-Max-Length")
maxRange := int64(0)
if maxRangeStr != "" {
maxRangeRes, err := strconv.ParseInt(maxRangeStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("OCI-Chunk-Max-Length is malformed %q: %w", maxRangeStr, err)
}

maxRange = maxRangeRes
}

return maxRange, nil
}

0 comments on commit 66f6603

Please sign in to comment.