408 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			408 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
package registry
 | 
						|
 | 
						|
import (
 | 
						|
	"bufio"
 | 
						|
	"context"
 | 
						|
	"crypto"
 | 
						|
	"crypto/ecdsa"
 | 
						|
	"crypto/elliptic"
 | 
						|
	"crypto/rand"
 | 
						|
	"crypto/rsa"
 | 
						|
	"crypto/tls"
 | 
						|
	"crypto/x509"
 | 
						|
	"crypto/x509/pkix"
 | 
						|
	"encoding/pem"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"math/big"
 | 
						|
	"net"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"path"
 | 
						|
	"reflect"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/distribution/distribution/v3/configuration"
 | 
						|
	dcontext "github.com/distribution/distribution/v3/context"
 | 
						|
	_ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
 | 
						|
	"github.com/sirupsen/logrus"
 | 
						|
	"gopkg.in/yaml.v2"
 | 
						|
)
 | 
						|
 | 
						|
// Tests to ensure nextProtos returns the correct protocols when:
 | 
						|
// * config.HTTP.HTTP2.Disabled is not explicitly set => [h2 http/1.1]
 | 
						|
// * config.HTTP.HTTP2.Disabled is explicitly set to false [h2 http/1.1]
 | 
						|
// * config.HTTP.HTTP2.Disabled is explicitly set to true [http/1.1]
 | 
						|
func TestNextProtos(t *testing.T) {
 | 
						|
	config := &configuration.Configuration{}
 | 
						|
	protos := nextProtos(config)
 | 
						|
	if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
 | 
						|
		t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
 | 
						|
	}
 | 
						|
	config.HTTP.HTTP2.Disabled = false
 | 
						|
	protos = nextProtos(config)
 | 
						|
	if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
 | 
						|
		t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
 | 
						|
	}
 | 
						|
	config.HTTP.HTTP2.Disabled = true
 | 
						|
	protos = nextProtos(config)
 | 
						|
	if !reflect.DeepEqual(protos, []string{"http/1.1"}) {
 | 
						|
		t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type registryTLSConfig struct {
 | 
						|
	cipherSuites    []string
 | 
						|
	certificatePath string
 | 
						|
	privateKeyPath  string
 | 
						|
	certificate     *tls.Certificate
 | 
						|
}
 | 
						|
 | 
						|
func setupRegistry(tlsCfg *registryTLSConfig, addr string) (*Registry, error) {
 | 
						|
	config := &configuration.Configuration{}
 | 
						|
	// TODO: this needs to change to something ephemeral as the test will fail if there is any server
 | 
						|
	// already listening on port 5000
 | 
						|
	config.HTTP.Addr = addr
 | 
						|
	config.HTTP.DrainTimeout = time.Duration(10) * time.Second
 | 
						|
	if tlsCfg != nil {
 | 
						|
		config.HTTP.TLS.CipherSuites = tlsCfg.cipherSuites
 | 
						|
		config.HTTP.TLS.Certificate = tlsCfg.certificatePath
 | 
						|
		config.HTTP.TLS.Key = tlsCfg.privateKeyPath
 | 
						|
	}
 | 
						|
	config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}}
 | 
						|
	return NewRegistry(context.Background(), config)
 | 
						|
}
 | 
						|
 | 
						|
