456 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			456 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
// Copyright 2012 Gary Burd
 | 
						|
//
 | 
						|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
 | 
						|
// not use this file except in compliance with the License. You may obtain
 | 
						|
// a copy of the License at
 | 
						|
//
 | 
						|
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
//
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
						|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
						|
// License for the specific language governing permissions and limitations
 | 
						|
// under the License.
 | 
						|
 | 
						|
package redis
 | 
						|
 | 
						|
import (
 | 
						|
	"bufio"
 | 
						|
	"bytes"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"strconv"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
// conn is the low-level implementation of Conn
 | 
						|
type conn struct {
 | 
						|
 | 
						|
	// Shared
 | 
						|
	mu      sync.Mutex
 | 
						|
	pending int
 | 
						|
	err     error
 | 
						|
	conn    net.Conn
 | 
						|
 | 
						|
	// Read
 | 
						|
	readTimeout time.Duration
 | 
						|
	br          *bufio.Reader
 | 
						|
 | 
						|
	// Write
 | 
						|
	writeTimeout time.Duration
 | 
						|
	bw           *bufio.Writer
 | 
						|
 | 
						|
	// Scratch space for formatting argument length.
 | 
						|
	// '*' or '$', length, "\r\n"
 | 
						|
	lenScratch [32]byte
 | 
						|
 | 
						|
	// Scratch space for formatting integers and floats.
 | 
						|
	numScratch [40]byte
 | 
						|
}
 | 
						|
 | 
						|
// Dial connects to the Redis server at the given network and address.
 | 
						|
func Dial(network, address string) (Conn, error) {
 | 
						|
	dialer := xDialer{}
 | 
						|
	return dialer.Dial(network, address)
 | 
						|
}
 | 
						|
 | 
						|
// DialTimeout acts like Dial but takes timeouts for establishing the
 | 
						|
// connection to the server, writing a command and reading a reply.
 | 
						|
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
 | 
						|
	netDialer := net.Dialer{Timeout: connectTimeout}
 | 
						|
	dialer := xDialer{
 | 
						|
		NetDial:      netDialer.Dial,
 | 
						|
		ReadTimeout:  readTimeout,
 | 
						|
		WriteTimeout: writeTimeout,
 | 
						|
	}
 | 
						|
	return dialer.Dial(network, address)
 | 
						|
}
 | 
						|
 | 
						|
// A Dialer specifies options for connecting to a Redis server.
 | 
						|
type xDialer struct {
 | 
						|
	// NetDial specifies the dial function for creating TCP connections. If
 | 
						|
	// NetDial is nil, then net.Dial is used.
 | 
						|
	NetDial func(network, addr string) (net.Conn, error)
 | 
						|
 | 
						|
	// ReadTimeout specifies the timeout for reading a single command
 | 
						|
	// reply. If ReadTimeout is zero, then no timeout is used.
 | 
						|
	ReadTimeout time.Duration
 | 
						|
 | 
						|
	// WriteTimeout specifies the timeout for writing a single command.  If
 | 
						|
	// WriteTimeout is zero, then no timeout is used.
 | 
						|
	WriteTimeout time.Duration
 | 
						|
}
 | 
						|
 | 
						|
// Dial connects to the Redis server at address on the named network.
 | 
						|
