OSDN Git Service

p2p: refactor switch code and add test (#1538)
[bytom/bytom.git] / p2p / switch.go
1 package p2p
2
3 import (
4         "encoding/json"
5         "fmt"
6         "net"
7         "sync"
8         "time"
9
10         log "github.com/sirupsen/logrus"
11         "github.com/tendermint/go-crypto"
12         cmn "github.com/tendermint/tmlibs/common"
13         dbm "github.com/tendermint/tmlibs/db"
14
15         cfg "github.com/bytom/config"
16         "github.com/bytom/consensus"
17         "github.com/bytom/errors"
18         "github.com/bytom/p2p/connection"
19         "github.com/bytom/p2p/discover"
20         "github.com/bytom/p2p/trust"
21         "github.com/bytom/version"
22 )
23
24 const (
25         bannedPeerKey       = "BannedPeer"
26         defaultBanDuration  = time.Hour * 1
27         minNumOutboundPeers = 3
28 )
29
30 //pre-define errors for connecting fail
31 var (
32         ErrDuplicatePeer     = errors.New("Duplicate peer")
33         ErrConnectSelf       = errors.New("Connect self")
34         ErrConnectBannedPeer = errors.New("Connect banned peer")
35         ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
36 )
37
38 type discv interface {
39         ReadRandomNodes(buf []*discover.Node) (n int)
40 }
41
42 // Switch handles peer connections and exposes an API to receive incoming messages
43 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
44 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
45 // incoming messages are received on the reactor.
46 type Switch struct {
47         cmn.BaseService
48
49         Config       *cfg.Config
50         peerConfig   *PeerConfig
51         listeners    []Listener
52         reactors     map[string]Reactor
53         chDescs      []*connection.ChannelDescriptor
54         reactorsByCh map[byte]Reactor
55         peers        *PeerSet
56         dialing      *cmn.CMap
57         nodeInfo     *NodeInfo             // our node info
58         nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
59         discv        discv
60         bannedPeer   map[string]time.Time
61         db           dbm.DB
62         mtx          sync.Mutex
63 }
64
65 // NewSwitch create a new Switch and set discover.
66 func NewSwitch(config *cfg.Config) (*Switch, error) {
67         blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
68         privKey := crypto.GenPrivKeyEd25519()
69         var l Listener
70         var listenAddr string
71         var err error
72         var discv *discover.Network
73         if !config.VaultMode {
74                 // Create listener
75                 l, listenAddr = GetListener(config.P2P)
76                 discv, err = discover.NewDiscover(config, &privKey, l.ExternalAddress().Port)
77                 if err != nil {
78                         return nil, err
79                 }
80         }
81
82         return newSwitch(discv, blacklistDB, l, config, privKey, listenAddr)
83 }
84
85 // newSwitch creates a new Switch with the given config.
86 func newSwitch(discv discv, blacklistDB dbm.DB, l Listener, config *cfg.Config, priv crypto.PrivKeyEd25519, listenAddr string) (*Switch, error) {
87         sw := &Switch{
88                 Config:       config,
89                 peerConfig:   DefaultPeerConfig(config.P2P),
90                 reactors:     make(map[string]Reactor),
91                 chDescs:      make([]*connection.ChannelDescriptor, 0),
92                 reactorsByCh: make(map[byte]Reactor),
93                 peers:        NewPeerSet(),
94                 dialing:      cmn.NewCMap(),
95                 nodePrivKey:  priv,
96                 discv:        discv,
97                 db:           blacklistDB,
98                 nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr),
99                 bannedPeer:   make(map[string]time.Time),
100         }
101         if err := sw.loadBannedPeers(); err != nil {
102                 return nil, err
103         }
104
105         sw.AddListener(l)
106         sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
107         trust.Init()
108         return sw, nil
109 }
110
111 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
112 func (sw *Switch) OnStart() error {
113         for _, reactor := range sw.reactors {
114                 if _, err := reactor.Start(); err != nil {
115                         return err
116                 }
117         }
118         for _, listener := range sw.listeners {
119                 go sw.listenerRoutine(listener)
120         }
121         go sw.ensureOutboundPeersRoutine()
122         return nil
123 }
124
125 // OnStop implements BaseService. It stops all listeners, peers, and reactors.
126 func (sw *Switch) OnStop() {
127         for _, listener := range sw.listeners {
128                 listener.Stop()
129         }
130         sw.listeners = nil
131
132         for _, peer := range sw.peers.List() {
133                 peer.Stop()
134                 sw.peers.Remove(peer)
135         }
136
137         for _, reactor := range sw.reactors {
138                 reactor.Stop()
139         }
140 }
141
142 //AddBannedPeer add peer to blacklist
143 func (sw *Switch) AddBannedPeer(ip string) error {
144         sw.mtx.Lock()
145         defer sw.mtx.Unlock()
146
147         sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
148         dataJSON, err := json.Marshal(sw.bannedPeer)
149         if err != nil {
150                 return err
151         }
152
153         sw.db.Set([]byte(bannedPeerKey), dataJSON)
154         return nil
155 }
156
157 // AddPeer performs the P2P handshake with a peer
158 // that already has a SecretConnection. If all goes well,
159 // it starts the peer and adds it to the switch.
160 // NOTE: This performs a blocking handshake before the peer is added.
161 // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
162 func (sw *Switch) AddPeer(pc *peerConn) error {
163         peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
164         if err != nil {
165                 return err
166         }
167
168         if err := version.Status.CheckUpdate(sw.nodeInfo.Version, peerNodeInfo.Version, peerNodeInfo.RemoteAddr); err != nil {
169                 return err
170         }
171         if err := sw.nodeInfo.CompatibleWith(peerNodeInfo); err != nil {
172                 return err
173         }
174
175         peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError)
176         if err := sw.filterConnByPeer(peer); err != nil {
177                 return err
178         }
179
180         if pc.outbound && !peer.ServiceFlag().IsEnable(consensus.SFFullNode) {
181                 return ErrConnectSpvPeer
182         }
183
184         // Start peer
185         if sw.IsRunning() {
186                 if err := sw.startInitPeer(peer); err != nil {
187                         return err
188                 }
189         }
190
191         return sw.peers.Add(peer)
192 }
193
194 // AddReactor adds the given reactor to the switch.
195 // NOTE: Not goroutine safe.
196 func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
197         // Validate the reactor.
198         // No two reactors can share the same channel.
199         for _, chDesc := range reactor.GetChannels() {
200                 chID := chDesc.ID
201                 if sw.reactorsByCh[chID] != nil {
202                         cmn.PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
203                 }
204                 sw.chDescs = append(sw.chDescs, chDesc)
205                 sw.reactorsByCh[chID] = reactor
206         }
207         sw.reactors[name] = reactor
208         reactor.SetSwitch(sw)
209         return reactor
210 }
211
212 // AddListener adds the given listener to the switch for listening to incoming peer connections.
213 // NOTE: Not goroutine safe.
214 func (sw *Switch) AddListener(l Listener) {
215         sw.listeners = append(sw.listeners, l)
216 }
217
218 //DialPeerWithAddress dial node from net address
219 func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
220         log.Debug("Dialing peer address:", addr)
221         sw.dialing.Set(addr.IP.String(), addr)
222         defer sw.dialing.Delete(addr.IP.String())
223         if err := sw.filterConnByIP(addr.IP.String()); err != nil {
224                 return err
225         }
226
227         pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
228         if err != nil {
229                 log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on newOutboundPeerConn")
230                 return err
231         }
232
233         if err = sw.AddPeer(pc); err != nil {
234                 log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on switch AddPeer")
235                 pc.CloseConn()
236                 return err
237         }
238         log.Debug("DialPeer added peer:", addr)
239         return nil
240 }
241
242 //IsDialing prevent duplicate dialing
243 func (sw *Switch) IsDialing(addr *NetAddress) bool {
244         return sw.dialing.Has(addr.IP.String())
245 }
246
247 // IsListening returns true if the switch has at least one listener.
248 // NOTE: Not goroutine safe.
249 func (sw *Switch) IsListening() bool {
250         return len(sw.listeners) > 0
251 }
252
253 // loadBannedPeers load banned peers from db
254 func (sw *Switch) loadBannedPeers() error {
255         if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
256                 if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
257                         return err
258                 }
259         }
260
261         return nil
262 }
263
264 // Listeners returns the list of listeners the switch listens on.
265 // NOTE: Not goroutine safe.
266 func (sw *Switch) Listeners() []Listener {
267         return sw.listeners
268 }
269
270 // NumPeers Returns the count of outbound/inbound and outbound-dialing peers.
271 func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
272         peers := sw.peers.List()
273         for _, peer := range peers {
274                 if peer.outbound {
275                         outbound++
276                 } else {
277                         inbound++
278                 }
279         }
280         dialing = sw.dialing.Size()
281         return
282 }
283
284 // NodeInfo returns the switch's NodeInfo.
285 // NOTE: Not goroutine safe.
286 func (sw *Switch) NodeInfo() *NodeInfo {
287         return sw.nodeInfo
288 }
289
290 //Peers return switch peerset
291 func (sw *Switch) Peers() *PeerSet {
292         return sw.peers
293 }
294
295 // StopPeerForError disconnects from a peer due to external error.
296 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
297         log.WithFields(log.Fields{"peer": peer, " err": reason}).Debug("stopping peer for error")
298         sw.stopAndRemovePeer(peer, reason)
299 }
300
301 // StopPeerGracefully disconnect from a peer gracefully.
302 func (sw *Switch) StopPeerGracefully(peerID string) {
303         if peer := sw.peers.Get(peerID); peer != nil {
304                 sw.stopAndRemovePeer(peer, nil)
305         }
306 }
307
308 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
309         peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
310         if err != nil {
311                 if err := conn.Close(); err != nil {
312                         log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
313                 }
314                 return err
315         }
316
317         if err = sw.AddPeer(peerConn); err != nil {
318                 if err := conn.Close(); err != nil {
319                         log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
320                 }
321                 return err
322         }
323
324         return nil
325 }
326
327 func (sw *Switch) checkBannedPeer(peer string) error {
328         sw.mtx.Lock()
329         defer sw.mtx.Unlock()
330
331         if banEnd, ok := sw.bannedPeer[peer]; ok {
332                 if time.Now().Before(banEnd) {
333                         return ErrConnectBannedPeer
334                 }
335
336                 if err := sw.delBannedPeer(peer); err != nil {
337                         return err
338                 }
339         }
340         return nil
341 }
342
343 func (sw *Switch) delBannedPeer(addr string) error {
344         sw.mtx.Lock()
345         defer sw.mtx.Unlock()
346
347         delete(sw.bannedPeer, addr)
348         datajson, err := json.Marshal(sw.bannedPeer)
349         if err != nil {
350                 return err
351         }
352
353         sw.db.Set([]byte(bannedPeerKey), datajson)
354         return nil
355 }
356
357 func (sw *Switch) filterConnByIP(ip string) error {
358         if ip == sw.nodeInfo.listenHost() {
359                 return ErrConnectSelf
360         }
361         return sw.checkBannedPeer(ip)
362 }
363
364 func (sw *Switch) filterConnByPeer(peer *Peer) error {
365         if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
366                 return err
367         }
368
369         if sw.nodeInfo.getPubkey().Equals(peer.PubKey().Wrap()) {
370                 return ErrConnectSelf
371         }
372
373         if sw.peers.Has(peer.Key) {
374                 return ErrDuplicatePeer
375         }
376         return nil
377 }
378
379 func (sw *Switch) listenerRoutine(l Listener) {
380         for {
381                 inConn, ok := <-l.Connections()
382                 if !ok {
383                         break
384                 }
385
386                 // disconnect if we alrady have MaxNumPeers
387                 if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
388                         if err := inConn.Close(); err != nil {
389                                 log.WithFields(log.Fields{"remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
390                         }
391                         log.Info("Ignoring inbound connection: already have enough peers.")
392                         continue
393                 }
394
395                 // New inbound connection!
396                 if err := sw.addPeerWithConnection(inConn); err != nil {
397                         log.Info("Ignoring inbound connection: error while adding peer.", " address:", inConn.RemoteAddr().String(), " error:", err)
398                         continue
399                 }
400         }
401 }
402
403 func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
404         if err := sw.DialPeerWithAddress(a); err != nil {
405                 log.WithFields(log.Fields{"addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
406         }
407         wg.Done()
408 }
409
410 func (sw *Switch) ensureOutboundPeers() {
411         numOutPeers, _, numDialing := sw.NumPeers()
412         numToDial := (minNumOutboundPeers - (numOutPeers + numDialing))
413         log.WithFields(log.Fields{"numOutPeers": numOutPeers, "numDialing": numDialing, "numToDial": numToDial}).Debug("ensure peers")
414         if numToDial <= 0 {
415                 return
416         }
417
418         connectedPeers := make(map[string]struct{})
419         for _, peer := range sw.Peers().List() {
420                 connectedPeers[peer.remoteAddrHost()] = struct{}{}
421         }
422
423         var wg sync.WaitGroup
424         nodes := make([]*discover.Node, numToDial)
425         n := sw.discv.ReadRandomNodes(nodes)
426         for i := 0; i < n; i++ {
427                 try := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
428                 if sw.NodeInfo().ListenAddr == try.String() {
429                         continue
430                 }
431                 if dialling := sw.IsDialing(try); dialling {
432                         continue
433                 }
434                 if _, ok := connectedPeers[try.IP.String()]; ok {
435                         continue
436                 }
437
438                 wg.Add(1)
439                 go sw.dialPeerWorker(try, &wg)
440         }
441         wg.Wait()
442 }
443
444 func (sw *Switch) ensureOutboundPeersRoutine() {
445         sw.ensureOutboundPeers()
446
447         ticker := time.NewTicker(10 * time.Second)
448         defer ticker.Stop()
449
450         for {
451                 select {
452                 case <-ticker.C:
453                         sw.ensureOutboundPeers()
454                 case <-sw.Quit:
455                         return
456                 }
457         }
458 }
459
460 func (sw *Switch) startInitPeer(peer *Peer) error {
461         // spawn send/recv routines
462         if _, err := peer.Start(); err != nil {
463                 log.WithFields(log.Fields{"remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
464         }
465
466         for _, reactor := range sw.reactors {
467                 if err := reactor.AddPeer(peer); err != nil {
468                         return err
469                 }
470         }
471         return nil
472 }
473
474 func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
475         sw.peers.Remove(peer)
476         for _, reactor := range sw.reactors {
477                 reactor.RemovePeer(peer, reason)
478         }
479         peer.Stop()
480
481         sentStatus, receivedStatus := peer.TrafficStatus()
482         log.WithFields(log.Fields{
483                 "address":               peer.Addr().String(),
484                 "reason":                reason,
485                 "duration":              sentStatus.Duration.String(),
486                 "total_sent":            sentStatus.Bytes,
487                 "total_received":        receivedStatus.Bytes,
488                 "average_sent_rate":     sentStatus.AvgRate,
489                 "average_received_rate": receivedStatus.AvgRate,
490         }).Info("disconnect with peer")
491 }