func TestGracefulShutdown(t *testing.T) {
 | 
						|
	registry, err := setupRegistry(nil, ":5000")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// run registry server
 | 
						|
	var errchan chan error
 | 
						|
	go func() {
 | 
						|
		errchan <- registry.ListenAndServe()
 | 
						|
	}()
 | 
						|
	select {
 | 
						|
	case err = <-errchan:
 | 
						|
		t.Fatalf("Error listening: %v", err)
 | 
						|
	default:
 | 
						|
	}
 | 
						|
 | 
						|
	// Wait for some unknown random time for server to start listening
 | 
						|
	time.Sleep(3 * time.Second)
 | 
						|
 | 
						|
	// send incomplete request
 | 
						|
	conn, err := net.Dial("tcp", "localhost:5000")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	fmt.Fprintf(conn, "GET /v2/ ")
 | 
						|
 | 
						|
	// send stop signal
 | 
						|
	quit <- os.Interrupt
 | 
						|
	time.Sleep(100 * time.Millisecond)
 | 
						|
 | 
						|
	// try connecting again. it shouldn't
 | 
						|
	_, err = net.Dial("tcp", "localhost:5000")
 | 
						|
	if err == nil {
 | 
						|
		t.Fatal("Managed to connect after stopping.")
 | 
						|
	}
 | 
						|
 | 
						|
	// make sure earlier request is not disconnected and response can be received
 | 
						|
	fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
 | 
						|
	resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if resp.Status != "200 OK" {
 | 
						|
		t.Error("response status is not 200 OK: ", resp.Status)
 | 
						|
	}
 | 
						|
	if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
 | 
						|
		t.Error("Body is not {}; ", string(body))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGetCipherSuite(t *testing.T) {
 | 
						|
	resp, err := getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA"})
 | 
						|
	if err != nil || len(resp) != 1 || resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA {
 | 
						|
		t.Errorf("expected cipher suite %q, got %q",
 | 
						|
			"TLS_RSA_WITH_AES_128_CBC_SHA",
 | 
						|
			strings.Join(getCipherSuiteNames(resp), ","),
 | 
						|
		)
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_AES_128_GCM_SHA256"})
 | 
						|
	if err != nil || len(resp) != 2 ||
 | 
						|
		resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA || resp[1] != tls.TLS_AES_128_GCM_SHA256 {
 | 
						|
		t.Errorf("expected cipher suites %q, got %q",
 | 
						|
			"TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_128_GCM_SHA256",
 | 
						|
			strings.Join(getCipherSuiteNames(resp), ","),
 | 
						|
		)
 | 
						|
	}
 | 
						|
 | 
						|
	_, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "bad_input"})
 | 
						|
	if err == nil {
 | 
						|
		t.Error("did not return expected error about unknown cipher suite")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func buildRegistryTLSConfig(name, keyType string, cipherSuites []string) (*registryTLSConfig, error) {
 | 
						|
	var priv interface{}
 | 
						|
	var pub crypto.PublicKey
 | 
						|
	var err error
 | 
						|
	switch keyType {
 | 
						|
	case "rsa":
 | 
						|
		priv, err = rsa.GenerateKey(rand.Reader, 2048)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("failed to create rsa private key: %v", err)
 | 
						|
		}
 | 
						|
		rsaKey := priv.(*rsa.PrivateKey)
 | 
						|
		pub = rsaKey.Public()
 | 
						|
	case "ecdsa":
 | 
						|
		priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("failed to create ecdsa private key: %v", err)
 | 
						|
		}
 | 
						|
		ecdsaKey := priv.(*ecdsa.PrivateKey)
 | 
						|
		pub = ecdsaKey.Public()
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("unsupported key type: %v", keyType)
 | 
						|
	}
 | 
						|
 | 
						|
	notBefore := time.Now()
 | 
						|
	notAfter := notBefore.Add(time.Minute)
 | 
						|
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 | 
						|
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create serial number: %v", err)
 | 
						|
	}
 | 
						|
	cert := x509.Certificate{
 | 
						|
		SerialNumber: serialNumber,
 | 
						|
		Subject: pkix.Name{
 | 
						|
			Organization: []string{"registry_test"},
 | 
						|
		},
 | 
						|
		NotBefore:             notBefore,
 | 
						|
		NotAfter:              notAfter,
 | 
						|
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
 | 
						|
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
 | 
						|
		BasicConstraintsValid: true,
 | 
						|
		IPAddresses:           []net.IP{net.ParseIP("127.0.0.1")},
 | 
						|
		DNSNames:              []string{"localhost"},
 | 
						|
		IsCA:                  true,
 | 
						|
	}
 | 
						|
	derBytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, pub, priv)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create certificate: %v", err)
 | 
						|
	}
 | 
						|
	if _, err := os.Stat(os.TempDir()); os.IsNotExist(err) {
 | 
						|
		os.Mkdir(os.TempDir(), 1777)
 | 
						|
	}
 | 
						|
 | 
						|
	certPath := path.Join(os.TempDir(), name+".pem")
 | 
						|
	certOut, err := os.Create(certPath)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create pem: %v", err)
 | 
						|
	}
 | 
						|
	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to write data to %s: %v", certPath, err)
 | 
						|
	}
 | 
						|
	if err := certOut.Close(); err != nil {
 | 
						|
		return nil, fmt.Errorf("error closing %s: %v", certPath, err)
 | 
						|
	}
 | 
						|
 | 
						|
	keyPath := path.Join(os.TempDir(), name+".key")
 | 
						|
	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
 | 
						|
	}
 | 
						|
	privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("unable to marshal private key: %v", err)
 | 
						|
	}
 | 
						|
	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to write data to key.pem: %v", err)
 | 
						|
	}
 | 
						|
	if err := keyOut.Close(); err != nil {
 | 
						|
		return nil, fmt.Errorf("error closing %s: %v", keyPath, err)
 | 
						|
	}
 | 
						|
 | 
						|
	tlsCert := tls.Certificate{
 | 
						|
		Certificate: [][]byte{derBytes},
 | 
						|
		PrivateKey:  priv,
 | 
						|
	}
 | 
						|
 | 
						|
	tlsTestCfg := registryTLSConfig{
 | 
						|
		cipherSuites:    cipherSuites,
 | 
						|
		certificatePath: certPath,
 | 
						|
		privateKeyPath:  keyPath,
 | 
						|
		certificate:     &tlsCert,
 | 
						|
	}
 | 
						|
 | 
						|
	return &tlsTestCfg, nil
 | 
						|
}
 | 
						|
 | 
						|
