OSDN Git Service

P2P: fixed node startup id (#1573)
[bytom/bytom.git] / p2p / switch.go
index b3bffe2..ec4b73d 100644 (file)
@@ -1,6 +1,7 @@
 package p2p
 
 import (
+       "encoding/hex"
        "encoding/json"
        "fmt"
        "net"
@@ -35,6 +36,10 @@ var (
        ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
 )
 
+type discv interface {
+       ReadRandomNodes(buf []*discover.Node) (n int)
+}
+
 // Switch handles peer connections and exposes an API to receive incoming messages
 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
@@ -52,14 +57,47 @@ type Switch struct {
        dialing      *cmn.CMap
        nodeInfo     *NodeInfo             // our node info
        nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
-       discv        *discover.Network
+       discv        discv
        bannedPeer   map[string]time.Time
        db           dbm.DB
        mtx          sync.Mutex
 }
 
-// NewSwitch creates a new Switch with the given config.
-func NewSwitch(config *cfg.Config) *Switch {
+// NewSwitch create a new Switch and set discover.
+func NewSwitch(config *cfg.Config) (*Switch, error) {
+       var err error
+       var l Listener
+       var listenAddr string
+       var discv *discover.Network
+
+       blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
+       config.P2P.PrivateKey, err = config.NodeKey()
+       if err != nil {
+               return nil, err
+       }
+
+       bytes, err := hex.DecodeString(config.P2P.PrivateKey)
+       if err != nil {
+               return nil, err
+       }
+
+       var newKey [64]byte
+       copy(newKey[:], bytes)
+       privKey := crypto.PrivKeyEd25519(newKey)
+       if !config.VaultMode {
+               // Create listener
+               l, listenAddr = GetListener(config.P2P)
+               discv, err = discover.NewDiscover(config, &privKey, l.ExternalAddress().Port)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       return newSwitch(config, discv, blacklistDB, l, privKey, listenAddr)
+}
+
+// newSwitch creates a new Switch with the given config.
+func newSwitch(config *cfg.Config, discv discv, blacklistDB dbm.DB, l Listener, priv crypto.PrivKeyEd25519, listenAddr string) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -68,18 +106,20 @@ func NewSwitch(config *cfg.Config) *Switch {
                reactorsByCh: make(map[byte]Reactor),
                peers:        NewPeerSet(),
                dialing:      cmn.NewCMap(),
-               nodeInfo:     nil,
-               db:           dbm.NewDB("trusthistory", config.DBBackend, config.DBDir()),
+               nodePrivKey:  priv,
+               discv:        discv,
+               db:           blacklistDB,
+               nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr),
+               bannedPeer:   make(map[string]time.Time),
        }
-       sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
-       sw.bannedPeer = make(map[string]time.Time)
-       if datajson := sw.db.Get([]byte(bannedPeerKey)); datajson != nil {
-               if err := json.Unmarshal(datajson, &sw.bannedPeer); err != nil {
-                       return nil
-               }
+       if err := sw.loadBannedPeers(); err != nil {
+               return nil, err
        }
+
+       sw.AddListener(l)
+       sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
        trust.Init()
-       return sw
+       return sw, nil
 }
 
 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
@@ -119,12 +159,12 @@ func (sw *Switch) AddBannedPeer(ip string) error {
        defer sw.mtx.Unlock()
 
        sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
-       datajson, err := json.Marshal(sw.bannedPeer)
+       dataJSON, err := json.Marshal(sw.bannedPeer)
        if err != nil {
                return err
        }
 
-       sw.db.Set([]byte(bannedPeerKey), datajson)
+       sw.db.Set([]byte(bannedPeerKey), dataJSON)
        return nil
 }
 
@@ -134,7 +174,7 @@ func (sw *Switch) AddBannedPeer(ip string) error {
 // NOTE: This performs a blocking handshake before the peer is added.
 // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
 func (sw *Switch) AddPeer(pc *peerConn) error {
-       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.peerConfig.HandshakeTimeout))
+       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
        if err != nil {
                return err
        }
@@ -161,6 +201,7 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
                        return err
                }
        }
+
        return sw.peers.Add(peer)
 }
 
@@ -199,12 +240,12 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
 
        pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
        if err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on newOutboundPeerConn")
