Merge pull request #2447 from tifayuki/cloudfront-s3-filter
add s3 region filters for cloudfrontmaster
						commit
						f411848591
					
				|  | @ -35,3 +35,4 @@ bin/* | |||
| # Editor/IDE specific files. | ||||
| *.sublime-project | ||||
| *.sublime-workspace | ||||
| .idea/* | ||||
|  |  | |||
|  | @ -39,6 +39,8 @@ type Logger interface { | |||
| 	Warn(args ...interface{}) | ||||
| 	Warnf(format string, args ...interface{}) | ||||
| 	Warnln(args ...interface{}) | ||||
| 
 | ||||
| 	WithError(err error) *logrus.Entry | ||||
| } | ||||
| 
 | ||||
| type loggerKey struct{} | ||||
|  |  | |||
|  | @ -183,6 +183,10 @@ middleware: | |||
|         privatekey: /path/to/pem | ||||
|         keypairid: cloudfrontkeypairid | ||||
|         duration: 3000s | ||||
|         ipfilteredby: awsregion | ||||
|         awsregion: us-east-1, use-east-2 | ||||
|         updatefrenquency: 12h | ||||
|         iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json | ||||
|   storage: | ||||
|     - name: redirect | ||||
|       options: | ||||
|  | @ -636,6 +640,10 @@ middleware: | |||
|         privatekey: /path/to/pem | ||||
|         keypairid: cloudfrontkeypairid | ||||
|         duration: 3000s | ||||
|         ipfilteredby: awsregion | ||||
|         awsregion: us-east-1, use-east-2 | ||||
|         updatefrenquency: 12h | ||||
|         iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json | ||||
| ``` | ||||
| 
 | ||||
| Each middleware entry has `name` and `options` entries. The `name` must | ||||
|  | @ -655,6 +663,14 @@ interpretation of the options. | |||
| | `privatekey` | yes   | The private key for Cloudfront, provided by AWS.        | | ||||
| | `keypairid` | yes    | The key pair ID provided by AWS.                         | | ||||
| | `duration` | no      | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).| | ||||
| |`ipfilteredby`|no     | A string with the following value `none|aws|awsregion`. | | ||||
| |`awsregion`|no        | A comma separated string of AWS regions, only available when `ipfilteredby` is `awsregion`. For example, `us-east-1, us-west-2`| | ||||
| |`updatefrenquency`|no | The frequency to update AWS IP regions, default: `12h`| | ||||
| |`iprangesurl`|no      | The URL contains the AWS IP ranges information, default: `https://ip-ranges.amazonaws.com/ip-ranges.json`| | ||||
| Then value of ipfilteredby: | ||||
| `none`: default, do not filter by IP | ||||
| `aws`: IP from AWS goes to S3 directly | ||||
| `awsregion`: IP from certain AWS regions goes to S3 directly, use together with `awsregion` | ||||
| 
 | ||||
| ### `redirect` | ||||
| 
 | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ import ( | |||
| 	"github.com/aws/aws-sdk-go/service/cloudfront/sign" | ||||
| 	dcontext "github.com/docker/distribution/context" | ||||
| 	storagedriver "github.com/docker/distribution/registry/storage/driver" | ||||
| 	storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware" | ||||
| 	"github.com/docker/distribution/registry/storage/driver/middleware" | ||||
| ) | ||||
| 
 | ||||
| // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
 | ||||
|  | @ -24,6 +24,7 @@ import ( | |||
| // then issues HTTP Temporary Redirects to this CloudFront content URL.
 | ||||
| type cloudFrontStorageMiddleware struct { | ||||
| 	storagedriver.StorageDriver | ||||
| 	awsIPs    *awsIPs | ||||
| 	urlSigner *sign.URLSigner | ||||
| 	baseURL   string | ||||
| 	duration  time.Duration | ||||
|  | @ -34,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{} | |||
| // newCloudFrontLayerHandler constructs and returns a new CloudFront
 | ||||
| // LayerHandler implementation.
 | ||||
| // Required options: baseurl, privatekey, keypairid
 | ||||
| 
 | ||||
| // Optional options: ipFilteredBy, awsregion
 | ||||
| // ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
 | ||||
| //               to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
 | ||||
| // awsregion: a comma separated string of AWS regions.
 | ||||
| func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { | ||||
| 	// parse baseurl
 | ||||
| 	base, ok := options["baseurl"] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no baseurl provided") | ||||
|  | @ -52,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | |||
| 	if _, err := url.Parse(baseURL); err != nil { | ||||
| 		return nil, fmt.Errorf("invalid baseurl: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// parse privatekey to get pkPath
 | ||||
| 	pk, ok := options["privatekey"] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no privatekey provided") | ||||
|  | @ -60,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | |||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("privatekey must be a string") | ||||
| 	} | ||||
| 
 | ||||
| 	// parse keypairid
 | ||||
| 	kpid, ok := options["keypairid"] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no keypairid provided") | ||||
|  | @ -69,6 +80,7 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | |||
| 		return nil, fmt.Errorf("keypairid must be a string") | ||||
| 	} | ||||
| 
 | ||||
| 	// get urlSigner from the file specified in pkPath
 | ||||
| 	pkBytes, err := ioutil.ReadFile(pkPath) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to read privatekey file: %s", err) | ||||
|  | @ -82,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | |||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	urlSigner := sign.NewURLSigner(keypairID, privateKey) | ||||
| 
 | ||||
| 	// parse duration
 | ||||
| 	duration := 20 * time.Minute | ||||
| 	d, ok := options["duration"] | ||||
| 	if ok { | ||||
| 	if d, ok := options["duration"]; ok { | ||||
| 		switch d := d.(type) { | ||||
| 		case time.Duration: | ||||
| 			duration = d | ||||
|  | @ -100,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// parse updatefrenquency
 | ||||
| 	updateFrequency := defaultUpdateFrequency | ||||
| 	if u, ok := options["updatefrenquency"]; ok { | ||||
| 		switch u := u.(type) { | ||||
| 		case time.Duration: | ||||
| 			updateFrequency = u | ||||
| 		case string: | ||||
| 			updateFreq, err := time.ParseDuration(u) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf("invalid updatefrenquency: %s", err) | ||||
| 			} | ||||
| 			duration = updateFreq | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// parse iprangesurl
 | ||||
| 	ipRangesURL := defaultIPRangesURL | ||||
| 	if i, ok := options["iprangesurl"]; ok { | ||||
| 		if iprangeurl, ok := i.(string); ok { | ||||
| 			ipRangesURL = iprangeurl | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("iprangesurl must be a string") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// parse ipfilteredby
 | ||||
| 	var awsIPs *awsIPs | ||||
| 	if ipFilteredBy := options["ipfilteredby"].(string); ok { | ||||
| 		switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) { | ||||
| 		case "", "none": | ||||
| 			awsIPs = nil | ||||
| 		case "aws": | ||||
| 			newAWSIPs(ipRangesURL, updateFrequency, nil) | ||||
| 		case "awsregion": | ||||
| 			var awsRegion []string | ||||
| 			if regions, ok := options["awsregion"].(string); ok { | ||||
| 				for _, awsRegions := range strings.Split(regions, ",") { | ||||
| 					awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) | ||||
| 				} | ||||
| 				awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) | ||||
| 			} else { | ||||
| 				return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") | ||||
| 			} | ||||
| 		default: | ||||
| 			return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion") | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion") | ||||
| 	} | ||||
| 
 | ||||
| 	return &cloudFrontStorageMiddleware{ | ||||
| 		StorageDriver: storageDriver, | ||||
| 		urlSigner:     urlSigner, | ||||
| 		baseURL:       baseURL, | ||||
| 		duration:      duration, | ||||
| 		awsIPs:        awsIPs, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -114,8 +176,8 @@ type S3BucketKeyer interface { | |||
| 	S3BucketKey(path string) string | ||||
| } | ||||
| 
 | ||||
| // Resolve returns an http.Handler which can serve the contents of the given
 | ||||
| // Layer, or an error if not supported by the storagedriver.
 | ||||
| // URLFor attempts to find a url which may be used to retrieve the file at the given path.
 | ||||
| // Returns an error if the file cannot be found.
 | ||||
| func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { | ||||
| 	// TODO(endophage): currently only supports S3
 | ||||
| 	keyer, ok := lh.StorageDriver.(S3BucketKeyer) | ||||
|  | @ -124,6 +186,11 @@ func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, | |||
| 		return lh.StorageDriver.URLFor(ctx, path, options) | ||||
| 	} | ||||
| 
 | ||||
| 	if eligibleForS3(ctx, lh.awsIPs) { | ||||
| 		return lh.StorageDriver.URLFor(ctx, path, options) | ||||
| 	} | ||||
| 
 | ||||
| 	// Get signed cloudfront url.
 | ||||
| 	cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
|  |  | |||
|  | @ -0,0 +1,223 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	dcontext "github.com/docker/distribution/context" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// ipRangesURL is the URL to get definition of AWS IPs
 | ||||
| 	defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json" | ||||
| 	// updateFrequency tells how frequently AWS IPs need to be updated
 | ||||
| 	defaultUpdateFrequency = time.Hour * 12 | ||||
| ) | ||||
| 
 | ||||
| // newAWSIPs returns a New awsIP object.
 | ||||
| // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
 | ||||
| func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { | ||||
| 	ips := &awsIPs{ | ||||
| 		host:            host, | ||||
| 		updateFrequency: updateFrequency, | ||||
| 		awsRegion:       awsRegion, | ||||
| 		updaterStopChan: make(chan bool), | ||||
| 	} | ||||
| 	if err := ips.tryUpdate(); err != nil { | ||||
| 		dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") | ||||
| 	} | ||||
| 	go ips.updater() | ||||
| 	return ips | ||||
| } | ||||
| 
 | ||||
| // awsIPs tracks a list of AWS ips, filtered by awsRegion
 | ||||
| type awsIPs struct { | ||||
| 	host            string | ||||
| 	updateFrequency time.Duration | ||||
| 	ipv4            []net.IPNet | ||||
| 	ipv6            []net.IPNet | ||||
| 	mutex           sync.RWMutex | ||||
| 	awsRegion       []string | ||||
| 	updaterStopChan chan bool | ||||
| 	initialized     bool | ||||
| } | ||||
| 
 | ||||
| type awsIPResponse struct { | ||||
| 	Prefixes   []prefixEntry `json:"prefixes"` | ||||
| 	V6Prefixes []prefixEntry `json:"ipv6_prefixes"` | ||||
| } | ||||
| 
 | ||||
| type prefixEntry struct { | ||||
| 	IPV4Prefix string `json:"ip_prefix"` | ||||
| 	IPV6Prefix string `json:"ipv6_prefix"` | ||||
| 	Region     string `json:"region"` | ||||
| 	Service    string `json:"service"` | ||||
| } | ||||
| 
 | ||||
| func fetchAWSIPs(url string) (awsIPResponse, error) { | ||||
| 	var response awsIPResponse | ||||
| 	resp, err := http.Get(url) | ||||
| 	if err != nil { | ||||
| 		return response, err | ||||
| 	} | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		body, _ := ioutil.ReadAll(resp.Body) | ||||
| 		return response, fmt.Errorf("failed to fetch network data. response = %s", body) | ||||
| 	} | ||||
| 	decoder := json.NewDecoder(resp.Body) | ||||
| 	err = decoder.Decode(&response) | ||||
| 	if err != nil { | ||||
| 		return response, err | ||||
| 	} | ||||
| 	return response, nil | ||||
| } | ||||
| 
 | ||||
| // tryUpdate attempts to download the new set of ip addresses.
 | ||||
| // tryUpdate must be thread safe with contains
 | ||||
| func (s *awsIPs) tryUpdate() error { | ||||
| 	response, err := fetchAWSIPs(s.host) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	var ipv4 []net.IPNet | ||||
| 	var ipv6 []net.IPNet | ||||
| 
 | ||||
| 	processAddress := func(output *[]net.IPNet, prefix string, region string) { | ||||
| 		regionAllowed := false | ||||
| 		if len(s.awsRegion) > 0 { | ||||
| 			for _, ar := range s.awsRegion { | ||||
| 				if strings.ToLower(region) == ar { | ||||
| 					regionAllowed = true | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			regionAllowed = true | ||||
| 		} | ||||
| 
 | ||||
| 		_, network, err := net.ParseCIDR(prefix) | ||||
| 		if err != nil { | ||||
| 			dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ | ||||
| 				"cidr": prefix, | ||||
| 			}).Error("unparseable cidr") | ||||
| 			return | ||||
| 		} | ||||
| 		if regionAllowed { | ||||
| 			*output = append(*output, *network) | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	for _, prefix := range response.Prefixes { | ||||
| 		processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region) | ||||
| 	} | ||||
| 	for _, prefix := range response.V6Prefixes { | ||||
| 		processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region) | ||||
| 	} | ||||
| 	s.mutex.Lock() | ||||
| 	defer s.mutex.Unlock() | ||||
| 	// Update each attr of awsips atomically.
 | ||||
| 	s.ipv4 = ipv4 | ||||
| 	s.ipv6 = ipv6 | ||||
| 	s.initialized = true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // This function is meant to be run in a background goroutine.
 | ||||
| // It will periodically update the ips from aws.
 | ||||
| func (s *awsIPs) updater() { | ||||
| 	defer close(s.updaterStopChan) | ||||
| 	for { | ||||
| 		time.Sleep(s.updateFrequency) | ||||
| 		select { | ||||
| 		case <-s.updaterStopChan: | ||||
| 			dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") | ||||
| 			return | ||||
| 		default: | ||||
| 			err := s.tryUpdate() | ||||
| 			if err != nil { | ||||
| 				dcontext.GetLogger(context.Background()).WithError(err).Error("git  AWS IP") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // getCandidateNetworks returns either the ipv4 or ipv6 networks
 | ||||
| // that were last read from aws. The networks returned
 | ||||
| // have the same type as the ip address provided.
 | ||||
| func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet { | ||||
| 	s.mutex.RLock() | ||||
| 	defer s.mutex.RUnlock() | ||||
| 	if ip.To4() != nil { | ||||
| 		return s.ipv4 | ||||
| 	} else if ip.To16() != nil { | ||||
| 		return s.ipv6 | ||||
| 	} else { | ||||
| 		dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ | ||||
| 			"ip": ip, | ||||
| 		}).Error("unknown ip address format") | ||||
| 		// assume mismatch, pass through cloudfront
 | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Contains determines whether the host is within aws.
 | ||||
| func (s *awsIPs) contains(ip net.IP) bool { | ||||
| 	networks := s.getCandidateNetworks(ip) | ||||
| 	for _, network := range networks { | ||||
| 		if network.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // parseIPFromRequest attempts to extract the ip address of the
 | ||||
| // client that made the request
 | ||||
| func parseIPFromRequest(ctx context.Context) (net.IP, error) { | ||||
| 	request, err := dcontext.GetRequest(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ipStr := dcontext.RemoteIP(request) | ||||
| 	ip := net.ParseIP(ipStr) | ||||
| 	if ip == nil { | ||||
| 		return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) | ||||
| 	} | ||||
| 
 | ||||
| 	return ip, nil | ||||
| } | ||||
| 
 | ||||
| // eligibleForS3 checks if a request is eligible for using S3 directly
 | ||||
| // Return true only when the IP belongs to a specific aws region and user-agent is docker
 | ||||
| func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { | ||||
| 	if awsIPs != nil && awsIPs.initialized { | ||||
| 		if addr, err := parseIPFromRequest(ctx); err == nil { | ||||
| 			request, err := dcontext.GetRequest(ctx) | ||||
| 			if err != nil { | ||||
| 				dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) | ||||
| 			} else { | ||||
| 				loggerField := map[interface{}]interface{}{ | ||||
| 					"user-client": request.UserAgent(), | ||||
| 					"ip":          dcontext.RemoteIP(request), | ||||
| 				} | ||||
| 				if awsIPs.contains(addr) { | ||||
| 					dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") | ||||
| 					return true | ||||
| 				} | ||||
| 				dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") | ||||
| 			} | ||||
| 		} else { | ||||
| 			dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | @ -0,0 +1,401 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	dcontext "github.com/docker/distribution/context" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"reflect" // used as a replacement for testify
 | ||||
| ) | ||||
| 
 | ||||
| // Rather than pull in all of testify
 | ||||
| func assertEqual(t *testing.T, x, y interface{}) { | ||||
| 	if !reflect.DeepEqual(x, y) { | ||||
| 		t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type mockIPRangeHandler struct { | ||||
| 	data awsIPResponse | ||||
| } | ||||
| 
 | ||||
| func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 	bytes, err := json.Marshal(m.data) | ||||
| 	if err != nil { | ||||
| 		w.WriteHeader(500) | ||||
| 		return | ||||
| 	} | ||||
| 	w.Write(bytes) | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func newTestHandler(data awsIPResponse) *httptest.Server { | ||||
| 	return httptest.NewServer(mockIPRangeHandler{ | ||||
| 		data: data, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func serverIPRanges(server *httptest.Server) string { | ||||
| 	return fmt.Sprintf("%s/", server.URL) | ||||
| } | ||||
| 
 | ||||
| func setupTest(data awsIPResponse) *httptest.Server { | ||||
| 	// This is a basic schema which only claims the exact ip
 | ||||
| 	// is in aws.
 | ||||
| 	server := newTestHandler(data) | ||||
| 	return server | ||||
| } | ||||
| 
 | ||||
| func TestS3TryUpdate(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{IPV4Prefix: "123.231.123.231/32"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 
 | ||||
| 	assertEqual(t, 1, len(ips.ipv4)) | ||||
| 	assertEqual(t, 0, len(ips.ipv6)) | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV6(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		V6Prefixes: []prefixEntry{ | ||||
| 			{IPV6Prefix: "ff00::/16"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) | ||||
| 	assertEqual(t, 1, len(ips.ipv6)) | ||||
| 	assertEqual(t, 0, len(ips.ipv4)) | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV4(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{IPV4Prefix: "192.168.0.0/24"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV4_2(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{ | ||||
| 				IPV4Prefix: "192.168.0.0/24", | ||||
| 				Region:     "us-east-1", | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV4WithRegionMatched(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{ | ||||
| 				IPV4Prefix: "192.168.0.0/24", | ||||
| 				Region:     "us-east-1", | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV4WithRegionMatch_2(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{ | ||||
| 				IPV4Prefix: "192.168.0.0/24", | ||||
| 				Region:     "us-east-1", | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) | ||||
| 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||
| } | ||||
| 
 | ||||
| func TestMatchIPV4WithRegionNotMatched(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{ | ||||
| 				IPV4Prefix: "192.168.0.0/24", | ||||
| 				Region:     "us-east-1", | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) | ||||
| 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) | ||||
| } | ||||
| 
 | ||||
| func TestInvalidData(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// Invalid entries from aws should be ignored.
 | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{IPV4Prefix: "9000"}, | ||||
| 			{IPV4Prefix: "192.168.0.0/24"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 	ips.tryUpdate() | ||||
| 	assertEqual(t, 1, len(ips.ipv4)) | ||||
| } | ||||
| 
 | ||||
| func TestInvalidNetworkType(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	server := setupTest(awsIPResponse{ | ||||
| 		Prefixes: []prefixEntry{ | ||||
| 			{IPV4Prefix: "192.168.0.0/24"}, | ||||
| 		}, | ||||
| 		V6Prefixes: []prefixEntry{ | ||||
| 			{IPV6Prefix: "ff00::/8"}, | ||||
| 			{IPV6Prefix: "fe00::/8"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) | ||||
| 	assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
 | ||||
| 	assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4))))  // netv4 networks
 | ||||
| 	assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
 | ||||
| } | ||||
| 
 | ||||
| func TestParsing(t *testing.T) { | ||||
| 	var data = `{ | ||||
|       "prefixes": [{ | ||||
|         "ip_prefix": "192.168.0.0", | ||||
|         "region": "someregion", | ||||
|         "service": "s3"}], | ||||
|       "ipv6_prefixes": [{ | ||||
|         "ipv6_prefix": "2001:4860:4860::8888", | ||||
|         "region": "anotherregion", | ||||
|         "service": "ec2"}] | ||||
|     }` | ||||
| 	rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) }) | ||||
| 	t.Parallel() | ||||
| 	server := httptest.NewServer(rawMockHandler) | ||||
| 	defer server.Close() | ||||
| 	schema, err := fetchAWSIPs(server.URL) | ||||
| 
 | ||||
| 	assertEqual(t, nil, err) | ||||
| 	assertEqual(t, 1, len(schema.Prefixes)) | ||||
| 	assertEqual(t, prefixEntry{ | ||||
| 		IPV4Prefix: "192.168.0.0", | ||||
| 		Region:     "someregion", | ||||
| 		Service:    "s3", | ||||
| 	}, schema.Prefixes[0]) | ||||
| 	assertEqual(t, 1, len(schema.V6Prefixes)) | ||||
| 	assertEqual(t, prefixEntry{ | ||||
| 		IPV6Prefix: "2001:4860:4860::8888", | ||||
| 		Region:     "anotherregion", | ||||
| 		Service:    "ec2", | ||||
| 	}, schema.V6Prefixes[0]) | ||||
| } | ||||
| 
 | ||||
| func TestUpdateCalledRegularly(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	updateCount := 0 | ||||
| 	server := httptest.NewServer(http.HandlerFunc( | ||||
| 		func(rw http.ResponseWriter, req *http.Request) { | ||||
| 			updateCount++ | ||||
| 			rw.Write([]byte("ok")) | ||||
| 		})) | ||||
| 	defer server.Close() | ||||
| 	newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) | ||||
| 	time.Sleep(time.Second*4 + time.Millisecond*500) | ||||
| 	if updateCount < 4 { | ||||
| 		t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEligibleForS3(t *testing.T) { | ||||
| 	awsIPs := &awsIPs{ | ||||
| 		ipv4: []net.IPNet{{ | ||||
| 			IP:   net.ParseIP("192.168.1.1"), | ||||
| 			Mask: net.IPv4Mask(255, 255, 255, 0), | ||||
| 		}}, | ||||
| 		initialized: true, | ||||
| 	} | ||||
| 	empty := context.TODO() | ||||
| 	makeContext := func(ip string) context.Context { | ||||
| 		req := &http.Request{ | ||||
| 			RemoteAddr: ip, | ||||
| 		} | ||||
| 
 | ||||
| 		return dcontext.WithRequest(empty, req) | ||||
| 	} | ||||
| 
 | ||||
| 	cases := []struct { | ||||
| 		Context  context.Context | ||||
| 		Expected bool | ||||
| 	}{ | ||||
| 		{Context: empty, Expected: false}, | ||||
| 		{Context: makeContext("192.168.1.2"), Expected: true}, | ||||
| 		{Context: makeContext("192.168.0.2"), Expected: false}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, testCase := range cases { | ||||
| 		name := fmt.Sprintf("Client IP = %v", | ||||
| 			testCase.Context.Value("http.request.ip")) | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { | ||||
| 	awsIPs := &awsIPs{ | ||||
| 		ipv4: []net.IPNet{{ | ||||
| 			IP:   net.ParseIP("192.168.1.1"), | ||||
| 			Mask: net.IPv4Mask(255, 255, 255, 0), | ||||
| 		}}, | ||||
| 		initialized: false, | ||||
| 	} | ||||
| 	empty := context.TODO() | ||||
| 	makeContext := func(ip string) context.Context { | ||||
| 		req := &http.Request{ | ||||
| 			RemoteAddr: ip, | ||||
| 		} | ||||
| 
 | ||||
| 		return dcontext.WithRequest(empty, req) | ||||
| 	} | ||||
| 
 | ||||
| 	cases := []struct { | ||||
| 		Context  context.Context | ||||
| 		Expected bool | ||||
| 	}{ | ||||
| 		{Context: empty, Expected: false}, | ||||
| 		{Context: makeContext("192.168.1.2"), Expected: false}, | ||||
| 		{Context: makeContext("192.168.0.2"), Expected: false}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, testCase := range cases { | ||||
| 		name := fmt.Sprintf("Client IP = %v", | ||||
| 			testCase.Context.Value("http.request.ip")) | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // populate ips with a number of different ipv4 and ipv6 networks, for the purposes
 | ||||
| // of benchmarking contains() performance.
 | ||||
| func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) { | ||||
| 	generateNetworks := func(dest *[]net.IPNet, bytes int, count int) { | ||||
| 		for i := 0; i < count; i++ { | ||||
| 			ip := make([]byte, bytes) | ||||
| 			_, err := rand.Read(ip) | ||||
| 			if err != nil { | ||||
| 				b.Fatalf("failed to generate network for test : %s", err.Error()) | ||||
| 			} | ||||
| 			mask := make([]byte, bytes) | ||||
| 			for i := 0; i < bytes; i++ { | ||||
| 				mask[i] = 0xff | ||||
| 			} | ||||
| 			*dest = append(*dest, net.IPNet{ | ||||
| 				IP:   ip, | ||||
| 				Mask: mask, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	generateNetworks(&ips.ipv4, 4, ipv4Count) | ||||
| 	generateNetworks(&ips.ipv6, 16, ipv6Count) | ||||
| } | ||||
| 
 | ||||
| func BenchmarkContainsRandom(b *testing.B) { | ||||
| 	// Generate a random network configuration, of size comparable to
 | ||||
| 	// aws official networks list
 | ||||
| 	// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
 | ||||
| 	// 941
 | ||||
| 	numNetworksPerType := 1000 // keep in sync with the above
 | ||||
| 	// intentionally skip constructor when creating awsIPs, to avoid updater routine.
 | ||||
| 	// This benchmark is only concerned with contains() performance.
 | ||||
| 	awsIPs := awsIPs{} | ||||
| 	populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType) | ||||
| 
 | ||||
| 	ipv4 := make([][]byte, b.N) | ||||
| 	ipv6 := make([][]byte, b.N) | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		ipv4[i] = make([]byte, 4) | ||||
| 		ipv6[i] = make([]byte, 16) | ||||
| 		rand.Read(ipv4[i]) | ||||
| 		rand.Read(ipv6[i]) | ||||
| 	} | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		awsIPs.contains(ipv4[i]) | ||||
| 		awsIPs.contains(ipv6[i]) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BenchmarkContainsProd(b *testing.B) { | ||||
| 	awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) | ||||
| 	ipv4 := make([][]byte, b.N) | ||||
| 	ipv6 := make([][]byte, b.N) | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		ipv4[i] = make([]byte, 4) | ||||
| 		ipv6[i] = make([]byte, 16) | ||||
| 		rand.Read(ipv4[i]) | ||||
| 		rand.Read(ipv6[i]) | ||||
| 	} | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		awsIPs.contains(ipv4[i]) | ||||
| 		awsIPs.contains(ipv6[i]) | ||||
| 	} | ||||
| } | ||||
|  | @ -1,340 +0,0 @@ | |||
| //+build ignore
 | ||||
| 
 | ||||
| // msg_generate.go is meant to run with go generate. It will use
 | ||||
| // go/{importer,types} to track down all the RR struct types. Then for each type
 | ||||
| // it will generate pack/unpack methods based on the struct tags. The generated source is
 | ||||
| // written to zmsg.go, and is meant to be checked into git.
 | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"go/format" | ||||
| 	"go/importer" | ||||
| 	"go/types" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| var packageHdr = ` | ||||
| // *** DO NOT MODIFY ***
 | ||||
| // AUTOGENERATED BY go generate from msg_generate.go
 | ||||
| 
 | ||||
| package dns | ||||
| 
 | ||||
| ` | ||||
| 
 | ||||
| // getTypeStruct will take a type and the package scope, and return the
 | ||||
| // (innermost) struct if the type is considered a RR type (currently defined as
 | ||||
| // those structs beginning with a RR_Header, could be redefined as implementing
 | ||||
| // the RR interface). The bool return value indicates if embedded structs were
 | ||||
| // resolved.
 | ||||
| func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { | ||||
| 	st, ok := t.Underlying().(*types.Struct) | ||||
| 	if !ok { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { | ||||
| 		return st, false | ||||
| 	} | ||||
| 	if st.Field(0).Anonymous() { | ||||
| 		st, _ := getTypeStruct(st.Field(0).Type(), scope) | ||||
| 		return st, true | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
| 
 | ||||
| func main() { | ||||
| 	// Import and type-check the package
 | ||||
| 	pkg, err := importer.Default().Import("github.com/miekg/dns") | ||||
| 	fatalIfErr(err) | ||||
| 	scope := pkg.Scope() | ||||
| 
 | ||||
| 	// Collect actual types (*X)
 | ||||
| 	var namedTypes []string | ||||
| 	for _, name := range scope.Names() { | ||||
| 		o := scope.Lookup(name) | ||||
| 		if o == nil || !o.Exported() { | ||||
| 			continue | ||||
| 		} | ||||
| 		if st, _ := getTypeStruct(o.Type(), scope); st == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if name == "PrivateRR" { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// Check if corresponding TypeX exists
 | ||||
| 		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { | ||||
| 			log.Fatalf("Constant Type%s does not exist.", o.Name()) | ||||
| 		} | ||||
| 
 | ||||
| 		namedTypes = append(namedTypes, o.Name()) | ||||
| 	} | ||||
| 
 | ||||
| 	b := &bytes.Buffer{} | ||||
| 	b.WriteString(packageHdr) | ||||
| 
 | ||||
| 	fmt.Fprint(b, "// pack*() functions\n\n") | ||||
| 	for _, name := range namedTypes { | ||||
| 		o := scope.Lookup(name) | ||||
| 		st, _ := getTypeStruct(o.Type(), scope) | ||||
| 
 | ||||
| 		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) | ||||
| 		fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) | ||||
| if err != nil { | ||||
| 	return off, err | ||||
| } | ||||
| headerEnd := off | ||||
| `) | ||||
| 		for i := 1; i < st.NumFields(); i++ { | ||||
| 			o := func(s string) { | ||||
| 				fmt.Fprintf(b, s, st.Field(i).Name()) | ||||
| 				fmt.Fprint(b, `if err != nil { | ||||
| return off, err | ||||
| } | ||||
| `) | ||||
| 			} | ||||
| 
 | ||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { | ||||
| 				switch st.Tag(i) { | ||||
| 				case `dns:"-"`: // ignored
 | ||||
| 				case `dns:"txt"`: | ||||
| 					o("off, err = packStringTxt(rr.%s, msg, off)\n") | ||||
| 				case `dns:"opt"`: | ||||
| 					o("off, err = packDataOpt(rr.%s, msg, off)\n") | ||||
| 				case `dns:"nsec"`: | ||||
| 					o("off, err = packDataNsec(rr.%s, msg, off)\n") | ||||
| 				case `dns:"domain-name"`: | ||||
| 					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			switch { | ||||
| 			case st.Tag(i) == `dns:"-"`: // ignored
 | ||||
| 			case st.Tag(i) == `dns:"cdomain-name"`: | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"domain-name"`: | ||||
| 				o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") | ||||
| 			case st.Tag(i) == `dns:"a"`: | ||||
| 				o("off, err = packDataA(rr.%s, msg, off)\n") | ||||
| 			case st.Tag(i) == `dns:"aaaa"`: | ||||
| 				o("off, err = packDataAAAA(rr.%s, msg, off)\n") | ||||
| 			case st.Tag(i) == `dns:"uint48"`: | ||||
| 				o("off, err = packUint48(rr.%s, msg, off)\n") | ||||
| 			case st.Tag(i) == `dns:"txt"`: | ||||
| 				o("off, err = packString(rr.%s, msg, off)\n") | ||||
| 
 | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
 | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"base32"`: | ||||
| 				o("off, err = packStringBase32(rr.%s, msg, off)\n") | ||||
| 
 | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
 | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"base64"`: | ||||
| 				o("off, err = packStringBase64(rr.%s, msg, off)\n") | ||||
| 
 | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3
 | ||||
| 				o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n") | ||||
| 				continue | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
 | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"hex"`: | ||||
| 				o("off, err = packStringHex(rr.%s, msg, off)\n") | ||||
| 
 | ||||
| 			case st.Tag(i) == `dns:"octet"`: | ||||
| 				o("off, err = packStringOctet(rr.%s, msg, off)\n") | ||||
| 			case st.Tag(i) == "": | ||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { | ||||
| 				case types.Uint8: | ||||
| 					o("off, err = packUint8(rr.%s, msg, off)\n") | ||||
| 				case types.Uint16: | ||||
| 					o("off, err = packUint16(rr.%s, msg, off)\n") | ||||
| 				case types.Uint32: | ||||
| 					o("off, err = packUint32(rr.%s, msg, off)\n") | ||||
| 				case types.Uint64: | ||||
| 					o("off, err = packUint64(rr.%s, msg, off)\n") | ||||
| 				case types.String: | ||||
| 					o("off, err = packString(rr.%s, msg, off)\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name()) | ||||
| 				} | ||||
| 			default: | ||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 			} | ||||
| 		} | ||||
| 		// We have packed everything, only now we know the rdlength of this RR
 | ||||
| 		fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)") | ||||
| 		fmt.Fprintln(b, "return off, nil }\n") | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Fprint(b, "// unpack*() functions\n\n") | ||||
| 	for _, name := range namedTypes { | ||||
| 		o := scope.Lookup(name) | ||||
| 		st, _ := getTypeStruct(o.Type(), scope) | ||||
| 
 | ||||
| 		fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) | ||||
| 		fmt.Fprintf(b, "rr := new(%s)\n", name) | ||||
| 		fmt.Fprint(b, "rr.Hdr = h\n") | ||||
| 		fmt.Fprint(b, `if noRdata(h) { | ||||
| return rr, off, nil | ||||
| 	} | ||||
| var err error | ||||
| rdStart := off | ||||
| _ = rdStart | ||||
| 
 | ||||
| `) | ||||
| 		for i := 1; i < st.NumFields(); i++ { | ||||
| 			o := func(s string) { | ||||
| 				fmt.Fprintf(b, s, st.Field(i).Name()) | ||||
| 				fmt.Fprint(b, `if err != nil { | ||||
| return rr, off, err | ||||
| } | ||||
| `) | ||||
| 			} | ||||
| 
 | ||||
| 			// size-* are special, because they reference a struct member we should use for the length.
 | ||||
| 			if strings.HasPrefix(st.Tag(i), `dns:"size-`) { | ||||
| 				structMember := structMember(st.Tag(i)) | ||||
| 				structTag := structTag(st.Tag(i)) | ||||
| 				switch structTag { | ||||
| 				case "hex": | ||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) | ||||
| 				case "base32": | ||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) | ||||
| 				case "base64": | ||||
| 					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 				} | ||||
| 				fmt.Fprint(b, `if err != nil { | ||||
| return rr, off, err | ||||
| } | ||||
| `) | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { | ||||
| 				switch st.Tag(i) { | ||||
| 				case `dns:"-"`: // ignored
 | ||||
| 				case `dns:"txt"`: | ||||
| 					o("rr.%s, off, err = unpackStringTxt(msg, off)\n") | ||||
| 				case `dns:"opt"`: | ||||
| 					o("rr.%s, off, err = unpackDataOpt(msg, off)\n") | ||||
| 				case `dns:"nsec"`: | ||||
| 					o("rr.%s, off, err = unpackDataNsec(msg, off)\n") | ||||
| 				case `dns:"domain-name"`: | ||||
| 					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			switch st.Tag(i) { | ||||
| 			case `dns:"-"`: // ignored
 | ||||
| 			case `dns:"cdomain-name"`: | ||||
| 				fallthrough | ||||
| 			case `dns:"domain-name"`: | ||||
| 				o("rr.%s, off, err = UnpackDomainName(msg, off)\n") | ||||
| 			case `dns:"a"`: | ||||
| 				o("rr.%s, off, err = unpackDataA(msg, off)\n") | ||||
| 			case `dns:"aaaa"`: | ||||
| 				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") | ||||
| 			case `dns:"uint48"`: | ||||
| 				o("rr.%s, off, err = unpackUint48(msg, off)\n") | ||||
| 			case `dns:"txt"`: | ||||
| 				o("rr.%s, off, err = unpackString(msg, off)\n") | ||||
| 			case `dns:"base32"`: | ||||
| 				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") | ||||
| 			case `dns:"base64"`: | ||||
| 				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") | ||||
| 			case `dns:"hex"`: | ||||
| 				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") | ||||
| 			case `dns:"octet"`: | ||||
| 				o("rr.%s, off, err = unpackStringOctet(msg, off)\n") | ||||
| 			case "": | ||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { | ||||
| 				case types.Uint8: | ||||
| 					o("rr.%s, off, err = unpackUint8(msg, off)\n") | ||||
| 				case types.Uint16: | ||||
| 					o("rr.%s, off, err = unpackUint16(msg, off)\n") | ||||
| 				case types.Uint32: | ||||
| 					o("rr.%s, off, err = unpackUint32(msg, off)\n") | ||||
| 				case types.Uint64: | ||||
| 					o("rr.%s, off, err = unpackUint64(msg, off)\n") | ||||
| 				case types.String: | ||||
| 					o("rr.%s, off, err = unpackString(msg, off)\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name()) | ||||
| 				} | ||||
| 			default: | ||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 			} | ||||
| 			// If we've hit len(msg) we return without error.
 | ||||
| 			if i < st.NumFields()-1 { | ||||
| 				fmt.Fprintf(b, `if off == len(msg) { | ||||
| return rr, off, nil | ||||
| 	} | ||||
| `) | ||||
| 			} | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "return rr, off, err }\n\n") | ||||
| 	} | ||||
| 	// Generate typeToUnpack map
 | ||||
| 	fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){") | ||||
| 	for _, name := range namedTypes { | ||||
| 		if name == "RFC3597" { | ||||
| 			continue | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name) | ||||
| 	} | ||||
| 	fmt.Fprintln(b, "}\n") | ||||
| 
 | ||||
| 	// gofmt
 | ||||
| 	res, err := format.Source(b.Bytes()) | ||||
| 	if err != nil { | ||||
| 		b.WriteTo(os.Stderr) | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// write result
 | ||||
| 	f, err := os.Create("zmsg.go") | ||||
| 	fatalIfErr(err) | ||||
| 	defer f.Close() | ||||
| 	f.Write(res) | ||||
| } | ||||
| 
 | ||||
| // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
 | ||||
| func structMember(s string) string { | ||||
| 	fields := strings.Split(s, ":") | ||||
| 	if len(fields) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
| 	f := fields[len(fields)-1] | ||||
| 	// f should have a closing "
 | ||||
| 	if len(f) > 1 { | ||||
| 		return f[:len(f)-1] | ||||
| 	} | ||||
| 	return f | ||||
| } | ||||
| 
 | ||||
| // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
 | ||||
| func structTag(s string) string { | ||||
| 	fields := strings.Split(s, ":") | ||||
| 	if len(fields) < 2 { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return fields[1][len("\"size-"):] | ||||
| } | ||||
| 
 | ||||
| func fatalIfErr(err error) { | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | @ -1,271 +0,0 @@ | |||
| //+build ignore
 | ||||
| 
 | ||||
| // types_generate.go is meant to run with go generate. It will use
 | ||||
| // go/{importer,types} to track down all the RR struct types. Then for each type
 | ||||
| // it will generate conversion tables (TypeToRR and TypeToString) and banal
 | ||||
| // methods (len, Header, copy) based on the struct tags. The generated source is
 | ||||
| // written to ztypes.go, and is meant to be checked into git.
 | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"go/format" | ||||
| 	"go/importer" | ||||
| 	"go/types" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"text/template" | ||||
| ) | ||||
| 
 | ||||
| var skipLen = map[string]struct{}{ | ||||
| 	"NSEC":  {}, | ||||
| 	"NSEC3": {}, | ||||
| 	"OPT":   {}, | ||||
| } | ||||
| 
 | ||||
| var packageHdr = ` | ||||
| // *** DO NOT MODIFY ***
 | ||||
| // AUTOGENERATED BY go generate from type_generate.go
 | ||||
| 
 | ||||
| package dns | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"net" | ||||
| ) | ||||
| 
 | ||||
| ` | ||||
| 
 | ||||
| var TypeToRR = template.Must(template.New("TypeToRR").Parse(` | ||||
| // TypeToRR is a map of constructors for each RR type.
 | ||||
| var TypeToRR = map[uint16]func() RR{ | ||||
| {{range .}}{{if ne . "RFC3597"}}  Type{{.}}:  func() RR { return new({{.}}) }, | ||||
| {{end}}{{end}}                    } | ||||
| 
 | ||||
| `)) | ||||
| 
 | ||||
| var typeToString = template.Must(template.New("typeToString").Parse(` | ||||
| // TypeToString is a map of strings for each RR type.
 | ||||
| var TypeToString = map[uint16]string{ | ||||
| {{range .}}{{if ne . "NSAPPTR"}}  Type{{.}}: "{{.}}", | ||||
| {{end}}{{end}}                    TypeNSAPPTR:    "NSAP-PTR", | ||||
| } | ||||
| 
 | ||||
| `)) | ||||
| 
 | ||||
| var headerFunc = template.Must(template.New("headerFunc").Parse(` | ||||
| // Header() functions
 | ||||
| {{range .}}  func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } | ||||
| {{end}} | ||||
| 
 | ||||
| `)) | ||||
| 
 | ||||
| // getTypeStruct will take a type and the package scope, and return the
 | ||||
| // (innermost) struct if the type is considered a RR type (currently defined as
 | ||||
| // those structs beginning with a RR_Header, could be redefined as implementing
 | ||||
| // the RR interface). The bool return value indicates if embedded structs were
 | ||||
| // resolved.
 | ||||
| func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { | ||||
| 	st, ok := t.Underlying().(*types.Struct) | ||||
| 	if !ok { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { | ||||
| 		return st, false | ||||
| 	} | ||||
| 	if st.Field(0).Anonymous() { | ||||
| 		st, _ := getTypeStruct(st.Field(0).Type(), scope) | ||||
| 		return st, true | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
| 
 | ||||
| func main() { | ||||
| 	// Import and type-check the package
 | ||||
| 	pkg, err := importer.Default().Import("github.com/miekg/dns") | ||||
| 	fatalIfErr(err) | ||||
| 	scope := pkg.Scope() | ||||
| 
 | ||||
| 	// Collect constants like TypeX
 | ||||
| 	var numberedTypes []string | ||||
| 	for _, name := range scope.Names() { | ||||
| 		o := scope.Lookup(name) | ||||
| 		if o == nil || !o.Exported() { | ||||
| 			continue | ||||
| 		} | ||||
| 		b, ok := o.Type().(*types.Basic) | ||||
| 		if !ok || b.Kind() != types.Uint16 { | ||||
| 			continue | ||||
| 		} | ||||
| 		if !strings.HasPrefix(o.Name(), "Type") { | ||||
| 			continue | ||||
| 		} | ||||
| 		name := strings.TrimPrefix(o.Name(), "Type") | ||||
| 		if name == "PrivateRR" { | ||||
| 			continue | ||||
| 		} | ||||
| 		numberedTypes = append(numberedTypes, name) | ||||
| 	} | ||||
| 
 | ||||
| 	// Collect actual types (*X)
 | ||||
| 	var namedTypes []string | ||||
| 	for _, name := range scope.Names() { | ||||
| 		o := scope.Lookup(name) | ||||
| 		if o == nil || !o.Exported() { | ||||
| 			continue | ||||
| 		} | ||||
| 		if st, _ := getTypeStruct(o.Type(), scope); st == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if name == "PrivateRR" { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// Check if corresponding TypeX exists
 | ||||
| 		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { | ||||
| 			log.Fatalf("Constant Type%s does not exist.", o.Name()) | ||||
| 		} | ||||
| 
 | ||||
| 		namedTypes = append(namedTypes, o.Name()) | ||||
| 	} | ||||
| 
 | ||||
| 	b := &bytes.Buffer{} | ||||
| 	b.WriteString(packageHdr) | ||||
| 
 | ||||
| 	// Generate TypeToRR
 | ||||
| 	fatalIfErr(TypeToRR.Execute(b, namedTypes)) | ||||
| 
 | ||||
| 	// Generate typeToString
 | ||||
| 	fatalIfErr(typeToString.Execute(b, numberedTypes)) | ||||
| 
 | ||||
| 	// Generate headerFunc
 | ||||
| 	fatalIfErr(headerFunc.Execute(b, namedTypes)) | ||||
| 
 | ||||
| 	// Generate len()
 | ||||
| 	fmt.Fprint(b, "// len() functions\n") | ||||
| 	for _, name := range namedTypes { | ||||
| 		if _, ok := skipLen[name]; ok { | ||||
| 			continue | ||||
| 		} | ||||
| 		o := scope.Lookup(name) | ||||
| 		st, isEmbedded := getTypeStruct(o.Type(), scope) | ||||
| 		if isEmbedded { | ||||
| 			continue | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) | ||||
| 		fmt.Fprintf(b, "l := rr.Hdr.len()\n") | ||||
| 		for i := 1; i < st.NumFields(); i++ { | ||||
| 			o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } | ||||
| 
 | ||||
| 			if _, ok := st.Field(i).Type().(*types.Slice); ok { | ||||
| 				switch st.Tag(i) { | ||||
| 				case `dns:"-"`: | ||||
| 					// ignored
 | ||||
| 				case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: | ||||
| 					o("for _, x := range rr.%s { l += len(x) + 1 }\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			switch { | ||||
| 			case st.Tag(i) == `dns:"-"`: | ||||
| 				// ignored
 | ||||
| 			case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`: | ||||
| 				o("l += len(rr.%s) + 1\n") | ||||
| 			case st.Tag(i) == `dns:"octet"`: | ||||
| 				o("l += len(rr.%s)\n") | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"base64"`: | ||||
| 				o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") | ||||
| 			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): | ||||
| 				fallthrough | ||||
| 			case st.Tag(i) == `dns:"hex"`: | ||||
| 				o("l += len(rr.%s)/2 + 1\n") | ||||
| 			case st.Tag(i) == `dns:"a"`: | ||||
| 				o("l += net.IPv4len // %s\n") | ||||
| 			case st.Tag(i) == `dns:"aaaa"`: | ||||
| 				o("l += net.IPv6len // %s\n") | ||||
| 			case st.Tag(i) == `dns:"txt"`: | ||||
| 				o("for _, t := range rr.%s { l += len(t) + 1 }\n") | ||||
| 			case st.Tag(i) == `dns:"uint48"`: | ||||
| 				o("l += 6 // %s\n") | ||||
| 			case st.Tag(i) == "": | ||||
| 				switch st.Field(i).Type().(*types.Basic).Kind() { | ||||
| 				case types.Uint8: | ||||
| 					o("l += 1 // %s\n") | ||||
| 				case types.Uint16: | ||||
| 					o("l += 2 // %s\n") | ||||
| 				case types.Uint32: | ||||
| 					o("l += 4 // %s\n") | ||||
| 				case types.Uint64: | ||||
| 					o("l += 8 // %s\n") | ||||
| 				case types.String: | ||||
| 					o("l += len(rr.%s) + 1\n") | ||||
| 				default: | ||||
| 					log.Fatalln(name, st.Field(i).Name()) | ||||
| 				} | ||||
| 			default: | ||||
| 				log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) | ||||
| 			} | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "return l }\n") | ||||
| 	} | ||||
| 
 | ||||
| 	// Generate copy()
 | ||||
| 	fmt.Fprint(b, "// copy() functions\n") | ||||
| 	for _, name := range namedTypes { | ||||
| 		o := scope.Lookup(name) | ||||
| 		st, isEmbedded := getTypeStruct(o.Type(), scope) | ||||
| 		if isEmbedded { | ||||
| 			continue | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) | ||||
| 		fields := []string{"*rr.Hdr.copyHeader()"} | ||||
| 		for i := 1; i < st.NumFields(); i++ { | ||||
| 			f := st.Field(i).Name() | ||||
| 			if sl, ok := st.Field(i).Type().(*types.Slice); ok { | ||||
| 				t := sl.Underlying().String() | ||||
| 				t = strings.TrimPrefix(t, "[]") | ||||
| 				if strings.Contains(t, ".") { | ||||
| 					splits := strings.Split(t, ".") | ||||
| 					t = splits[len(splits)-1] | ||||
| 				} | ||||
| 				fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", | ||||
| 					f, t, f, f, f) | ||||
| 				fields = append(fields, f) | ||||
| 				continue | ||||
| 			} | ||||
| 			if st.Field(i).Type().String() == "net.IP" { | ||||
| 				fields = append(fields, "copyIP(rr."+f+")") | ||||
| 				continue | ||||
| 			} | ||||
| 			fields = append(fields, "rr."+f) | ||||
| 		} | ||||
| 		fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) | ||||
| 		fmt.Fprintf(b, "}\n") | ||||
| 	} | ||||
| 
 | ||||
| 	// gofmt
 | ||||
| 	res, err := format.Source(b.Bytes()) | ||||
| 	if err != nil { | ||||
| 		b.WriteTo(os.Stderr) | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// write result
 | ||||
| 	f, err := os.Create("ztypes.go") | ||||
| 	fatalIfErr(err) | ||||
| 	defer f.Close() | ||||
| 	f.Write(res) | ||||
| } | ||||
| 
 | ||||
| func fatalIfErr(err error) { | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | @ -1,68 +0,0 @@ | |||
| // Copyright 2012 The Go Authors. All rights reserved.
 | ||||
| // Use of this source code is governed by a BSD-style
 | ||||
| // license that can be found in the LICENSE file.
 | ||||
| 
 | ||||
| // Package idna implements IDNA2008 (Internationalized Domain Names for
 | ||||
| // Applications), defined in RFC 5890, RFC 5891, RFC 5892, RFC 5893 and
 | ||||
| // RFC 5894.
 | ||||
| package idna // import "golang.org/x/net/idna"
 | ||||
| 
 | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
| 
 | ||||
| // TODO(nigeltao): specify when errors occur. For example, is ToASCII(".") or
 | ||||
| // ToASCII("foo\x00") an error? See also http://www.unicode.org/faq/idn.html#11
 | ||||
| 
 | ||||
| // acePrefix is the ASCII Compatible Encoding prefix.
 | ||||
| const acePrefix = "xn--" | ||||
| 
 | ||||
| // ToASCII converts a domain or domain label to its ASCII form. For example,
 | ||||
| // ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
 | ||||
| // ToASCII("golang") is "golang".
 | ||||
| func ToASCII(s string) (string, error) { | ||||
| 	if ascii(s) { | ||||
| 		return s, nil | ||||
| 	} | ||||
| 	labels := strings.Split(s, ".") | ||||
| 	for i, label := range labels { | ||||
| 		if !ascii(label) { | ||||
| 			a, err := encode(acePrefix, label) | ||||
| 			if err != nil { | ||||
| 				return "", err | ||||
| 			} | ||||
| 			labels[i] = a | ||||
| 		} | ||||
| 	} | ||||
| 	return strings.Join(labels, "."), nil | ||||
| } | ||||
| 
 | ||||
| // ToUnicode converts a domain or domain label to its Unicode form. For example,
 | ||||
| // ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and
 | ||||
| // ToUnicode("golang") is "golang".
 | ||||
| func ToUnicode(s string) (string, error) { | ||||
| 	if !strings.Contains(s, acePrefix) { | ||||
| 		return s, nil | ||||
| 	} | ||||
| 	labels := strings.Split(s, ".") | ||||
| 	for i, label := range labels { | ||||
| 		if strings.HasPrefix(label, acePrefix) { | ||||
| 			u, err := decode(label[len(acePrefix):]) | ||||
| 			if err != nil { | ||||
| 				return "", err | ||||
| 			} | ||||
| 			labels[i] = u | ||||
| 		} | ||||
| 	} | ||||
| 	return strings.Join(labels, "."), nil | ||||
| } | ||||
| 
 | ||||
| func ascii(s string) bool { | ||||
| 	for i := 0; i < len(s); i++ { | ||||
| 		if s[i] >= utf8.RuneSelf { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | @ -1,200 +0,0 @@ | |||
| // Copyright 2012 The Go Authors. All rights reserved.
 | ||||
| // Use of this source code is governed by a BSD-style
 | ||||
| // license that can be found in the LICENSE file.
 | ||||
| 
 | ||||
| package idna | ||||
| 
 | ||||
| // This file implements the Punycode algorithm from RFC 3492.
 | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strings" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
| 
 | ||||
| // These parameter values are specified in section 5.
 | ||||
| //
 | ||||
| // All computation is done with int32s, so that overflow behavior is identical
 | ||||
| // regardless of whether int is 32-bit or 64-bit.
 | ||||
| const ( | ||||
| 	base        int32 = 36 | ||||
| 	damp        int32 = 700 | ||||
| 	initialBias int32 = 72 | ||||
| 	initialN    int32 = 128 | ||||
| 	skew        int32 = 38 | ||||
| 	tmax        int32 = 26 | ||||
| 	tmin        int32 = 1 | ||||
| ) | ||||
| 
 | ||||
| // decode decodes a string as specified in section 6.2.
 | ||||
| func decode(encoded string) (string, error) { | ||||
| 	if encoded == "" { | ||||
| 		return "", nil | ||||
| 	} | ||||
| 	pos := 1 + strings.LastIndex(encoded, "-") | ||||
| 	if pos == 1 { | ||||
| 		return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 	} | ||||
| 	if pos == len(encoded) { | ||||
| 		return encoded[:len(encoded)-1], nil | ||||
| 	} | ||||
| 	output := make([]rune, 0, len(encoded)) | ||||
| 	if pos != 0 { | ||||
| 		for _, r := range encoded[:pos-1] { | ||||
| 			output = append(output, r) | ||||
| 		} | ||||
| 	} | ||||
| 	i, n, bias := int32(0), initialN, initialBias | ||||
| 	for pos < len(encoded) { | ||||
| 		oldI, w := i, int32(1) | ||||
| 		for k := base; ; k += base { | ||||
| 			if pos == len(encoded) { | ||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 			} | ||||
| 			digit, ok := decodeDigit(encoded[pos]) | ||||
| 			if !ok { | ||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 			} | ||||
| 			pos++ | ||||
| 			i += digit * w | ||||
| 			if i < 0 { | ||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 			} | ||||
| 			t := k - bias | ||||
| 			if t < tmin { | ||||
| 				t = tmin | ||||
| 			} else if t > tmax { | ||||
| 				t = tmax | ||||
| 			} | ||||
| 			if digit < t { | ||||
| 				break | ||||
| 			} | ||||
| 			w *= base - t | ||||
| 			if w >= math.MaxInt32/base { | ||||
| 				return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 			} | ||||
| 		} | ||||
| 		x := int32(len(output) + 1) | ||||
| 		bias = adapt(i-oldI, x, oldI == 0) | ||||
| 		n += i / x | ||||
| 		i %= x | ||||
| 		if n > utf8.MaxRune || len(output) >= 1024 { | ||||
| 			return "", fmt.Errorf("idna: invalid label %q", encoded) | ||||
| 		} | ||||
| 		output = append(output, 0) | ||||
| 		copy(output[i+1:], output[i:]) | ||||
| 		output[i] = n | ||||
| 		i++ | ||||
| 	} | ||||
| 	return string(output), nil | ||||
| } | ||||
| 
 | ||||
| // encode encodes a string as specified in section 6.3 and prepends prefix to
 | ||||
| // the result.
 | ||||
| //
 | ||||
| // The "while h < length(input)" line in the specification becomes "for
 | ||||
| // remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
 | ||||
| func encode(prefix, s string) (string, error) { | ||||
| 	output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) | ||||
| 	copy(output, prefix) | ||||
| 	delta, n, bias := int32(0), initialN, initialBias | ||||
| 	b, remaining := int32(0), int32(0) | ||||
| 	for _, r := range s { | ||||
| 		if r < 0x80 { | ||||
| 			b++ | ||||
| 			output = append(output, byte(r)) | ||||
| 		} else { | ||||
| 			remaining++ | ||||
| 		} | ||||
| 	} | ||||
| 	h := b | ||||
| 	if b > 0 { | ||||
| 		output = append(output, '-') | ||||
| 	} | ||||
| 	for remaining != 0 { | ||||
| 		m := int32(0x7fffffff) | ||||
| 		for _, r := range s { | ||||
| 			if m > r && r >= n { | ||||
| 				m = r | ||||
| 			} | ||||
| 		} | ||||
| 		delta += (m - n) * (h + 1) | ||||
| 		if delta < 0 { | ||||
| 			return "", fmt.Errorf("idna: invalid label %q", s) | ||||
| 		} | ||||
| 		n = m | ||||
| 		for _, r := range s { | ||||
| 			if r < n { | ||||
| 				delta++ | ||||
| 				if delta < 0 { | ||||
| 					return "", fmt.Errorf("idna: invalid label %q", s) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 			if r > n { | ||||
| 				continue | ||||
| 			} | ||||
| 			q := delta | ||||
| 			for k := base; ; k += base { | ||||
| 				t := k - bias | ||||
| 				if t < tmin { | ||||
| 					t = tmin | ||||
| 				} else if t > tmax { | ||||
| 					t = tmax | ||||
| 				} | ||||
| 				if q < t { | ||||
| 					break | ||||
| 				} | ||||
| 				output = append(output, encodeDigit(t+(q-t)%(base-t))) | ||||
| 				q = (q - t) / (base - t) | ||||
| 			} | ||||
| 			output = append(output, encodeDigit(q)) | ||||
| 			bias = adapt(delta, h+1, h == b) | ||||
| 			delta = 0 | ||||
| 			h++ | ||||
| 			remaining-- | ||||
| 		} | ||||
| 		delta++ | ||||
| 		n++ | ||||
| 	} | ||||
| 	return string(output), nil | ||||
| } | ||||
| 
 | ||||
| func decodeDigit(x byte) (digit int32, ok bool) { | ||||
| 	switch { | ||||
| 	case '0' <= x && x <= '9': | ||||
| 		return int32(x - ('0' - 26)), true | ||||
| 	case 'A' <= x && x <= 'Z': | ||||
| 		return int32(x - 'A'), true | ||||
| 	case 'a' <= x && x <= 'z': | ||||
| 		return int32(x - 'a'), true | ||||
| 	} | ||||
| 	return 0, false | ||||
| } | ||||
| 
 | ||||
| func encodeDigit(digit int32) byte { | ||||
| 	switch { | ||||
| 	case 0 <= digit && digit < 26: | ||||
| 		return byte(digit + 'a') | ||||
| 	case 26 <= digit && digit < 36: | ||||
| 		return byte(digit + ('0' - 26)) | ||||
| 	} | ||||
| 	panic("idna: internal error in punycode encoding") | ||||
| } | ||||
| 
 | ||||
| // adapt is the bias adaptation function specified in section 6.1.
 | ||||
| func adapt(delta, numPoints int32, firstTime bool) int32 { | ||||
| 	if firstTime { | ||||
| 		delta /= damp | ||||
| 	} else { | ||||
| 		delta /= 2 | ||||
| 	} | ||||
| 	delta += delta / numPoints | ||||
| 	k := int32(0) | ||||
| 	for delta > ((base-tmin)*tmax)/2 { | ||||
| 		delta /= base - tmin | ||||
| 		k += base | ||||
| 	} | ||||
| 	return k + (base-tmin+1)*delta/(delta+skew) | ||||
| } | ||||
|  | @ -1,663 +0,0 @@ | |||
| // Copyright 2012 The Go Authors. All rights reserved.
 | ||||
| // Use of this source code is governed by a BSD-style
 | ||||
| // license that can be found in the LICENSE file.
 | ||||
| 
 | ||||
| // +build ignore
 | ||||
| 
 | ||||
| package main | ||||
| 
 | ||||
| // This program generates table.go and table_test.go.
 | ||||
| // Invoke as:
 | ||||
| //
 | ||||
| //	go run gen.go -version "xxx"       >table.go
 | ||||
| //	go run gen.go -version "xxx" -test >table_test.go
 | ||||
| //
 | ||||
| // Pass -v to print verbose progress information.
 | ||||
| //
 | ||||
| // The version is derived from information found at
 | ||||
| // https://github.com/publicsuffix/list/commits/master/public_suffix_list.dat
 | ||||
| //
 | ||||
| // To fetch a particular git revision, such as 5c70ccd250, pass
 | ||||
| // -url "https://raw.githubusercontent.com/publicsuffix/list/5c70ccd250/public_suffix_list.dat"
 | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"go/format" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"golang.org/x/net/idna" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// These sum of these four values must be no greater than 32.
 | ||||
| 	nodesBitsChildren   = 9 | ||||
| 	nodesBitsICANN      = 1 | ||||
| 	nodesBitsTextOffset = 15 | ||||
| 	nodesBitsTextLength = 6 | ||||
| 
 | ||||
| 	// These sum of these four values must be no greater than 32.
 | ||||
| 	childrenBitsWildcard = 1 | ||||
| 	childrenBitsNodeType = 2 | ||||
| 	childrenBitsHi       = 14 | ||||
| 	childrenBitsLo       = 14 | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	maxChildren   int | ||||
| 	maxTextOffset int | ||||
| 	maxTextLength int | ||||
| 	maxHi         uint32 | ||||
| 	maxLo         uint32 | ||||
| ) | ||||
| 
 | ||||
| func max(a, b int) int { | ||||
| 	if a < b { | ||||
| 		return b | ||||
| 	} | ||||
| 	return a | ||||
| } | ||||
| 
 | ||||
| func u32max(a, b uint32) uint32 { | ||||
| 	if a < b { | ||||
| 		return b | ||||
| 	} | ||||
| 	return a | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	nodeTypeNormal     = 0 | ||||
| 	nodeTypeException  = 1 | ||||
| 	nodeTypeParentOnly = 2 | ||||
| 	numNodeType        = 3 | ||||
| ) | ||||
| 
 | ||||
| func nodeTypeStr(n int) string { | ||||
| 	switch n { | ||||
| 	case nodeTypeNormal: | ||||
| 		return "+" | ||||
| 	case nodeTypeException: | ||||
| 		return "!" | ||||
| 	case nodeTypeParentOnly: | ||||
| 		return "o" | ||||
| 	} | ||||
| 	panic("unreachable") | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	labelEncoding = map[string]uint32{} | ||||
| 	labelsList    = []string{} | ||||
| 	labelsMap     = map[string]bool{} | ||||
| 	rules         = []string{} | ||||
| 
 | ||||
| 	// validSuffix is used to check that the entries in the public suffix list
 | ||||
| 	// are in canonical form (after Punycode encoding). Specifically, capital
 | ||||
| 	// letters are not allowed.
 | ||||
| 	validSuffix = regexp.MustCompile(`^[a-z0-9_\!\*\-\.]+$`) | ||||
| 
 | ||||
| 	subset = flag.Bool("subset", false, "generate only a subset of the full table, for debugging") | ||||
| 	url    = flag.String("url", | ||||
| 		"https://publicsuffix.org/list/effective_tld_names.dat", | ||||
| 		"URL of the publicsuffix.org list. If empty, stdin is read instead") | ||||
| 	v       = flag.Bool("v", false, "verbose output (to stderr)") | ||||
| 	version = flag.String("version", "", "the effective_tld_names.dat version") | ||||
| 	test    = flag.Bool("test", false, "generate table_test.go") | ||||
| ) | ||||
| 
 | ||||
| func main() { | ||||
| 	if err := main1(); err != nil { | ||||
| 		fmt.Fprintln(os.Stderr, err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func main1() error { | ||||
| 	flag.Parse() | ||||
| 	if nodesBitsTextLength+nodesBitsTextOffset+nodesBitsICANN+nodesBitsChildren > 32 { | ||||
| 		return fmt.Errorf("not enough bits to encode the nodes table") | ||||
| 	} | ||||
| 	if childrenBitsLo+childrenBitsHi+childrenBitsNodeType+childrenBitsWildcard > 32 { | ||||
| 		return fmt.Errorf("not enough bits to encode the children table") | ||||
| 	} | ||||
| 	if *version == "" { | ||||
| 		return fmt.Errorf("-version was not specified") | ||||
| 	} | ||||
| 	var r io.Reader = os.Stdin | ||||
| 	if *url != "" { | ||||
| 		res, err := http.Get(*url) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if res.StatusCode != http.StatusOK { | ||||
| 			return fmt.Errorf("bad GET status for %s: %d", *url, res.Status) | ||||
| 		} | ||||
| 		r = res.Body | ||||
| 		defer res.Body.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	var root node | ||||
| 	icann := false | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	br := bufio.NewReader(r) | ||||
| 	for { | ||||
| 		s, err := br.ReadString('\n') | ||||
| 		if err != nil { | ||||
| 			if err == io.EOF { | ||||
| 				break | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| 		s = strings.TrimSpace(s) | ||||
| 		if strings.Contains(s, "BEGIN ICANN DOMAINS") { | ||||
| 			icann = true | ||||
| 			continue | ||||
| 		} | ||||
| 		if strings.Contains(s, "END ICANN DOMAINS") { | ||||
| 			icann = false | ||||
| 			continue | ||||
| 		} | ||||
| 		if s == "" || strings.HasPrefix(s, "//") { | ||||
| 			continue | ||||
| 		} | ||||
| 		s, err = idna.ToASCII(s) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if !validSuffix.MatchString(s) { | ||||
| 			return fmt.Errorf("bad publicsuffix.org list data: %q", s) | ||||
| 		} | ||||
| 
 | ||||
| 		if *subset { | ||||
| 			switch { | ||||
| 			case s == "ac.jp" || strings.HasSuffix(s, ".ac.jp"): | ||||
| 			case s == "ak.us" || strings.HasSuffix(s, ".ak.us"): | ||||
| 			case s == "ao" || strings.HasSuffix(s, ".ao"): | ||||
| 			case s == "ar" || strings.HasSuffix(s, ".ar"): | ||||
| 			case s == "arpa" || strings.HasSuffix(s, ".arpa"): | ||||
| 			case s == "cy" || strings.HasSuffix(s, ".cy"): | ||||
| 			case s == "dyndns.org" || strings.HasSuffix(s, ".dyndns.org"): | ||||
| 			case s == "jp": | ||||
| 			case s == "kobe.jp" || strings.HasSuffix(s, ".kobe.jp"): | ||||
| 			case s == "kyoto.jp" || strings.HasSuffix(s, ".kyoto.jp"): | ||||
| 			case s == "om" || strings.HasSuffix(s, ".om"): | ||||
| 			case s == "uk" || strings.HasSuffix(s, ".uk"): | ||||
| 			case s == "uk.com" || strings.HasSuffix(s, ".uk.com"): | ||||
| 			case s == "tw" || strings.HasSuffix(s, ".tw"): | ||||
| 			case s == "zw" || strings.HasSuffix(s, ".zw"): | ||||
| 			case s == "xn--p1ai" || strings.HasSuffix(s, ".xn--p1ai"): | ||||
| 				// xn--p1ai is Russian-Cyrillic "рф".
 | ||||
| 			default: | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		rules = append(rules, s) | ||||
| 
 | ||||
| 		nt, wildcard := nodeTypeNormal, false | ||||
| 		switch { | ||||
| 		case strings.HasPrefix(s, "*."): | ||||
| 			s, nt = s[2:], nodeTypeParentOnly | ||||
| 			wildcard = true | ||||
| 		case strings.HasPrefix(s, "!"): | ||||
| 			s, nt = s[1:], nodeTypeException | ||||
| 		} | ||||
| 		labels := strings.Split(s, ".") | ||||
| 		for n, i := &root, len(labels)-1; i >= 0; i-- { | ||||
| 			label := labels[i] | ||||
| 			n = n.child(label) | ||||
| 			if i == 0 { | ||||
| 				if nt != nodeTypeParentOnly && n.nodeType == nodeTypeParentOnly { | ||||
| 					n.nodeType = nt | ||||
| 				} | ||||
| 				n.icann = n.icann && icann | ||||
| 				n.wildcard = n.wildcard || wildcard | ||||
| 			} | ||||
| 			labelsMap[label] = true | ||||
| 		} | ||||
| 	} | ||||
| 	labelsList = make([]string, 0, len(labelsMap)) | ||||
| 	for label := range labelsMap { | ||||
| 		labelsList = append(labelsList, label) | ||||
| 	} | ||||
| 	sort.Strings(labelsList) | ||||
| 
 | ||||
| 	p := printReal | ||||
| 	if *test { | ||||
| 		p = printTest | ||||
| 	} | ||||
| 	if err := p(buf, &root); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	b, err := format.Source(buf.Bytes()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = os.Stdout.Write(b) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func printTest(w io.Writer, n *node) error { | ||||
| 	fmt.Fprintf(w, "// generated by go run gen.go; DO NOT EDIT\n\n") | ||||
| 	fmt.Fprintf(w, "package publicsuffix\n\nvar rules = [...]string{\n") | ||||
| 	for _, rule := range rules { | ||||
| 		fmt.Fprintf(w, "%q,\n", rule) | ||||
| 	} | ||||
| 	fmt.Fprintf(w, "}\n\nvar nodeLabels = [...]string{\n") | ||||
| 	if err := n.walk(w, printNodeLabel); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fmt.Fprintf(w, "}\n") | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func printReal(w io.Writer, n *node) error { | ||||
| 	const header = `// generated by go run gen.go; DO NOT EDIT
 | ||||
| 
 | ||||
| package publicsuffix | ||||
| 
 | ||||
| const version = %q | ||||
| 
 | ||||
| const ( | ||||
| 	nodesBitsChildren   = %d | ||||
| 	nodesBitsICANN      = %d | ||||
| 	nodesBitsTextOffset = %d | ||||
| 	nodesBitsTextLength = %d | ||||
| 
 | ||||
| 	childrenBitsWildcard = %d | ||||
| 	childrenBitsNodeType = %d | ||||
| 	childrenBitsHi       = %d | ||||
| 	childrenBitsLo       = %d | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	nodeTypeNormal     = %d | ||||
| 	nodeTypeException  = %d | ||||
| 	nodeTypeParentOnly = %d | ||||
| ) | ||||
| 
 | ||||
| // numTLD is the number of top level domains.
 | ||||
| const numTLD = %d | ||||
| 
 | ||||
| ` | ||||
| 	fmt.Fprintf(w, header, *version, | ||||
| 		nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength, | ||||
| 		childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo, | ||||
| 		nodeTypeNormal, nodeTypeException, nodeTypeParentOnly, len(n.children)) | ||||
| 
 | ||||
| 	text := combineText(labelsList) | ||||
| 	if text == "" { | ||||
| 		return fmt.Errorf("internal error: makeText returned no text") | ||||
| 	} | ||||
| 	for _, label := range labelsList { | ||||
| 		offset, length := strings.Index(text, label), len(label) | ||||
| 		if offset < 0 { | ||||
| 			return fmt.Errorf("internal error: could not find %q in text %q", label, text) | ||||
| 		} | ||||
| 		maxTextOffset, maxTextLength = max(maxTextOffset, offset), max(maxTextLength, length) | ||||
| 		if offset >= 1<<nodesBitsTextOffset { | ||||
| 			return fmt.Errorf("text offset %d is too large, or nodeBitsTextOffset is too small", offset) | ||||
| 		} | ||||
| 		if length >= 1<<nodesBitsTextLength { | ||||
| 			return fmt.Errorf("text length %d is too large, or nodeBitsTextLength is too small", length) | ||||
| 		} | ||||
| 		labelEncoding[label] = uint32(offset)<<nodesBitsTextLength | uint32(length) | ||||
| 	} | ||||
| 	fmt.Fprintf(w, "// Text is the combined text of all labels.\nconst text = ") | ||||
| 	for len(text) > 0 { | ||||
| 		n, plus := len(text), "" | ||||
| 		if n > 64 { | ||||
| 			n, plus = 64, " +" | ||||
| 		} | ||||
| 		fmt.Fprintf(w, "%q%s\n", text[:n], plus) | ||||
| 		text = text[n:] | ||||
| 	} | ||||
| 
 | ||||
| 	if err := n.walk(w, assignIndexes); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Fprintf(w, ` | ||||
| 
 | ||||
| // nodes is the list of nodes. Each node is represented as a uint32, which
 | ||||
| // encodes the node's children, wildcard bit and node type (as an index into
 | ||||
| // the children array), ICANN bit and text.
 | ||||
| //
 | ||||
| // In the //-comment after each node's data, the nodes indexes of the children
 | ||||
| // are formatted as (n0x1234-n0x1256), with * denoting the wildcard bit. The
 | ||||
| // nodeType is printed as + for normal, ! for exception, and o for parent-only
 | ||||
| // nodes that have children but don't match a domain label in their own right.
 | ||||
| // An I denotes an ICANN domain.
 | ||||
| //
 | ||||
| // The layout within the uint32, from MSB to LSB, is:
 | ||||
| //	[%2d bits] unused
 | ||||
| //	[%2d bits] children index
 | ||||
| //	[%2d bits] ICANN bit
 | ||||
| //	[%2d bits] text index
 | ||||
| //	[%2d bits] text length
 | ||||
| var nodes = [...]uint32{ | ||||
| `, | ||||
| 		32-nodesBitsChildren-nodesBitsICANN-nodesBitsTextOffset-nodesBitsTextLength, | ||||
| 		nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength) | ||||
| 	if err := n.walk(w, printNode); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fmt.Fprintf(w, `} | ||||
| 
 | ||||
| // children is the list of nodes' children, the parent's wildcard bit and the
 | ||||
| // parent's node type. If a node has no children then their children index
 | ||||
| // will be in the range [0, 6), depending on the wildcard bit and node type.
 | ||||
| //
 | ||||
| // The layout within the uint32, from MSB to LSB, is:
 | ||||
| //	[%2d bits] unused
 | ||||
| //	[%2d bits] wildcard bit
 | ||||
| //	[%2d bits] node type
 | ||||
| //	[%2d bits] high nodes index (exclusive) of children
 | ||||
| //	[%2d bits] low nodes index (inclusive) of children
 | ||||
| var children=[...]uint32{ | ||||
| `, | ||||
| 		32-childrenBitsWildcard-childrenBitsNodeType-childrenBitsHi-childrenBitsLo, | ||||
| 		childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo) | ||||
| 	for i, c := range childrenEncoding { | ||||
| 		s := "---------------" | ||||
| 		lo := c & (1<<childrenBitsLo - 1) | ||||
| 		hi := (c >> childrenBitsLo) & (1<<childrenBitsHi - 1) | ||||
| 		if lo != hi { | ||||
| 			s = fmt.Sprintf("n0x%04x-n0x%04x", lo, hi) | ||||
| 		} | ||||
| 		nodeType := int(c>>(childrenBitsLo+childrenBitsHi)) & (1<<childrenBitsNodeType - 1) | ||||
| 		wildcard := c>>(childrenBitsLo+childrenBitsHi+childrenBitsNodeType) != 0 | ||||
| 		fmt.Fprintf(w, "0x%08x, // c0x%04x (%s)%s %s\n", | ||||
| 			c, i, s, wildcardStr(wildcard), nodeTypeStr(nodeType)) | ||||
| 	} | ||||
| 	fmt.Fprintf(w, "}\n\n") | ||||
| 	fmt.Fprintf(w, "// max children %d (capacity %d)\n", maxChildren, 1<<nodesBitsChildren-1) | ||||
| 	fmt.Fprintf(w, "// max text offset %d (capacity %d)\n", maxTextOffset, 1<<nodesBitsTextOffset-1) | ||||
| 	fmt.Fprintf(w, "// max text length %d (capacity %d)\n", maxTextLength, 1<<nodesBitsTextLength-1) | ||||
| 	fmt.Fprintf(w, "// max hi %d (capacity %d)\n", maxHi, 1<<childrenBitsHi-1) | ||||
| 	fmt.Fprintf(w, "// max lo %d (capacity %d)\n", maxLo, 1<<childrenBitsLo-1) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type node struct { | ||||
| 	label    string | ||||
| 	nodeType int | ||||
| 	icann    bool | ||||
| 	wildcard bool | ||||
| 	// nodesIndex and childrenIndex are the index of this node in the nodes
 | ||||
| 	// and the index of its children offset/length in the children arrays.
 | ||||
| 	nodesIndex, childrenIndex int | ||||
| 	// firstChild is the index of this node's first child, or zero if this
 | ||||
| 	// node has no children.
 | ||||
| 	firstChild int | ||||
| 	// children are the node's children, in strictly increasing node label order.
 | ||||
| 	children []*node | ||||
| } | ||||
| 
 | ||||
| func (n *node) walk(w io.Writer, f func(w1 io.Writer, n1 *node) error) error { | ||||
| 	if err := f(w, n); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, c := range n.children { | ||||
| 		if err := c.walk(w, f); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // child returns the child of n with the given label. The child is created if
 | ||||
| // it did not exist beforehand.
 | ||||
| func (n *node) child(label string) *node { | ||||
| 	for _, c := range n.children { | ||||
| 		if c.label == label { | ||||
| 			return c | ||||
| 		} | ||||
| 	} | ||||
| 	c := &node{ | ||||
| 		label:    label, | ||||
| 		nodeType: nodeTypeParentOnly, | ||||
| 		icann:    true, | ||||
| 	} | ||||
| 	n.children = append(n.children, c) | ||||
| 	sort.Sort(byLabel(n.children)) | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| type byLabel []*node | ||||
| 
 | ||||
| func (b byLabel) Len() int           { return len(b) } | ||||
| func (b byLabel) Swap(i, j int)      { b[i], b[j] = b[j], b[i] } | ||||
| func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label } | ||||
| 
 | ||||
| var nextNodesIndex int | ||||
| 
 | ||||
| // childrenEncoding are the encoded entries in the generated children array.
 | ||||
| // All these pre-defined entries have no children.
 | ||||
| var childrenEncoding = []uint32{ | ||||
| 	0 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeNormal.
 | ||||
| 	1 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeException.
 | ||||
| 	2 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeParentOnly.
 | ||||
| 	4 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeNormal.
 | ||||
| 	5 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeException.
 | ||||
| 	6 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeParentOnly.
 | ||||
| } | ||||
| 
 | ||||
| var firstCallToAssignIndexes = true | ||||
| 
 | ||||
| func assignIndexes(w io.Writer, n *node) error { | ||||
| 	if len(n.children) != 0 { | ||||
| 		// Assign nodesIndex.
 | ||||
| 		n.firstChild = nextNodesIndex | ||||
| 		for _, c := range n.children { | ||||
| 			c.nodesIndex = nextNodesIndex | ||||
| 			nextNodesIndex++ | ||||
| 		} | ||||
| 
 | ||||
| 		// The root node's children is implicit.
 | ||||
| 		if firstCallToAssignIndexes { | ||||
| 			firstCallToAssignIndexes = false | ||||
| 			return nil | ||||
| 		} | ||||
| 
 | ||||
| 		// Assign childrenIndex.
 | ||||
| 		maxChildren = max(maxChildren, len(childrenEncoding)) | ||||
| 		if len(childrenEncoding) >= 1<<nodesBitsChildren { | ||||
| 			return fmt.Errorf("children table size %d is too large, or nodeBitsChildren is too small", len(childrenEncoding)) | ||||
| 		} | ||||
| 		n.childrenIndex = len(childrenEncoding) | ||||
| 		lo := uint32(n.firstChild) | ||||
| 		hi := lo + uint32(len(n.children)) | ||||
| 		maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi) | ||||
| 		if lo >= 1<<childrenBitsLo { | ||||
| 			return fmt.Errorf("children lo %d is too large, or childrenBitsLo is too small", lo) | ||||
| 		} | ||||
| 		if hi >= 1<<childrenBitsHi { | ||||
| 			return fmt.Errorf("children hi %d is too large, or childrenBitsHi is too small", hi) | ||||
| 		} | ||||
| 		enc := hi<<childrenBitsLo | lo | ||||
| 		enc |= uint32(n.nodeType) << (childrenBitsLo + childrenBitsHi) | ||||
| 		if n.wildcard { | ||||
| 			enc |= 1 << (childrenBitsLo + childrenBitsHi + childrenBitsNodeType) | ||||
| 		} | ||||
| 		childrenEncoding = append(childrenEncoding, enc) | ||||
| 	} else { | ||||
| 		n.childrenIndex = n.nodeType | ||||
| 		if n.wildcard { | ||||
| 			n.childrenIndex += numNodeType | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func printNode(w io.Writer, n *node) error { | ||||
| 	for _, c := range n.children { | ||||
| 		s := "---------------" | ||||
| 		if len(c.children) != 0 { | ||||
| 			s = fmt.Sprintf("n0x%04x-n0x%04x", c.firstChild, c.firstChild+len(c.children)) | ||||
| 		} | ||||
| 		encoding := labelEncoding[c.label] | ||||
| 		if c.icann { | ||||
| 			encoding |= 1 << (nodesBitsTextLength + nodesBitsTextOffset) | ||||
| 		} | ||||
| 		encoding |= uint32(c.childrenIndex) << (nodesBitsTextLength + nodesBitsTextOffset + nodesBitsICANN) | ||||
| 		fmt.Fprintf(w, "0x%08x, // n0x%04x c0x%04x (%s)%s %s %s %s\n", | ||||
| 			encoding, c.nodesIndex, c.childrenIndex, s, wildcardStr(c.wildcard), | ||||
| 			nodeTypeStr(c.nodeType), icannStr(c.icann), c.label, | ||||
| 		) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func printNodeLabel(w io.Writer, n *node) error { | ||||
| 	for _, c := range n.children { | ||||
| 		fmt.Fprintf(w, "%q,\n", c.label) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func icannStr(icann bool) string { | ||||
| 	if icann { | ||||
| 		return "I" | ||||
| 	} | ||||
| 	return " " | ||||
| } | ||||
| 
 | ||||
| func wildcardStr(wildcard bool) string { | ||||
| 	if wildcard { | ||||
| 		return "*" | ||||
| 	} | ||||
| 	return " " | ||||
| } | ||||
| 
 | ||||
| // combineText combines all the strings in labelsList to form one giant string.
 | ||||
| // Overlapping strings will be merged: "arpa" and "parliament" could yield
 | ||||
| // "arparliament".
 | ||||
| func combineText(labelsList []string) string { | ||||
| 	beforeLength := 0 | ||||
| 	for _, s := range labelsList { | ||||
| 		beforeLength += len(s) | ||||
| 	} | ||||
| 
 | ||||
| 	text := crush(removeSubstrings(labelsList)) | ||||
| 	if *v { | ||||
| 		fmt.Fprintf(os.Stderr, "crushed %d bytes to become %d bytes\n", beforeLength, len(text)) | ||||
| 	} | ||||
| 	return text | ||||
| } | ||||
| 
 | ||||
| type byLength []string | ||||
| 
 | ||||
| func (s byLength) Len() int           { return len(s) } | ||||
| func (s byLength) Swap(i, j int)      { s[i], s[j] = s[j], s[i] } | ||||
| func (s byLength) Less(i, j int) bool { return len(s[i]) < len(s[j]) } | ||||
| 
 | ||||
| // removeSubstrings returns a copy of its input with any strings removed
 | ||||
| // that are substrings of other provided strings.
 | ||||
| func removeSubstrings(input []string) []string { | ||||
| 	// Make a copy of input.
 | ||||
| 	ss := append(make([]string, 0, len(input)), input...) | ||||
| 	sort.Sort(byLength(ss)) | ||||
| 
 | ||||
| 	for i, shortString := range ss { | ||||
| 		// For each string, only consider strings higher than it in sort order, i.e.
 | ||||
| 		// of equal length or greater.
 | ||||
| 		for _, longString := range ss[i+1:] { | ||||
| 			if strings.Contains(longString, shortString) { | ||||
| 				ss[i] = "" | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Remove the empty strings.
 | ||||
| 	sort.Strings(ss) | ||||
| 	for len(ss) > 0 && ss[0] == "" { | ||||
| 		ss = ss[1:] | ||||
| 	} | ||||
| 	return ss | ||||
| } | ||||
| 
 | ||||
| // crush combines a list of strings, taking advantage of overlaps. It returns a
 | ||||
| // single string that contains each input string as a substring.
 | ||||
| func crush(ss []string) string { | ||||
| 	maxLabelLen := 0 | ||||
| 	for _, s := range ss { | ||||
| 		if maxLabelLen < len(s) { | ||||
| 			maxLabelLen = len(s) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for prefixLen := maxLabelLen; prefixLen > 0; prefixLen-- { | ||||
| 		prefixes := makePrefixMap(ss, prefixLen) | ||||
| 		for i, s := range ss { | ||||
| 			if len(s) <= prefixLen { | ||||
| 				continue | ||||
| 			} | ||||
| 			mergeLabel(ss, i, prefixLen, prefixes) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return strings.Join(ss, "") | ||||
| } | ||||
| 
 | ||||
| // mergeLabel merges the label at ss[i] with the first available matching label
 | ||||
| // in prefixMap, where the last "prefixLen" characters in ss[i] match the first
 | ||||
| // "prefixLen" characters in the matching label.
 | ||||
| // It will merge ss[i] repeatedly until no more matches are available.
 | ||||
| // All matching labels merged into ss[i] are replaced by "".
 | ||||
| func mergeLabel(ss []string, i, prefixLen int, prefixes prefixMap) { | ||||
| 	s := ss[i] | ||||
| 	suffix := s[len(s)-prefixLen:] | ||||
| 	for _, j := range prefixes[suffix] { | ||||
| 		// Empty strings mean "already used." Also avoid merging with self.
 | ||||
| 		if ss[j] == "" || i == j { | ||||
| 			continue | ||||
| 		} | ||||
| 		if *v { | ||||
| 			fmt.Fprintf(os.Stderr, "%d-length overlap at (%4d,%4d): %q and %q share %q\n", | ||||
| 				prefixLen, i, j, ss[i], ss[j], suffix) | ||||
| 		} | ||||
| 		ss[i] += ss[j][prefixLen:] | ||||
| 		ss[j] = "" | ||||
| 		// ss[i] has a new suffix, so merge again if possible.
 | ||||
| 		// Note: we only have to merge again at the same prefix length. Shorter
 | ||||
| 		// prefix lengths will be handled in the next iteration of crush's for loop.
 | ||||
| 		// Can there be matches for longer prefix lengths, introduced by the merge?
 | ||||
| 		// I believe that any such matches would by necessity have been eliminated
 | ||||
| 		// during substring removal or merged at a higher prefix length. For
 | ||||
| 		// instance, in crush("abc", "cde", "bcdef"), combining "abc" and "cde"
 | ||||
| 		// would yield "abcde", which could be merged with "bcdef." However, in
 | ||||
| 		// practice "cde" would already have been elimintated by removeSubstrings.
 | ||||
| 		mergeLabel(ss, i, prefixLen, prefixes) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // prefixMap maps from a prefix to a list of strings containing that prefix. The
 | ||||
| // list of strings is represented as indexes into a slice of strings stored
 | ||||
| // elsewhere.
 | ||||
| type prefixMap map[string][]int | ||||
| 
 | ||||
| // makePrefixMap constructs a prefixMap from a slice of strings.
 | ||||
| func makePrefixMap(ss []string, prefixLen int) prefixMap { | ||||
| 	prefixes := make(prefixMap) | ||||
| 	for i, s := range ss { | ||||
| 		// We use < rather than <= because if a label matches on a prefix equal to
 | ||||
| 		// its full length, that's actually a substring match handled by
 | ||||
| 		// removeSubstrings.
 | ||||
| 		if prefixLen < len(s) { | ||||
| 			prefix := s[:prefixLen] | ||||
| 			prefixes[prefix] = append(prefixes[prefix], i) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return prefixes | ||||
| } | ||||
		Loading…
	
		Reference in New Issue