func TestRegistrySupportedCipherSuite(t *testing.T) {
 | 
						|
	name := "registry_test_server_supported_cipher"
 | 
						|
	cipherSuites := []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}
 | 
						|
	serverTLS, err := buildRegistryTLSConfig(name, "rsa", cipherSuites)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	registry, err := setupRegistry(serverTLS, ":5001")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// run registry server
 | 
						|
	var errchan chan error
 | 
						|
	go func() {
 | 
						|
		errchan <- registry.ListenAndServe()
 | 
						|
	}()
 | 
						|
	select {
 | 
						|
	case err = <-errchan:
 | 
						|
		t.Fatalf("Error listening: %v", err)
 | 
						|
	default:
 | 
						|
	}
 | 
						|
 | 
						|
	// Wait for some unknown random time for server to start listening
 | 
						|
	time.Sleep(3 * time.Second)
 | 
						|
 | 
						|
	// send tls request with server supported cipher suite
 | 
						|
	clientCipherSuites, err := getCipherSuites(cipherSuites)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	clientTLS := tls.Config{
 | 
						|
		InsecureSkipVerify: true,
 | 
						|
		CipherSuites:       clientCipherSuites,
 | 
						|
	}
 | 
						|
	dialer := net.Dialer{
 | 
						|
		Timeout: time.Second * 5,
 | 
						|
	}
 | 
						|
	conn, err := tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5001", &clientTLS)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	fmt.Fprintf(conn, "GET /v2/ HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
 | 
						|
 | 
						|
	resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if resp.Status != "200 OK" {
 | 
						|
		t.Error("response status is not 200 OK: ", resp.Status)
 | 
						|
	}
 | 
						|
	if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
 | 
						|
		t.Error("Body is not {}; ", string(body))
 | 
						|
	}
 | 
						|
 | 
						|
	// send stop signal
 | 
						|
	quit <- os.Interrupt
 | 
						|
	time.Sleep(100 * time.Millisecond)
 | 
						|
}
 | 
						|
 | 
						|
