OSDN Git Service

5caf6a93b61c86cabae0c5ea938437ffc59d6acf
[bytom/bytom.git] / p2p / discover / udp.go
1 package discover
2
3 import (
4         "bytes"
5         "crypto/ecdsa"
6         "encoding/hex"
7         "errors"
8         "fmt"
9         "net"
10         "path"
11         "strconv"
12         "strings"
13         "time"
14
15         log "github.com/sirupsen/logrus"
16         "github.com/tendermint/go-wire"
17
18         "github.com/bytom/common"
19         cfg "github.com/bytom/config"
20         "github.com/bytom/crypto"
21         "github.com/bytom/crypto/ed25519"
22         "github.com/bytom/p2p/netutil"
23         "github.com/bytom/version"
24 )
25
26 const (
27         Version   = 4
28         logModule = "discover"
29 )
30
31 // Errors
32 var (
33         errPacketTooSmall   = errors.New("too small")
34         errBadPrefix        = errors.New("bad prefix")
35         errExpired          = errors.New("expired")
36         errUnsolicitedReply = errors.New("unsolicited reply")
37         errUnknownNode      = errors.New("unknown node")
38         errTimeout          = errors.New("RPC timeout")
39         errClockWarp        = errors.New("reply deadline too far in the future")
40         errClosed           = errors.New("socket closed")
41         errInvalidSeedIP    = errors.New("seed ip is invalid")
42         errInvalidSeedPort  = errors.New("seed port is invalid")
43 )
44
45 // Timeouts
46 const (
47         respTimeout = 1 * time.Second
48         queryDelay  = 1000 * time.Millisecond
49         expiration  = 20 * time.Second
50
51         ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
52         ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
53         driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
54 )
55
56 // ReadPacket is sent to the unhandled channel when it could not be processed
57 type ReadPacket struct {
58         Data []byte
59         Addr *net.UDPAddr
60 }
61
62 // Config holds Table-related settings.
63 type Config struct {
64         // These settings are required and configure the UDP listener:
65         PrivateKey *ecdsa.PrivateKey
66
67         // These settings are optional:
68         AnnounceAddr *net.UDPAddr // local address announced in the DHT
69         NodeDBPath   string       // if set, the node database is stored at this filesystem location
70         //NetRestrict  *netutil.Netlist  // network whitelist
71         Bootnodes []*Node           // list of bootstrap nodes
72         Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
73 }
74
75 // RPC request structures
76 type (
77         ping struct {
78                 Version    uint
79                 From, To   rpcEndpoint
80                 Expiration uint64
81
82                 // v5
83                 Topics []Topic
84
85                 // Ignore additional fields (for forward compatibility).
86                 Rest []byte
87         }
88
89         // pong is the reply to ping.
90         pong struct {
91                 // This field should mirror the UDP envelope address
92                 // of the ping packet, which provides a way to discover the
93                 // the external address (after NAT).
94                 To rpcEndpoint
95
96                 ReplyTok   []byte // This contains the hash of the ping packet.
97                 Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
98
99                 // v5
100                 TopicHash    common.Hash
101                 TicketSerial uint32
102                 WaitPeriods  []uint32
103
104                 // Ignore additional fields (for forward compatibility).
105                 Rest []byte
106         }
107
108         // findnode is a query for nodes close to the given target.
109         findnode struct {
110                 Target     NodeID // doesn't need to be an actual public key
111                 Expiration uint64
112                 // Ignore additional fields (for forward compatibility).
113                 Rest []byte
114         }
115
116         // findnode is a query for nodes close to the given target.
117         findnodeHash struct {
118                 Target     common.Hash
119                 Expiration uint64
120                 // Ignore additional fields (for forward compatibility).
121                 Rest []byte
122         }
123
124         // reply to findnode
125         neighbors struct {
126                 Nodes      []rpcNode
127                 Expiration uint64
128                 // Ignore additional fields (for forward compatibility).
129                 Rest []byte
130         }
131
132         topicRegister struct {
133                 Topics []Topic
134                 Idx    uint
135                 Pong   []byte
136         }
137
138         topicQuery struct {
139                 Topic      Topic
140                 Expiration uint64
141         }
142
143         // reply to topicQuery
144         topicNodes struct {
145                 Echo  common.Hash
146                 Nodes []rpcNode
147         }
148
149         rpcNode struct {
150                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
151                 UDP uint16 // for discovery protocol
152                 TCP uint16 // for RLPx protocol
153                 ID  NodeID
154         }
155
156         rpcEndpoint struct {
157                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
158                 UDP uint16 // for discovery protocol
159                 TCP uint16 // for RLPx protocol
160         }
161 )
162
163 var (
164         versionPrefix     = []byte("bytom discovery")
165         versionPrefixSize = len(versionPrefix)
166         nodeIDSize        = 32
167         sigSize           = 520 / 8
168         headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
169 )
170
171 // Neighbors replies are sent across multiple packets to
172 // stay below the 1280 byte limit. We compute the maximum number
173 // of entries by stuffing a packet until it grows too large.
174 var maxNeighbors = func() int {
175         p := neighbors{Expiration: ^uint64(0)}
176         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
177         for n := 0; ; n++ {
178                 p.Nodes = append(p.Nodes, maxSizeNode)
179                 var size int
180                 var err error
181                 b := new(bytes.Buffer)
182                 wire.WriteJSON(p, b, &size, &err)
183                 if err != nil {
184                         // If this ever happens, it will be caught by the unit tests.
185                         panic("cannot encode: " + err.Error())
186                 }
187                 if headSize+size+1 >= 1280 {
188                         return n
189                 }
190         }
191 }()
192
193 var maxTopicNodes = func() int {
194         p := topicNodes{}
195         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
196         for n := 0; ; n++ {
197                 p.Nodes = append(p.Nodes, maxSizeNode)
198                 var size int
199                 var err error
200                 b := new(bytes.Buffer)
201                 wire.WriteJSON(p, b, &size, &err)
202                 if err != nil {
203                         // If this ever happens, it will be caught by the unit tests.
204                         panic("cannot encode: " + err.Error())
205                 }
206                 if headSize+size+1 >= 1280 {
207                         return n
208                 }
209         }
210 }()
211
212 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
213         ip := addr.IP.To4()
214         if ip == nil {
215                 ip = addr.IP.To16()
216         }
217         return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
218 }
219
220 func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
221         return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
222 }
223
224 func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
225         if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
226                 return nil, err
227         }
228         n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
229         err := n.validateComplete()
230         return n, err
231 }
232
233 func nodeToRPC(n *Node) rpcNode {
234         return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
235 }
236
237 type ingressPacket struct {
238         remoteID   NodeID
239         remoteAddr *net.UDPAddr
240         ev         nodeEvent
241         hash       []byte
242         data       interface{} // one of the RPC structs
243         rawData    []byte
244 }
245
246 type conn interface {
247         ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
248         WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
249         Close() error
250         LocalAddr() net.Addr
251 }
252
253 type netWork interface {
254         reqReadPacket(pkt ingressPacket)
255         selfIP() net.IP
256 }
257
258 // udp implements the RPC protocol.
259 type udp struct {
260         conn        conn
261         priv        ed25519.PrivateKey
262         ourEndpoint rpcEndpoint
263         //nat         nat.Interface
264         net netWork
265 }
266
267 func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Network, error) {
268         addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
269         if err != nil {
270                 return nil, err
271         }
272
273         conn, err := net.ListenUDP("udp", addr)
274         if err != nil {
275                 return nil, err
276         }
277
278         realaddr := conn.LocalAddr().(*net.UDPAddr)
279         ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover.db"), nil)
280         if err != nil {
281                 return nil, err
282         }
283         seeds, err := QueryDNSSeeds(net.LookupHost)
284         if err != nil {
285                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
286         }
287
288         if config.P2P.Seeds != "" {
289                 codedSeeds := strings.Split(config.P2P.Seeds, ",")
290                 for _, codedSeed := range codedSeeds {
291                         ip, port, err := net.SplitHostPort(codedSeed)
292                         if err != nil {
293                                 return nil, err
294                         }
295
296                         if validIP := net.ParseIP(ip); validIP == nil {
297                                 return nil, errInvalidSeedIP
298                         }
299
300                         if _, err := strconv.ParseUint(port, 10, 16); err != nil {
301                                 return nil, errInvalidSeedPort
302                         }
303
304                         seeds = append(seeds, codedSeed)
305                 }
306         }
307
308         if len(seeds) == 0 {
309                 return ntab, nil
310         }
311
312         var nodes []*Node
313         for _, seed := range seeds {
314                 version.Status.AddSeed(seed)
315                 url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
316                 nodes = append(nodes, MustParseNode(url))
317         }
318
319         if err = ntab.SetFallbackNodes(nodes); err != nil {
320                 return nil, err
321         }
322         return ntab, nil
323 }
324
325 // ListenUDP returns a new table that listens for UDP packets on laddr.
326 func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
327         transport, err := listenUDP(priv, conn, realaddr)
328         if err != nil {
329                 return nil, err
330         }
331
332         net, err := newNetwork(transport, priv.Public(), nodeDBPath, netrestrict)
333         if err != nil {
334                 return nil, err
335         }
336         log.WithFields(log.Fields{"module": logModule, "net": net.tab.self}).Info("UDP listener up v5")
337         transport.net = net
338         go transport.readLoop()
339         return net, nil
340 }
341
342 func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
343         return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
344 }
345
346 func (t *udp) localAddr() *net.UDPAddr {
347         return t.conn.LocalAddr().(*net.UDPAddr)
348 }
349
350 func (t *udp) Close() {
351         t.conn.Close()
352 }
353
354 func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
355         hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
356         return hash
357 }
358
359 func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
360         hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
361                 Version:    Version,
362                 From:       t.ourEndpoint,
363                 To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
364                 Expiration: uint64(time.Now().Add(expiration).Unix()),
365                 Topics:     topics,
366         })
367         return hash
368 }
369
370 func (t *udp) sendFindnode(remote *Node, target NodeID) {
371         t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
372                 Target:     target,
373                 Expiration: uint64(time.Now().Add(expiration).Unix()),
374         })
375 }
376
377 func (t *udp) sendNeighbours(remote *Node, results []*Node) {
378         // Send neighbors in chunks with at most maxNeighbors per packet
379         // to stay below the 1280 byte limit.
380         p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
381         for i, result := range results {
382                 p.Nodes = append(p.Nodes, nodeToRPC(result))
383                 if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
384                         t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
385                         p.Nodes = p.Nodes[:0]
386                 }
387         }
388 }
389
390 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
391         t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
392                 Target:     common.Hash(target),
393                 Expiration: uint64(time.Now().Add(expiration).Unix()),
394         })
395 }
396
397 func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
398         t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
399                 Topics: topics,
400                 Idx:    uint(idx),
401                 Pong:   pong,
402         })
403 }
404
405 func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
406         p := topicNodes{Echo: queryHash}
407         var sent bool
408         for _, result := range nodes {
409                 if result.IP.Equal(t.net.selfIP()) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
410                         p.Nodes = append(p.Nodes, nodeToRPC(result))
411                 }
412                 if len(p.Nodes) == maxTopicNodes {
413                         t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
414                         p.Nodes = p.Nodes[:0]
415                         sent = true
416                 }
417         }
418         if !sent || len(p.Nodes) > 0 {
419                 t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
420         }
421 }
422
423 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
424         packet, hash, err := encodePacket(t.priv, ptype, req)
425         if err != nil {
426                 return hash, err
427         }
428         log.WithFields(log.Fields{"module": logModule, "event": nodeEvent(ptype), "to id": hex.EncodeToString(toid[:8]), "to addr": toaddr}).Debug("send packet")
429         if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
430                 log.WithFields(log.Fields{"module": logModule, "error": err}).Info(fmt.Sprint("UDP send failed"))
431         }
432         return hash, err
433 }
434
435 // zeroed padding space for encodePacket.
436 var headSpace = make([]byte, headSize)
437
438 func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
439         b := new(bytes.Buffer)
440         b.Write(headSpace)
441         b.WriteByte(ptype)
442         var size int
443         wire.WriteJSON(req, b, &size, &err)
444         if err != nil {
445                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("error encoding packet")
446                 return nil, nil, err
447         }
448         packet := b.Bytes()
449         nodeID := priv.Public()
450         sig := ed25519.Sign(priv, common.BytesToHash(packet[headSize:]).Bytes())
451         copy(packet, versionPrefix)
452         copy(packet[versionPrefixSize:], nodeID[:])
453         copy(packet[versionPrefixSize+nodeIDSize:], sig)
454
455         hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
456         return packet, hash, nil
457 }
458
459 // readLoop runs in its own goroutine. it injects ingress UDP packets
460 // into the network loop.
461 func (t *udp) readLoop() {
462         defer t.conn.Close()
463         // Discovery packets are defined to be no larger than 1280 bytes.
464         // Packets larger than this size will be cut at the end and treated
465         // as invalid because their hash won't match.
466         buf := make([]byte, 1280)
467         for {
468                 nbytes, from, err := t.conn.ReadFromUDP(buf)
469                 if netutil.IsTemporaryError(err) {
470                         // Ignore temporary read errors.
471                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Temporary read error")
472                         continue
473                 } else if err != nil {
474                         // Shut down the loop for permament errors.
475                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
476                         return
477                 }
478                 t.handlePacket(from, buf[:nbytes])
479         }
480 }
481
482 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
483         pkt := ingressPacket{remoteAddr: from}
484         if err := decodePacket(buf, &pkt); err != nil {
485                 log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
486                 return err
487         }
488         t.net.reqReadPacket(pkt)
489         return nil
490 }
491
492 func decodePacket(buffer []byte, pkt *ingressPacket) error {
493         if len(buffer) < headSize+1 {
494                 return errPacketTooSmall
495         }
496         buf := make([]byte, len(buffer))
497         copy(buf, buffer)
498         prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
499         if !bytes.Equal(prefix, versionPrefix) {
500                 return errBadPrefix
501         }
502         pkt.rawData = buf
503         pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
504         pkt.remoteID = ByteID(fromID)
505         switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
506         case pingPacket:
507                 pkt.data = new(ping)
508         case pongPacket:
509                 pkt.data = new(pong)
510         case findnodePacket:
511                 pkt.data = new(findnode)
512         case neighborsPacket:
513                 pkt.data = new(neighbors)
514         case findnodeHashPacket:
515                 pkt.data = new(findnodeHash)
516         case topicRegisterPacket:
517                 pkt.data = new(topicRegister)
518         case topicQueryPacket:
519                 pkt.data = new(topicQuery)
520         case topicNodesPacket:
521                 pkt.data = new(topicNodes)
522         default:
523                 return fmt.Errorf("unknown packet type: %d", sigdata[0])
524         }
525         var err error
526         wire.ReadJSON(pkt.data, sigdata[1:], &err)
527         if err != nil {
528                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("wire readjson err")
529         }
530
531         return err
532 }