OSDN Git Service

refactor: split ensureOutboundPeers (#1662)
[bytom/bytom.git] / p2p / switch.go
1 package p2p
2
3 import (
4         "encoding/hex"
5         "encoding/json"
6         "fmt"
7         "net"
8         "sync"
9         "time"
10
11         log "github.com/sirupsen/logrus"
12         "github.com/tendermint/go-crypto"
13         cmn "github.com/tendermint/tmlibs/common"
14         dbm "github.com/tendermint/tmlibs/db"
15
16         cfg "github.com/bytom/config"
17         "github.com/bytom/consensus"
18         "github.com/bytom/crypto/ed25519"
19         "github.com/bytom/errors"
20         "github.com/bytom/p2p/connection"
21         "github.com/bytom/p2p/discover"
22         "github.com/bytom/p2p/netutil"
23         "github.com/bytom/p2p/trust"
24         "github.com/bytom/version"
25 )
26
27 const (
28         bannedPeerKey      = "BannedPeer"
29         defaultBanDuration = time.Hour * 1
30         logModule          = "p2p"
31
32         minNumOutboundPeers = 4
33 )
34
35 //pre-define errors for connecting fail
36 var (
37         ErrDuplicatePeer     = errors.New("Duplicate peer")
38         ErrConnectSelf       = errors.New("Connect self")
39         ErrConnectBannedPeer = errors.New("Connect banned peer")
40         ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
41 )
42
43 type discv interface {
44         ReadRandomNodes(buf []*discover.Node) (n int)
45 }
46
47 // Switch handles peer connections and exposes an API to receive incoming messages
48 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
49 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
50 // incoming messages are received on the reactor.
51 type Switch struct {
52         cmn.BaseService
53
54         Config       *cfg.Config
55         peerConfig   *PeerConfig
56         listeners    []Listener
57         reactors     map[string]Reactor
58         chDescs      []*connection.ChannelDescriptor
59         reactorsByCh map[byte]Reactor
60         peers        *PeerSet
61         dialing      *cmn.CMap
62         nodeInfo     *NodeInfo             // our node info
63         nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
64         discv        discv
65         bannedPeer   map[string]time.Time
66         db           dbm.DB
67         mtx          sync.Mutex
68 }
69
70 // NewSwitch create a new Switch and set discover.
71 func NewSwitch(config *cfg.Config) (*Switch, error) {
72         var err error
73         var l Listener
74         var listenAddr string
75         var discv *discover.Network
76
77         blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
78         config.P2P.PrivateKey, err = config.NodeKey()
79         if err != nil {
80                 return nil, err
81         }
82
83         bytes, err := hex.DecodeString(config.P2P.PrivateKey)
84         if err != nil {
85                 return nil, err
86         }
87
88         var newKey [64]byte
89         copy(newKey[:], bytes)
90         privKey := crypto.PrivKeyEd25519(newKey)
91         if !config.VaultMode {
92                 // Create listener
93                 l, listenAddr = GetListener(config.P2P)
94                 discv, err = discover.NewDiscover(config, ed25519.PrivateKey(bytes), l.ExternalAddress().Port)
95                 if err != nil {
96                         return nil, err
97                 }
98         }
99
100         return newSwitch(config, discv, blacklistDB, l, privKey, listenAddr)
101 }
102
103 // newSwitch creates a new Switch with the given config.
104 func newSwitch(config *cfg.Config, discv discv, blacklistDB dbm.DB, l Listener, priv crypto.PrivKeyEd25519, listenAddr string) (*Switch, error) {
105         sw := &Switch{
106                 Config:       config,
107                 peerConfig:   DefaultPeerConfig(config.P2P),
108                 reactors:     make(map[string]Reactor),
109                 chDescs:      make([]*connection.ChannelDescriptor, 0),
110                 reactorsByCh: make(map[byte]Reactor),
111                 peers:        NewPeerSet(),
112                 dialing:      cmn.NewCMap(),
113                 nodePrivKey:  priv,
114                 discv:        discv,
115                 db:           blacklistDB,
116                 nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr),
117                 bannedPeer:   make(map[string]time.Time),
118         }
119         if err := sw.loadBannedPeers(); err != nil {
120                 return nil, err
121         }
122
123         sw.AddListener(l)
124         sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
125         trust.Init()
126         return sw, nil
127 }
128
129 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
130 func (sw *Switch) OnStart() error {
131         for _, reactor := range sw.reactors {
132                 if _, err := reactor.Start(); err != nil {
133                         return err
134                 }
135         }
136         for _, listener := range sw.listeners {
137                 go sw.listenerRoutine(listener)
138         }
139         go sw.ensureOutboundPeersRoutine()
140         return nil
141 }
142
143 // OnStop implements BaseService. It stops all listeners, peers, and reactors.
144 func (sw *Switch) OnStop() {
145         for _, listener := range sw.listeners {
146                 listener.Stop()
147         }
148         sw.listeners = nil
149
150         for _, peer := range sw.peers.List() {
151                 peer.Stop()
152                 sw.peers.Remove(peer)
153         }
154
155         for _, reactor := range sw.reactors {
156                 reactor.Stop()
157         }
158 }
159
160 //AddBannedPeer add peer to blacklist
161 func (sw *Switch) AddBannedPeer(ip string) error {
162         sw.mtx.Lock()
163         defer sw.mtx.Unlock()
164
165         sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
166         dataJSON, err := json.Marshal(sw.bannedPeer)
167         if err != nil {
168                 return err
169         }
170
171         sw.db.Set([]byte(bannedPeerKey), dataJSON)
172         return nil
173 }
174
175 // AddPeer performs the P2P handshake with a peer
176 // that already has a SecretConnection. If all goes well,
177 // it starts the peer and adds it to the switch.
178 // NOTE: This performs a blocking handshake before the peer is added.
179 // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
180 func (sw *Switch) AddPeer(pc *peerConn) error {
181         peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
182         if err != nil {
183                 return err
184         }
185
186         if err := version.Status.CheckUpdate(sw.nodeInfo.Version, peerNodeInfo.Version, peerNodeInfo.RemoteAddr); err != nil {
187                 return err
188         }
189         if err := sw.nodeInfo.CompatibleWith(peerNodeInfo); err != nil {
190                 return err
191         }
192
193         peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError)
194         if err := sw.filterConnByPeer(peer); err != nil {
195                 return err
196         }
197
198         if pc.outbound && !peer.ServiceFlag().IsEnable(consensus.SFFullNode) {
199                 return ErrConnectSpvPeer
200         }
201
202         // Start peer
203         if sw.IsRunning() {
204                 if err := sw.startInitPeer(peer); err != nil {
205                         return err
206                 }
207         }
208
209         return sw.peers.Add(peer)
210 }
211
212 // AddReactor adds the given reactor to the switch.
213 // NOTE: Not goroutine safe.
214 func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
215         // Validate the reactor.
216         // No two reactors can share the same channel.
217         for _, chDesc := range reactor.GetChannels() {
218                 chID := chDesc.ID
219                 if sw.reactorsByCh[chID] != nil {
220                         cmn.PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
221                 }
222                 sw.chDescs = append(sw.chDescs, chDesc)
223                 sw.reactorsByCh[chID] = reactor
224         }
225         sw.reactors[name] = reactor
226         reactor.SetSwitch(sw)
227         return reactor
228 }
229
230 // AddListener adds the given listener to the switch for listening to incoming peer connections.
231 // NOTE: Not goroutine safe.
232 func (sw *Switch) AddListener(l Listener) {
233         sw.listeners = append(sw.listeners, l)
234 }
235
236 //DialPeerWithAddress dial node from net address
237 func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
238         log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer")
239         sw.dialing.Set(addr.IP.String(), addr)
240         defer sw.dialing.Delete(addr.IP.String())
241         if err := sw.filterConnByIP(addr.IP.String()); err != nil {
242                 return err
243         }
244
245         pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
246         if err != nil {
247                 log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Error("DialPeer fail on newOutboundPeerConn")
248                 return err
249         }
250
251         if err = sw.AddPeer(pc); err != nil {
252                 log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Error("DialPeer fail on switch AddPeer")
253                 pc.CloseConn()
254                 return err
255         }
256         log.WithFields(log.Fields{"module": logModule, "address": addr, "peer num": sw.peers.Size()}).Debug("DialPeer added peer")
257         return nil
258 }
259
260 //IsDialing prevent duplicate dialing
261 func (sw *Switch) IsDialing(addr *NetAddress) bool {
262         return sw.dialing.Has(addr.IP.String())
263 }
264
265 // IsListening returns true if the switch has at least one listener.
266 // NOTE: Not goroutine safe.
267 func (sw *Switch) IsListening() bool {
268         return len(sw.listeners) > 0
269 }
270
271 // loadBannedPeers load banned peers from db
272 func (sw *Switch) loadBannedPeers() error {
273         if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
274                 if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
275                         return err
276                 }
277         }
278
279         return nil
280 }
281
282 // Listeners returns the list of listeners the switch listens on.
283 // NOTE: Not goroutine safe.
284 func (sw *Switch) Listeners() []Listener {
285         return sw.listeners
286 }
287
288 // NumPeers Returns the count of outbound/inbound and outbound-dialing peers.
289 func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
290         peers := sw.peers.List()
291         for _, peer := range peers {
292                 if peer.outbound {
293                         outbound++
294                 } else {
295                         inbound++
296                 }
297         }
298         dialing = sw.dialing.Size()
299         return
300 }
301
302 // NodeInfo returns the switch's NodeInfo.
303 // NOTE: Not goroutine safe.
304 func (sw *Switch) NodeInfo() *NodeInfo {
305         return sw.nodeInfo
306 }
307
308 //Peers return switch peerset
309 func (sw *Switch) Peers() *PeerSet {
310         return sw.peers
311 }
312
313 // StopPeerForError disconnects from a peer due to external error.
314 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
315         log.WithFields(log.Fields{"module": logModule, "peer": peer, " err": reason}).Debug("stopping peer for error")
316         sw.stopAndRemovePeer(peer, reason)
317 }
318
319 // StopPeerGracefully disconnect from a peer gracefully.
320 func (sw *Switch) StopPeerGracefully(peerID string) {
321         if peer := sw.peers.Get(peerID); peer != nil {
322                 sw.stopAndRemovePeer(peer, nil)
323         }
324 }
325
326 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
327         peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
328         if err != nil {
329                 if err := conn.Close(); err != nil {
330                         log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
331                 }
332                 return err
333         }
334
335         if err = sw.AddPeer(peerConn); err != nil {
336                 if err := conn.Close(); err != nil {
337                         log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
338                 }
339                 return err
340         }
341
342         log.WithFields(log.Fields{"module": logModule, "address": conn.RemoteAddr().String(), "peer num": sw.peers.Size()}).Debug("add inbound peer")
343         return nil
344 }
345
346 func (sw *Switch) checkBannedPeer(peer string) error {
347         sw.mtx.Lock()
348         defer sw.mtx.Unlock()
349
350         if banEnd, ok := sw.bannedPeer[peer]; ok {
351                 if time.Now().Before(banEnd) {
352                         return ErrConnectBannedPeer
353                 }
354
355                 if err := sw.delBannedPeer(peer); err != nil {
356                         return err
357                 }
358         }
359         return nil
360 }
361
362 func (sw *Switch) delBannedPeer(addr string) error {
363         sw.mtx.Lock()
364         defer sw.mtx.Unlock()
365
366         delete(sw.bannedPeer, addr)
367         datajson, err := json.Marshal(sw.bannedPeer)
368         if err != nil {
369                 return err
370         }
371
372         sw.db.Set([]byte(bannedPeerKey), datajson)
373         return nil
374 }
375
376 func (sw *Switch) filterConnByIP(ip string) error {
377         if ip == sw.nodeInfo.listenHost() {
378                 return ErrConnectSelf
379         }
380         return sw.checkBannedPeer(ip)
381 }
382
383 func (sw *Switch) filterConnByPeer(peer *Peer) error {
384         if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
385                 return err
386         }
387
388         if sw.nodeInfo.getPubkey().Equals(peer.PubKey().Wrap()) {
389                 return ErrConnectSelf
390         }
391
392         if sw.peers.Has(peer.Key) {
393                 return ErrDuplicatePeer
394         }
395         return nil
396 }
397
398 func (sw *Switch) listenerRoutine(l Listener) {
399         for {
400                 inConn, ok := <-l.Connections()
401                 if !ok {
402                         break
403                 }
404
405                 // disconnect if we alrady have MaxNumPeers
406                 if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
407                         if err := inConn.Close(); err != nil {
408                                 log.WithFields(log.Fields{"module": logModule, "remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
409                         }
410                         log.Info("Ignoring inbound connection: already have enough peers.")
411                         continue
412                 }
413
414                 // New inbound connection!
415                 if err := sw.addPeerWithConnection(inConn); err != nil {
416                         log.Info("Ignoring inbound connection: error while adding peer.", " address:", inConn.RemoteAddr().String(), " error:", err)
417                         continue
418                 }
419         }
420 }
421
422 func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
423         if err := sw.DialPeerWithAddress(a); err != nil {
424                 log.WithFields(log.Fields{"module": logModule, "addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
425         }
426         wg.Done()
427 }
428
429 func (sw *Switch) ensureKeepConnectPeers() {
430         keepDials := netutil.CheckAndSplitAddresses(sw.Config.P2P.KeepDial)
431         connectedPeers := make(map[string]struct{})
432         for _, peer := range sw.Peers().List() {
433                 connectedPeers[peer.remoteAddrHost()] = struct{}{}
434         }
435
436         var wg sync.WaitGroup
437         for _, keepDial := range keepDials {
438                 try, err := NewNetAddressString(keepDial)
439                 if err != nil {
440                         log.WithFields(log.Fields{"module": logModule, "err": err, "address": keepDial}).Warn("parse address to NetAddress")
441                         continue
442                 }
443
444                 if sw.NodeInfo().ListenAddr == try.String() {
445                         continue
446                 }
447                 if dialling := sw.IsDialing(try); dialling {
448                         continue
449                 }
450                 if _, ok := connectedPeers[try.IP.String()]; ok {
451                         continue
452                 }
453
454                 wg.Add(1)
455                 go sw.dialPeerWorker(try, &wg)
456         }
457         wg.Wait()
458 }
459
460 func (sw *Switch) ensureOutboundPeers() {
461         numOutPeers, _, numDialing := sw.NumPeers()
462         numToDial := (minNumOutboundPeers - (numOutPeers + numDialing))
463         log.WithFields(log.Fields{"module": logModule, "numOutPeers": numOutPeers, "numDialing": numDialing, "numToDial": numToDial}).Debug("ensure peers")
464         if numToDial <= 0 {
465                 return
466         }
467
468         connectedPeers := make(map[string]struct{})
469         for _, peer := range sw.Peers().List() {
470                 connectedPeers[peer.remoteAddrHost()] = struct{}{}
471         }
472
473         var wg sync.WaitGroup
474         nodes := make([]*discover.Node, numToDial)
475         n := sw.discv.ReadRandomNodes(nodes)
476         for i := 0; i < n; i++ {
477                 try := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
478                 if sw.NodeInfo().ListenAddr == try.String() {
479                         continue
480                 }
481                 if dialling := sw.IsDialing(try); dialling {
482                         continue
483                 }
484                 if _, ok := connectedPeers[try.IP.String()]; ok {
485                         continue
486                 }
487
488                 wg.Add(1)
489                 go sw.dialPeerWorker(try, &wg)
490         }
491         wg.Wait()
492 }
493
494 func (sw *Switch) ensureOutboundPeersRoutine() {
495         sw.ensureKeepConnectPeers()
496         sw.ensureOutboundPeers()
497
498         ticker := time.NewTicker(10 * time.Second)
499         defer ticker.Stop()
500
501         for {
502                 select {
503                 case <-ticker.C:
504                         sw.ensureKeepConnectPeers()
505                         sw.ensureOutboundPeers()
506                 case <-sw.Quit:
507                         return
508                 }
509         }
510 }
511
512 func (sw *Switch) startInitPeer(peer *Peer) error {
513         // spawn send/recv routines
514         if _, err := peer.Start(); err != nil {
515                 log.WithFields(log.Fields{"module": logModule, "remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
516         }
517
518         for _, reactor := range sw.reactors {
519                 if err := reactor.AddPeer(peer); err != nil {
520                         return err
521                 }
522         }
523         return nil
524 }
525
526 func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
527         sw.peers.Remove(peer)
528         for _, reactor := range sw.reactors {
529                 reactor.RemovePeer(peer, reason)
530         }
531         peer.Stop()
532
533         sentStatus, receivedStatus := peer.TrafficStatus()
534         log.WithFields(log.Fields{
535                 "module":                logModule,
536                 "address":               peer.Addr().String(),
537                 "reason":                reason,
538                 "duration":              sentStatus.Duration.String(),
539                 "total_sent":            sentStatus.Bytes,
540                 "total_received":        receivedStatus.Bytes,
541                 "average_sent_rate":     sentStatus.AvgRate,
542                 "average_received_rate": receivedStatus.AvgRate,
543                 "peer num":              sw.peers.Size(),
544         }).Info("disconnect with peer")
545 }