func TestRegistryUnsupportedCipherSuite(t *testing.T) {
 | 
						|
	name := "registry_test_server_unsupported_cipher"
 | 
						|
	serverTLS, err := buildRegistryTLSConfig(name, "rsa", []string{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA358"})
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	registry, err := setupRegistry(serverTLS, ":5002")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// run registry server
 | 
						|
	var errchan chan error
 | 
						|
	go func() {
 | 
						|
		errchan <- registry.ListenAndServe()
 | 
						|
	}()
 | 
						|
	select {
 | 
						|
	case err = <-errchan:
 | 
						|
		t.Fatalf("Error listening: %v", err)
 | 
						|
	default:
 | 
						|
	}
 | 
						|
 | 
						|
	// Wait for some unknown random time for server to start listening
 | 
						|
	time.Sleep(3 * time.Second)
 | 
						|
 | 
						|
	// send tls request with server unsupported cipher suite
 | 
						|
	clientTLS := tls.Config{
 | 
						|
		InsecureSkipVerify: true,
 | 
						|
		CipherSuites:       []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
 | 
						|
	}
 | 
						|
	dialer := net.Dialer{
 | 
						|
		Timeout: time.Second * 5,
 | 
						|
	}
 | 
						|
	_, err = tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5002", &clientTLS)
 | 
						|
	if err == nil {
 | 
						|
		t.Error("expected TLS connection to timeout")
 | 
						|
	}
 | 
						|
 | 
						|
	// send stop signal
 | 
						|
	quit <- os.Interrupt
 | 
						|
	time.Sleep(100 * time.Millisecond)
 | 
						|
}
 | 
						|
 | 
						|
func TestConfigureLogging(t *testing.T) {
 | 
						|
	yamlConfig := `---
 | 
						|
log:
 | 
						|
  level: warn
 | 
						|
  fields:
 | 
						|
    foo: bar
 | 
						|
    baz: xyzzy
 | 
						|
`
 | 
						|
 | 
						|
	var config configuration.Configuration
 | 
						|
	err := yaml.Unmarshal([]byte(yamlConfig), &config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("failed to parse config: ", err)
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, err := configureLogging(context.Background(), &config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("failed to configure logging: ", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Check that the log level was set to Warn.
 | 
						|
	if logrus.IsLevelEnabled(logrus.InfoLevel) {
 | 
						|
		t.Error("expected Info to be disabled, is enabled")
 | 
						|
	}
 | 
						|
 | 
						|
	// Check that the returned context's logger includes the right fields.
 | 
						|
	logger := dcontext.GetLogger(ctx)
 | 
						|
	entry, ok := logger.(*logrus.Entry)
 | 
						|
	if !ok {
 | 
						|
		t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry)
 | 
						|
	}
 | 
						|
	val, ok := entry.Data["foo"].(string)
 | 
						|
	if !ok || val != "bar" {
 | 
						|
		t.Error("field foo not configured correctly; expected 'bar' got: ", val)
 | 
						|
	}
 | 
						|
	val, ok = entry.Data["baz"].(string)
 | 
						|
	if !ok || val != "xyzzy" {
 | 
						|
		t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val)
 | 
						|
	}
 | 
						|
 | 
						|
	// Get a logger for a new, empty context and make sure it also has the right fields.
 | 
						|
	logger = dcontext.GetLogger(context.Background())
 | 
						|
	entry, ok = logger.(*logrus.Entry)
 | 
						|
	if !ok {
 | 
						|
		t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry)
 | 
						|
	}
 | 
						|
	val, ok = entry.Data["foo"].(string)
 | 
						|
	if !ok || val != "bar" {
 | 
						|
		t.Error("field foo not configured correctly; expected 'bar' got: ", val)
 | 
						|
	}
 | 
						|
	val, ok = entry.Data["baz"].(string)
 | 
						|
	if !ok || val != "xyzzy" {
 | 
						|
		t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val)
 | 
						|
	}
 | 
						|
}
 |