176 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			176 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
package libtrust
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/tls"
 | 
						|
	"crypto/x509"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"net"
 | 
						|
	"os"
 | 
						|
	"path"
 | 
						|
	"sync"
 | 
						|
)
 | 
						|
 | 
						|
// ClientKeyManager manages client keys on the filesystem
 | 
						|
type ClientKeyManager struct {
 | 
						|
	key        PrivateKey
 | 
						|
	clientFile string
 | 
						|
	clientDir  string
 | 
						|
 | 
						|
	clientLock sync.RWMutex
 | 
						|
	clients    []PublicKey
 | 
						|
 | 
						|
	configLock sync.Mutex
 | 
						|
	configs    []*tls.Config
 | 
						|
}
 | 
						|
 | 
						|
// NewClientKeyManager loads a new manager from a set of key files
 | 
						|
// and managed by the given private key.
 | 
						|
func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
 | 
						|
	m := &ClientKeyManager{
 | 
						|
		key:        trustKey,
 | 
						|
		clientFile: clientFile,
 | 
						|
		clientDir:  clientDir,
 | 
						|
	}
 | 
						|
	if err := m.loadKeys(); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	// TODO Start watching file and directory
 | 
						|
 | 
						|
	return m, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *ClientKeyManager) loadKeys() (err error) {
 | 
						|
	// Load authorized keys file
 | 
						|
	var clients []PublicKey
 | 
						|
	if c.clientFile != "" {
 | 
						|
		clients, err = LoadKeySetFile(c.clientFile)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("unable to load authorized keys: %s", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Add clients from authorized keys directory
 | 
						|
	files, err := ioutil.ReadDir(c.clientDir)
 | 
						|
	if err != nil && !os.IsNotExist(err) {
 | 
						|
		return fmt.Errorf("unable to open authorized keys directory: %s", err)
 | 
						|
	}
 | 
						|
	for _, f := range files {
 | 
						|
		if !f.IsDir() {
 | 
						|
			publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
 | 
						|
			if err != nil {
 | 
						|
				return fmt.Errorf("unable to load authorized key file: %s", err)
 | 
						|
			}
 | 
						|
			clients = append(clients, publicKey)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	c.clientLock.Lock()
 | 
						|
	c.clients = clients
 | 
						|
	c.clientLock.Unlock()
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// RegisterTLSConfig registers a tls configuration to manager
 | 
						|
// such that any changes to the keys may be reflected in
 | 
						|
// the tls client CA pool
 | 
						|
func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
 | 
						|
	c.clientLock.RLock()
 | 
						|
	certPool, err := GenerateCACertPool(c.key, c.clients)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("CA pool generation error: %s", err)
 | 
						|
	}
 | 
						|
	c.clientLock.RUnlock()
 | 
						|
 | 
						|
	tlsConfig.ClientCAs = certPool
 | 
						|
 | 
						|
	c.configLock.Lock()
 | 
						|
	c.configs = append(c.configs, tlsConfig)
 | 
						|
	c.configLock.Unlock()
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
 | 
						|
// libtrust identity authentication for the domain specified
 | 
						|
func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) {
 | 
						|
	tlsConfig := newTLSConfig()
 | 
						|
 | 
						|
	tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
 | 
						|
	if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Generate cert
 | 
						|
	ips, domains, err := parseAddr(addr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	// add domain that it expects clients to use
 | 
						|
	domains = append(domains, domain)
 | 
						|
	x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("certificate generation error: %s", err)
 | 
						|
	}
 | 
						|
	tlsConfig.Certificates = []tls.Certificate{{
 | 
						|
		Certificate: [][]byte{x509Cert.Raw},
 | 
						|
		PrivateKey:  trustKey.CryptoPrivateKey(),
 | 
						|
		Leaf:        x509Cert,
 | 
						|
	}}
 | 
						|
 | 
						|
	return tlsConfig, nil
 | 
						|
}
 | 
						|
 | 
						|
// NewCertAuthTLSConfig creates a tls.Config for the server to use for
 | 
						|
// certificate authentication
 | 
						|
func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
 | 
						|
	tlsConfig := newTLSConfig()
 | 
						|
 | 
						|
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
 | 
						|
	}
 | 
						|
	tlsConfig.Certificates = []tls.Certificate{cert}
 | 
						|
 | 
						|
	// Verify client certificates against a CA?
 | 
						|
	if caPath != "" {
 | 
						|
		certPool := x509.NewCertPool()
 | 
						|
		file, err := ioutil.ReadFile(caPath)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
 | 
						|
		}
 | 
						|
		certPool.AppendCertsFromPEM(file)
 | 
						|
 | 
						|
		tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
 | 
						|
		tlsConfig.ClientCAs = certPool
 | 
						|
	}
 | 
						|
 | 
						|
	return tlsConfig, nil
 | 
						|
}
 | 
						|
 | 
						|
func newTLSConfig() *tls.Config {
 | 
						|
	return &tls.Config{
 | 
						|
		NextProtos: []string{"http/1.1"},
 | 
						|
		// Avoid fallback on insecure SSL protocols
 | 
						|
		MinVersion: tls.VersionTLS10,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// parseAddr parses an address into an array of IPs and domains
 | 
						|
func parseAddr(addr string) ([]net.IP, []string, error) {
 | 
						|
	host, _, err := net.SplitHostPort(addr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	var domains []string
 | 
						|
	var ips []net.IP
 | 
						|
	ip := net.ParseIP(host)
 | 
						|
	if ip != nil {
 | 
						|
		ips = []net.IP{ip}
 | 
						|
	} else {
 | 
						|
		domains = []string{host}
 | 
						|
	}
 | 
						|
	return ips, domains, nil
 | 
						|
}
 |