package ice import ( "io" "net" "os" "strings" "sync" "github.com/pion/logging" "github.com/pion/stun" ) // UDPMux allows multiple connections to go over a single UDP port type UDPMux interface { io.Closer GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } // UDPMuxDefault is an implementation of the interface type UDPMuxDefault struct { params UDPMuxParams closedChan chan struct{} closeOnce sync.Once // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn addressMapMu sync.RWMutex addressMap map[string]*udpMuxedConn // buffer pool to recycle buffers for net.UDPAddr encodes/decodes pool *sync.Pool mu sync.Mutex } const maxAddrSize = 512 // UDPMuxParams are parameters for UDPMux. type UDPMuxParams struct { Logger logging.LeveledLogger UDPConn net.PacketConn } // NewUDPMuxDefault creates an implementation of UDPMux func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } m := &UDPMuxDefault{ addressMap: map[string]*udpMuxedConn{}, params: params, connsIPv4: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn), closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address return newBufferHolder(receiveMTU + maxAddrSize) }, }, } go m.connWorker() return m } // LocalAddr returns the listening address of this UDPMuxDefault func (m *UDPMuxDefault) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetConn returns a PacketConn given the connection's ufrag and network // creates the connection if an existing one can't be found func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() if m.IsClosed() { return nil, io.ErrClosedPipe } if conn, ok := m.getConn(ufrag, isIPv6); ok { return conn, nil } c := m.createMuxedConn(ufrag) go func() { <-c.CloseChannel() m.removeConn(ufrag) }() if isIPv6 { m.connsIPv6[ufrag] = c } else { m.connsIPv4[ufrag] = c } return c, nil } // RemoveConnByUfrag stops and removes the muxed packet connection func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock m.mu.Lock() if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) } m.mu.Unlock() m.addressMapMu.Lock() defer m.addressMapMu.Unlock() for _, c := range removedConns { addresses := c.getAddresses() for _, addr := range addresses { delete(m.addressMap, addr) } } } // IsClosed returns true if the mux had been closed func (m *UDPMuxDefault) IsClosed() bool { select { case <-m.closedChan: return true default: return false } } // Close the mux, no further connections could be created func (m *UDPMuxDefault) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() defer m.mu.Unlock() for _, c := range m.connsIPv4 { _ = c.Close() } for _, c := range m.connsIPv6 { _ = c.Close() } m.connsIPv4 = make(map[string]*udpMuxedConn) m.connsIPv6 = make(map[string]*udpMuxedConn) close(m.closedChan) }) return err } func (m *UDPMuxDefault) removeConn(key string) { // keep lock section small to avoid deadlock with conn lock c := func() *udpMuxedConn { m.mu.Lock() defer m.mu.Unlock() if c, ok := m.connsIPv4[key]; ok { delete(m.connsIPv4, key) return c } if c, ok := m.connsIPv6[key]; ok { delete(m.connsIPv6, key) return c } return nil }() if c == nil { return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() addresses := c.getAddresses() for _, addr := range addresses { delete(m.addressMap, addr) } } func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, raddr) } func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() existing, ok := m.addressMap[addr] if ok { existing.removeAddress(addr) } m.addressMap[addr] = conn m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key) } func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ Mux: m, Key: key, AddrPool: m.pool, LocalAddr: m.LocalAddr(), Logger: m.params.Logger, }) return c } func (m *UDPMuxDefault) connWorker() { logger := m.params.Logger defer func() { _ = m.Close() }() buf := make([]byte, receiveMTU) for { n, addr, err := m.params.UDPConn.ReadFrom(buf) if m.IsClosed() { return } else if err != nil { if os.IsTimeout(err) { continue } else if err != io.EOF { logger.Errorf("could not read udp packet: %v", err) } return } udpAddr, ok := addr.(*net.UDPAddr) if !ok { logger.Errorf("underlying PacketConn did not return a UDPAddr") return } // If we have already seen this address dispatch to the appropriate destination m.addressMapMu.Lock() destinationConn := m.addressMap[addr.String()] m.addressMapMu.Unlock() // If we haven't seen this address before but is a STUN packet lookup by ufrag if destinationConn == nil && stun.IsMessage(buf[:n]) { msg := &stun.Message{ Raw: append([]byte{}, buf[:n]...), } if err = msg.Decode(); err != nil { m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) continue } attr, stunAttrErr := msg.Get(stun.AttrUsername) if stunAttrErr != nil { m.params.Logger.Warnf("No Username attribute in STUN message from %s\n", addr.String()) continue } ufrag := strings.Split(string(attr), ":")[0] isIPv6 := udpAddr.IP.To4() == nil m.mu.Lock() destinationConn, _ = m.getConn(ufrag, isIPv6) m.mu.Unlock() } if destinationConn == nil { m.params.Logger.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String()) continue } if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil { m.params.Logger.Errorf("could not write packet: %v", err) } } } func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { val, ok = m.connsIPv4[ufrag] } return } type bufferHolder struct { buffer []byte } func newBufferHolder(size int) *bufferHolder { return &bufferHolder{ buffer: make([]byte, size), } }