+               log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on newOutboundPeerConn")
                return err
        }
 
        if err = sw.AddPeer(pc); err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on switch AddPeer")
+               log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on switch AddPeer")
                pc.CloseConn()
                return err
        }
@@ -223,6 +264,17 @@ func (sw *Switch) IsListening() bool {
        return len(sw.listeners) > 0
 }
 
+// loadBannedPeers load banned peers from db
+func (sw *Switch) loadBannedPeers() error {
+       if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
+               if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
+
 // Listeners returns the list of listeners the switch listens on.
 // NOTE: Not goroutine safe.
 func (sw *Switch) Listeners() []Listener {
@@ -254,21 +306,6 @@ func (sw *Switch) Peers() *PeerSet {
        return sw.peers
 }
 
-// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
-       sw.nodeInfo = nodeInfo
-}
-
-// SetNodePrivKey sets the switch's private key for authenticated encryption.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
-       sw.nodePrivKey = nodePrivKey
-       if sw.nodeInfo != nil {
-               sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
-       }
-}
-
 // StopPeerForError disconnects from a peer due to external error.
 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
        log.WithFields(log.Fields{"peer": peer, " err": reason}).Debug("stopping peer for error")
@@ -285,14 +322,19 @@ func (sw *Switch) StopPeerGracefully(peerID string) {
 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
        peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
        if err != nil {
-               conn.Close()
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+               }
                return err
        }
 
        if err = sw.AddPeer(peerConn); err != nil {
-               conn.Close()
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+               }
                return err
        }
+
        return nil
 }
 
@@ -304,7 +346,10 @@ func (sw *Switch) checkBannedPeer(peer string) error {
                if time.Now().Before(banEnd) {
                        return ErrConnectBannedPeer
                }
-               sw.delBannedPeer(peer)
+
+               if err := sw.delBannedPeer(peer); err != nil {
+                       return err
+               }
        }
        return nil
 }
@@ -324,18 +369,18 @@ func (sw *Switch) delBannedPeer(addr string) error {
 }
 
 func (sw *Switch) filterConnByIP(ip string) error {
-       if ip == sw.nodeInfo.ListenHost() {
+       if ip == sw.nodeInfo.listenHost() {
                return ErrConnectSelf
        }
        return sw.checkBannedPeer(ip)
 }
 
 func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.RemoteAddrHost()); err != nil {
+       if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
                return err
        }
 
-       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
+       if sw.nodeInfo.getPubkey().Equals(peer.PubKey().Wrap()) {
                return ErrConnectSelf
        }
 
@@ -354,7 +399,9 @@ func (sw *Switch) listenerRoutine(l Listener) {
 
                // disconnect if we alrady have MaxNumPeers
                if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
-                       inConn.Close()
+                       if err := inConn.Close(); err != nil {
+                               log.WithFields(log.Fields{"remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+                       }
                        log.Info("Ignoring inbound connection: already have enough peers.")
                        continue
                }
@@ -367,11 +414,6 @@ func (sw *Switch) listenerRoutine(l Listener) {
        }
 }
 
-// SetDiscv connect the discv model to the switch
-func (sw *Switch) SetDiscv(discv *discover.Network) {
-       sw.discv = discv
-}
-
 func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
        if err := sw.DialPeerWithAddress(a); err != nil {
                log.WithFields(log.Fields{"addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
@@ -389,7 +431,7 @@ func (sw *Switch) ensureOutboundPeers() {
 
        connectedPeers := make(map[string]struct{})
        for _, peer := range sw.Peers().List() {
-               connectedPeers[peer.RemoteAddrHost()] = struct{}{}
+               connectedPeers[peer.remoteAddrHost()] = struct{}{}
        }
 
        var wg sync.WaitGroup
@@ -430,7 +472,11 @@ func (sw *Switch) ensureOutboundPeersRoutine() {
 }
 
 func (sw *Switch) startInitPeer(peer *Peer) error {
-       peer.Start() // spawn send/recv routines
+       // spawn send/recv routines
+       if _, err := peer.Start(); err != nil {
+               log.WithFields(log.Fields{"remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
+       }
+
        for _, reactor := range sw.reactors {
                if err := reactor.AddPeer(peer); err != nil {
                        return err