OSDN Git Service

p2p: refactor switch code and add test (#1538)
authoryahtoo <yahtoo.ma@gmail.com>
Mon, 18 Feb 2019 07:44:03 +0000 (15:44 +0800)
committerPaladz <yzhu101@uottawa.ca>
Mon, 18 Feb 2019 07:44:03 +0000 (15:44 +0800)
* p2p: refactor switch code and add test

* Adjust parameter position

* Resolve merging conflicts

* Fix test fail

* Resolve merging conflicts

* Fix review bugs

* Fix review bugs

* Fix review bug

* Fix review bug

17 files changed:
api/api.go
api/nodeinfo.go
cmd/bytomd/commands/run_node.go
netsync/handle.go
node/node.go
p2p/discover/dns_seeds.go [moved from p2p/dns_seeds.go with 97% similarity]
p2p/discover/dns_seeds_test.go [moved from p2p/dns_seeds_test.go with 99% similarity]
p2p/discover/net.go
p2p/discover/udp.go
p2p/listener.go
p2p/listener_test.go
p2p/node_info.go
p2p/peer.go
p2p/peer_test.go [new file with mode: 0644]
p2p/switch.go
p2p/switch_test.go [new file with mode: 0644]
p2p/test_util.go

index 6d2423d..c6945cf 100644 (file)
@@ -26,6 +26,7 @@ import (
        "github.com/bytom/net/http/static"
        "github.com/bytom/net/websocket"
        "github.com/bytom/netsync"
+       "github.com/bytom/p2p"
        "github.com/bytom/protocol"
        "github.com/bytom/wallet"
 )
@@ -105,7 +106,7 @@ func (wh *waitHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 
 // API is the scheduling center for server
 type API struct {
-       sync            *netsync.SyncManager
+       sync            NetSync
        wallet          *wallet.Wallet
        accessTokens    *accesstoken.CredentialStore
        chain           *protocol.Chain
@@ -168,8 +169,19 @@ func (a *API) StartServer(address string) {
        }()
 }
 
+type NetSync interface {
+       IsListening() bool
+       IsCaughtUp() bool
+       PeerCount() int
+       GetNetwork() string
+       BestPeer() *netsync.PeerInfo
+       DialPeerWithAddress(addr *p2p.NetAddress) error
+       GetPeerInfos() []*netsync.PeerInfo
+       StopPeer(peerID string) error
+}
+
 // NewAPI create and initialize the API
-func NewAPI(sync *netsync.SyncManager, wallet *wallet.Wallet, txfeeds *txfeed.Tracker, cpuMiner *cpuminer.CPUMiner, miningPool *miningpool.MiningPool, chain *protocol.Chain, config *cfg.Config, token *accesstoken.CredentialStore, dispatcher *event.Dispatcher, notificationMgr *websocket.WSNotificationManager) *API {
+func NewAPI(sync NetSync, wallet *wallet.Wallet, txfeeds *txfeed.Tracker, cpuMiner *cpuminer.CPUMiner, miningPool *miningpool.MiningPool, chain *protocol.Chain, config *cfg.Config, token *accesstoken.CredentialStore, dispatcher *event.Dispatcher, notificationMgr *websocket.WSNotificationManager) *API {
        api := &API{
                sync:          sync,
                wallet:        wallet,
index 6b95d90..bda4c19 100644 (file)
@@ -31,12 +31,12 @@ type NetInfo struct {
 // GetNodeInfo return net information
 func (a *API) GetNodeInfo() *NetInfo {
        info := &NetInfo{
-               Listening:    a.sync.Switch().IsListening(),
+               Listening:    a.sync.IsListening(),
                Syncing:      !a.sync.IsCaughtUp(),
                Mining:       a.cpuMiner.IsMining(),
-               PeerCount:    len(a.sync.Switch().Peers().List()),
+               PeerCount:    a.sync.PeerCount(),
                CurrentBlock: a.chain.BestBlockHeight(),
-               NetWorkID:    a.sync.NodeInfo().Network,
+               NetWorkID:    a.sync.GetNetwork(),
                Version: &VersionInfo{
                        Version: version.Version,
                        Update:  version.Status.VersionStatus(),
@@ -76,9 +76,8 @@ func (a *API) connectPeerByIpAndPort(ip string, port uint16) (*netsync.PeerInfo,
        }
 
        addr := p2p.NewNetAddressIPPort(netIp, port)
-       sw := a.sync.Switch()
 
-       if err := sw.DialPeerWithAddress(addr); err != nil {
+       if err := a.sync.DialPeerWithAddress(addr); err != nil {
                return nil, errors.Wrap(err, "can not connect to the address")
        }
        peer := a.getPeerInfoByAddr(addr.String())
index ed4e03c..587b2c5 100644 (file)
@@ -83,7 +83,7 @@ func runNode(cmd *cobra.Command, args []string) error {
                log.WithField("err", err).Fatal("failed to start node")
        }
 
-       nodeInfo := n.SyncManager().NodeInfo()
+       nodeInfo := n.NodeInfo()
        log.WithFields(log.Fields{
                "version":  nodeInfo.Version,
                "network":  nodeInfo.Network,
index 5b350a6..df4d54f 100644 (file)
@@ -1,27 +1,19 @@
 package netsync
 
 import (
-       "encoding/hex"
        "errors"
-       "net"
-       "path"
        "reflect"
-       "strconv"
-       "strings"
 
        log "github.com/sirupsen/logrus"
-       "github.com/tendermint/go-crypto"
-       cmn "github.com/tendermint/tmlibs/common"
 
        cfg "github.com/bytom/config"
        "github.com/bytom/consensus"
        "github.com/bytom/event"
        "github.com/bytom/p2p"
-       "github.com/bytom/p2p/discover"
        core "github.com/bytom/protocol"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
-       "github.com/bytom/version"
+       "github.com/tendermint/go-crypto"
 )
 
 const (
@@ -32,8 +24,7 @@ const (
 )
 
 var (
-       errInvalidSeedIP   = errors.New("seed ip is invalid")
-       errInvalidSeedPort = errors.New("seed port is invalid")
+       errVaultModeDialPeer = errors.New("can't dial peer in vault mode")
 )
 
 // Chain is the interface for Bytom core
@@ -51,12 +42,22 @@ type Chain interface {
        ValidateTx(*types.Tx) (bool, error)
 }
 
+type Switch interface {
+       AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
+       AddBannedPeer(string) error
+       StopPeerGracefully(string)
+       NodeInfo() *p2p.NodeInfo
+       Start() (bool, error)
+       Stop() bool
+       IsListening() bool
+       DialPeerWithAddress(addr *p2p.NetAddress) error
+       Peers() *p2p.PeerSet
+}
+
 //SyncManager Sync Manager is responsible for the business layer information synchronization
 type SyncManager struct {
-       sw          *p2p.Switch
-       genesisHash bc.Hash
-
-       privKey      crypto.PrivKeyEd25519 // local node's p2p key
+       sw           Switch
+       genesisHash  bc.Hash
        chain        Chain
        txPool       *core.TxPool
        blockFetcher *blockFetcher
@@ -72,21 +73,28 @@ type SyncManager struct {
        minedBlockSub   *event.Subscription
 }
 
-//NewSyncManager create a sync manager
+// CreateSyncManager create sync manager and set switch.
 func NewSyncManager(config *cfg.Config, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
-       genesisHeader, err := chain.GetHeaderByHeight(0)
+       sw, err := p2p.NewSwitch(config)
        if err != nil {
                return nil, err
        }
 
-       sw := p2p.NewSwitch(config)
+       return newSyncManager(config, sw, chain, txPool, dispatcher)
+}
+
+//NewSyncManager create a sync manager
+func newSyncManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
+       genesisHeader, err := chain.GetHeaderByHeight(0)
+       if err != nil {
+               return nil, err
+       }
        peers := newPeerSet(sw)
        manager := &SyncManager{
                sw:              sw,
                genesisHash:     genesisHeader.Hash(),
                txPool:          txPool,
                chain:           chain,
-               privKey:         crypto.GenPrivKeyEd25519(),
                blockFetcher:    newBlockFetcher(chain, peers),
                blockKeeper:     newBlockKeeper(chain, peers),
                peers:           peers,
@@ -97,25 +105,10 @@ func NewSyncManager(config *cfg.Config, chain Chain, txPool *core.TxPool, dispat
                eventDispatcher: dispatcher,
        }
 
-       protocolReactor := NewProtocolReactor(manager, manager.peers)
-       manager.sw.AddReactor("PROTOCOL", protocolReactor)
-
-       // Create & add listener
-       var listenerStatus bool
-       var l p2p.Listener
        if !config.VaultMode {
-               p, address := protocolAndAddress(manager.config.P2P.ListenAddress)
-               l, listenerStatus = p2p.NewDefaultListener(p, address, manager.config.P2P.SkipUPNP)
-               manager.sw.AddListener(l)
-
-               discv, err := initDiscover(config, &manager.privKey, l.ExternalAddress().Port)
-               if err != nil {
-                       return nil, err
-               }
-               manager.sw.SetDiscv(discv)
+               protocolReactor := NewProtocolReactor(manager, peers)
+               manager.sw.AddReactor("PROTOCOL", protocolReactor)
        }
-       manager.sw.SetNodeInfo(manager.makeNodeInfo(listenerStatus))
-       manager.sw.SetNodePrivKey(manager.privKey)
        return manager, nil
 }
 
@@ -128,11 +121,23 @@ func (sm *SyncManager) BestPeer() *PeerInfo {
        return nil
 }
 
+func (sm *SyncManager) DialPeerWithAddress(addr *p2p.NetAddress) error {
+       if sm.config.VaultMode {
+               return errVaultModeDialPeer
+       }
+
+       return sm.sw.DialPeerWithAddress(addr)
+}
+
 // GetNewTxCh return a unconfirmed transaction feed channel
 func (sm *SyncManager) GetNewTxCh() chan *types.Tx {
        return sm.newTxCh
 }
 
+func (sm *SyncManager) GetNetwork() string {
+       return sm.config.ChainID
+}
+
 //GetPeerInfos return peer info of all peers
 func (sm *SyncManager) GetPeerInfos() []*PeerInfo {
        return sm.peers.getPeerInfos()
@@ -144,11 +149,6 @@ func (sm *SyncManager) IsCaughtUp() bool {
        return peer == nil || peer.Height() <= sm.chain.BestBlockHeight()
 }
 
-//NodeInfo get P2P peer node info
-func (sm *SyncManager) NodeInfo() *p2p.NodeInfo {
-       return sm.sw.NodeInfo()
-}
-
 //StopPeer try to stop peer by given ID
 func (sm *SyncManager) StopPeer(peerID string) error {
        if peer := sm.peers.getPeer(peerID); peer == nil {
@@ -158,11 +158,6 @@ func (sm *SyncManager) StopPeer(peerID string) error {
        return nil
 }
 
-//Switch get sync manager switch
-func (sm *SyncManager) Switch() *p2p.Switch {
-       return sm.sw
-}
-
 func (sm *SyncManager) handleBlockMsg(peer *peer, msg *BlockMessage) {
        block, err := msg.GetBlock()
        if err != nil {
@@ -360,6 +355,27 @@ func (sm *SyncManager) handleTransactionMsg(peer *peer, msg *TransactionMessage)
        }
 }
 
+func (sm *SyncManager) IsListening() bool {
+       if sm.config.VaultMode {
+               return false
+       }
+       return sm.sw.IsListening()
+}
+
+func (sm *SyncManager) NodeInfo() *p2p.NodeInfo {
+       if sm.config.VaultMode {
+               return p2p.NewNodeInfo(sm.config, crypto.PubKeyEd25519{}, "")
+       }
+       return sm.sw.NodeInfo()
+}
+
+func (sm *SyncManager) PeerCount() int {
+       if sm.config.VaultMode {
+               return 0
+       }
+       return len(sm.sw.Peers().List())
+}
+
 func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg BlockchainMessage) {
        peer := sm.peers.getPeer(basePeer.ID())
        if peer == nil && msgType != StatusResponseByte && msgType != StatusRequestByte {
@@ -425,124 +441,34 @@ func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg Blockchai
        }
 }
 
-// Defaults to tcp
-func protocolAndAddress(listenAddr string) (string, string) {
-       p, address := "tcp", listenAddr
-       parts := strings.SplitN(address, "://", 2)
-       if len(parts) == 2 {
-               p, address = parts[0], parts[1]
-       }
-       return p, address
-}
-
-func (sm *SyncManager) makeNodeInfo(listenerStatus bool) *p2p.NodeInfo {
-       nodeInfo := &p2p.NodeInfo{
-               PubKey:  sm.privKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
-               Moniker: sm.config.Moniker,
-               Network: sm.config.ChainID,
-               Version: version.Version,
-               Other:   []string{strconv.FormatUint(uint64(consensus.DefaultServices), 10)},
-       }
-
-       if !sm.sw.IsListening() {
-               return nodeInfo
-       }
-
-       p2pListener := sm.sw.Listeners()[0]
-
-       // We assume that the rpcListener has the same ExternalAddress.
-       // This is probably true because both P2P and RPC listeners use UPnP,
-       // except of course if the rpc is only bound to localhost
-       if listenerStatus {
-               nodeInfo.ListenAddr = cmn.Fmt("%v:%v", p2pListener.ExternalAddress().IP.String(), p2pListener.ExternalAddress().Port)
-       } else {
-               nodeInfo.ListenAddr = cmn.Fmt("%v:%v", p2pListener.InternalAddress().IP.String(), p2pListener.InternalAddress().Port)
+func (sm *SyncManager) Start() error {
+       var err error
+       if _, err = sm.sw.Start(); err != nil {
+               log.Error("switch start err")
+               return err
        }
-       return nodeInfo
-}
 
-//Start start sync manager service
-func (sm *SyncManager) Start() {
-       _, err := sm.sw.Start()
-       if err != nil {
-               cmn.Exit(cmn.Fmt("fail on start SyncManager: %v", err))
-       }
        // broadcast transactions
        go sm.txBroadcastLoop()
 
        sm.minedBlockSub, err = sm.eventDispatcher.Subscribe(event.NewMinedBlockEvent{})
        if err != nil {
-               cmn.Exit(cmn.Fmt("fail on start SyncManager: %v", err))
+               return err
        }
 
        go sm.minedBroadcastLoop()
        go sm.txSyncLoop()
+
+       return nil
 }
 
 //Stop stop sync manager
 func (sm *SyncManager) Stop() {
        close(sm.quitSync)
        sm.minedBlockSub.Unsubscribe()
-       sm.sw.Stop()
-}
-
-func initDiscover(config *cfg.Config, priv *crypto.PrivKeyEd25519, port uint16) (*discover.Network, error) {
-       addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
-       if err != nil {
-               return nil, err
-       }
-
-       conn, err := net.ListenUDP("udp", addr)
-       if err != nil {
-               return nil, err
-       }
-
-       realaddr := conn.LocalAddr().(*net.UDPAddr)
-       ntab, err := discover.ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover.db"), nil)
-       if err != nil {
-               return nil, err
-       }
-
-       seeds, err := p2p.QueryDNSSeeds(net.LookupHost)
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
-       }
-
-       if config.P2P.Seeds != "" {
-               codedSeeds := strings.Split(config.P2P.Seeds, ",")
-               for _, codedSeed := range codedSeeds {
-                       ip, port, err := net.SplitHostPort(codedSeed)
-                       if err != nil {
-                               return nil, err
-                       }
-
-                       if validIP := net.ParseIP(ip); validIP == nil {
-                               return nil, errInvalidSeedIP
-                       }
-
-                       if _, err := strconv.ParseUint(port, 10, 16); err != nil {
-                               return nil, errInvalidSeedPort
-                       }
-
-                       seeds = append(seeds, codedSeed)
-               }
-       }
-
-       if len(seeds) == 0 {
-               return ntab, nil
-       }
-
-       var nodes []*discover.Node
-       for _, seed := range seeds {
-               version.Status.AddSeed(seed)
-               url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
-               nodes = append(nodes, discover.MustParseNode(url))
-       }
-
-       if err = ntab.SetFallbackNodes(nodes); err != nil {
-               return nil, err
+       if !sm.config.VaultMode {
+               sm.sw.Stop()
        }
-       return ntab, nil
 }
 
 func (sm *SyncManager) minedBroadcastLoop() {
index ccef22b..65b742d 100644 (file)
@@ -5,7 +5,6 @@ import (
        "errors"
        "net"
        "net/http"
-       _ "net/http/pprof"
        "os"
        "path/filepath"
 
@@ -31,19 +30,17 @@ import (
        "github.com/bytom/mining/tensority"
        "github.com/bytom/net/websocket"
        "github.com/bytom/netsync"
+       "github.com/bytom/p2p"
        "github.com/bytom/protocol"
        w "github.com/bytom/wallet"
 )
 
-const (
-       webHost           = "http://127.0.0.1"
-       maxNewBlockChSize = 1024
-)
+const webHost = "http://127.0.0.1"
 
+// Node represent bytom node
 type Node struct {
        cmn.BaseService
 
-       // config
        config          *cfg.Config
        eventDispatcher *event.Dispatcher
        syncManager     *netsync.SyncManager
@@ -59,6 +56,7 @@ type Node struct {
        miningEnable    bool
 }
 
+// NewNode create bytom node
 func NewNode(config *cfg.Config) *Node {
        ctx := context.Background()
        if err := lockDataDirectory(config); err != nil {
@@ -84,10 +82,10 @@ func NewNode(config *cfg.Config) *Node {
                cmn.Exit(cmn.Fmt("Failed to create chain structure: %v", err))
        }
 
-       var accounts *account.Manager = nil
-       var assets *asset.Registry = nil
-       var wallet *w.Wallet = nil
-       var txFeed *txfeed.Tracker = nil
+       var accounts *account.Manager
+       var assets *asset.Registry
+       var wallet *w.Wallet
+       var txFeed *txfeed.Tracker
 
        txFeedDB := dbm.NewDB("txfeeds", config.DBBackend, config.DBDir())
        txFeed = txfeed.NewTracker(txFeedDB, chain)
@@ -116,10 +114,11 @@ func NewNode(config *cfg.Config) *Node {
                        wallet.RescanBlocks()
                }
        }
+
        dispatcher := event.NewDispatcher()
        syncManager, err := netsync.NewSyncManager(config, chain, txPool, dispatcher)
        if err != nil {
-               cmn.Exit(cmn.Fmt("create sync manager failed: %v", err))
+               cmn.Exit(cmn.Fmt("Failed to create sync manager: %v", err))
        }
 
        notificationMgr := websocket.NewWsNotificationManager(config.Websocket.MaxNumWebsockets, config.Websocket.MaxNumConcurrentReqs, chain)
@@ -233,7 +232,7 @@ func launchWebBrowser(port string) {
        }
 }
 
-func (n *Node) initAndstartApiServer() {
+func (n *Node) initAndstartAPIServer() {
        n.api = api.NewAPI(n.syncManager, n.wallet, n.txfeed, n.cpuMiner, n.miningPool, n.chain, n.config, n.accessTokens, n.eventDispatcher, n.notificationMgr)
 
        listenAddr := env.String("LISTEN", n.config.ApiAddress)
@@ -251,9 +250,12 @@ func (n *Node) OnStart() error {
                }
        }
        if !n.config.VaultMode {
-               n.syncManager.Start()
+               if err := n.syncManager.Start(); err != nil {
+                       return err
+               }
        }
-       n.initAndstartApiServer()
+
+       n.initAndstartAPIServer()
        n.notificationMgr.Start()
        if !n.config.Web.Closed {
                _, port, err := net.SplitHostPort(n.config.ApiAddress)
@@ -286,8 +288,8 @@ func (n *Node) RunForever() {
        })
 }
 
-func (n *Node) SyncManager() *netsync.SyncManager {
-       return n.syncManager
+func (n *Node) NodeInfo() *p2p.NodeInfo {
+       return n.syncManager.NodeInfo()
 }
 
 func (n *Node) MiningPool() *miningpool.MiningPool {
similarity index 97%
rename from p2p/dns_seeds.go
rename to p2p/discover/dns_seeds.go
index e15bbf8..235f29f 100644 (file)
@@ -1,4 +1,4 @@
-package p2p
+package discover
 
 import (
        "net"
@@ -10,8 +10,6 @@ import (
        "github.com/bytom/errors"
 )
 
-const logModule = "p2p"
-
 var (
        errInvalidIP     = errors.New("invalid ip address")
        errDNSTimeout    = errors.New("get dns seed timeout")
similarity index 99%
rename from p2p/dns_seeds_test.go
rename to p2p/discover/dns_seeds_test.go
index 7da4681..3a31a10 100644 (file)
@@ -1,4 +1,4 @@
-package p2p
+package discover
 
 import (
        "reflect"
index cfcfed6..55fcf10 100644 (file)
@@ -167,6 +167,10 @@ func (net *Network) Self() *Node {
        return net.tab.self
 }
 
+func (net *Network) selfIP() net.IP {
+       return net.tab.self.IP
+}
+
 // ReadRandomNodes fills the given slice with random nodes from the
 // table. It will not write the same node more than once. The nodes in
 // the slice are copies and can be modified by the caller.
index cac340a..a230a6a 100644 (file)
@@ -3,9 +3,13 @@ package discover
 import (
        "bytes"
        "crypto/ecdsa"
+       "encoding/hex"
        "errors"
        "fmt"
        "net"
+       "path"
+       "strconv"
+       "strings"
        "time"
 
        log "github.com/sirupsen/logrus"
@@ -13,10 +17,15 @@ import (
        "github.com/tendermint/go-wire"
 
        "github.com/bytom/common"
+       cfg "github.com/bytom/config"
        "github.com/bytom/p2p/netutil"
+       "github.com/bytom/version"
 )
 
-const Version = 4
+const (
+       Version   = 4
+       logModule = "discover"
+)
 
 // Errors
 var (
@@ -28,6 +37,8 @@ var (
        errTimeout          = errors.New("RPC timeout")
        errClockWarp        = errors.New("reply deadline too far in the future")
        errClosed           = errors.New("socket closed")
+       errInvalidSeedIP    = errors.New("seed ip is invalid")
+       errInvalidSeedPort  = errors.New("seed port is invalid")
 )
 
 // Timeouts
@@ -238,13 +249,76 @@ type conn interface {
        LocalAddr() net.Addr
 }
 
+type netWork interface {
+       reqReadPacket(pkt ingressPacket)
+       selfIP() net.IP
+}
+
 // udp implements the RPC protocol.
 type udp struct {
        conn        conn
        priv        *crypto.PrivKeyEd25519
        ourEndpoint rpcEndpoint
        //nat         nat.Interface
-       net *Network
+       net netWork
+}
+
+func NewDiscover(config *cfg.Config, priv *crypto.PrivKeyEd25519, port uint16) (*Network, error) {
+       addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
+       if err != nil {
+               return nil, err
+       }
+
+       conn, err := net.ListenUDP("udp", addr)
+       if err != nil {
+               return nil, err
+       }
+
+       realaddr := conn.LocalAddr().(*net.UDPAddr)
+       ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover.db"), nil)
+       if err != nil {
+               return nil, err
+       }
+       seeds, err := QueryDNSSeeds(net.LookupHost)
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
+       }
+
+       if config.P2P.Seeds != "" {
+               codedSeeds := strings.Split(config.P2P.Seeds, ",")
+               for _, codedSeed := range codedSeeds {
+                       ip, port, err := net.SplitHostPort(codedSeed)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       if validIP := net.ParseIP(ip); validIP == nil {
+                               return nil, errInvalidSeedIP
+                       }
+
+                       if _, err := strconv.ParseUint(port, 10, 16); err != nil {
+                               return nil, errInvalidSeedPort
+                       }
+
+                       seeds = append(seeds, codedSeed)
+               }
+       }
+
+       if len(seeds) == 0 {
+               return ntab, nil
+       }
+
+       var nodes []*Node
+       for _, seed := range seeds {
+               version.Status.AddSeed(seed)
+               url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
+               nodes = append(nodes, MustParseNode(url))
+       }
+
+       if err = ntab.SetFallbackNodes(nodes); err != nil {
+               return nil, err
+       }
+       return ntab, nil
 }
 
 // ListenUDP returns a new table that listens for UDP packets on laddr.
@@ -330,7 +404,7 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
        p := topicNodes{Echo: queryHash}
        var sent bool
        for _, result := range nodes {
-               if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
+               if result.IP.Equal(t.net.selfIP()) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
                        p.Nodes = append(p.Nodes, nodeToRPC(result))
                }
                if len(p.Nodes) == maxTopicNodes {
index 122b1d3..5ab438c 100644 (file)
@@ -4,11 +4,13 @@ import (
        "fmt"
        "net"
        "strconv"
+       "strings"
        "time"
 
        log "github.com/sirupsen/logrus"
        cmn "github.com/tendermint/tmlibs/common"
 
+       cfg "github.com/bytom/config"
        "github.com/bytom/errors"
        "github.com/bytom/p2p/upnp"
 )
@@ -28,6 +30,31 @@ type Listener interface {
        Stop() bool
 }
 
+// Defaults to tcp
+func protocolAndAddress(listenAddr string) (string, string) {
+       p, address := "tcp", listenAddr
+       parts := strings.SplitN(address, "://", 2)
+       if len(parts) == 2 {
+               p, address = parts[0], parts[1]
+       }
+       return p, address
+}
+
+// GetListener get listener and listen address.
+func GetListener(config *cfg.P2PConfig) (Listener, string) {
+       p, address := protocolAndAddress(config.ListenAddress)
+       l, listenerStatus := NewDefaultListener(p, address, config.SkipUPNP)
+
+       // We assume that the rpcListener has the same ExternalAddress.
+       // This is probably true because both P2P and RPC listeners use UPnP,
+       // except of course if the rpc is only bound to localhost
+       if listenerStatus {
+               return l, cmn.Fmt("%v:%v", l.ExternalAddress().IP.String(), l.ExternalAddress().Port)
+       }
+
+       return l, cmn.Fmt("%v:%v", l.InternalAddress().IP.String(), l.InternalAddress().Port)
+}
+
 //getUPNPExternalAddress UPNP external address discovery & port mapping
 func getUPNPExternalAddress(externalPort, internalPort int) (*NetAddress, error) {
        nat, err := upnp.Discover()
index bd2704b..c5c9a46 100644 (file)
@@ -9,16 +9,15 @@ import (
 
 func TestListener(t *testing.T) {
        // Create a listener
-       l, _ := NewDefaultListener("tcp", ":8001", true)
+       l, _ := NewDefaultListener("tcp", "localhost:8001", true)
 
        // Dial the listener
-       lAddr := l.ExternalAddress()
+       lAddr := l.InternalAddress()
        connOut, err := lAddr.Dial()
        if err != nil {
                t.Fatalf("Could not connect to listener address %v", lAddr)
-       } else {
-               t.Logf("Created a connection to listener address %v", lAddr)
        }
+
        connIn, ok := <-l.Connections()
        if !ok {
                t.Fatalf("Could not get inbound connection from listener")
index e826e3e..c8ee30d 100644 (file)
@@ -3,10 +3,10 @@ package p2p
 import (
        "fmt"
        "net"
-       "strconv"
 
-       crypto "github.com/tendermint/go-crypto"
+       "github.com/tendermint/go-crypto"
 
+       cfg "github.com/bytom/config"
        "github.com/bytom/version"
 )
 
@@ -23,6 +23,16 @@ type NodeInfo struct {
        Other      []string             `json:"other"`   // other application specific data
 }
 
+func NewNodeInfo(config *cfg.Config, pubkey crypto.PubKeyEd25519, listenAddr string) *NodeInfo {
+       return &NodeInfo{
+               PubKey:     pubkey,
+               Moniker:    config.Moniker,
+               Network:    config.ChainID,
+               ListenAddr: listenAddr,
+               Version:    version.Version,
+       }
+}
+
 // CompatibleWith checks if two NodeInfo are compatible with eachother.
 // CONTRACT: two nodes are compatible if the major version matches and network match
 func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
@@ -40,28 +50,27 @@ func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
        return nil
 }
 
+func (info *NodeInfo) getPubkey() crypto.PubKeyEd25519 {
+       return info.PubKey
+}
+
 //ListenHost peer listener ip address
-func (info *NodeInfo) ListenHost() string {
+func (info *NodeInfo) listenHost() string {
        host, _, _ := net.SplitHostPort(info.ListenAddr)
        return host
 }
 
-//ListenPort peer listener port
-func (info *NodeInfo) ListenPort() int {
-       _, port, _ := net.SplitHostPort(info.ListenAddr)
-       portInt, err := strconv.Atoi(port)
-       if err != nil {
-               return -1
-       }
-       return portInt
-}
-
 //RemoteAddrHost peer external ip address
-func (info *NodeInfo) RemoteAddrHost() string {
+func (info *NodeInfo) remoteAddrHost() string {
        host, _, _ := net.SplitHostPort(info.RemoteAddr)
        return host
 }
 
+//GetNetwork get node info network field
+func (info *NodeInfo) GetNetwork() string {
+       return info.Network
+}
+
 //String representation
 func (info NodeInfo) String() string {
        return fmt.Sprintf("NodeInfo{pk: %v, moniker: %v, network: %v [listen %v], version: %v (%v)}", info.PubKey, info.Moniker, info.Network, info.ListenAddr, info.Version, info.Other)
index 963c56b..42b8811 100644 (file)
@@ -6,6 +6,7 @@ import (
        "strconv"
        "time"
 
+       "github.com/btcsuite/go-socks/socks"
        "github.com/pkg/errors"
        log "github.com/sirupsen/logrus"
        crypto "github.com/tendermint/go-crypto"
@@ -16,7 +17,6 @@ import (
        cfg "github.com/bytom/config"
        "github.com/bytom/consensus"
        "github.com/bytom/p2p/connection"
-       "github.com/btcsuite/go-socks/socks"
 )
 
 // peerConn contains the raw connection and its config.
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
new file mode 100644 (file)
index 0000000..f927619
--- /dev/null
@@ -0,0 +1,183 @@
+package p2p
+
+import (
+       "fmt"
+       "net"
+       "testing"
+       "time"
+
+       "github.com/tendermint/go-crypto"
+
+       cfg "github.com/bytom/config"
+       conn "github.com/bytom/p2p/connection"
+       "github.com/bytom/version"
+)
+
+const testCh = 0x01
+
+func TestPeerBasic(t *testing.T) {
+       // simulate remote peer
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+
+       p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), cfg.DefaultP2PConfig())
+       if err != nil {
+               t.Fatal(err)
+       }
+       _, err = p.Start()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer p.Stop()
+}
+
+func TestPeerSend(t *testing.T) {
+       config := testCfg
+
+       // simulate remote peer
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config}
+       rp.Start()
+       defer rp.Stop()
+
+       p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config.P2P)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       _, err = p.Start()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       defer p.Stop()
+       if ok := p.CanSend(testCh); !ok {
+               t.Fatal("TestPeerSend send err")
+       }
+
+       if ok := p.TrySend(testCh, []byte("test date")); !ok {
+               t.Fatal("TestPeerSend try send err")
+       }
+}
+
+func createOutboundPeerAndPerformHandshake(
+       addr *NetAddress,
+       config *cfg.P2PConfig,
+) (*Peer, error) {
+       chDescs := []*conn.ChannelDescriptor{
+               {ID: testCh, Priority: 1},
+       }
+       reactorsByCh := map[byte]Reactor{testCh: NewTestReactor(chDescs, true)}
+       privkey := crypto.GenPrivKeyEd25519()
+       peerConfig := DefaultPeerConfig(config)
+       pc, err := newOutboundPeerConn(addr, privkey, peerConfig)
+       if err != nil {
+               return nil, err
+       }
+       nodeInfo, err := pc.HandshakeTimeout(&NodeInfo{
+               Moniker: "host_peer",
+               Network: "testing",
+               Version: "123.123.123",
+       }, 5*time.Second)
+       if err != nil {
+               fmt.Println(err)
+               return nil, err
+       }
+       p := newPeer(pc, nodeInfo, reactorsByCh, chDescs, nil)
+       return p, nil
+}
+
+type remotePeer struct {
+       PrivKey    crypto.PrivKeyEd25519
+       Config     *cfg.Config
+       addr       *NetAddress
+       quit       chan struct{}
+       listenAddr string
+}
+
+func (rp *remotePeer) Addr() *NetAddress {
+       return rp.addr
+}
+
+func (rp *remotePeer) Start() {
+       if rp.listenAddr == "" {
+               rp.listenAddr = "127.0.0.1:0"
+       }
+
+       l, e := net.Listen("tcp", rp.listenAddr) // any available address
+       if e != nil {
+               fmt.Println("net.Listen tcp :0:", e)
+       }
+       rp.addr = NewNetAddress(l.Addr())
+       rp.quit = make(chan struct{})
+       go rp.accept(l)
+}
+
+func (rp *remotePeer) Stop() {
+       close(rp.quit)
+}
+
+func (rp *remotePeer) accept(l net.Listener) {
+       conns := []net.Conn{}
+
+       for {
+               conn, err := l.Accept()
+               if err != nil {
+                       fmt.Println("Failed to accept conn:", err)
+               }
+
+               pc, err := newInboundPeerConn(conn, rp.PrivKey, rp.Config.P2P)
+               if err != nil {
+                       fmt.Println("Failed to create a peer:", err)
+               }
+
+               _, err = pc.HandshakeTimeout(&NodeInfo{
+                       PubKey:     rp.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
+                       Moniker:    "remote_peer",
+                       Network:    rp.Config.ChainID,
+                       Version:    version.Version,
+                       ListenAddr: l.Addr().String(),
+               }, 5*time.Second)
+               if err != nil {
+                       fmt.Println("Failed to perform handshake:", err)
+               }
+               conns = append(conns, conn)
+               select {
+               case <-rp.quit:
+                       for _, conn := range conns {
+                               if err := conn.Close(); err != nil {
+                                       fmt.Println(err)
+                               }
+                       }
+                       return
+               default:
+               }
+       }
+}
+
+type inboundPeer struct {
+       PrivKey crypto.PrivKeyEd25519
+       config  *cfg.Config
+}
+
+func (ip *inboundPeer) dial(addr *NetAddress) error {
+       pc, err := newOutboundPeerConn(addr, ip.PrivKey, DefaultPeerConfig(ip.config.P2P))
+       if err != nil {
+               fmt.Println("newOutboundPeerConn:", err)
+               return err
+       }
+
+       _, err = pc.HandshakeTimeout(&NodeInfo{
+               PubKey:     ip.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
+               Moniker:    "remote_peer",
+               Network:    ip.config.ChainID,
+               Version:    version.Version,
+               ListenAddr: addr.String(),
+       }, 5*time.Second)
+       if err != nil {
+               fmt.Println("Failed to perform handshake:", err)
+               return err
+       }
+
+       return nil
+}
index b3bffe2..0643cbc 100644 (file)
@@ -35,6 +35,10 @@ var (
        ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
 )
 
+type discv interface {
+       ReadRandomNodes(buf []*discover.Node) (n int)
+}
+
 // Switch handles peer connections and exposes an API to receive incoming messages
 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
@@ -52,14 +56,34 @@ type Switch struct {
        dialing      *cmn.CMap
        nodeInfo     *NodeInfo             // our node info
        nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
-       discv        *discover.Network
+       discv        discv
        bannedPeer   map[string]time.Time
        db           dbm.DB
        mtx          sync.Mutex
 }
 
-// NewSwitch creates a new Switch with the given config.
-func NewSwitch(config *cfg.Config) *Switch {
+// NewSwitch create a new Switch and set discover.
+func NewSwitch(config *cfg.Config) (*Switch, error) {
+       blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
+       privKey := crypto.GenPrivKeyEd25519()
+       var l Listener
+       var listenAddr string
+       var err error
+       var discv *discover.Network
+       if !config.VaultMode {
+               // Create listener
+               l, listenAddr = GetListener(config.P2P)
+               discv, err = discover.NewDiscover(config, &privKey, l.ExternalAddress().Port)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       return newSwitch(discv, blacklistDB, l, config, privKey, listenAddr)
+}
+
+// newSwitch creates a new Switch with the given config.
+func newSwitch(discv discv, blacklistDB dbm.DB, l Listener, config *cfg.Config, priv crypto.PrivKeyEd25519, listenAddr string) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -68,18 +92,20 @@ func NewSwitch(config *cfg.Config) *Switch {
                reactorsByCh: make(map[byte]Reactor),
                peers:        NewPeerSet(),
                dialing:      cmn.NewCMap(),
-               nodeInfo:     nil,
-               db:           dbm.NewDB("trusthistory", config.DBBackend, config.DBDir()),
+               nodePrivKey:  priv,
+               discv:        discv,
+               db:           blacklistDB,
+               nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr),
+               bannedPeer:   make(map[string]time.Time),
        }
-       sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
-       sw.bannedPeer = make(map[string]time.Time)
-       if datajson := sw.db.Get([]byte(bannedPeerKey)); datajson != nil {
-               if err := json.Unmarshal(datajson, &sw.bannedPeer); err != nil {
-                       return nil
-               }
+       if err := sw.loadBannedPeers(); err != nil {
+               return nil, err
        }
+
+       sw.AddListener(l)
+       sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
        trust.Init()
-       return sw
+       return sw, nil
 }
 
 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
@@ -119,12 +145,12 @@ func (sw *Switch) AddBannedPeer(ip string) error {
        defer sw.mtx.Unlock()
 
        sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
-       datajson, err := json.Marshal(sw.bannedPeer)
+       dataJSON, err := json.Marshal(sw.bannedPeer)
        if err != nil {
                return err
        }
 
-       sw.db.Set([]byte(bannedPeerKey), datajson)
+       sw.db.Set([]byte(bannedPeerKey), dataJSON)
        return nil
 }
 
@@ -134,7 +160,7 @@ func (sw *Switch) AddBannedPeer(ip string) error {
 // NOTE: This performs a blocking handshake before the peer is added.
 // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
 func (sw *Switch) AddPeer(pc *peerConn) error {
-       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.peerConfig.HandshakeTimeout))
+       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
        if err != nil {
                return err
        }
@@ -161,6 +187,7 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
                        return err
                }
        }
+
        return sw.peers.Add(peer)
 }
 
@@ -199,12 +226,12 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
 
        pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
        if err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on newOutboundPeerConn")
+               log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on newOutboundPeerConn")
                return err
        }
 
        if err = sw.AddPeer(pc); err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on switch AddPeer")
+               log.WithFields(log.Fields{"address": addr, " err": err}).Error("DialPeer fail on switch AddPeer")
                pc.CloseConn()
                return err
        }
@@ -223,6 +250,17 @@ func (sw *Switch) IsListening() bool {
        return len(sw.listeners) > 0
 }
 
+// loadBannedPeers load banned peers from db
+func (sw *Switch) loadBannedPeers() error {
+       if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
+               if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
+
 // Listeners returns the list of listeners the switch listens on.
 // NOTE: Not goroutine safe.
 func (sw *Switch) Listeners() []Listener {
@@ -254,21 +292,6 @@ func (sw *Switch) Peers() *PeerSet {
        return sw.peers
 }
 
-// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
-       sw.nodeInfo = nodeInfo
-}
-
-// SetNodePrivKey sets the switch's private key for authenticated encryption.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
-       sw.nodePrivKey = nodePrivKey
-       if sw.nodeInfo != nil {
-               sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
-       }
-}
-
 // StopPeerForError disconnects from a peer due to external error.
 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
        log.WithFields(log.Fields{"peer": peer, " err": reason}).Debug("stopping peer for error")
@@ -285,14 +308,19 @@ func (sw *Switch) StopPeerGracefully(peerID string) {
 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
        peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
        if err != nil {
-               conn.Close()
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+               }
                return err
        }
 
        if err = sw.AddPeer(peerConn); err != nil {
-               conn.Close()
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+               }
                return err
        }
+
        return nil
 }
 
@@ -304,7 +332,10 @@ func (sw *Switch) checkBannedPeer(peer string) error {
                if time.Now().Before(banEnd) {
                        return ErrConnectBannedPeer
                }
-               sw.delBannedPeer(peer)
+
+               if err := sw.delBannedPeer(peer); err != nil {
+                       return err
+               }
        }
        return nil
 }
@@ -324,18 +355,18 @@ func (sw *Switch) delBannedPeer(addr string) error {
 }
 
 func (sw *Switch) filterConnByIP(ip string) error {
-       if ip == sw.nodeInfo.ListenHost() {
+       if ip == sw.nodeInfo.listenHost() {
                return ErrConnectSelf
        }
        return sw.checkBannedPeer(ip)
 }
 
 func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.RemoteAddrHost()); err != nil {
+       if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
                return err
        }
 
-       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
+       if sw.nodeInfo.getPubkey().Equals(peer.PubKey().Wrap()) {
                return ErrConnectSelf
        }
 
@@ -354,7 +385,9 @@ func (sw *Switch) listenerRoutine(l Listener) {
 
                // disconnect if we alrady have MaxNumPeers
                if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
-                       inConn.Close()
+                       if err := inConn.Close(); err != nil {
+                               log.WithFields(log.Fields{"remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+                       }
                        log.Info("Ignoring inbound connection: already have enough peers.")
                        continue
                }
@@ -367,11 +400,6 @@ func (sw *Switch) listenerRoutine(l Listener) {
        }
 }
 
-// SetDiscv connect the discv model to the switch
-func (sw *Switch) SetDiscv(discv *discover.Network) {
-       sw.discv = discv
-}
-
 func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
        if err := sw.DialPeerWithAddress(a); err != nil {
                log.WithFields(log.Fields{"addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
@@ -389,7 +417,7 @@ func (sw *Switch) ensureOutboundPeers() {
 
        connectedPeers := make(map[string]struct{})
        for _, peer := range sw.Peers().List() {
-               connectedPeers[peer.RemoteAddrHost()] = struct{}{}
+               connectedPeers[peer.remoteAddrHost()] = struct{}{}
        }
 
        var wg sync.WaitGroup
@@ -430,7 +458,11 @@ func (sw *Switch) ensureOutboundPeersRoutine() {
 }
 
 func (sw *Switch) startInitPeer(peer *Peer) error {
-       peer.Start() // spawn send/recv routines
+       // spawn send/recv routines
+       if _, err := peer.Start(); err != nil {
+               log.WithFields(log.Fields{"remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
+       }
+
        for _, reactor := range sw.reactors {
                if err := reactor.AddPeer(peer); err != nil {
                        return err
diff --git a/p2p/switch_test.go b/p2p/switch_test.go
new file mode 100644 (file)
index 0000000..ce505a5
--- /dev/null
@@ -0,0 +1,312 @@
+package p2p
+
+import (
+       "io/ioutil"
+       "os"
+       "sync"
+       "testing"
+       "time"
+
+       "github.com/tendermint/go-crypto"
+       dbm "github.com/tendermint/tmlibs/db"
+
+       cfg "github.com/bytom/config"
+       "github.com/bytom/errors"
+       conn "github.com/bytom/p2p/connection"
+)
+
+var (
+       testCfg *cfg.Config
+)
+
+func init() {
+       testCfg = cfg.DefaultConfig()
+}
+
+/*
+Each peer has one `MConnection` (multiplex connection) instance.
+
+__multiplex__ *noun* a system or signal involving simultaneous transmission of
+several messages along a single channel of communication.
+
+Each `MConnection` handles message transmission on multiple abstract communication
+`Channel`s.  Each channel has a globally unique byte id.
+The byte id and the relative priorities of each `Channel` are configured upon
+initialization of the connection.
+
+There are two methods for sending messages:
+       func (m MConnection) Send(chID byte, msgBytes []byte) bool {}
+       func (m MConnection) TrySend(chID byte, msgBytes []byte}) bool {}
+
+`Send(chID, msgBytes)` is a blocking call that waits until `msg` is
+successfully queued for the channel with the given id byte `chID`, or until the
+request times out.  The message `msg` is serialized using Go-Amino.
+
+`TrySend(chID, msgBytes)` is a nonblocking call that returns false if the
+channel's queue is full.
+
+Inbound message bytes are handled with an onReceive callback function.
+*/
+type PeerMessage struct {
+       PeerID  string
+       Bytes   []byte
+       Counter int
+}
+
+type TestReactor struct {
+       BaseReactor
+
+       mtx          sync.Mutex
+       channels     []*conn.ChannelDescriptor
+       logMessages  bool
+       msgsCounter  int
+       msgsReceived map[byte][]PeerMessage
+}
+
+func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
+       tr := &TestReactor{
+               channels:     channels,
+               logMessages:  logMessages,
+               msgsReceived: make(map[byte][]PeerMessage),
+       }
+       tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
+
+       return tr
+}
+
+// GetChannels implements Reactor
+func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
+       return tr.channels
+}
+
+// OnStart implements BaseService
+func (tr *TestReactor) OnStart() error {
+       tr.BaseReactor.OnStart()
+       return nil
+}
+
+// OnStop implements BaseService
+func (tr *TestReactor) OnStop() {
+       tr.BaseReactor.OnStop()
+}
+
+// AddPeer implements Reactor by sending our state to peer.
+func (tr *TestReactor) AddPeer(peer *Peer) error {
+       return nil
+}
+
+// RemovePeer implements Reactor by removing peer from the pool.
+func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) {
+}
+
+// Receive implements Reactor by handling 4 types of messages (look below).
+func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {
+       if tr.logMessages {
+               tr.mtx.Lock()
+               defer tr.mtx.Unlock()
+               tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.ID(), msgBytes, tr.msgsCounter})
+               tr.msgsCounter++
+       }
+}
+
+func initSwitchFunc(sw *Switch) *Switch {
+       // Make two reactors of two channels each
+       sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
+               {ID: byte(0x00), Priority: 10},
+               {ID: byte(0x01), Priority: 10},
+       }, true))
+       sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
+               {ID: byte(0x02), Priority: 10},
+               {ID: byte(0x03), Priority: 10},
+       }, true))
+
+       return sw
+}
+
+//Test connect self.
+func TestFiltersOutItself(t *testing.T) {
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
+
+       s1 := MakeSwitch(testCfg, testDB, initSwitchFunc)
+       s1.Start()
+       defer s1.Stop()
+       // simulate s1 having a public key and creating a remote peer with the same key
+       rp := &remotePeer{PrivKey: s1.nodePrivKey, Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+       if err = s1.DialPeerWithAddress(rp.addr); errors.Root(err) != ErrConnectSelf {
+               t.Fatal(err)
+       }
+
+       //S1 dialing itself ip address
+       addr, err := NewNetAddressString(s1.NodeInfo().ListenAddr)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if err := s1.DialPeerWithAddress(addr); errors.Root(err) != ErrConnectSelf {
+               t.Fatal(err)
+       }
+}
+
+func TestDialBannedPeer(t *testing.T) {
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
+       s1 := MakeSwitch(testCfg, testDB, initSwitchFunc)
+       s1.Start()
+       defer s1.Stop()
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+       s1.AddBannedPeer(rp.addr.IP.String())
+       if err := s1.DialPeerWithAddress(rp.addr); errors.Root(err) != ErrConnectBannedPeer {
+               t.Fatal(err)
+       }
+
+       s1.delBannedPeer(rp.addr.IP.String())
+       if err := s1.DialPeerWithAddress(rp.addr); err != nil {
+               t.Fatal(err)
+       }
+}
+
+func TestDuplicatePeer(t *testing.T) {
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
+       s1 := MakeSwitch(testCfg, testDB, initSwitchFunc)
+       s1.Start()
+       defer s1.Stop()
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+       if err = s1.DialPeerWithAddress(rp.addr); err != nil {
+               t.Fatal(err)
+       }
+
+       if err = s1.DialPeerWithAddress(rp.addr); errors.Root(err) != ErrDuplicatePeer {
+               t.Fatal(err)
+       }
+
+       inp := &inboundPeer{PrivKey: crypto.GenPrivKeyEd25519(), config: testCfg}
+       addr, err := NewNetAddressString(s1.NodeInfo().ListenAddr)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if err = inp.dial(addr); err != nil {
+               t.Fatal(err)
+       }
+
+       inp1 := &inboundPeer{PrivKey: inp.PrivKey, config: testCfg}
+
+       if err = inp1.dial(addr); err != nil {
+               t.Fatal(err)
+       }
+
+       time.Sleep(1 * time.Second)
+       if outbound, inbound, dialing := s1.NumPeers(); outbound+inbound+dialing != 2 {
+               t.Fatal("TestSwitchAddInboundPeer peer size error")
+       }
+}
+
+func TestAddInboundPeer(t *testing.T) {
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
+       cfg := *testCfg
+       cfg.P2P.MaxNumPeers = 2
+       s1 := MakeSwitch(&cfg, testDB, initSwitchFunc)
+       s1.Start()
+       defer s1.Stop()
+
+       inp := &inboundPeer{PrivKey: crypto.GenPrivKeyEd25519(), config: testCfg}
+       addr, err := NewNetAddressString(s1.NodeInfo().ListenAddr)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if err := inp.dial(addr); err != nil {
+               t.Fatal(err)
+       }
+
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+       if err := s1.DialPeerWithAddress(rp.addr); err != nil {
+               t.Fatal(err)
+       }
+
+       if outbound, inbound, dialing := s1.NumPeers(); outbound+inbound+dialing != 2 {
+               t.Fatal("TestSwitchAddInboundPeer peer size error")
+       }
+       inp2 := &inboundPeer{PrivKey: crypto.GenPrivKeyEd25519(), config: testCfg}
+
+       if err := inp2.dial(addr); err == nil {
+               t.Fatal("TestSwitchAddInboundPeer MaxNumPeers limit error")
+       }
+}
+
+func TestStopPeer(t *testing.T) {
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
+       cfg := *testCfg
+       cfg.P2P.MaxNumPeers = 2
+       s1 := MakeSwitch(&cfg, testDB, initSwitchFunc)
+       s1.Start()
+       defer s1.Stop()
+
+       inp := &inboundPeer{PrivKey: crypto.GenPrivKeyEd25519(), config: testCfg}
+       addr, err := NewNetAddressString(s1.NodeInfo().ListenAddr)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if err := inp.dial(addr); err != nil {
+               t.Fatal(err)
+       }
+
+       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: testCfg}
+       rp.Start()
+       defer rp.Stop()
+       if err := s1.DialPeerWithAddress(rp.addr); err != nil {
+               t.Fatal(err)
+       }
+
+       if outbound, inbound, dialing := s1.NumPeers(); outbound+inbound+dialing != 2 {
+               t.Fatal("TestSwitchAddInboundPeer peer size error")
+       }
+
+       s1.StopPeerGracefully(s1.peers.list[0].Key)
+       if outbound, inbound, dialing := s1.NumPeers(); outbound+inbound+dialing != 1 {
+               t.Fatal("TestSwitchAddInboundPeer peer size error")
+       }
+
+       s1.StopPeerForError(s1.peers.list[0], "stop for test")
+       if outbound, inbound, dialing := s1.NumPeers(); outbound+inbound+dialing != 0 {
+               t.Fatal("TestSwitchAddInboundPeer peer size error")
+       }
+}
index de18e51..02a300d 100644 (file)
@@ -1,14 +1,17 @@
 package p2p
 
 import (
-       "math/rand"
        "net"
 
+       log "github.com/sirupsen/logrus"
        "github.com/tendermint/go-crypto"
        cmn "github.com/tendermint/tmlibs/common"
+       dbm "github.com/tendermint/tmlibs/db"
 
        cfg "github.com/bytom/config"
+       "github.com/bytom/errors"
        "github.com/bytom/p2p/connection"
+       "github.com/bytom/p2p/discover"
 )
 
 //PanicOnAddPeerErr add peer error
@@ -48,10 +51,13 @@ func CreateRoutableAddr() (addr string, netAddr *NetAddress) {
 // If connect==Connect2Switches, the switches will be fully connected.
 // initSwitch defines how the ith switch should be initialized (ie. with what reactors).
 // NOTE: panics if any switch fails to start.
-func MakeConnectedSwitches(cfg *cfg.Config, n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch {
+func MakeConnectedSwitches(cfg []*cfg.Config, n int, testDB dbm.DB, initSwitch func(*Switch) *Switch, connect func([]*Switch, int, int)) []*Switch {
+       if len(cfg) != n {
+               panic(errors.New("cfg number error"))
+       }
        switches := make([]*Switch, n)
        for i := 0; i < n; i++ {
-               switches[i] = MakeSwitch(cfg, i, "testing", "123.123.123", initSwitch)
+               switches[i] = MakeSwitch(cfg[i], testDB, initSwitch)
        }
 
        if err := startSwitches(switches); err != nil {
@@ -103,18 +109,23 @@ func startSwitches(switches []*Switch) error {
        return nil
 }
 
-func MakeSwitch(cfg *cfg.Config, i int, network, version string, initSwitch func(int, *Switch) *Switch) *Switch {
-       privKey := crypto.GenPrivKeyEd25519()
+type mockDiscv struct {
+}
+
+func (m *mockDiscv) ReadRandomNodes(buf []*discover.Node) (n int) {
+       return 0
+}
+
+func MakeSwitch(cfg *cfg.Config, testdb dbm.DB, initSwitch func(*Switch) *Switch) *Switch {
        // new switch, add reactors
        // TODO: let the config be passed in?
-       s := initSwitch(i, NewSwitch(cfg))
-       s.SetNodeInfo(&NodeInfo{
-               PubKey:     privKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
-               Moniker:    cmn.Fmt("switch%d", i),
-               Network:    network,
-               Version:    version,
-               ListenAddr: cmn.Fmt("%v:%v", network, rand.Intn(64512)+1023),
-       })
-       s.SetNodePrivKey(privKey)
+       privKey := crypto.GenPrivKeyEd25519()
+       l, listenAddr := GetListener(cfg.P2P)
+       sw, err := newSwitch(new(mockDiscv), testdb, l, cfg, privKey, listenAddr)
+       if err != nil {
+               log.Errorf("create switch error: %s", err)
+               return nil
+       }
+       s := initSwitch(sw)
        return s
 }