Merge pull request #2340 from stevvooe/limit-payload-size
registry/{storage,handlers}: limit content sizes
			
			
				master
			
			
		
						commit
						91c507a39a
					
				|  | @ -179,8 +179,8 @@ func (buh *blobUploadHandler) PatchBlobData(w http.ResponseWriter, r *http.Reque | |||
| 
 | ||||
| 	// TODO(dmcgowan): support Content-Range header to seek and write range
 | ||||
| 
 | ||||
| 	if err := copyFullPayload(w, r, buh.Upload, buh, "blob PATCH", &buh.Errors); err != nil { | ||||
| 		// copyFullPayload reports the error if necessary
 | ||||
| 	if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PATCH"); err != nil { | ||||
| 		buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error())) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -218,8 +218,8 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if err := copyFullPayload(w, r, buh.Upload, buh, "blob PUT", &buh.Errors); err != nil { | ||||
| 		// copyFullPayload reports the error if necessary
 | ||||
| 	if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PUT"); err != nil { | ||||
| 		buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error())) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ import ( | |||
| 	"net/http" | ||||
| 
 | ||||
| 	ctxu "github.com/docker/distribution/context" | ||||
| 	"github.com/docker/distribution/registry/api/errcode" | ||||
| ) | ||||
| 
 | ||||
| // closeResources closes all the provided resources after running the target
 | ||||
