Merge pull request #2447 from tifayuki/cloudfront-s3-filter
add s3 region filters for cloudfrontmaster
						commit
						f411848591
					
				|  | @ -35,3 +35,4 @@ bin/* | ||||||
| # Editor/IDE specific files. | # Editor/IDE specific files. | ||||||
| *.sublime-project | *.sublime-project | ||||||
| *.sublime-workspace | *.sublime-workspace | ||||||
|  | .idea/* | ||||||
|  |  | ||||||
|  | @ -39,6 +39,8 @@ type Logger interface { | ||||||
| 	Warn(args ...interface{}) | 	Warn(args ...interface{}) | ||||||
| 	Warnf(format string, args ...interface{}) | 	Warnf(format string, args ...interface{}) | ||||||
| 	Warnln(args ...interface{}) | 	Warnln(args ...interface{}) | ||||||
|  | 
 | ||||||
|  | 	WithError(err error) *logrus.Entry | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type loggerKey struct{} | type loggerKey struct{} | ||||||
|  |  | ||||||
|  | @ -183,6 +183,10 @@ middleware: | ||||||
|         privatekey: /path/to/pem |         privatekey: /path/to/pem | ||||||
|         keypairid: cloudfrontkeypairid |         keypairid: cloudfrontkeypairid | ||||||
|         duration: 3000s |         duration: 3000s | ||||||
|  |         ipfilteredby: awsregion | ||||||
|  |         awsregion: us-east-1, use-east-2 | ||||||
|  |         updatefrenquency: 12h | ||||||
|  |         iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json | ||||||
|   storage: |   storage: | ||||||
|     - name: redirect |     - name: redirect | ||||||
|       options: |       options: | ||||||
|  | @ -636,6 +640,10 @@ middleware: | ||||||
|         privatekey: /path/to/pem |         privatekey: /path/to/pem | ||||||
|         keypairid: cloudfrontkeypairid |         keypairid: cloudfrontkeypairid | ||||||
|         duration: 3000s |         duration: 3000s | ||||||
|  |         ipfilteredby: awsregion | ||||||
|  |         awsregion: us-east-1, use-east-2 | ||||||
|  |         updatefrenquency: 12h | ||||||
|  |         iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| Each middleware entry has `name` and `options` entries. The `name` must | Each middleware entry has `name` and `options` entries. The `name` must | ||||||
|  | @ -655,6 +663,14 @@ interpretation of the options. | ||||||
| | `privatekey` | yes   | The private key for Cloudfront, provided by AWS.        | | | `privatekey` | yes   | The private key for Cloudfront, provided by AWS.        | | ||||||
| | `keypairid` | yes    | The key pair ID provided by AWS.                         | | | `keypairid` | yes    | The key pair ID provided by AWS.                         | | ||||||
| | `duration` | no      | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).| | | `duration` | no      | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).| | ||||||
|  | |`ipfilteredby`|no     | A string with the following value `none|aws|awsregion`. | | ||||||
|  | |`awsregion`|no        | A comma separated string of AWS regions, only available when `ipfilteredby` is `awsregion`. For example, `us-east-1, us-west-2`| | ||||||
|  | |`updatefrenquency`|no | The frequency to update AWS IP regions, default: `12h`| | ||||||
|  | |`iprangesurl`|no      | The URL contains the AWS IP ranges information, default: `https://ip-ranges.amazonaws.com/ip-ranges.json`| | ||||||
|  | Then value of ipfilteredby: | ||||||
|  | `none`: default, do not filter by IP | ||||||
|  | `aws`: IP from AWS goes to S3 directly | ||||||
|  | `awsregion`: IP from certain AWS regions goes to S3 directly, use together with `awsregion` | ||||||
| 
 | 
 | ||||||
| ### `redirect` | ### `redirect` | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -16,7 +16,7 @@ import ( | ||||||
| 	"github.com/aws/aws-sdk-go/service/cloudfront/sign" | 	"github.com/aws/aws-sdk-go/service/cloudfront/sign" | ||||||
| 	dcontext "github.com/docker/distribution/context" | 	dcontext "github.com/docker/distribution/context" | ||||||
| 	storagedriver "github.com/docker/distribution/registry/storage/driver" | 	storagedriver "github.com/docker/distribution/registry/storage/driver" | ||||||
| 	storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware" | 	"github.com/docker/distribution/registry/storage/driver/middleware" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
 | // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
 | ||||||
|  | @ -24,6 +24,7 @@ import ( | ||||||
| // then issues HTTP Temporary Redirects to this CloudFront content URL.
 | // then issues HTTP Temporary Redirects to this CloudFront content URL.
 | ||||||
| type cloudFrontStorageMiddleware struct { | type cloudFrontStorageMiddleware struct { | ||||||
| 	storagedriver.StorageDriver | 	storagedriver.StorageDriver | ||||||
|  | 	awsIPs    *awsIPs | ||||||
| 	urlSigner *sign.URLSigner | 	urlSigner *sign.URLSigner | ||||||
| 	baseURL   string | 	baseURL   string | ||||||
| 	duration  time.Duration | 	duration  time.Duration | ||||||
|  | @ -34,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{} | ||||||
| // newCloudFrontLayerHandler constructs and returns a new CloudFront
 | // newCloudFrontLayerHandler constructs and returns a new CloudFront
 | ||||||
| // LayerHandler implementation.
 | // LayerHandler implementation.
 | ||||||
| // Required options: baseurl, privatekey, keypairid
 | // Required options: baseurl, privatekey, keypairid
 | ||||||
|  | 
 | ||||||
|  | // Optional options: ipFilteredBy, awsregion
 | ||||||
|  | // ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
 | ||||||
|  | //               to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
 | ||||||
|  | // awsregion: a comma separated string of AWS regions.
 | ||||||
| func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { | func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { | ||||||
|  | 	// parse baseurl
 | ||||||
| 	base, ok := options["baseurl"] | 	base, ok := options["baseurl"] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("no baseurl provided") | 		return nil, fmt.Errorf("no baseurl provided") | ||||||
|  | @ -52,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | ||||||
| 	if _, err := url.Parse(baseURL); err != nil { | 	if _, err := url.Parse(baseURL); err != nil { | ||||||
| 		return nil, fmt.Errorf("invalid baseurl: %v", err) | 		return nil, fmt.Errorf("invalid baseurl: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	// parse privatekey to get pkPath
 | ||||||
| 	pk, ok := options["privatekey"] | 	pk, ok := options["privatekey"] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("no privatekey provided") | 		return nil, fmt.Errorf("no privatekey provided") | ||||||
|  | @ -60,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("privatekey must be a string") | 		return nil, fmt.Errorf("privatekey must be a string") | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	// parse keypairid
 | ||||||
| 	kpid, ok := options["keypairid"] | 	kpid, ok := options["keypairid"] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("no keypairid provided") | 		return nil, fmt.Errorf("no keypairid provided") | ||||||
|  | @ -69,6 +80,7 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | ||||||
| 		return nil, fmt.Errorf("keypairid must be a string") | 		return nil, fmt.Errorf("keypairid must be a string") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// get urlSigner from the file specified in pkPath
 | ||||||
| 	pkBytes, err := ioutil.ReadFile(pkPath) | 	pkBytes, err := ioutil.ReadFile(pkPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to read privatekey file: %s", err) | 		return nil, fmt.Errorf("failed to read privatekey file: %s", err) | ||||||
|  | @ -82,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	urlSigner := sign.NewURLSigner(keypairID, privateKey) | 	urlSigner := sign.NewURLSigner(keypairID, privateKey) | ||||||
| 
 | 
 | ||||||
|  | 	// parse duration
 | ||||||
| 	duration := 20 * time.Minute | 	duration := 20 * time.Minute | ||||||
| 	d, ok := options["duration"] | 	if d, ok := options["duration"]; ok { | ||||||
| 	if ok { |  | ||||||
| 		switch d := d.(type) { | 		switch d := d.(type) { | ||||||
| 		case time.Duration: | 		case time.Duration: | ||||||
| 			duration = d | 			duration = d | ||||||
|  | @ -100,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// parse updatefrenquency
 | ||||||
|  | 	updateFrequency := defaultUpdateFrequency | ||||||
|  | 	if u, ok := options["updatefrenquency"]; ok { | ||||||
|  | 		switch u := u.(type) { | ||||||
|  | 		case time.Duration: | ||||||
|  | 			updateFrequency = u | ||||||
|  | 		case string: | ||||||
|  | 			updateFreq, err := time.ParseDuration(u) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return nil, fmt.Errorf("invalid updatefrenquency: %s", err) | ||||||
|  | 			} | ||||||
|  | 			duration = updateFreq | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// parse iprangesurl
 | ||||||
|  | 	ipRangesURL := defaultIPRangesURL | ||||||
|  | 	if i, ok := options["iprangesurl"]; ok { | ||||||
|  | 		if iprangeurl, ok := i.(string); ok { | ||||||
|  | 			ipRangesURL = iprangeurl | ||||||
|  | 		} else { | ||||||
|  | 			return nil, fmt.Errorf("iprangesurl must be a string") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// parse ipfilteredby
 | ||||||
|  | 	var awsIPs *awsIPs | ||||||
|  | 	if ipFilteredBy := options["ipfilteredby"].(string); ok { | ||||||
|  | 		switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) { | ||||||
|  | 		case "", "none": | ||||||
|  | 			awsIPs = nil | ||||||
|  | 		case "aws": | ||||||
|  | 			newAWSIPs(ipRangesURL, updateFrequency, nil) | ||||||
|  | 		case "awsregion": | ||||||
|  | 			var awsRegion []string | ||||||
|  | 			if regions, ok := options["awsregion"].(string); ok { | ||||||
|  | 				for _, awsRegions := range strings.Split(regions, ",") { | ||||||
|  | 					awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) | ||||||
|  | 				} | ||||||
|  | 				awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) | ||||||
|  | 			} else { | ||||||
|  | 				return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion") | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return &cloudFrontStorageMiddleware{ | 	return &cloudFrontStorageMiddleware{ | ||||||
| 		StorageDriver: storageDriver, | 		StorageDriver: storageDriver, | ||||||
| 		urlSigner:     urlSigner, | 		urlSigner:     urlSigner, | ||||||
| 		baseURL:       baseURL, | 		baseURL:       baseURL, | ||||||
| 		duration:      duration, | 		duration:      duration, | ||||||
|  | 		awsIPs:        awsIPs, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -114,8 +176,8 @@ type S3BucketKeyer interface { | ||||||
| 	S3BucketKey(path string) string | 	S3BucketKey(path string) string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Resolve returns an http.Handler which can serve the contents of the given
 | // URLFor attempts to find a url which may be used to retrieve the file at the given path.
 | ||||||
| // Layer, or an error if not supported by the storagedriver.
 | // Returns an error if the file cannot be found.
 | ||||||
| func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { | func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { | ||||||
| 	// TODO(endophage): currently only supports S3
 | 	// TODO(endophage): currently only supports S3
 | ||||||
| 	keyer, ok := lh.StorageDriver.(S3BucketKeyer) | 	keyer, ok := lh.StorageDriver.(S3BucketKeyer) | ||||||
|  | @ -124,6 +186,11 @@ func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, | ||||||
| 		return lh.StorageDriver.URLFor(ctx, path, options) | 		return lh.StorageDriver.URLFor(ctx, path, options) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if eligibleForS3(ctx, lh.awsIPs) { | ||||||
|  | 		return lh.StorageDriver.URLFor(ctx, path, options) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get signed cloudfront url.
 | ||||||
| 	cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration)) | 	cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -0,0 +1,223 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	dcontext "github.com/docker/distribution/context" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	// ipRangesURL is the URL to get definition of AWS IPs
 | ||||||
|  | 	defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json" | ||||||
|  | 	// updateFrequency tells how frequently AWS IPs need to be updated
 | ||||||
|  | 	defaultUpdateFrequency = time.Hour * 12 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // newAWSIPs returns a New awsIP object.
 | ||||||
|  | // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
 | ||||||
|  | func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { | ||||||
|  | 	ips := &awsIPs{ | ||||||
|  | 		host:            host, | ||||||
|  | 		updateFrequency: updateFrequency, | ||||||
|  | 		awsRegion:       awsRegion, | ||||||
|  | 		updaterStopChan: make(chan bool), | ||||||
|  | 	} | ||||||
|  | 	if err := ips.tryUpdate(); err != nil { | ||||||
|  | 		dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") | ||||||
|  | 	} | ||||||
|  | 	go ips.updater() | ||||||
|  | 	return ips | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // awsIPs tracks a list of AWS ips, filtered by awsRegion
 | ||||||
|  | type awsIPs struct { | ||||||
|  | 	host            string | ||||||
|  | 	updateFrequency time.Duration | ||||||
|  | 	ipv4            []net.IPNet | ||||||
|  | 	ipv6            []net.IPNet | ||||||
|  | 	mutex           sync.RWMutex | ||||||
|  | 	awsRegion       []string | ||||||
|  | 	updaterStopChan chan bool | ||||||
|  | 	initialized     bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type awsIPResponse struct { | ||||||
|  | 	Prefixes   []prefixEntry `json:"prefixes"` | ||||||
|  | 	V6Prefixes []prefixEntry `json:"ipv6_prefixes"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type prefixEntry struct { | ||||||
|  | 	IPV4Prefix string `json:"ip_prefix"` | ||||||
|  | 	IPV6Prefix string `json:"ipv6_prefix"` | ||||||
|  | 	Region     string `json:"region"` | ||||||
|  | 	Service    string `json:"service"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func fetchAWSIPs(url string) (awsIPResponse, error) { | ||||||
|  | 	var response awsIPResponse | ||||||
|  | 	resp, err := http.Get(url) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return response, err | ||||||
|  | 	} | ||||||
|  | 	if resp.StatusCode != 200 { | ||||||
|  | 		body, _ := ioutil.ReadAll(resp.Body) | ||||||
|  | 		return response, fmt.Errorf("failed to fetch network data. response = %s", body) | ||||||
|  | 	} | ||||||
|  | 	decoder := json.NewDecoder(resp.Body) | ||||||
|  | 	err = decoder.Decode(&response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return response, err | ||||||
|  | 	} | ||||||
|  | 	return response, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // tryUpdate attempts to download the new set of ip addresses.
 | ||||||
|  | // tryUpdate must be thread safe with contains
 | ||||||
|  | func (s *awsIPs) tryUpdate() error { | ||||||
|  | 	response, err := fetchAWSIPs(s.host) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var ipv4 []net.IPNet | ||||||
|  | 	var ipv6 []net.IPNet | ||||||
|  | 
 | ||||||
|  | 	processAddress := func(output *[]net.IPNet, prefix string, region string) { | ||||||
|  | 		regionAllowed := false | ||||||
|  | 		if len(s.awsRegion) > 0 { | ||||||
|  | 			for _, ar := range s.awsRegion { | ||||||
|  | 				if strings.ToLower(region) == ar { | ||||||
|  | 					regionAllowed = true | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			regionAllowed = true | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		_, network, err := net.ParseCIDR(prefix) | ||||||
|  | 		if err != nil { | ||||||
|  | 			dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ | ||||||
|  | 				"cidr": prefix, | ||||||
|  | 			}).Error("unparseable cidr") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		if regionAllowed { | ||||||
|  | 			*output = append(*output, *network) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, prefix := range response.Prefixes { | ||||||
|  | 		processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region) | ||||||
|  | 	} | ||||||
|  | 	for _, prefix := range response.V6Prefixes { | ||||||
|  | 		processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region) | ||||||
|  | 	} | ||||||
|  | 	s.mutex.Lock() | ||||||
|  | 	defer s.mutex.Unlock() | ||||||
|  | 	// Update each attr of awsips atomically.
 | ||||||
|  | 	s.ipv4 = ipv4 | ||||||
|  | 	s.ipv6 = ipv6 | ||||||
|  | 	s.initialized = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // This function is meant to be run in a background goroutine.
 | ||||||
|  | // It will periodically update the ips from aws.
 | ||||||
|  | func (s *awsIPs) updater() { | ||||||
|  | 	defer close(s.updaterStopChan) | ||||||
|  | 	for { | ||||||
|  | 		time.Sleep(s.updateFrequency) | ||||||
|  | 		select { | ||||||
|  | 		case <-s.updaterStopChan: | ||||||
|  | 			dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") | ||||||
|  | 			return | ||||||
|  | 		default: | ||||||
|  | 			err := s.tryUpdate() | ||||||
|  | 			if err != nil { | ||||||
|  | 				dcontext.GetLogger(context.Background()).WithError(err).Error("git  AWS IP") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getCandidateNetworks returns either the ipv4 or ipv6 networks
 | ||||||
|  | // that were last read from aws. The networks returned
 | ||||||
|  | // have the same type as the ip address provided.
 | ||||||
|  | func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet { | ||||||
|  | 	s.mutex.RLock() | ||||||
|  | 	defer s.mutex.RUnlock() | ||||||
|  | 	if ip.To4() != nil { | ||||||
|  | 		return s.ipv4 | ||||||
|  | 	} else if ip.To16() != nil { | ||||||
|  | 		return s.ipv6 | ||||||
|  | 	} else { | ||||||
|  | 		dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ | ||||||
|  | 			"ip": ip, | ||||||
|  | 		}).Error("unknown ip address format") | ||||||
|  | 		// assume mismatch, pass through cloudfront
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Contains determines whether the host is within aws.
 | ||||||
|  | func (s *awsIPs) contains(ip net.IP) bool { | ||||||
|  | 	networks := s.getCandidateNetworks(ip) | ||||||
|  | 	for _, network := range networks { | ||||||
|  | 		if network.Contains(ip) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // parseIPFromRequest attempts to extract the ip address of the
 | ||||||
|  | // client that made the request
 | ||||||
|  | func parseIPFromRequest(ctx context.Context) (net.IP, error) { | ||||||
|  | 	request, err := dcontext.GetRequest(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	ipStr := dcontext.RemoteIP(request) | ||||||
|  | 	ip := net.ParseIP(ipStr) | ||||||
|  | 	if ip == nil { | ||||||
|  | 		return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return ip, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // eligibleForS3 checks if a request is eligible for using S3 directly
 | ||||||
|  | // Return true only when the IP belongs to a specific aws region and user-agent is docker
 | ||||||
|  | func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { | ||||||
|  | 	if awsIPs != nil && awsIPs.initialized { | ||||||
|  | 		if addr, err := parseIPFromRequest(ctx); err == nil { | ||||||
|  | 			request, err := dcontext.GetRequest(ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) | ||||||
|  | 			} else { | ||||||
|  | 				loggerField := map[interface{}]interface{}{ | ||||||
|  | 					"user-client": request.UserAgent(), | ||||||
|  | 					"ip":          dcontext.RemoteIP(request), | ||||||
|  | 				} | ||||||
|  | 				if awsIPs.contains(addr) { | ||||||
|  | 					dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") | ||||||
|  | 					return true | ||||||
|  | 				} | ||||||
|  | 				dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | @ -0,0 +1,401 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	dcontext "github.com/docker/distribution/context" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"reflect" // used as a replacement for testify
 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Rather than pull in all of testify
 | ||||||
|  | func assertEqual(t *testing.T, x, y interface{}) { | ||||||
|  | 	if !reflect.DeepEqual(x, y) { | ||||||
|  | 		t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type mockIPRangeHandler struct { | ||||||
|  | 	data awsIPResponse | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	bytes, err := json.Marshal(m.data) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.WriteHeader(500) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	w.Write(bytes) | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newTestHandler(data awsIPResponse) *httptest.Server { | ||||||
|  | 	return httptest.NewServer(mockIPRangeHandler{ | ||||||
|  | 		data: data, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func serverIPRanges(server *httptest.Server) string { | ||||||
|  | 	return fmt.Sprintf("%s/", server.URL) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func setupTest(data awsIPResponse) *httptest.Server { | ||||||
|  | 	// This is a basic schema which only claims the exact ip
 | ||||||
|  | 	// is in aws.
 | ||||||
|  | 	server := newTestHandler(data) | ||||||
|  | 	return server | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestS3TryUpdate(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV4Prefix: "123.231.123.231/32"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 
 | ||||||
|  | 	assertEqual(t, 1, len(ips.ipv4)) | ||||||
|  | 	assertEqual(t, 0, len(ips.ipv6)) | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV6(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		V6Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV6Prefix: "ff00::/16"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) | ||||||
|  | 	assertEqual(t, 1, len(ips.ipv6)) | ||||||
|  | 	assertEqual(t, 0, len(ips.ipv4)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV4(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV4Prefix: "192.168.0.0/24"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV4_2(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{ | ||||||
|  | 				IPV4Prefix: "192.168.0.0/24", | ||||||
|  | 				Region:     "us-east-1", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV4WithRegionMatched(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{ | ||||||
|  | 				IPV4Prefix: "192.168.0.0/24", | ||||||
|  | 				Region:     "us-east-1", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV4WithRegionMatch_2(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{ | ||||||
|  | 				IPV4Prefix: "192.168.0.0/24", | ||||||
|  | 				Region:     "us-east-1", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||||
|  | 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMatchIPV4WithRegionNotMatched(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{ | ||||||
|  | 				IPV4Prefix: "192.168.0.0/24", | ||||||
|  | 				Region:     "us-east-1", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) | ||||||
|  | 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestInvalidData(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	// Invalid entries from aws should be ignored.
 | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV4Prefix: "9000"}, | ||||||
|  | 			{IPV4Prefix: "192.168.0.0/24"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 	ips.tryUpdate() | ||||||
|  | 	assertEqual(t, 1, len(ips.ipv4)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestInvalidNetworkType(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := setupTest(awsIPResponse{ | ||||||
|  | 		Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV4Prefix: "192.168.0.0/24"}, | ||||||
|  | 		}, | ||||||
|  | 		V6Prefixes: []prefixEntry{ | ||||||
|  | 			{IPV6Prefix: "ff00::/8"}, | ||||||
|  | 			{IPV6Prefix: "fe00::/8"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||||
|  | 	assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
 | ||||||
|  | 	assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4))))  // netv4 networks
 | ||||||
|  | 	assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestParsing(t *testing.T) { | ||||||
|  | 	var data = `{ | ||||||
|  |       "prefixes": [{ | ||||||
|  |         "ip_prefix": "192.168.0.0", | ||||||
|  |         "region": "someregion", | ||||||
|  |         "service": "s3"}], | ||||||
|  |       "ipv6_prefixes": [{ | ||||||
|  |         "ipv6_prefix": "2001:4860:4860::8888", | ||||||
|  |         "region": "anotherregion", | ||||||
|  |         "service": "ec2"}] | ||||||
|  |     }` | ||||||
|  | 	rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) }) | ||||||
|  | 	t.Parallel() | ||||||
|  | 	server := httptest.NewServer(rawMockHandler) | ||||||
|  | 	defer server.Close() | ||||||
|  | 	schema, err := fetchAWSIPs(server.URL) | ||||||
|  | 
 | ||||||
|  | 	assertEqual(t, nil, err) | ||||||
|  | 	assertEqual(t, 1, len(schema.Prefixes)) | ||||||
|  | 	assertEqual(t, prefixEntry{ | ||||||
|  | 		IPV4Prefix: "192.168.0.0", | ||||||
|  | 		Region:     "someregion", | ||||||
|  | 		Service:    "s3", | ||||||
|  | 	}, schema.Prefixes[0]) | ||||||
|  | 	assertEqual(t, 1, len(schema.V6Prefixes)) | ||||||
|  | 	assertEqual(t, prefixEntry{ | ||||||
|  | 		IPV6Prefix: "2001:4860:4860::8888", | ||||||
|  | 		Region:     "anotherregion", | ||||||
|  | 		Service:    "ec2", | ||||||
|  | 	}, schema.V6Prefixes[0]) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestUpdateCalledRegularly(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  | 
 | ||||||
|  | 	updateCount := 0 | ||||||
|  | 	server := httptest.NewServer(http.HandlerFunc( | ||||||
|  | 		func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 			updateCount++ | ||||||
|  | 			rw.Write([]byte("ok")) | ||||||
|  | 		})) | ||||||
|  | 	defer server.Close() | ||||||
|  | 	newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) | ||||||
|  | 	time.Sleep(time.Second*4 + time.Millisecond*500) | ||||||
|  | 	if updateCount < 4 { | ||||||
|  | 		t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestEligibleForS3(t *testing.T) { | ||||||
|  | 	awsIPs := &awsIPs{ | ||||||
|  | 		ipv4: []net.IPNet{{ | ||||||
|  | 			IP:   net.ParseIP("192.168.1.1"), | ||||||
|  | 			Mask: net.IPv4Mask(255, 255, 255, 0), | ||||||
|  | 		}}, | ||||||
|  | 		initialized: true, | ||||||
|  | 	} | ||||||
|  | 	empty := context.TODO() | ||||||
|  | 	makeContext := func(ip string) context.Context { | ||||||
|  | 		req := &http.Request{ | ||||||
|  | 			RemoteAddr: ip, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return dcontext.WithRequest(empty, req) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	cases := []struct { | ||||||
|  | 		Context  context.Context | ||||||
|  | 		Expected bool | ||||||
|  | 	}{ | ||||||
|  | 		{Context: empty, Expected: false}, | ||||||
|  | 		{Context: makeContext("192.168.1.2"), Expected: true}, | ||||||
|  | 		{Context: makeContext("192.168.0.2"), Expected: false}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, testCase := range cases { | ||||||
|  | 		name := fmt.Sprintf("Client IP = %v", | ||||||
|  | 			testCase.Context.Value("http.request.ip")) | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { | ||||||
|  | 	awsIPs := &awsIPs{ | ||||||
|  | 		ipv4: []net.IPNet{{ | ||||||
|  | 			IP:   net.ParseIP("192.168.1.1"), | ||||||
|  | 			Mask: net.IPv4Mask(255, 255, 255, 0), | ||||||
|  | 		}}, | ||||||
|  | 		initialized: false, | ||||||
|  | 	} | ||||||
|  | 	empty := context.TODO() | ||||||
|  | 	makeContext := func(ip string) context.Context { | ||||||
|  | 		req := &http.Request{ | ||||||
|  | 			RemoteAddr: ip, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return dcontext.WithRequest(empty, req) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	cases := []struct { | ||||||
|  | 		Context  context.Context | ||||||
|  | 		Expected bool | ||||||
|  | 	}{ | ||||||
|  | 		{Context: empty, Expected: false}, | ||||||
|  | 		{Context: makeContext("192.168.1.2"), Expected: false}, | ||||||
|  | 		{Context: makeContext("192.168.0.2"), Expected: false}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, testCase := range cases { | ||||||
|  | 		name := fmt.Sprintf("Client IP = %v", | ||||||
|  | 			testCase.Context.Value("http.request.ip")) | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // populate ips with a number of different ipv4 and ipv6 networks, for the purposes
 | ||||||
|  | // of benchmarking contains() performance.
 | ||||||
|  | func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) { | ||||||
|  | 	generateNetworks := func(dest *[]net.IPNet, bytes int, count int) { | ||||||
|  | 		for i := 0; i < count; i++ { | ||||||
|  | 			ip := make([]byte, bytes) | ||||||
|  | 			_, err := rand.Read(ip) | ||||||
|  | 			if err != nil { | ||||||
|  | 				b.Fatalf("failed to generate network for test : %s", err.Error()) | ||||||
|  | 			} | ||||||
|  | 			mask := make([]byte, bytes) | ||||||
|  | 			for i := 0; i < bytes; i++ { | ||||||
|  | 				mask[i] = 0xff | ||||||
|  | 			} | ||||||
|  | 			*dest = append(*dest, net.IPNet{ | ||||||
|  | 				IP:   ip, | ||||||
|  | 				Mask: mask, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	generateNetworks(&ips.ipv4, 4, ipv4Count) | ||||||
|  | 	generateNetworks(&ips.ipv6, 16, ipv6Count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func BenchmarkContainsRandom(b *testing.B) { | ||||||
|  | 	// Generate a random network configuration, of size comparable to
 | ||||||
|  | 	// aws official networks list
 | ||||||
|  | 	// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
 | ||||||
|  | 	// 941
 | ||||||
|  | 	numNetworksPerType := 1000 // keep in sync with the above
 | ||||||
|  | 	// intentionally skip constructor when creating awsIPs, to avoid updater routine.
 | ||||||
|  | 	// This benchmark is only concerned with contains() performance.
 | ||||||
|  | 	awsIPs := awsIPs{} | ||||||
|  | 	populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType) | ||||||
|  | 
 | ||||||
|  | 	ipv4 := make([][]byte, b.N) | ||||||
|  | 	ipv6 := make([][]byte, b.N) | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		ipv4[i] = make([]byte, 4) | ||||||
|  | 		ipv6[i] = make([]byte, 16) | ||||||
|  | 		rand.Read(ipv4[i]) | ||||||
|  | 		rand.Read(ipv6[i]) | ||||||
|  | 	} | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		awsIPs.contains(ipv4[i]) | ||||||
|  | 		awsIPs.contains(ipv6[i]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func BenchmarkContainsProd(b *testing.B) { | ||||||
|  | 	awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) | ||||||
|  | 	ipv4 := make([][]byte, b.N) | ||||||
|  | 	ipv6 := make([][]byte, b.N) | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		ipv4[i] = make([]byte, 4) | ||||||
|  | 		ipv6[i] = make([]byte, 16) | ||||||
|  | 		rand.Read(ipv4[i]) | ||||||
|  | 		rand.Read(ipv6[i]) | ||||||
|  | 	} | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		awsIPs.contains(ipv4[i]) | ||||||
|  | 		awsIPs.contains(ipv6[i]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -1,340 +0,0 @@ | ||||||
| //+build ignore
 |  | ||||||
| 
 |  | ||||||
| // msg_generate.go is meant to run with go generate. It will use
 |  | ||||||
| // go/{importer,types} to track down all the RR struct types. Then for each type
 |  | ||||||
| // it will generate pack/unpack methods based on the struct tags. The generated source is
 |  | ||||||
| // written to zmsg.go, and is meant to be checked into git.
 |  | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"go/format" |  | ||||||
| 	"go/importer" |  | ||||||
| 	"go/types" |  | ||||||
| 	"log" |  | ||||||
| 	"os" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| var packageHdr = ` |  | ||||||
| // *** DO NOT MODIFY ***
 |  | ||||||
| // AUTOGENERATED BY go generate from msg_generate.go
 |  | ||||||
| 
 |  | ||||||
| package dns |  | ||||||
| 
 |  | ||||||
| ` |  | ||||||
| 
 |  | ||||||
| // getTypeStruct will take a type and the package scope, and return the
 |  | ||||||
| // (innermost) struct if the type is considered a RR type (currently defined as
 |  | ||||||
| // those structs beginning with a RR_Header, could be redefined as implementing
 |  | ||||||
| // the RR interface). The bool return value indicates if embedded structs were
 |  | ||||||
| // resolved.
 |  | ||||||
| func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { |  | ||||||
| 	st, ok := t.Underlying().(*types.Struct) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, false |  | ||||||
| 	} |  | ||||||
| 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { |  | ||||||
| 		return st, false |  | ||||||
| 	} |  | ||||||
| 	if st.Field(0).Anonymous() { |  | ||||||
| 		st, _ := getTypeStruct(st.Field(0).Type(), scope) |  | ||||||
| 		return st, true |  | ||||||
| 	} |  | ||||||
| 	return nil, false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func main() { |  | ||||||
| 	// Import and type-check the package
 |  | ||||||
| 	pkg, err := importer.Default().Import("github.com/miekg/dns") |  | ||||||
| 	fatalIfErr(err) |  | ||||||
| 	scope := pkg.Scope() |  | ||||||
| 
 |  | ||||||
| 	// Collect actual types (*X)
 |  | ||||||
| 	var namedTypes []string |  | ||||||
| 	for _, name := range scope.Names() { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		if o == nil || !o.Exported() { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if st, _ := getTypeStruct(o.Type(), scope); st == nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if name == "PrivateRR" { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Check if corresponding TypeX exists
 |  | ||||||
| 		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { |  | ||||||
| 			log.Fatalf("Constant Type%s does not exist.", o.Name()) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		namedTypes = append(namedTypes, o.Name()) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	b := &bytes.Buffer{} |  | ||||||
| 	b.WriteString(packageHdr) |  | ||||||
| 
 |  | ||||||
| 	fmt.Fprint(b, "// pack*() functions\n\n") |  | ||||||
| 	for _, name := range namedTypes { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		st, _ := getTypeStruct(o.Type(), scope) |  | ||||||
| 
 |  | ||||||
| 		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) |  | ||||||
| 		fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) |  | ||||||
| if err != nil { |  | ||||||
| 	return off, err |  | ||||||
| } |  | ||||||
| headerEnd := off |  | ||||||
| `) |  | ||||||
| 		for i := 1; i < st.NumFields(); i++ { |  | ||||||
| 			o := func(s string) { |  | ||||||
| 				fmt.Fprintf(b, s, st.Field(i).Name()) |  | ||||||
| 				fmt.Fprint(b, `if err != nil { |  | ||||||
| return off, err |  | ||||||
| } |  | ||||||
| `) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { |  | ||||||
| 				switch st.Tag(i) { |  | ||||||
| 				case `dns:"-"`: // ignored
 |  | ||||||
| 				case `dns:"txt"`: |  | ||||||
| 					o("off, err = packStringTxt(rr.%s, msg, off)\n") |  | ||||||
| 				case `dns:"opt"`: |  | ||||||
| 					o("off, err = packDataOpt(rr.%s, msg, off)\n") |  | ||||||
| 				case `dns:"nsec"`: |  | ||||||
| 					o("off, err = packDataNsec(rr.%s, msg, off)\n") |  | ||||||
| 				case `dns:"domain-name"`: |  | ||||||
| 					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 				} |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			switch { |  | ||||||
| 			case st.Tag(i) == `dns:"-"`: // ignored
 |  | ||||||
| 			case st.Tag(i) == `dns:"cdomain-name"`: |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"domain-name"`: |  | ||||||
| 				o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") |  | ||||||
| 			case st.Tag(i) == `dns:"a"`: |  | ||||||
| 				o("off, err = packDataA(rr.%s, msg, off)\n") |  | ||||||
| 			case st.Tag(i) == `dns:"aaaa"`: |  | ||||||
| 				o("off, err = packDataAAAA(rr.%s, msg, off)\n") |  | ||||||
| 			case st.Tag(i) == `dns:"uint48"`: |  | ||||||
| 				o("off, err = packUint48(rr.%s, msg, off)\n") |  | ||||||
| 			case st.Tag(i) == `dns:"txt"`: |  | ||||||
| 				o("off, err = packString(rr.%s, msg, off)\n") |  | ||||||
| 
 |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
 |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"base32"`: |  | ||||||
| 				o("off, err = packStringBase32(rr.%s, msg, off)\n") |  | ||||||
| 
 |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
 |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"base64"`: |  | ||||||
| 				o("off, err = packStringBase64(rr.%s, msg, off)\n") |  | ||||||
| 
 |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3
 |  | ||||||
| 				o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n") |  | ||||||
| 				continue |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
 |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"hex"`: |  | ||||||
| 				o("off, err = packStringHex(rr.%s, msg, off)\n") |  | ||||||
| 
 |  | ||||||
| 			case st.Tag(i) == `dns:"octet"`: |  | ||||||
| 				o("off, err = packStringOctet(rr.%s, msg, off)\n") |  | ||||||
| 			case st.Tag(i) == "": |  | ||||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { |  | ||||||
| 				case types.Uint8: |  | ||||||
| 					o("off, err = packUint8(rr.%s, msg, off)\n") |  | ||||||
| 				case types.Uint16: |  | ||||||
| 					o("off, err = packUint16(rr.%s, msg, off)\n") |  | ||||||
| 				case types.Uint32: |  | ||||||
| 					o("off, err = packUint32(rr.%s, msg, off)\n") |  | ||||||
| 				case types.Uint64: |  | ||||||
| 					o("off, err = packUint64(rr.%s, msg, off)\n") |  | ||||||
| 				case types.String: |  | ||||||
| 					o("off, err = packString(rr.%s, msg, off)\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name()) |  | ||||||
| 				} |  | ||||||
| 			default: |  | ||||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		// We have packed everything, only now we know the rdlength of this RR
 |  | ||||||
| 		fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)") |  | ||||||
| 		fmt.Fprintln(b, "return off, nil }\n") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	fmt.Fprint(b, "// unpack*() functions\n\n") |  | ||||||
| 	for _, name := range namedTypes { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		st, _ := getTypeStruct(o.Type(), scope) |  | ||||||
| 
 |  | ||||||
| 		fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) |  | ||||||
| 		fmt.Fprintf(b, "rr := new(%s)\n", name) |  | ||||||
| 		fmt.Fprint(b, "rr.Hdr = h\n") |  | ||||||
| 		fmt.Fprint(b, `if noRdata(h) { |  | ||||||
| return rr, off, nil |  | ||||||
| 	} |  | ||||||
| var err error |  | ||||||
| rdStart := off |  | ||||||
| _ = rdStart |  | ||||||
| 
 |  | ||||||
| `) |  | ||||||
| 		for i := 1; i < st.NumFields(); i++ { |  | ||||||
| 			o := func(s string) { |  | ||||||
| 				fmt.Fprintf(b, s, st.Field(i).Name()) |  | ||||||
| 				fmt.Fprint(b, `if err != nil { |  | ||||||
| return rr, off, err |  | ||||||
| } |  | ||||||
| `) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			// size-* are special, because they reference a struct member we should use for the length.
 |  | ||||||
| 			if strings.HasPrefix(st.Tag(i), `dns:"size-`) { |  | ||||||
| 				structMember := structMember(st.Tag(i)) |  | ||||||
| 				structTag := structTag(st.Tag(i)) |  | ||||||
| 				switch structTag { |  | ||||||
| 				case "hex": |  | ||||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) |  | ||||||
| 				case "base32": |  | ||||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) |  | ||||||
| 				case "base64": |  | ||||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 				} |  | ||||||
| 				fmt.Fprint(b, `if err != nil { |  | ||||||
| return rr, off, err |  | ||||||
| } |  | ||||||
| `) |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { |  | ||||||
| 				switch st.Tag(i) { |  | ||||||
| 				case `dns:"-"`: // ignored
 |  | ||||||
| 				case `dns:"txt"`: |  | ||||||
| 					o("rr.%s, off, err = unpackStringTxt(msg, off)\n") |  | ||||||
| 				case `dns:"opt"`: |  | ||||||
| 					o("rr.%s, off, err = unpackDataOpt(msg, off)\n") |  | ||||||
| 				case `dns:"nsec"`: |  | ||||||
| 					o("rr.%s, off, err = unpackDataNsec(msg, off)\n") |  | ||||||
| 				case `dns:"domain-name"`: |  | ||||||
| 					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 				} |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			switch st.Tag(i) { |  | ||||||
| 			case `dns:"-"`: // ignored
 |  | ||||||
| 			case `dns:"cdomain-name"`: |  | ||||||
| 				fallthrough |  | ||||||
| 			case `dns:"domain-name"`: |  | ||||||
| 				o("rr.%s, off, err = UnpackDomainName(msg, off)\n") |  | ||||||
| 			case `dns:"a"`: |  | ||||||
| 				o("rr.%s, off, err = unpackDataA(msg, off)\n") |  | ||||||
| 			case `dns:"aaaa"`: |  | ||||||
| 				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") |  | ||||||
| 			case `dns:"uint48"`: |  | ||||||
| 				o("rr.%s, off, err = unpackUint48(msg, off)\n") |  | ||||||
| 			case `dns:"txt"`: |  | ||||||
| 				o("rr.%s, off, err = unpackString(msg, off)\n") |  | ||||||
| 			case `dns:"base32"`: |  | ||||||
| 				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") |  | ||||||
| 			case `dns:"base64"`: |  | ||||||
| 				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") |  | ||||||
| 			case `dns:"hex"`: |  | ||||||
| 				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") |  | ||||||
| 			case `dns:"octet"`: |  | ||||||
| 				o("rr.%s, off, err = unpackStringOctet(msg, off)\n") |  | ||||||
| 			case "": |  | ||||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { |  | ||||||
| 				case types.Uint8: |  | ||||||
| 					o("rr.%s, off, err = unpackUint8(msg, off)\n") |  | ||||||
| 				case types.Uint16: |  | ||||||
| 					o("rr.%s, off, err = unpackUint16(msg, off)\n") |  | ||||||
| 				case types.Uint32: |  | ||||||
| 					o("rr.%s, off, err = unpackUint32(msg, off)\n") |  | ||||||
| 				case types.Uint64: |  | ||||||
| 					o("rr.%s, off, err = unpackUint64(msg, off)\n") |  | ||||||
| 				case types.String: |  | ||||||
| 					o("rr.%s, off, err = unpackString(msg, off)\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name()) |  | ||||||
| 				} |  | ||||||
| 			default: |  | ||||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 			} |  | ||||||
| 			// If we've hit len(msg) we return without error.
 |  | ||||||
| 			if i < st.NumFields()-1 { |  | ||||||
| 				fmt.Fprintf(b, `if off == len(msg) { |  | ||||||
| return rr, off, nil |  | ||||||
| 	} |  | ||||||
| `) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "return rr, off, err }\n\n") |  | ||||||
| 	} |  | ||||||
| 	// Generate typeToUnpack map
 |  | ||||||
| 	fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){") |  | ||||||
| 	for _, name := range namedTypes { |  | ||||||
| 		if name == "RFC3597" { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name) |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintln(b, "}\n") |  | ||||||
| 
 |  | ||||||
| 	// gofmt
 |  | ||||||
| 	res, err := format.Source(b.Bytes()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		b.WriteTo(os.Stderr) |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// write result
 |  | ||||||
| 	f, err := os.Create("zmsg.go") |  | ||||||
| 	fatalIfErr(err) |  | ||||||
| 	defer f.Close() |  | ||||||
| 	f.Write(res) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
 |  | ||||||
| func structMember(s string) string { |  | ||||||
| 	fields := strings.Split(s, ":") |  | ||||||
| 	if len(fields) == 0 { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
| 	f := fields[len(fields)-1] |  | ||||||
| 	// f should have a closing "
 |  | ||||||
| 	if len(f) > 1 { |  | ||||||
| 		return f[:len(f)-1] |  | ||||||
| 	} |  | ||||||
| 	return f |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
 |  | ||||||
| func structTag(s string) string { |  | ||||||
| 	fields := strings.Split(s, ":") |  | ||||||
| 	if len(fields) < 2 { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
| 	return fields[1][len("\"size-"):] |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func fatalIfErr(err error) { |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  | @ -1,271 +0,0 @@ | ||||||
| //+build ignore
 |  | ||||||
| 
 |  | ||||||
| // types_generate.go is meant to run with go generate. It will use
 |  | ||||||
| // go/{importer,types} to track down all the RR struct types. Then for each type
 |  | ||||||
| // it will generate conversion tables (TypeToRR and TypeToString) and banal
 |  | ||||||
| // methods (len, Header, copy) based on the struct tags. The generated source is
 |  | ||||||
| // written to ztypes.go, and is meant to be checked into git.
 |  | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"go/format" |  | ||||||
| 	"go/importer" |  | ||||||
| 	"go/types" |  | ||||||
| 	"log" |  | ||||||
| 	"os" |  | ||||||
| 	"strings" |  | ||||||
| 	"text/template" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| var skipLen = map[string]struct{}{ |  | ||||||
| 	"NSEC":  {}, |  | ||||||
| 	"NSEC3": {}, |  | ||||||
| 	"OPT":   {}, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var packageHdr = ` |  | ||||||
| // *** DO NOT MODIFY ***
 |  | ||||||
| // AUTOGENERATED BY go generate from type_generate.go
 |  | ||||||
| 
 |  | ||||||
| package dns |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"net" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| ` |  | ||||||
| 
 |  | ||||||
| var TypeToRR = template.Must(template.New("TypeToRR").Parse(` |  | ||||||
| // TypeToRR is a map of constructors for each RR type.
 |  | ||||||
| var TypeToRR = map[uint16]func() RR{ |  | ||||||
| {{range .}}{{if ne . "RFC3597"}}  Type{{.}}:  func() RR { return new({{.}}) }, |  | ||||||
| {{end}}{{end}}                    } |  | ||||||
| 
 |  | ||||||
| `)) |  | ||||||
| 
 |  | ||||||
| var typeToString = template.Must(template.New("typeToString").Parse(` |  | ||||||
| // TypeToString is a map of strings for each RR type.
 |  | ||||||
| var TypeToString = map[uint16]string{ |  | ||||||
| {{range .}}{{if ne . "NSAPPTR"}}  Type{{.}}: "{{.}}", |  | ||||||
| {{end}}{{end}}                    TypeNSAPPTR:    "NSAP-PTR", |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| `)) |  | ||||||
| 
 |  | ||||||
| var headerFunc = template.Must(template.New("headerFunc").Parse(` |  | ||||||
| // Header() functions
 |  | ||||||
| {{range .}}  func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } |  | ||||||
| {{end}} |  | ||||||
| 
 |  | ||||||
| `)) |  | ||||||
| 
 |  | ||||||
| // getTypeStruct will take a type and the package scope, and return the
 |  | ||||||
| // (innermost) struct if the type is considered a RR type (currently defined as
 |  | ||||||
| // those structs beginning with a RR_Header, could be redefined as implementing
 |  | ||||||
| // the RR interface). The bool return value indicates if embedded structs were
 |  | ||||||
| // resolved.
 |  | ||||||
| func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { |  | ||||||
| 	st, ok := t.Underlying().(*types.Struct) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, false |  | ||||||
| 	} |  | ||||||
| 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { |  | ||||||
| 		return st, false |  | ||||||
| 	} |  | ||||||
| 	if st.Field(0).Anonymous() { |  | ||||||
| 		st, _ := getTypeStruct(st.Field(0).Type(), scope) |  | ||||||
| 		return st, true |  | ||||||
| 	} |  | ||||||
| 	return nil, false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func main() { |  | ||||||
| 	// Import and type-check the package
 |  | ||||||
| 	pkg, err := importer.Default().Import("github.com/miekg/dns") |  | ||||||
| 	fatalIfErr(err) |  | ||||||
| 	scope := pkg.Scope() |  | ||||||
| 
 |  | ||||||
| 	// Collect constants like TypeX
 |  | ||||||
| 	var numberedTypes []string |  | ||||||
| 	for _, name := range scope.Names() { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		if o == nil || !o.Exported() { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		b, ok := o.Type().(*types.Basic) |  | ||||||
| 		if !ok || b.Kind() != types.Uint16 { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if !strings.HasPrefix(o.Name(), "Type") { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		name := strings.TrimPrefix(o.Name(), "Type") |  | ||||||
| 		if name == "PrivateRR" { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		numberedTypes = append(numberedTypes, name) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Collect actual types (*X)
 |  | ||||||
| 	var namedTypes []string |  | ||||||
| 	for _, name := range scope.Names() { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		if o == nil || !o.Exported() { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if st, _ := getTypeStruct(o.Type(), scope); st == nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if name == "PrivateRR" { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Check if corresponding TypeX exists
 |  | ||||||
| 		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { |  | ||||||
| 			log.Fatalf("Constant Type%s does not exist.", o.Name()) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		namedTypes = append(namedTypes, o.Name()) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	b := &bytes.Buffer{} |  | ||||||
| 	b.WriteString(packageHdr) |  | ||||||
| 
 |  | ||||||
| 	// Generate TypeToRR
 |  | ||||||
| 	fatalIfErr(TypeToRR.Execute(b, namedTypes)) |  | ||||||
| 
 |  | ||||||
| 	// Generate typeToString
 |  | ||||||
| 	fatalIfErr(typeToString.Execute(b, numberedTypes)) |  | ||||||
| 
 |  | ||||||
| 	// Generate headerFunc
 |  | ||||||
| 	fatalIfErr(headerFunc.Execute(b, namedTypes)) |  | ||||||
| 
 |  | ||||||
| 	// Generate len()
 |  | ||||||
| 	fmt.Fprint(b, "// len() functions\n") |  | ||||||
| 	for _, name := range namedTypes { |  | ||||||
| 		if _, ok := skipLen[name]; ok { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		st, isEmbedded := getTypeStruct(o.Type(), scope) |  | ||||||
| 		if isEmbedded { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) |  | ||||||
| 		fmt.Fprintf(b, "l := rr.Hdr.len()\n") |  | ||||||
| 		for i := 1; i < st.NumFields(); i++ { |  | ||||||
| 			o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } |  | ||||||
| 
 |  | ||||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { |  | ||||||
| 				switch st.Tag(i) { |  | ||||||
| 				case `dns:"-"`: |  | ||||||
| 					// ignored
 |  | ||||||
| 				case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: |  | ||||||
| 					o("for _, x := range rr.%s { l += len(x) + 1 }\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 				} |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			switch { |  | ||||||
| 			case st.Tag(i) == `dns:"-"`: |  | ||||||
| 				// ignored
 |  | ||||||
| 			case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`: |  | ||||||
| 				o("l += len(rr.%s) + 1\n") |  | ||||||
| 			case st.Tag(i) == `dns:"octet"`: |  | ||||||
| 				o("l += len(rr.%s)\n") |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"base64"`: |  | ||||||
| 				o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") |  | ||||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): |  | ||||||
| 				fallthrough |  | ||||||
| 			case st.Tag(i) == `dns:"hex"`: |  | ||||||
| 				o("l += len(rr.%s)/2 + 1\n") |  | ||||||
| 			case st.Tag(i) == `dns:"a"`: |  | ||||||
| 				o("l += net.IPv4len // %s\n") |  | ||||||
| 			case st.Tag(i) == `dns:"aaaa"`: |  | ||||||
| 				o("l += net.IPv6len // %s\n") |  | ||||||
| 			case st.Tag(i) == `dns:"txt"`: |  | ||||||
| 				o("for _, t := range rr.%s { l += len(t) + 1 }\n") |  | ||||||
| 			case st.Tag(i) == `dns:"uint48"`: |  | ||||||
| 				o("l += 6 // %s\n") |  | ||||||
| 			case st.Tag(i) == "": |  | ||||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { |  | ||||||
| 				case types.Uint8: |  | ||||||
| 					o("l += 1 // %s\n") |  | ||||||
| 				case types.Uint16: |  | ||||||
| 					o("l += 2 // %s\n") |  | ||||||
| 				case types.Uint32: |  | ||||||
| 					o("l += 4 // %s\n") |  | ||||||
| 				case types.Uint64: |  | ||||||
| 					o("l += 8 // %s\n") |  | ||||||
| 				case types.String: |  | ||||||
| 					o("l += len(rr.%s) + 1\n") |  | ||||||
| 				default: |  | ||||||
| 					log.Fatalln(name, st.Field(i).Name()) |  | ||||||
| 				} |  | ||||||
| 			default: |  | ||||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "return l }\n") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Generate copy()
 |  | ||||||
| 	fmt.Fprint(b, "// copy() functions\n") |  | ||||||
| 	for _, name := range namedTypes { |  | ||||||
| 		o := scope.Lookup(name) |  | ||||||
| 		st, isEmbedded := getTypeStruct(o.Type(), scope) |  | ||||||
| 		if isEmbedded { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) |  | ||||||
| 		fields := []string{"*rr.Hdr.copyHeader()"} |  | ||||||
| 		for i := 1; i < st.NumFields(); i++ { |  | ||||||
| 			f := st.Field(i).Name() |  | ||||||
| 			if sl, ok := st.Field(i).Type().(*types.Slice); ok { |  | ||||||
| 				t := sl.Underlying().String() |  | ||||||
| 				t = strings.TrimPrefix(t, "[]") |  | ||||||
| 				if strings.Contains(t, ".") { |  | ||||||
| 					splits := strings.Split(t, ".") |  | ||||||
| 					t = splits[len(splits)-1] |  | ||||||
| 				} |  | ||||||
| 				fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", |  | ||||||
| 					f, t, f, f, f) |  | ||||||
| 				fields = append(fields, f) |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if st.Field(i).Type().String() == "net.IP" { |  | ||||||
| 				fields = append(fields, "copyIP(rr."+f+")") |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			fields = append(fields, "rr."+f) |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) |  | ||||||
| 		fmt.Fprintf(b, "}\n") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// gofmt
 |  | ||||||
| 	res, err := format.Source(b.Bytes()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		b.WriteTo(os.Stderr) |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// write result
 |  | ||||||
| 	f, err := os.Create("ztypes.go") |  | ||||||
| 	fatalIfErr(err) |  | ||||||
| 	defer f.Close() |  | ||||||
| 	f.Write(res) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func fatalIfErr(err error) { |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  | @ -1,68 +0,0 @@ | ||||||
| // Copyright 2012 The Go Authors. All rights reserved.
 |  | ||||||
| // Use of this source code is governed by a BSD-style
 |  | ||||||
| // license that can be found in the LICENSE file.
 |  | ||||||
| 
 |  | ||||||
| // Package idna implements IDNA2008 (Internationalized Domain Names for
 |  | ||||||
| // Applications), defined in RFC 5890, RFC 5891, RFC 5892, RFC 5893 and
 |  | ||||||
| // RFC 5894.
 |  | ||||||
| package idna // import "golang.org/x/net/idna"
 |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"strings" |  | ||||||
| 	"unicode/utf8" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // TODO(nigeltao): specify when errors occur. For example, is ToASCII(".") or
 |  | ||||||
| // ToASCII("foo\x00") an error? See also http://www.unicode.org/faq/idn.html#11
 |  | ||||||
| 
 |  | ||||||
| // acePrefix is the ASCII Compatible Encoding prefix.
 |  | ||||||
| const acePrefix = "xn--" |  | ||||||
| 
 |  | ||||||
| // ToASCII converts a domain or domain label to its ASCII form. For example,
 |  | ||||||
| // ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
 |  | ||||||
| // ToASCII("golang") is "golang".
 |  | ||||||
| func ToASCII(s string) (string, error) { |  | ||||||
| 	if ascii(s) { |  | ||||||
| 		return s, nil |  | ||||||
| 	} |  | ||||||
| 	labels := strings.Split(s, ".") |  | ||||||
| 	for i, label := range labels { |  | ||||||
| 		if !ascii(label) { |  | ||||||
| 			a, err := encode(acePrefix, label) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return "", err |  | ||||||
| 			} |  | ||||||
| 			labels[i] = a |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return strings.Join(labels, "."), nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ToUnicode converts a domain or domain label to its Unicode form. For example,
 |  | ||||||
| // ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and
 |  | ||||||
| // ToUnicode("golang") is "golang".
 |  | ||||||
| func ToUnicode(s string) (string, error) { |  | ||||||
| 	if !strings.Contains(s, acePrefix) { |  | ||||||
| 		return s, nil |  | ||||||
| 	} |  | ||||||
| 	labels := strings.Split(s, ".") |  | ||||||
| 	for i, label := range labels { |  | ||||||
| 		if strings.HasPrefix(label, acePrefix) { |  | ||||||
| 			u, err := decode(label[len(acePrefix):]) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return "", err |  | ||||||
| 			} |  | ||||||
| 			labels[i] = u |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return strings.Join(labels, "."), nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func ascii(s string) bool { |  | ||||||
| 	for i := 0; i < len(s); i++ { |  | ||||||
| 		if s[i] >= utf8.RuneSelf { |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
|  | @ -1,200 +0,0 @@ | ||||||
| // Copyright 2012 The Go Authors. All rights reserved.
 |  | ||||||
| // Use of this source code is governed by a BSD-style
 |  | ||||||
| // license that can be found in the LICENSE file.
 |  | ||||||
| 
 |  | ||||||
| package idna |  | ||||||
| 
 |  | ||||||
| // This file implements the Punycode algorithm from RFC 3492.
 |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"math" |  | ||||||
| 	"strings" |  | ||||||
| 	"unicode/utf8" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // These parameter values are specified in section 5.
 |  | ||||||
| //
 |  | ||||||
| // All computation is done with int32s, so that overflow behavior is identical
 |  | ||||||
| // regardless of whether int is 32-bit or 64-bit.
 |  | ||||||
| const ( |  | ||||||
| 	base        int32 = 36 |  | ||||||
| 	damp        int32 = 700 |  | ||||||
| 	initialBias int32 = 72 |  | ||||||
| 	initialN    int32 = 128 |  | ||||||
| 	skew        int32 = 38 |  | ||||||
| 	tmax        int32 = 26 |  | ||||||
| 	tmin        int32 = 1 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // decode decodes a string as specified in section 6.2.
 |  | ||||||
| func decode(encoded string) (string, error) { |  | ||||||
| 	if encoded == "" { |  | ||||||
| 		return "", nil |  | ||||||
| 	} |  | ||||||
| 	pos := 1 + strings.LastIndex(encoded, "-") |  | ||||||
| 	if pos == 1 { |  | ||||||
| 		return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 	} |  | ||||||
| 	if pos == len(encoded) { |  | ||||||
| 		return encoded[:len(encoded)-1], nil |  | ||||||
| 	} |  | ||||||
| 	output := make([]rune, 0, len(encoded)) |  | ||||||
| 	if pos != 0 { |  | ||||||
| 		for _, r := range encoded[:pos-1] { |  | ||||||
| 			output = append(output, r) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	i, n, bias := int32(0), initialN, initialBias |  | ||||||
| 	for pos < len(encoded) { |  | ||||||
| 		oldI, w := i, int32(1) |  | ||||||
| 		for k := base; ; k += base { |  | ||||||
| 			if pos == len(encoded) { |  | ||||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 			} |  | ||||||
| 			digit, ok := decodeDigit(encoded[pos]) |  | ||||||
| 			if !ok { |  | ||||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 			} |  | ||||||
| 			pos++ |  | ||||||
| 			i += digit * w |  | ||||||
| 			if i < 0 { |  | ||||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 			} |  | ||||||
| 			t := k - bias |  | ||||||
| 			if t < tmin { |  | ||||||
| 				t = tmin |  | ||||||
| 			} else if t > tmax { |  | ||||||
| 				t = tmax |  | ||||||
| 			} |  | ||||||
| 			if digit < t { |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 			w *= base - t |  | ||||||
| 			if w >= math.MaxInt32/base { |  | ||||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		x := int32(len(output) + 1) |  | ||||||
| 		bias = adapt(i-oldI, x, oldI == 0) |  | ||||||
| 		n += i / x |  | ||||||
| 		i %= x |  | ||||||
| 		if n > utf8.MaxRune || len(output) >= 1024 { |  | ||||||
| 			return "", fmt.Errorf("idna: invalid label %q", encoded) |  | ||||||
| 		} |  | ||||||
| 		output = append(output, 0) |  | ||||||
| 		copy(output[i+1:], output[i:]) |  | ||||||
| 		output[i] = n |  | ||||||
| 		i++ |  | ||||||
| 	} |  | ||||||
| 	return string(output), nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // encode encodes a string as specified in section 6.3 and prepends prefix to
 |  | ||||||
| // the result.
 |  | ||||||
| //
 |  | ||||||
| // The "while h < length(input)" line in the specification becomes "for
 |  | ||||||
| // remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
 |  | ||||||
| func encode(prefix, s string) (string, error) { |  | ||||||
| 	output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) |  | ||||||
| 	copy(output, prefix) |  | ||||||
| 	delta, n, bias := int32(0), initialN, initialBias |  | ||||||
| 	b, remaining := int32(0), int32(0) |  | ||||||
| 	for _, r := range s { |  | ||||||
| 		if r < 0x80 { |  | ||||||
| 			b++ |  | ||||||
| 			output = append(output, byte(r)) |  | ||||||
| 		} else { |  | ||||||
| 			remaining++ |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	h := b |  | ||||||
| 	if b > 0 { |  | ||||||
| 		output = append(output, '-') |  | ||||||
| 	} |  | ||||||
| 	for remaining != 0 { |  | ||||||
| 		m := int32(0x7fffffff) |  | ||||||
| 		for _, r := range s { |  | ||||||
| 			if m > r && r >= n { |  | ||||||
| 				m = r |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		delta += (m - n) * (h + 1) |  | ||||||
| 		if delta < 0 { |  | ||||||
| 			return "", fmt.Errorf("idna: invalid label %q", s) |  | ||||||
| 		} |  | ||||||
| 		n = m |  | ||||||
| 		for _, r := range s { |  | ||||||
| 			if r < n { |  | ||||||
| 				delta++ |  | ||||||
| 				if delta < 0 { |  | ||||||
| 					return "", fmt.Errorf("idna: invalid label %q", s) |  | ||||||
| 				} |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if r > n { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			q := delta |  | ||||||
| 			for k := base; ; k += base { |  | ||||||
| 				t := k - bias |  | ||||||
| 				if t < tmin { |  | ||||||
| 					t = tmin |  | ||||||
| 				} else if t > tmax { |  | ||||||
| 					t = tmax |  | ||||||
| 				} |  | ||||||
| 				if q < t { |  | ||||||
| 					break |  | ||||||
| 				} |  | ||||||
| 				output = append(output, encodeDigit(t+(q-t)%(base-t))) |  | ||||||
| 				q = (q - t) / (base - t) |  | ||||||
| 			} |  | ||||||
| 			output = append(output, encodeDigit(q)) |  | ||||||
| 			bias = adapt(delta, h+1, h == b) |  | ||||||
| 			delta = 0 |  | ||||||
| 			h++ |  | ||||||
| 			remaining-- |  | ||||||
| 		} |  | ||||||
| 		delta++ |  | ||||||
| 		n++ |  | ||||||
| 	} |  | ||||||
| 	return string(output), nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func decodeDigit(x byte) (digit int32, ok bool) { |  | ||||||
| 	switch { |  | ||||||
| 	case '0' <= x && x <= '9': |  | ||||||
| 		return int32(x - ('0' - 26)), true |  | ||||||
| 	case 'A' <= x && x <= 'Z': |  | ||||||
| 		return int32(x - 'A'), true |  | ||||||
| 	case 'a' <= x && x <= 'z': |  | ||||||
| 		return int32(x - 'a'), true |  | ||||||
| 	} |  | ||||||
| 	return 0, false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func encodeDigit(digit int32) byte { |  | ||||||
| 	switch { |  | ||||||
| 	case 0 <= digit && digit < 26: |  | ||||||
| 		return byte(digit + 'a') |  | ||||||
| 	case 26 <= digit && digit < 36: |  | ||||||
| 		return byte(digit + ('0' - 26)) |  | ||||||
| 	} |  | ||||||
| 	panic("idna: internal error in punycode encoding") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // adapt is the bias adaptation function specified in section 6.1.
 |  | ||||||
| func adapt(delta, numPoints int32, firstTime bool) int32 { |  | ||||||
| 	if firstTime { |  | ||||||
| 		delta /= damp |  | ||||||
| 	} else { |  | ||||||
| 		delta /= 2 |  | ||||||
| 	} |  | ||||||
| 	delta += delta / numPoints |  | ||||||
| 	k := int32(0) |  | ||||||
| 	for delta > ((base-tmin)*tmax)/2 { |  | ||||||
| 		delta /= base - tmin |  | ||||||
| 		k += base |  | ||||||
| 	} |  | ||||||
| 	return k + (base-tmin+1)*delta/(delta+skew) |  | ||||||
| } |  | ||||||
|  | @ -1,663 +0,0 @@ | ||||||
| // Copyright 2012 The Go Authors. All rights reserved.
 |  | ||||||
| // Use of this source code is governed by a BSD-style
 |  | ||||||
| // license that can be found in the LICENSE file.
 |  | ||||||
| 
 |  | ||||||
| // +build ignore
 |  | ||||||
| 
 |  | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| // This program generates table.go and table_test.go.
 |  | ||||||
| // Invoke as:
 |  | ||||||
| //
 |  | ||||||
| //	go run gen.go -version "xxx"       >table.go
 |  | ||||||
| //	go run gen.go -version "xxx" -test >table_test.go
 |  | ||||||
| //
 |  | ||||||
| // Pass -v to print verbose progress information.
 |  | ||||||
| //
 |  | ||||||
| // The version is derived from information found at
 |  | ||||||
| // https://github.com/publicsuffix/list/commits/master/public_suffix_list.dat
 |  | ||||||
| //
 |  | ||||||
| // To fetch a particular git revision, such as 5c70ccd250, pass
 |  | ||||||
| // -url "https://raw.githubusercontent.com/publicsuffix/list/5c70ccd250/public_suffix_list.dat"
 |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"bytes" |  | ||||||
| 	"flag" |  | ||||||
| 	"fmt" |  | ||||||
| 	"go/format" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"os" |  | ||||||
| 	"regexp" |  | ||||||
| 	"sort" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"golang.org/x/net/idna" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	// These sum of these four values must be no greater than 32.
 |  | ||||||
| 	nodesBitsChildren   = 9 |  | ||||||
| 	nodesBitsICANN      = 1 |  | ||||||
| 	nodesBitsTextOffset = 15 |  | ||||||
| 	nodesBitsTextLength = 6 |  | ||||||
| 
 |  | ||||||
| 	// These sum of these four values must be no greater than 32.
 |  | ||||||
| 	childrenBitsWildcard = 1 |  | ||||||
| 	childrenBitsNodeType = 2 |  | ||||||
| 	childrenBitsHi       = 14 |  | ||||||
| 	childrenBitsLo       = 14 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| var ( |  | ||||||
| 	maxChildren   int |  | ||||||
| 	maxTextOffset int |  | ||||||
| 	maxTextLength int |  | ||||||
| 	maxHi         uint32 |  | ||||||
| 	maxLo         uint32 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func max(a, b int) int { |  | ||||||
| 	if a < b { |  | ||||||
| 		return b |  | ||||||
| 	} |  | ||||||
| 	return a |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func u32max(a, b uint32) uint32 { |  | ||||||
| 	if a < b { |  | ||||||
| 		return b |  | ||||||
| 	} |  | ||||||
| 	return a |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	nodeTypeNormal     = 0 |  | ||||||
| 	nodeTypeException  = 1 |  | ||||||
| 	nodeTypeParentOnly = 2 |  | ||||||
| 	numNodeType        = 3 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func nodeTypeStr(n int) string { |  | ||||||
| 	switch n { |  | ||||||
| 	case nodeTypeNormal: |  | ||||||
| 		return "+" |  | ||||||
| 	case nodeTypeException: |  | ||||||
| 		return "!" |  | ||||||
| 	case nodeTypeParentOnly: |  | ||||||
| 		return "o" |  | ||||||
| 	} |  | ||||||
| 	panic("unreachable") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var ( |  | ||||||
| 	labelEncoding = map[string]uint32{} |  | ||||||
| 	labelsList    = []string{} |  | ||||||
| 	labelsMap     = map[string]bool{} |  | ||||||
| 	rules         = []string{} |  | ||||||
| 
 |  | ||||||
| 	// validSuffix is used to check that the entries in the public suffix list
 |  | ||||||
| 	// are in canonical form (after Punycode encoding). Specifically, capital
 |  | ||||||
| 	// letters are not allowed.
 |  | ||||||
| 	validSuffix = regexp.MustCompile(`^[a-z0-9_\!\*\-\.]+$`) |  | ||||||
| 
 |  | ||||||
| 	subset = flag.Bool("subset", false, "generate only a subset of the full table, for debugging") |  | ||||||
| 	url    = flag.String("url", |  | ||||||
| 		"https://publicsuffix.org/list/effective_tld_names.dat", |  | ||||||
| 		"URL of the publicsuffix.org list. If empty, stdin is read instead") |  | ||||||
| 	v       = flag.Bool("v", false, "verbose output (to stderr)") |  | ||||||
| 	version = flag.String("version", "", "the effective_tld_names.dat version") |  | ||||||
| 	test    = flag.Bool("test", false, "generate table_test.go") |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func main() { |  | ||||||
| 	if err := main1(); err != nil { |  | ||||||
| 		fmt.Fprintln(os.Stderr, err) |  | ||||||
| 		os.Exit(1) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func main1() error { |  | ||||||
| 	flag.Parse() |  | ||||||
| 	if nodesBitsTextLength+nodesBitsTextOffset+nodesBitsICANN+nodesBitsChildren > 32 { |  | ||||||
| 		return fmt.Errorf("not enough bits to encode the nodes table") |  | ||||||
| 	} |  | ||||||
| 	if childrenBitsLo+childrenBitsHi+childrenBitsNodeType+childrenBitsWildcard > 32 { |  | ||||||
| 		return fmt.Errorf("not enough bits to encode the children table") |  | ||||||
| 	} |  | ||||||
| 	if *version == "" { |  | ||||||
| 		return fmt.Errorf("-version was not specified") |  | ||||||
| 	} |  | ||||||
| 	var r io.Reader = os.Stdin |  | ||||||
| 	if *url != "" { |  | ||||||
| 		res, err := http.Get(*url) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		if res.StatusCode != http.StatusOK { |  | ||||||
| 			return fmt.Errorf("bad GET status for %s: %d", *url, res.Status) |  | ||||||
| 		} |  | ||||||
| 		r = res.Body |  | ||||||
| 		defer res.Body.Close() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var root node |  | ||||||
| 	icann := false |  | ||||||
| 	buf := new(bytes.Buffer) |  | ||||||
| 	br := bufio.NewReader(r) |  | ||||||
| 	for { |  | ||||||
| 		s, err := br.ReadString('\n') |  | ||||||
| 		if err != nil { |  | ||||||
| 			if err == io.EOF { |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		s = strings.TrimSpace(s) |  | ||||||
| 		if strings.Contains(s, "BEGIN ICANN DOMAINS") { |  | ||||||
| 			icann = true |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if strings.Contains(s, "END ICANN DOMAINS") { |  | ||||||
| 			icann = false |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if s == "" || strings.HasPrefix(s, "//") { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		s, err = idna.ToASCII(s) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		if !validSuffix.MatchString(s) { |  | ||||||
| 			return fmt.Errorf("bad publicsuffix.org list data: %q", s) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if *subset { |  | ||||||
| 			switch { |  | ||||||
| 			case s == "ac.jp" || strings.HasSuffix(s, ".ac.jp"): |  | ||||||
| 			case s == "ak.us" || strings.HasSuffix(s, ".ak.us"): |  | ||||||
| 			case s == "ao" || strings.HasSuffix(s, ".ao"): |  | ||||||
| 			case s == "ar" || strings.HasSuffix(s, ".ar"): |  | ||||||
| 			case s == "arpa" || strings.HasSuffix(s, ".arpa"): |  | ||||||
| 			case s == "cy" || strings.HasSuffix(s, ".cy"): |  | ||||||
| 			case s == "dyndns.org" || strings.HasSuffix(s, ".dyndns.org"): |  | ||||||
| 			case s == "jp": |  | ||||||
| 			case s == "kobe.jp" || strings.HasSuffix(s, ".kobe.jp"): |  | ||||||
| 			case s == "kyoto.jp" || strings.HasSuffix(s, ".kyoto.jp"): |  | ||||||
| 			case s == "om" || strings.HasSuffix(s, ".om"): |  | ||||||
| 			case s == "uk" || strings.HasSuffix(s, ".uk"): |  | ||||||
| 			case s == "uk.com" || strings.HasSuffix(s, ".uk.com"): |  | ||||||
| 			case s == "tw" || strings.HasSuffix(s, ".tw"): |  | ||||||
| 			case s == "zw" || strings.HasSuffix(s, ".zw"): |  | ||||||
| 			case s == "xn--p1ai" || strings.HasSuffix(s, ".xn--p1ai"): |  | ||||||
| 				// xn--p1ai is Russian-Cyrillic "рф".
 |  | ||||||
| 			default: |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		rules = append(rules, s) |  | ||||||
| 
 |  | ||||||
| 		nt, wildcard := nodeTypeNormal, false |  | ||||||
| 		switch { |  | ||||||
| 		case strings.HasPrefix(s, "*."): |  | ||||||
| 			s, nt = s[2:], nodeTypeParentOnly |  | ||||||
| 			wildcard = true |  | ||||||
| 		case strings.HasPrefix(s, "!"): |  | ||||||
| 			s, nt = s[1:], nodeTypeException |  | ||||||
| 		} |  | ||||||
| 		labels := strings.Split(s, ".") |  | ||||||
| 		for n, i := &root, len(labels)-1; i >= 0; i-- { |  | ||||||
| 			label := labels[i] |  | ||||||
| 			n = n.child(label) |  | ||||||
| 			if i == 0 { |  | ||||||
| 				if nt != nodeTypeParentOnly && n.nodeType == nodeTypeParentOnly { |  | ||||||
| 					n.nodeType = nt |  | ||||||
| 				} |  | ||||||
| 				n.icann = n.icann && icann |  | ||||||
| 				n.wildcard = n.wildcard || wildcard |  | ||||||
| 			} |  | ||||||
| 			labelsMap[label] = true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	labelsList = make([]string, 0, len(labelsMap)) |  | ||||||
| 	for label := range labelsMap { |  | ||||||
| 		labelsList = append(labelsList, label) |  | ||||||
| 	} |  | ||||||
| 	sort.Strings(labelsList) |  | ||||||
| 
 |  | ||||||
| 	p := printReal |  | ||||||
| 	if *test { |  | ||||||
| 		p = printTest |  | ||||||
| 	} |  | ||||||
| 	if err := p(buf, &root); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	b, err := format.Source(buf.Bytes()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	_, err = os.Stdout.Write(b) |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func printTest(w io.Writer, n *node) error { |  | ||||||
| 	fmt.Fprintf(w, "// generated by go run gen.go; DO NOT EDIT\n\n") |  | ||||||
| 	fmt.Fprintf(w, "package publicsuffix\n\nvar rules = [...]string{\n") |  | ||||||
| 	for _, rule := range rules { |  | ||||||
| 		fmt.Fprintf(w, "%q,\n", rule) |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintf(w, "}\n\nvar nodeLabels = [...]string{\n") |  | ||||||
| 	if err := n.walk(w, printNodeLabel); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintf(w, "}\n") |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func printReal(w io.Writer, n *node) error { |  | ||||||
| 	const header = `// generated by go run gen.go; DO NOT EDIT
 |  | ||||||
| 
 |  | ||||||
| package publicsuffix |  | ||||||
| 
 |  | ||||||
| const version = %q |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	nodesBitsChildren   = %d |  | ||||||
| 	nodesBitsICANN      = %d |  | ||||||
| 	nodesBitsTextOffset = %d |  | ||||||
| 	nodesBitsTextLength = %d |  | ||||||
| 
 |  | ||||||
| 	childrenBitsWildcard = %d |  | ||||||
| 	childrenBitsNodeType = %d |  | ||||||
| 	childrenBitsHi       = %d |  | ||||||
| 	childrenBitsLo       = %d |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	nodeTypeNormal     = %d |  | ||||||
| 	nodeTypeException  = %d |  | ||||||
| 	nodeTypeParentOnly = %d |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // numTLD is the number of top level domains.
 |  | ||||||
| const numTLD = %d |  | ||||||
| 
 |  | ||||||
| ` |  | ||||||
| 	fmt.Fprintf(w, header, *version, |  | ||||||
| 		nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength, |  | ||||||
| 		childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo, |  | ||||||
| 		nodeTypeNormal, nodeTypeException, nodeTypeParentOnly, len(n.children)) |  | ||||||
| 
 |  | ||||||
| 	text := combineText(labelsList) |  | ||||||
| 	if text == "" { |  | ||||||
| 		return fmt.Errorf("internal error: makeText returned no text") |  | ||||||
| 	} |  | ||||||
| 	for _, label := range labelsList { |  | ||||||
| 		offset, length := strings.Index(text, label), len(label) |  | ||||||
| 		if offset < 0 { |  | ||||||
| 			return fmt.Errorf("internal error: could not find %q in text %q", label, text) |  | ||||||
| 		} |  | ||||||
| 		maxTextOffset, maxTextLength = max(maxTextOffset, offset), max(maxTextLength, length) |  | ||||||
| 		if offset >= 1<<nodesBitsTextOffset { |  | ||||||
| 			return fmt.Errorf("text offset %d is too large, or nodeBitsTextOffset is too small", offset) |  | ||||||
| 		} |  | ||||||
| 		if length >= 1<<nodesBitsTextLength { |  | ||||||
| 			return fmt.Errorf("text length %d is too large, or nodeBitsTextLength is too small", length) |  | ||||||
| 		} |  | ||||||
| 		labelEncoding[label] = uint32(offset)<<nodesBitsTextLength | uint32(length) |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintf(w, "// Text is the combined text of all labels.\nconst text = ") |  | ||||||
| 	for len(text) > 0 { |  | ||||||
| 		n, plus := len(text), "" |  | ||||||
| 		if n > 64 { |  | ||||||
| 			n, plus = 64, " +" |  | ||||||
| 		} |  | ||||||
| 		fmt.Fprintf(w, "%q%s\n", text[:n], plus) |  | ||||||
| 		text = text[n:] |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err := n.walk(w, assignIndexes); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	fmt.Fprintf(w, ` |  | ||||||
| 
 |  | ||||||
| // nodes is the list of nodes. Each node is represented as a uint32, which
 |  | ||||||
| // encodes the node's children, wildcard bit and node type (as an index into
 |  | ||||||
| // the children array), ICANN bit and text.
 |  | ||||||
| //
 |  | ||||||
| // In the //-comment after each node's data, the nodes indexes of the children
 |  | ||||||
| // are formatted as (n0x1234-n0x1256), with * denoting the wildcard bit. The
 |  | ||||||
| // nodeType is printed as + for normal, ! for exception, and o for parent-only
 |  | ||||||
| // nodes that have children but don't match a domain label in their own right.
 |  | ||||||
| // An I denotes an ICANN domain.
 |  | ||||||
| //
 |  | ||||||
| // The layout within the uint32, from MSB to LSB, is:
 |  | ||||||
| //	[%2d bits] unused
 |  | ||||||
| //	[%2d bits] children index
 |  | ||||||
| //	[%2d bits] ICANN bit
 |  | ||||||
| //	[%2d bits] text index
 |  | ||||||
| //	[%2d bits] text length
 |  | ||||||
| var nodes = [...]uint32{ |  | ||||||
| `, |  | ||||||
| 		32-nodesBitsChildren-nodesBitsICANN-nodesBitsTextOffset-nodesBitsTextLength, |  | ||||||
| 		nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength) |  | ||||||
| 	if err := n.walk(w, printNode); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintf(w, `} |  | ||||||
| 
 |  | ||||||
| // children is the list of nodes' children, the parent's wildcard bit and the
 |  | ||||||
| // parent's node type. If a node has no children then their children index
 |  | ||||||
| // will be in the range [0, 6), depending on the wildcard bit and node type.
 |  | ||||||
| //
 |  | ||||||
| // The layout within the uint32, from MSB to LSB, is:
 |  | ||||||
| //	[%2d bits] unused
 |  | ||||||
| //	[%2d bits] wildcard bit
 |  | ||||||
| //	[%2d bits] node type
 |  | ||||||
| //	[%2d bits] high nodes index (exclusive) of children
 |  | ||||||
| //	[%2d bits] low nodes index (inclusive) of children
 |  | ||||||
| var children=[...]uint32{ |  | ||||||
| `, |  | ||||||
| 		32-childrenBitsWildcard-childrenBitsNodeType-childrenBitsHi-childrenBitsLo, |  | ||||||
| 		childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo) |  | ||||||
| 	for i, c := range childrenEncoding { |  | ||||||
| 		s := "---------------" |  | ||||||
| 		lo := c & (1<<childrenBitsLo - 1) |  | ||||||
| 		hi := (c >> childrenBitsLo) & (1<<childrenBitsHi - 1) |  | ||||||
| 		if lo != hi { |  | ||||||
| 			s = fmt.Sprintf("n0x%04x-n0x%04x", lo, hi) |  | ||||||
| 		} |  | ||||||
| 		nodeType := int(c>>(childrenBitsLo+childrenBitsHi)) & (1<<childrenBitsNodeType - 1) |  | ||||||
| 		wildcard := c>>(childrenBitsLo+childrenBitsHi+childrenBitsNodeType) != 0 |  | ||||||
| 		fmt.Fprintf(w, "0x%08x, // c0x%04x (%s)%s %s\n", |  | ||||||
| 			c, i, s, wildcardStr(wildcard), nodeTypeStr(nodeType)) |  | ||||||
| 	} |  | ||||||
| 	fmt.Fprintf(w, "}\n\n") |  | ||||||
| 	fmt.Fprintf(w, "// max children %d (capacity %d)\n", maxChildren, 1<<nodesBitsChildren-1) |  | ||||||
| 	fmt.Fprintf(w, "// max text offset %d (capacity %d)\n", maxTextOffset, 1<<nodesBitsTextOffset-1) |  | ||||||
| 	fmt.Fprintf(w, "// max text length %d (capacity %d)\n", maxTextLength, 1<<nodesBitsTextLength-1) |  | ||||||
| 	fmt.Fprintf(w, "// max hi %d (capacity %d)\n", maxHi, 1<<childrenBitsHi-1) |  | ||||||
| 	fmt.Fprintf(w, "// max lo %d (capacity %d)\n", maxLo, 1<<childrenBitsLo-1) |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type node struct { |  | ||||||
| 	label    string |  | ||||||
| 	nodeType int |  | ||||||
| 	icann    bool |  | ||||||
| 	wildcard bool |  | ||||||
| 	// nodesIndex and childrenIndex are the index of this node in the nodes
 |  | ||||||
| 	// and the index of its children offset/length in the children arrays.
 |  | ||||||
| 	nodesIndex, childrenIndex int |  | ||||||
| 	// firstChild is the index of this node's first child, or zero if this
 |  | ||||||
| 	// node has no children.
 |  | ||||||
| 	firstChild int |  | ||||||
| 	// children are the node's children, in strictly increasing node label order.
 |  | ||||||
| 	children []*node |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (n *node) walk(w io.Writer, f func(w1 io.Writer, n1 *node) error) error { |  | ||||||
| 	if err := f(w, n); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	for _, c := range n.children { |  | ||||||
| 		if err := c.walk(w, f); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // child returns the child of n with the given label. The child is created if
 |  | ||||||
| // it did not exist beforehand.
 |  | ||||||
| func (n *node) child(label string) *node { |  | ||||||
| 	for _, c := range n.children { |  | ||||||
| 		if c.label == label { |  | ||||||
| 			return c |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	c := &node{ |  | ||||||
| 		label:    label, |  | ||||||
| 		nodeType: nodeTypeParentOnly, |  | ||||||
| 		icann:    true, |  | ||||||
| 	} |  | ||||||
| 	n.children = append(n.children, c) |  | ||||||
| 	sort.Sort(byLabel(n.children)) |  | ||||||
| 	return c |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type byLabel []*node |  | ||||||
| 
 |  | ||||||
| func (b byLabel) Len() int           { return len(b) } |  | ||||||
| func (b byLabel) Swap(i, j int)      { b[i], b[j] = b[j], b[i] } |  | ||||||
| func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label } |  | ||||||
| 
 |  | ||||||
| var nextNodesIndex int |  | ||||||
| 
 |  | ||||||
| // childrenEncoding are the encoded entries in the generated children array.
 |  | ||||||
| // All these pre-defined entries have no children.
 |  | ||||||
| var childrenEncoding = []uint32{ |  | ||||||
| 	0 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeNormal.
 |  | ||||||
| 	1 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeException.
 |  | ||||||
| 	2 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeParentOnly.
 |  | ||||||
| 	4 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeNormal.
 |  | ||||||
| 	5 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeException.
 |  | ||||||
| 	6 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeParentOnly.
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var firstCallToAssignIndexes = true |  | ||||||
| 
 |  | ||||||
| func assignIndexes(w io.Writer, n *node) error { |  | ||||||
| 	if len(n.children) != 0 { |  | ||||||
| 		// Assign nodesIndex.
 |  | ||||||
| 		n.firstChild = nextNodesIndex |  | ||||||
| 		for _, c := range n.children { |  | ||||||
| 			c.nodesIndex = nextNodesIndex |  | ||||||
| 			nextNodesIndex++ |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// The root node's children is implicit.
 |  | ||||||
| 		if firstCallToAssignIndexes { |  | ||||||
| 			firstCallToAssignIndexes = false |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Assign childrenIndex.
 |  | ||||||
| 		maxChildren = max(maxChildren, len(childrenEncoding)) |  | ||||||
| 		if len(childrenEncoding) >= 1<<nodesBitsChildren { |  | ||||||
| 			return fmt.Errorf("children table size %d is too large, or nodeBitsChildren is too small", len(childrenEncoding)) |  | ||||||
| 		} |  | ||||||
| 		n.childrenIndex = len(childrenEncoding) |  | ||||||
| 		lo := uint32(n.firstChild) |  | ||||||
| 		hi := lo + uint32(len(n.children)) |  | ||||||
| 		maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi) |  | ||||||
| 		if lo >= 1<<childrenBitsLo { |  | ||||||
| 			return fmt.Errorf("children lo %d is too large, or childrenBitsLo is too small", lo) |  | ||||||
| 		} |  | ||||||
| 		if hi >= 1<<childrenBitsHi { |  | ||||||
| 			return fmt.Errorf("children hi %d is too large, or childrenBitsHi is too small", hi) |  | ||||||
| 		} |  | ||||||
| 		enc := hi<<childrenBitsLo | lo |  | ||||||
| 		enc |= uint32(n.nodeType) << (childrenBitsLo + childrenBitsHi) |  | ||||||
| 		if n.wildcard { |  | ||||||
| 			enc |= 1 << (childrenBitsLo + childrenBitsHi + childrenBitsNodeType) |  | ||||||
| 		} |  | ||||||
| 		childrenEncoding = append(childrenEncoding, enc) |  | ||||||
| 	} else { |  | ||||||
| 		n.childrenIndex = n.nodeType |  | ||||||
| 		if n.wildcard { |  | ||||||
| 			n.childrenIndex += numNodeType |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func printNode(w io.Writer, n *node) error { |  | ||||||
| 	for _, c := range n.children { |  | ||||||
| 		s := "---------------" |  | ||||||
| 		if len(c.children) != 0 { |  | ||||||
| 			s = fmt.Sprintf("n0x%04x-n0x%04x", c.firstChild, c.firstChild+len(c.children)) |  | ||||||
| 		} |  | ||||||
| 		encoding := labelEncoding[c.label] |  | ||||||
| 		if c.icann { |  | ||||||
| 			encoding |= 1 << (nodesBitsTextLength + nodesBitsTextOffset) |  | ||||||
| 		} |  | ||||||
| 		encoding |= uint32(c.childrenIndex) << (nodesBitsTextLength + nodesBitsTextOffset + nodesBitsICANN) |  | ||||||
| 		fmt.Fprintf(w, "0x%08x, // n0x%04x c0x%04x (%s)%s %s %s %s\n", |  | ||||||
| 			encoding, c.nodesIndex, c.childrenIndex, s, wildcardStr(c.wildcard), |  | ||||||
| 			nodeTypeStr(c.nodeType), icannStr(c.icann), c.label, |  | ||||||
| 		) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func printNodeLabel(w io.Writer, n *node) error { |  | ||||||
| 	for _, c := range n.children { |  | ||||||
| 		fmt.Fprintf(w, "%q,\n", c.label) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func icannStr(icann bool) string { |  | ||||||
| 	if icann { |  | ||||||
| 		return "I" |  | ||||||
| 	} |  | ||||||
| 	return " " |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func wildcardStr(wildcard bool) string { |  | ||||||
| 	if wildcard { |  | ||||||
| 		return "*" |  | ||||||
| 	} |  | ||||||
| 	return " " |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // combineText combines all the strings in labelsList to form one giant string.
 |  | ||||||
| // Overlapping strings will be merged: "arpa" and "parliament" could yield
 |  | ||||||
| // "arparliament".
 |  | ||||||
| func combineText(labelsList []string) string { |  | ||||||
| 	beforeLength := 0 |  | ||||||
| 	for _, s := range labelsList { |  | ||||||
| 		beforeLength += len(s) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	text := crush(removeSubstrings(labelsList)) |  | ||||||
| 	if *v { |  | ||||||
| 		fmt.Fprintf(os.Stderr, "crushed %d bytes to become %d bytes\n", beforeLength, len(text)) |  | ||||||
| 	} |  | ||||||
| 	return text |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type byLength []string |  | ||||||
| 
 |  | ||||||
| func (s byLength) Len() int           { return len(s) } |  | ||||||
| func (s byLength) Swap(i, j int)      { s[i], s[j] = s[j], s[i] } |  | ||||||
| func (s byLength) Less(i, j int) bool { return len(s[i]) < len(s[j]) } |  | ||||||
| 
 |  | ||||||
| // removeSubstrings returns a copy of its input with any strings removed
 |  | ||||||
| // that are substrings of other provided strings.
 |  | ||||||
| func removeSubstrings(input []string) []string { |  | ||||||
| 	// Make a copy of input.
 |  | ||||||
| 	ss := append(make([]string, 0, len(input)), input...) |  | ||||||
| 	sort.Sort(byLength(ss)) |  | ||||||
| 
 |  | ||||||
| 	for i, shortString := range ss { |  | ||||||
| 		// For each string, only consider strings higher than it in sort order, i.e.
 |  | ||||||
| 		// of equal length or greater.
 |  | ||||||
| 		for _, longString := range ss[i+1:] { |  | ||||||
| 			if strings.Contains(longString, shortString) { |  | ||||||
| 				ss[i] = "" |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Remove the empty strings.
 |  | ||||||
| 	sort.Strings(ss) |  | ||||||
| 	for len(ss) > 0 && ss[0] == "" { |  | ||||||
| 		ss = ss[1:] |  | ||||||
| 	} |  | ||||||
| 	return ss |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // crush combines a list of strings, taking advantage of overlaps. It returns a
 |  | ||||||
| // single string that contains each input string as a substring.
 |  | ||||||
| func crush(ss []string) string { |  | ||||||
| 	maxLabelLen := 0 |  | ||||||
| 	for _, s := range ss { |  | ||||||
| 		if maxLabelLen < len(s) { |  | ||||||
| 			maxLabelLen = len(s) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for prefixLen := maxLabelLen; prefixLen > 0; prefixLen-- { |  | ||||||
| 		prefixes := makePrefixMap(ss, prefixLen) |  | ||||||
| 		for i, s := range ss { |  | ||||||
| 			if len(s) <= prefixLen { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			mergeLabel(ss, i, prefixLen, prefixes) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return strings.Join(ss, "") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // mergeLabel merges the label at ss[i] with the first available matching label
 |  | ||||||
| // in prefixMap, where the last "prefixLen" characters in ss[i] match the first
 |  | ||||||
| // "prefixLen" characters in the matching label.
 |  | ||||||
| // It will merge ss[i] repeatedly until no more matches are available.
 |  | ||||||
| // All matching labels merged into ss[i] are replaced by "".
 |  | ||||||
| func mergeLabel(ss []string, i, prefixLen int, prefixes prefixMap) { |  | ||||||
| 	s := ss[i] |  | ||||||
| 	suffix := s[len(s)-prefixLen:] |  | ||||||
| 	for _, j := range prefixes[suffix] { |  | ||||||
| 		// Empty strings mean "already used." Also avoid merging with self.
 |  | ||||||
| 		if ss[j] == "" || i == j { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if *v { |  | ||||||
| 			fmt.Fprintf(os.Stderr, "%d-length overlap at (%4d,%4d): %q and %q share %q\n", |  | ||||||
| 				prefixLen, i, j, ss[i], ss[j], suffix) |  | ||||||
| 		} |  | ||||||
| 		ss[i] += ss[j][prefixLen:] |  | ||||||
| 		ss[j] = "" |  | ||||||
| 		// ss[i] has a new suffix, so merge again if possible.
 |  | ||||||
| 		// Note: we only have to merge again at the same prefix length. Shorter
 |  | ||||||
| 		// prefix lengths will be handled in the next iteration of crush's for loop.
 |  | ||||||
| 		// Can there be matches for longer prefix lengths, introduced by the merge?
 |  | ||||||
| 		// I believe that any such matches would by necessity have been eliminated
 |  | ||||||
| 		// during substring removal or merged at a higher prefix length. For
 |  | ||||||
| 		// instance, in crush("abc", "cde", "bcdef"), combining "abc" and "cde"
 |  | ||||||
| 		// would yield "abcde", which could be merged with "bcdef." However, in
 |  | ||||||
| 		// practice "cde" would already have been elimintated by removeSubstrings.
 |  | ||||||
| 		mergeLabel(ss, i, prefixLen, prefixes) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // prefixMap maps from a prefix to a list of strings containing that prefix. The
 |  | ||||||
| // list of strings is represented as indexes into a slice of strings stored
 |  | ||||||
| // elsewhere.
 |  | ||||||
| type prefixMap map[string][]int |  | ||||||
| 
 |  | ||||||
| // makePrefixMap constructs a prefixMap from a slice of strings.
 |  | ||||||
| func makePrefixMap(ss []string, prefixLen int) prefixMap { |  | ||||||
| 	prefixes := make(prefixMap) |  | ||||||
| 	for i, s := range ss { |  | ||||||
| 		// We use < rather than <= because if a label matches on a prefix equal to
 |  | ||||||
| 		// its full length, that's actually a substring match handled by
 |  | ||||||
| 		// removeSubstrings.
 |  | ||||||
| 		if prefixLen < len(s) { |  | ||||||
| 			prefix := s[:prefixLen] |  | ||||||
| 			prefixes[prefix] = append(prefixes[prefix], i) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return prefixes |  | ||||||
| } |  | ||||||
		Loading…
	
		Reference in New Issue