func (d *xDialer) Dial(network, address string) (Conn, error) {
 | 
						|
	dial := d.NetDial
 | 
						|
	if dial == nil {
 | 
						|
		dial = net.Dial
 | 
						|
	}
 | 
						|
	netConn, err := dial(network, address)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return &conn{
 | 
						|
		conn:         netConn,
 | 
						|
		bw:           bufio.NewWriter(netConn),
 | 
						|
		br:           bufio.NewReader(netConn),
 | 
						|
		readTimeout:  d.ReadTimeout,
 | 
						|
		writeTimeout: d.WriteTimeout,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
// NewConn returns a new Redigo connection for the given net connection.
 | 
						|
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
 | 
						|
	return &conn{
 | 
						|
		conn:         netConn,
 | 
						|
		bw:           bufio.NewWriter(netConn),
 | 
						|
		br:           bufio.NewReader(netConn),
 | 
						|
		readTimeout:  readTimeout,
 | 
						|
		writeTimeout: writeTimeout,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Close() error {
 | 
						|
	c.mu.Lock()
 | 
						|
	err := c.err
 | 
						|
	if c.err == nil {
 | 
						|
		c.err = errors.New("redigo: closed")
 | 
						|
		err = c.conn.Close()
 | 
						|
	}
 | 
						|
	c.mu.Unlock()
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) fatal(err error) error {
 | 
						|
	c.mu.Lock()
 | 
						|
	if c.err == nil {
 | 
						|
		c.err = err
 | 
						|
		// Close connection to force errors on subsequent calls and to unblock
 | 
						|
		// other reader or writer.
 | 
						|
		c.conn.Close()
 | 
						|
	}
 | 
						|
	c.mu.Unlock()
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Err() error {
 | 
						|
	c.mu.Lock()
 | 
						|
	err := c.err
 | 
						|
	c.mu.Unlock()
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeLen(prefix byte, n int) error {
 | 
						|
	c.lenScratch[len(c.lenScratch)-1] = '\n'
 | 
						|
	c.lenScratch[len(c.lenScratch)-2] = '\r'
 | 
						|
	i := len(c.lenScratch) - 3
 | 
						|
	for {
 | 
						|
		c.lenScratch[i] = byte('0' + n%10)
 | 
						|
		i -= 1
 | 
						|
		n = n / 10
 | 
						|
		if n == 0 {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
	c.lenScratch[i] = prefix
 | 
						|
	_, err := c.bw.Write(c.lenScratch[i:])
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeString(s string) error {
 | 
						|
	c.writeLen('$', len(s))
 | 
						|
	c.bw.WriteString(s)
 | 
						|
	_, err := c.bw.WriteString("\r\n")
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeBytes(p []byte) error {
 | 
						|
	c.writeLen('$', len(p))
 | 
						|
	c.bw.Write(p)
 | 
						|
	_, err := c.bw.WriteString("\r\n")
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeInt64(n int64) error {
 | 
						|
	return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeFloat64(n float64) error {
 | 
						|
	return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
 | 
						|
	c.writeLen('*', 1+len(args))
 | 
						|
	err = c.writeString(cmd)
 | 
						|
	for _, arg := range args {
 | 
						|
		if err != nil {
 | 
						|
			break
 | 
						|
		}
 | 
						|
		switch arg := arg.(type) {
 | 
						|
		case string:
 | 
						|
			err = c.writeString(arg)
 | 
						|
		case []byte:
 | 
						|
			err = c.writeBytes(arg)
 | 
						|
		case int:
 | 
						|
			err = c.writeInt64(int64(arg))
 | 
						|
		case int64:
 | 
						|
			err = c.writeInt64(arg)
 | 
						|
		case float64:
 | 
						|
			err = c.writeFloat64(arg)
 | 
						|
		case bool:
 | 
						|
			if arg {
 | 
						|
				err = c.writeString("1")
 | 
						|
			} else {
 | 
						|
				err = c.writeString("0")
 | 
						|
			}
 | 
						|
		case nil:
 | 
						|
			err = c.writeString("")
 | 
						|
		default:
 | 
						|
			var buf bytes.Buffer
 | 
						|
			fmt.Fprint(&buf, arg)
 | 
						|
			err = c.writeBytes(buf.Bytes())
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
type protocolError string
 | 
						|
 | 
						|
func (pe protocolError) Error() string {
 | 
						|
	return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) readLine() ([]byte, error) {
 | 
						|
	p, err := c.br.ReadSlice('\n')
 | 
						|
	if err == bufio.ErrBufferFull {
 | 
						|
		return nil, protocolError("long response line")
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	i := len(p) - 2
 | 
						|
	if i < 0 || p[i] != '\r' {
 | 
						|
		return nil, protocolError("bad response line terminator")
 | 
						|
	}
 | 
						|
	return p[:i], nil
 | 
						|
}
 | 
						|
 | 
						|
// parseLen parses bulk string and array lengths.
 | 
						|
func parseLen(p []byte) (int, error) {
 | 
						|
	if len(p) == 0 {
 | 
						|
		return -1, protocolError("malformed length")
 | 
						|
	}
 | 
						|
 | 
						|
	if p[0] == '-' && len(p) == 2 && p[1] == '1' {
 | 
						|
		// handle $-1 and $-1 null replies.
 | 
						|
		return -1, nil
 | 
						|
	}
 | 
						|
 | 
						|
	var n int
 | 
						|
	for _, b := range p {
 | 
						|
		n *= 10
 | 
						|
		if b < '0' || b > '9' {
 | 
						|
			return -1, protocolError("illegal bytes in length")
 | 
						|
		}
 | 
						|
		n += int(b - '0')
 | 
						|
	}
 | 
						|
 | 
						|
	return n, nil
 | 
						|
}
 | 
						|
 | 
						|
// parseInt parses an integer reply.
 | 
						|
func parseInt(p []byte) (interface{}, error) {
 | 
						|
	if len(p) == 0 {
 | 
						|
		return 0, protocolError("malformed integer")
 | 
						|
	}
 | 
						|
 | 
						|
	var negate bool
 | 
						|
	if p[0] == '-' {
 | 
						|
		negate = true
 | 
						|
		p = p[1:]
 | 
						|
		if len(p) == 0 {
 | 
						|
			return 0, protocolError("malformed integer")
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	var n int64
 | 
						|
	for _, b := range p {
 | 
						|
		n *= 10
 | 
						|
		if b < '0' || b > '9' {
 | 
						|
			return 0, protocolError("illegal bytes in length")
 | 
						|
		}
 | 
						|
		n += int64(b - '0')
 | 
						|
	}
 | 
						|
 | 
						|
	if negate {
 | 
						|
		n = -n
 | 
						|
	}
 | 
						|
	return n, nil
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	okReply   interface{} = "OK"
 | 
						|
	pongReply interface{} = "PONG"
 | 
						|
)
 | 
						|
 | 
						|
func (c *conn) readReply() (interface{}, error) {
 | 
						|
	line, err := c.readLine()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	if len(line) == 0 {
 | 
						|
		return nil, protocolError("short response line")
 | 
						|
	}
 | 
						|
	switch line[0] {
 | 
						|
	case '+':
 | 
						|
		switch {
 | 
						|
		case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
 | 
						|
			// Avoid allocation for frequent "+OK" response.
 | 
						|
			return okReply, nil
 | 
						|
		case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
 | 
						|
			// Avoid allocation in PING command benchmarks :)
 | 
						|
			return pongReply, nil
 | 
						|
		default:
 | 
						|
			return string(line[1:]), nil
 | 
						|
		}
 | 
						|
	case '-':
 | 
						|
		return Error(string(line[1:])), nil
 | 
						|
	case ':':
 | 
						|
		return parseInt(line[1:])
 | 
						|
	case '$':
 | 
						|
		n, err := parseLen(line[1:])
 | 
						|
		if n < 0 || err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		p := make([]byte, n)
 | 
						|
		_, err = io.ReadFull(c.br, p)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		if line, err := c.readLine(); err != nil {
 | 
						|
			return nil, err
 | 
						|
		} else if len(line) != 0 {
 | 
						|
			return nil, protocolError("bad bulk string format")
 | 
						|
		}
 | 
						|
		return p, nil
 | 
						|
	case '*':
 | 
						|
		n, err := parseLen(line[1:])
 | 
						|
		if n < 0 || err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		r := make([]interface{}, n)
 | 
						|
		for i := range r {
 | 
						|
			r[i], err = c.readReply()
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return r, nil
 | 
						|
	}
 | 
						|
	return nil, protocolError("unexpected response line")
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Send(cmd string, args ...interface{}) error {
 | 
						|
	c.mu.Lock()
 | 
						|
	c.pending += 1
 | 
						|
	c.mu.Unlock()
 | 
						|
	if c.writeTimeout != 0 {
 | 
						|
		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
 | 
						|
	}
 | 
						|
	if err := c.writeCommand(cmd, args); err != nil {
 | 
						|
		return c.fatal(err)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Flush() error {
 | 
						|
	if c.writeTimeout != 0 {
 | 
						|
		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
 | 
						|
	}
 | 
						|
	if err := c.bw.Flush(); err != nil {
 | 
						|
		return c.fatal(err)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Receive() (reply interface{}, err error) {
 | 
						|
	if c.readTimeout != 0 {
 | 
						|
		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
 | 
						|
	}
 | 
						|
	if reply, err = c.readReply(); err != nil {
 | 
						|
		return nil, c.fatal(err)
 | 
						|
	}
 | 
						|
	// When using pub/sub, the number of receives can be greater than the
 | 
						|
	// number of sends. To enable normal use of the connection after
 | 
						|
	// unsubscribing from all channels, we do not decrement pending to a
 | 
						|
	// negative value.
 | 
						|
	//
 | 
						|
	// The pending field is decremented after the reply is read to handle the
 | 
						|
	// case where Receive is called before Send.
 | 
						|
	c.mu.Lock()
 | 
						|
	if c.pending > 0 {
 | 
						|
		c.pending -= 1
 | 
						|
	}
 | 
						|
	c.mu.Unlock()
 | 
						|
	if err, ok := reply.(Error); ok {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
 | 
						|
	c.mu.Lock()
 | 
						|
	pending := c.pending
 | 
						|
	c.pending = 0
 | 
						|
	c.mu.Unlock()
 | 
						|
 | 
						|
	if cmd == "" && pending == 0 {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if c.writeTimeout != 0 {
 | 
						|
		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
 | 
						|
	}
 | 
						|
 | 
						|
	if cmd != "" {
 | 
						|
		c.writeCommand(cmd, args)
 | 
						|
	}
 | 
						|
 | 
						|
	if err := c.bw.Flush(); err != nil {
 | 
						|
		return nil, c.fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	if c.readTimeout != 0 {
 | 
						|
		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
 | 
						|
	}
 | 
						|
 | 
						|
	if cmd == "" {
 | 
						|
		reply := make([]interface{}, pending)
 | 
						|
		for i := range reply {
 | 
						|
			r, e := c.readReply()
 | 
						|
			if e != nil {
 | 
						|
				return nil, c.fatal(e)
 | 
						|
			}
 | 
						|
			reply[i] = r
 | 
						|
		}
 | 
						|
		return reply, nil
 | 
						|
	}
 | 
						|
 | 
						|
	var err error
 | 
						|
	var reply interface{}
 | 
						|
	for i := 0; i <= pending; i++ {
 | 
						|
		var e error
 | 
						|
		if reply, e = c.readReply(); e != nil {
 | 
						|
			return nil, c.fatal(e)
 | 
						|
		}
 | 
						|
		if e, ok := reply.(Error); ok && err == nil {
 | 
						|
			err = e
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return reply, err
 | 
						|
}
 |