|  | @ -23,7 +22,9 @@ func closeResources(handler http.Handler, closers ...io.Closer) http.Handler { | |||
| // copyFullPayload copies the payload of an HTTP request to destWriter. If it
 | ||||
| // receives less content than expected, and the client disconnected during the
 | ||||
| // upload, it avoids sending a 400 error to keep the logs cleaner.
 | ||||
| func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error { | ||||
| //
 | ||||
| // The copy will be limited to `limit` bytes, if limit is greater than zero.
 | ||||
| func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, limit int64, context ctxu.Context, action string) error { | ||||
| 	// Get a channel that tells us if the client disconnects
 | ||||
| 	var clientClosed <-chan bool | ||||
| 	if notifier, ok := responseWriter.(http.CloseNotifier); ok { | ||||
|  | @ -32,8 +33,13 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr | |||
| 		ctxu.GetLogger(context).Warnf("the ResponseWriter does not implement CloseNotifier (type: %T)", responseWriter) | ||||
| 	} | ||||
| 
 | ||||
| 	var body = r.Body | ||||
| 	if limit > 0 { | ||||
| 		body = http.MaxBytesReader(responseWriter, body, limit) | ||||
| 	} | ||||
| 
 | ||||
| 	// Read in the data, if any.
 | ||||
| 	copied, err := io.Copy(destWriter, r.Body) | ||||
| 	copied, err := io.Copy(destWriter, body) | ||||
| 	if clientClosed != nil && (err != nil || (r.ContentLength > 0 && copied < r.ContentLength)) { | ||||
| 		// Didn't receive as much content as expected. Did the client
 | ||||
| 		// disconnect during the request? If so, avoid returning a 400
 | ||||
|  | @ -58,7 +64,6 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr | |||
| 
 | ||||
| 	if err != nil { | ||||
| 		ctxu.GetLogger(context).Errorf("unknown error reading request payload: %v", err) | ||||
| 		*errSlice = append(*errSlice, errcode.ErrorCodeUnknown.WithDetail(err)) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,8 +22,9 @@ import ( | |||
| // These constants determine which architecture and OS to choose from a
 | ||||
| // manifest list when downconverting it to a schema1 manifest.
 | ||||
| const ( | ||||
| 	defaultArch = "amd64" | ||||
| 	defaultOS   = "linux" | ||||
| 	defaultArch         = "amd64" | ||||
| 	defaultOS           = "linux" | ||||
| 	maxManifestBodySize = 4 << 20 | ||||
| ) | ||||
| 
 | ||||
| // manifestDispatcher takes the request context and builds the
 | ||||
|  | @ -259,8 +260,9 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request) | |||
| 	} | ||||
| 
 | ||||
| 	var jsonBuf bytes.Buffer | ||||
| 	if err := copyFullPayload(w, r, &jsonBuf, imh, "image manifest PUT", &imh.Errors); err != nil { | ||||
| 	if err := copyFullPayload(w, r, &jsonBuf, maxManifestBodySize, imh, "image manifest PUT"); err != nil { | ||||
| 		// copyFullPayload reports the error if necessary
 | ||||
| 		imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err.Error())) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ func (bs *blobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	p, err := bs.driver.GetContent(ctx, bp) | ||||
| 	p, err := getContent(ctx, bs.driver, bp) | ||||
| 	if err != nil { | ||||
| 		switch err.(type) { | ||||
| 		case driver.PathNotFoundError: | ||||
|  | @ -37,7 +37,7 @@ func (bs *blobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return p, err | ||||
| 	return p, nil | ||||
| } | ||||
| 
 | ||||
| func (bs *blobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) { | ||||
|  |  | |||
|  | @ -0,0 +1,71 @@ | |||
| package storage | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 
 | ||||
| 	"github.com/docker/distribution/context" | ||||
| 	"github.com/docker/distribution/registry/storage/driver" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	maxBlobGetSize = 4 << 20 | ||||
| ) | ||||
| 
 | ||||
| func getContent(ctx context.Context, driver driver.StorageDriver, p string) ([]byte, error) { | ||||
| 	r, err := driver.Reader(ctx, p, 0) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return readAllLimited(r, maxBlobGetSize) | ||||
| } | ||||
| 
 | ||||
| func readAllLimited(r io.Reader, limit int64) ([]byte, error) { | ||||
| 	r = limitReader(r, limit) | ||||
| 	return ioutil.ReadAll(r) | ||||
| } | ||||
| 
 | ||||
| // limitReader returns a new reader limited to n bytes. Unlike io.LimitReader,
 | ||||
| // this returns an error when the limit reached.
 | ||||
| func limitReader(r io.Reader, n int64) io.Reader { | ||||
| 	return &limitedReader{r: r, n: n} | ||||
| } | ||||
| 
 | ||||
| // limitedReader implements a reader that errors when the limit is reached.
 | ||||
| //
 | ||||
| // Partially cribbed from net/http.MaxBytesReader.
 | ||||
| type limitedReader struct { | ||||
| 	r   io.Reader // underlying reader
 | ||||
| 	n   int64     // max bytes remaining
 | ||||
| 	err error     // sticky error
 | ||||
| } | ||||
| 
 | ||||
| func (l *limitedReader) Read(p []byte) (n int, err error) { | ||||
| 	if l.err != nil { | ||||
| 		return 0, l.err | ||||
| 	} | ||||
| 	if len(p) == 0 { | ||||
| 		return 0, nil | ||||
| 	} | ||||
| 	// If they asked for a 32KB byte read but only 5 bytes are
 | ||||
| 	// remaining, no need to read 32KB. 6 bytes will answer the
 | ||||
| 	// question of the whether we hit the limit or go past it.
 | ||||
| 	if int64(len(p)) > l.n+1 { | ||||
| 		p = p[:l.n+1] | ||||
| 	} | ||||
| 	n, err = l.r.Read(p) | ||||
| 
 | ||||
| 	if int64(n) <= l.n { | ||||
| 		l.n -= int64(n) | ||||
| 		l.err = err | ||||
| 		return n, err | ||||
| 	} | ||||
| 
 | ||||
| 	n = int(l.n) | ||||
| 	l.n = 0 | ||||
| 
 | ||||
| 	l.err = errors.New("storage: read exceeds limit") | ||||
| 	return n, l.err | ||||
| } | ||||
		Loading…
	
		Reference in New Issue