OSDN Git Service

add unit test for netsync (#1162)
authorPaladz <yzhu101@uottawa.ca>
Wed, 25 Jul 2018 03:20:30 +0000 (11:20 +0800)
committerGitHub <noreply@github.com>
Wed, 25 Jul 2018 03:20:30 +0000 (11:20 +0800)
* add unit test for block_keeper

* fix unit test error

* add unit test for locateHeaders

* add unit test for locateBlocks

* add unit test for TestNextCheckpoint

* add unit test for TestRequireBlock

* add unit test for regular block sync

* add unit test for TestFastBlockSync

* fix golint

netsync/block_fetcher.go
netsync/block_keeper.go
netsync/block_keeper_test.go [new file with mode: 0644]
netsync/handle.go
netsync/protocol_reactor.go
netsync/tool_test.go [new file with mode: 0644]
test/mock/chain.go [new file with mode: 0644]

index b55b5c7..2949f7b 100644 (file)
@@ -4,7 +4,6 @@ import (
        log "github.com/sirupsen/logrus"
        "gopkg.in/karalabe/cookiejar.v2/collections/prque"
 
-       "github.com/bytom/protocol"
        "github.com/bytom/protocol/bc"
 )
 
@@ -17,7 +16,7 @@ const (
 // blockFetcher is responsible for accumulating block announcements from various peers
 // and scheduling them for retrieval.
 type blockFetcher struct {
-       chain *protocol.Chain
+       chain Chain
        peers *peerSet
 
        newBlockCh chan *blockMsg
@@ -26,7 +25,7 @@ type blockFetcher struct {
 }
 
 //NewBlockFetcher creates a block fetcher to retrieve blocks of the new mined.
-func newBlockFetcher(chain *protocol.Chain, peers *peerSet) *blockFetcher {
+func newBlockFetcher(chain Chain, peers *peerSet) *blockFetcher {
        f := &blockFetcher{
                chain:      chain,
                peers:      peers,
index e0408c1..3a96d15 100644 (file)
@@ -9,22 +9,23 @@ import (
        "github.com/bytom/consensus"
        "github.com/bytom/errors"
        "github.com/bytom/mining/tensority"
-       "github.com/bytom/protocol"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
 )
 
 const (
-       syncTimeout           = 30 * time.Second
-       syncCycle             = 5 * time.Second
-       blockProcessChSize    = 1024
-       blocksProcessChSize   = 128
-       headersProcessChSize  = 1024
-       maxBlockPerMsg        = 128
-       maxBlockHeadersPerMsg = 2048
+       syncCycle            = 5 * time.Second
+       blockProcessChSize   = 1024
+       blocksProcessChSize  = 128
+       headersProcessChSize = 1024
 )
 
 var (
+       maxBlockPerMsg        = 128
+       maxBlockHeadersPerMsg = uint64(2048)
+       syncTimeout           = 30 * time.Second
+
+       errAppendHeaders  = errors.New("fail to append list due to order dismatch")
        errRequestTimeout = errors.New("request timeout")
        errPeerDropped    = errors.New("Peer dropped")
        errPeerMisbehave  = errors.New("peer is misbehave")
@@ -46,7 +47,7 @@ type headersMsg struct {
 }
 
 type blockKeeper struct {
-       chain *protocol.Chain
+       chain Chain
        peers *peerSet
 
        syncPeer         *peer
@@ -57,7 +58,7 @@ type blockKeeper struct {
        headerList *list.List
 }
 
-func newBlockKeeper(chain *protocol.Chain, peers *peerSet) *blockKeeper {
+func newBlockKeeper(chain Chain, peers *peerSet) *blockKeeper {
        bk := &blockKeeper{
                chain:            chain,
                peers:            peers,
@@ -75,7 +76,7 @@ func (bk *blockKeeper) appendHeaderList(headers []*types.BlockHeader) error {
        for _, header := range headers {
                prevHeader := bk.headerList.Back().Value.(*types.BlockHeader)
                if prevHeader.Hash() != header.PreviousBlockHash {
-                       return errors.New("fail to append list due to order dismatch")
+                       return errAppendHeaders
                }
                bk.headerList.PushBack(header)
        }
@@ -105,7 +106,7 @@ func (bk *blockKeeper) blockLocator() []*bc.Hash {
                        break
                }
 
-               if len(locator) > 10 {
+               if len(locator) >= 9 {
                        step *= 2
                }
        }
@@ -181,7 +182,7 @@ func (bk *blockKeeper) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*t
 
        blocks := []*types.Block{}
        for i, header := range headers {
-               if i > maxBlockPerMsg {
+               if i >= maxBlockPerMsg {
                        break
                }
 
diff --git a/netsync/block_keeper_test.go b/netsync/block_keeper_test.go
new file mode 100644 (file)
index 0000000..57f2c7b
--- /dev/null
@@ -0,0 +1,512 @@
+package netsync
+
+import (
+       "container/list"
+       "testing"
+       "time"
+
+       "github.com/bytom/consensus"
+       "github.com/bytom/errors"
+       "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/bc/types"
+       "github.com/bytom/test/mock"
+       "github.com/bytom/testutil"
+)
+
+func TestAppendHeaderList(t *testing.T) {
+       blocks := mockBlocks(nil, 7)
+       cases := []struct {
+               originalHeaders []*types.BlockHeader
+               inputHeaders    []*types.BlockHeader
+               wantHeaders     []*types.BlockHeader
+               err             error
+       }{
+               {
+                       originalHeaders: []*types.BlockHeader{&blocks[0].BlockHeader},
+                       inputHeaders:    []*types.BlockHeader{&blocks[1].BlockHeader, &blocks[2].BlockHeader},
+                       wantHeaders:     []*types.BlockHeader{&blocks[0].BlockHeader, &blocks[1].BlockHeader, &blocks[2].BlockHeader},
+                       err:             nil,
+               },
+               {
+                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
+                       inputHeaders:    []*types.BlockHeader{&blocks[6].BlockHeader},
+                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader, &blocks[6].BlockHeader},
+                       err:             nil,
+               },
+               {
+                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
+                       inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader},
+                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
+                       err:             errAppendHeaders,
+               },
+               {
+                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
+                       inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader, &blocks[6].BlockHeader},
+                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
+                       err:             errAppendHeaders,
+               },
+               {
+                       originalHeaders: []*types.BlockHeader{&blocks[2].BlockHeader},
+                       inputHeaders:    []*types.BlockHeader{&blocks[3].BlockHeader, &blocks[4].BlockHeader, &blocks[6].BlockHeader},
+                       wantHeaders:     []*types.BlockHeader{&blocks[2].BlockHeader, &blocks[3].BlockHeader, &blocks[4].BlockHeader},
+                       err:             errAppendHeaders,
+               },
+       }
+
+       for i, c := range cases {
+               bk := &blockKeeper{headerList: list.New()}
+               for _, header := range c.originalHeaders {
+                       bk.headerList.PushBack(header)
+               }
+
+               if err := bk.appendHeaderList(c.inputHeaders); err != c.err {
+                       t.Errorf("case %d: got error %v want error %v", i, err, c.err)
+               }
+
+               gotHeaders := []*types.BlockHeader{}
+               for e := bk.headerList.Front(); e != nil; e = e.Next() {
+                       gotHeaders = append(gotHeaders, e.Value.(*types.BlockHeader))
+               }
+
+               if !testutil.DeepEqual(gotHeaders, c.wantHeaders) {
+                       t.Errorf("case %d: got %v want %v", i, gotHeaders, c.wantHeaders)
+               }
+       }
+}
+
+func TestBlockLocator(t *testing.T) {
+       blocks := mockBlocks(nil, 500)
+       cases := []struct {
+               bestHeight uint64
+               wantHeight []uint64
+       }{
+               {
+                       bestHeight: 0,
+                       wantHeight: []uint64{0},
+               },
+               {
+                       bestHeight: 1,
+                       wantHeight: []uint64{1, 0},
+               },
+               {
+                       bestHeight: 7,
+                       wantHeight: []uint64{7, 6, 5, 4, 3, 2, 1, 0},
+               },
+               {
+                       bestHeight: 10,
+                       wantHeight: []uint64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
+               },
+               {
+                       bestHeight: 100,
+                       wantHeight: []uint64{100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 89, 85, 77, 61, 29, 0},
+               },
+               {
+                       bestHeight: 500,
+                       wantHeight: []uint64{500, 499, 498, 497, 496, 495, 494, 493, 492, 491, 489, 485, 477, 461, 429, 365, 237, 0},
+               },
+       }
+
+       for i, c := range cases {
+               mockChain := mock.NewChain()
+               bk := &blockKeeper{chain: mockChain}
+               mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader)
+               for i := uint64(0); i <= c.bestHeight; i++ {
+                       mockChain.SetBlockByHeight(i, blocks[i])
+               }
+
+               want := []*bc.Hash{}
+               for _, i := range c.wantHeight {
+                       hash := blocks[i].Hash()
+                       want = append(want, &hash)
+               }
+
+               if got := bk.blockLocator(); !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestFastBlockSync(t *testing.T) {
+       maxBlockPerMsg = 5
+       maxBlockHeadersPerMsg = 10
+       baseChain := mockBlocks(nil, 300)
+
+       cases := []struct {
+               syncTimeout time.Duration
+               aBlocks     []*types.Block
+               bBlocks     []*types.Block
+               checkPoint  *consensus.Checkpoint
+               want        []*types.Block
+               err         error
+       }{
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:100],
+                       bBlocks:     baseChain[:301],
+                       checkPoint: &consensus.Checkpoint{
+                               Height: baseChain[250].Height,
+                               Hash:   baseChain[250].Hash(),
+                       },
+                       want: baseChain[:251],
+                       err:  nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:100],
+                       bBlocks:     baseChain[:301],
+                       checkPoint: &consensus.Checkpoint{
+                               Height: baseChain[100].Height,
+                               Hash:   baseChain[100].Hash(),
+                       },
+                       want: baseChain[:101],
+                       err:  nil,
+               },
+               {
+                       syncTimeout: 1 * time.Millisecond,
+                       aBlocks:     baseChain[:100],
+                       bBlocks:     baseChain[:100],
+                       checkPoint: &consensus.Checkpoint{
+                               Height: baseChain[200].Height,
+                               Hash:   baseChain[200].Hash(),
+                       },
+                       want: baseChain[:100],
+                       err:  errRequestTimeout,
+               },
+       }
+
+       for i, c := range cases {
+               syncTimeout = c.syncTimeout
+               a := mockSync(c.aBlocks)
+               b := mockSync(c.bBlocks)
+               netWork := NewNetWork()
+               netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
+               netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
+               if err := netWork.HandsShake(a, b); err != nil {
+                       t.Errorf("fail on peer hands shake %v", err)
+               }
+
+               a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
+               if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err {
+                       t.Errorf("case %d: got %v want %v", i, err, c.err)
+               }
+
+               got := []*types.Block{}
+               for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
+                       block, err := a.chain.GetBlockByHeight(i)
+                       if err != nil {
+                               t.Error("case %d got err %v", err)
+                       }
+                       got = append(got, block)
+               }
+
+               if !testutil.DeepEqual(got, c.want) {
+                       t.Errorf("case %d: got %v want %v", i, got, c.want)
+               }
+       }
+}
+
+func TestLocateBlocks(t *testing.T) {
+       maxBlockPerMsg = 5
+       blocks := mockBlocks(nil, 100)
+       cases := []struct {
+               locator    []uint64
+               stopHash   bc.Hash
+               wantHeight []uint64
+       }{
+               {
+                       locator:    []uint64{20},
+                       stopHash:   blocks[100].Hash(),
+                       wantHeight: []uint64{21, 22, 23, 24, 25},
+               },
+       }
+
+       mockChain := mock.NewChain()
+       bk := &blockKeeper{chain: mockChain}
+       for _, block := range blocks {
+               mockChain.SetBlockByHeight(block.Height, block)
+       }
+
+       for i, c := range cases {
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.Block{}
+               for _, i := range c.wantHeight {
+                       want = append(want, blocks[i])
+               }
+
+               got, _ := bk.locateBlocks(locator, &c.stopHash)
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestLocateHeaders(t *testing.T) {
+       maxBlockHeadersPerMsg = 10
+       blocks := mockBlocks(nil, 150)
+       cases := []struct {
+               chainHeight uint64
+               locator     []uint64
+               stopHash    bc.Hash
+               wantHeight  []uint64
+               err         bool
+       }{
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{},
+                       stopHash:    blocks[100].Hash(),
+                       wantHeight:  []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    blocks[100].Hash(),
+                       wantHeight:  []uint64{21, 22, 23, 24, 25, 26, 27, 28, 29, 30},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    blocks[24].Hash(),
+                       wantHeight:  []uint64{21, 22, 23, 24},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    blocks[20].Hash(),
+                       wantHeight:  []uint64{},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    bc.Hash{},
+                       wantHeight:  []uint64{},
+                       err:         true,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{120, 70},
+                       stopHash:    blocks[78].Hash(),
+                       wantHeight:  []uint64{71, 72, 73, 74, 75, 76, 77, 78},
+                       err:         false,
+               },
+       }
+
+       for i, c := range cases {
+               mockChain := mock.NewChain()
+               bk := &blockKeeper{chain: mockChain}
+               for i := uint64(0); i <= c.chainHeight; i++ {
+                       mockChain.SetBlockByHeight(i, blocks[i])
+               }
+
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.BlockHeader{}
+               for _, i := range c.wantHeight {
+                       want = append(want, &blocks[i].BlockHeader)
+               }
+
+               got, err := bk.locateHeaders(locator, &c.stopHash)
+               if err != nil != c.err {
+                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
+               }
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestNextCheckpoint(t *testing.T) {
+       cases := []struct {
+               checkPoints []consensus.Checkpoint
+               bestHeight  uint64
+               want        *consensus.Checkpoint
+       }{
+               {
+                       checkPoints: []consensus.Checkpoint{},
+                       bestHeight:  5000,
+                       want:        nil,
+               },
+               {
+                       checkPoints: []consensus.Checkpoint{
+                               {10000, bc.Hash{V0: 1}},
+                       },
+                       bestHeight: 5000,
+                       want:       &consensus.Checkpoint{10000, bc.Hash{V0: 1}},
+               },
+               {
+                       checkPoints: []consensus.Checkpoint{
+                               {10000, bc.Hash{V0: 1}},
+                               {20000, bc.Hash{V0: 2}},
+                               {30000, bc.Hash{V0: 3}},
+                       },
+                       bestHeight: 15000,
+                       want:       &consensus.Checkpoint{20000, bc.Hash{V0: 2}},
+               },
+               {
+                       checkPoints: []consensus.Checkpoint{
+                               {10000, bc.Hash{V0: 1}},
+                               {20000, bc.Hash{V0: 2}},
+                               {30000, bc.Hash{V0: 3}},
+                       },
+                       bestHeight: 10000,
+                       want:       &consensus.Checkpoint{20000, bc.Hash{V0: 2}},
+               },
+               {
+                       checkPoints: []consensus.Checkpoint{
+                               {10000, bc.Hash{V0: 1}},
+                               {20000, bc.Hash{V0: 2}},
+                               {30000, bc.Hash{V0: 3}},
+                       },
+                       bestHeight: 35000,
+                       want:       nil,
+               },
+       }
+
+       mockChain := mock.NewChain()
+       for i, c := range cases {
+               consensus.ActiveNetParams.Checkpoints = c.checkPoints
+               mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight})
+               bk := &blockKeeper{chain: mockChain}
+
+               if got := bk.nextCheckpoint(); !testutil.DeepEqual(got, c.want) {
+                       t.Errorf("case %d: got %v want %v", i, got, c.want)
+               }
+       }
+}
+
+func TestRegularBlockSync(t *testing.T) {
+       baseChain := mockBlocks(nil, 50)
+       chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
+       chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
+       cases := []struct {
+               syncTimeout time.Duration
+               aBlocks     []*types.Block
+               bBlocks     []*types.Block
+               syncHeight  uint64
+               want        []*types.Block
+               err         error
+       }{
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:20],
+                       bBlocks:     baseChain[:50],
+                       syncHeight:  45,
+                       want:        baseChain[:46],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     chainX,
+                       bBlocks:     chainY,
+                       syncHeight:  70,
+                       want:        chainY,
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     chainX[:52],
+                       bBlocks:     chainY[:53],
+                       syncHeight:  52,
+                       want:        chainY[:53],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 1 * time.Millisecond,
+                       aBlocks:     baseChain,
+                       bBlocks:     baseChain,
+                       syncHeight:  52,
+                       want:        baseChain,
+                       err:         errRequestTimeout,
+               },
+       }
+
+       for i, c := range cases {
+               syncTimeout = c.syncTimeout
+               a := mockSync(c.aBlocks)
+               b := mockSync(c.bBlocks)
+               netWork := NewNetWork()
+               netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
+               netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
+               if err := netWork.HandsShake(a, b); err != nil {
+                       t.Errorf("fail on peer hands shake %v", err)
+               }
+
+               a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
+               if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err {
+                       t.Errorf("case %d: got %v want %v", i, err, c.err)
+               }
+
+               got := []*types.Block{}
+               for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
+                       block, err := a.chain.GetBlockByHeight(i)
+                       if err != nil {
+                               t.Error("case %d got err %v", err)
+                       }
+                       got = append(got, block)
+               }
+
+               if !testutil.DeepEqual(got, c.want) {
+                       t.Errorf("case %d: got %v want %v", i, got, c.want)
+               }
+       }
+}
+
+func TestRequireBlock(t *testing.T) {
+       blocks := mockBlocks(nil, 5)
+       a := mockSync(blocks[:1])
+       b := mockSync(blocks[:5])
+       netWork := NewNetWork()
+       netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
+       netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
+       if err := netWork.HandsShake(a, b); err != nil {
+               t.Errorf("fail on peer hands shake %v", err)
+       }
+
+       a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
+       b.blockKeeper.syncPeer = b.peers.getPeer("test node A")
+       cases := []struct {
+               syncTimeout   time.Duration
+               testNode      *SyncManager
+               requireHeight uint64
+               want          *types.Block
+               err           error
+       }{
+               {
+                       syncTimeout:   30 * time.Second,
+                       testNode:      a,
+                       requireHeight: 4,
+                       want:          blocks[4],
+                       err:           nil,
+               },
+               {
+                       syncTimeout:   1 * time.Millisecond,
+                       testNode:      b,
+                       requireHeight: 4,
+                       want:          nil,
+                       err:           errRequestTimeout,
+               },
+       }
+
+       for i, c := range cases {
+               syncTimeout = c.syncTimeout
+               got, err := c.testNode.blockKeeper.requireBlock(c.requireHeight)
+               if !testutil.DeepEqual(got, c.want) {
+                       t.Errorf("case %d: got %v want %v", i, got, c.want)
+               }
+               if errors.Root(err) != c.err {
+                       t.Errorf("case %d: got %v want %v", i, err, c.err)
+               }
+       }
+}
index 033a93b..6a4526b 100644 (file)
@@ -5,6 +5,7 @@ import (
        "errors"
        "net"
        "path"
+       "reflect"
        "strconv"
        "strings"
 
@@ -27,13 +28,27 @@ const (
        maxTxChanSize = 10000
 )
 
+// Chain is the interface for Bytom core
+type Chain interface {
+       BestBlockHeader() *types.BlockHeader
+       BestBlockHeight() uint64
+       CalcNextSeed(*bc.Hash) (*bc.Hash, error)
+       GetBlockByHash(*bc.Hash) (*types.Block, error)
+       GetBlockByHeight(uint64) (*types.Block, error)
+       GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
+       GetHeaderByHeight(uint64) (*types.BlockHeader, error)
+       InMainChain(bc.Hash) bool
+       ProcessBlock(*types.Block) (bool, error)
+       ValidateTx(*types.Tx) (bool, error)
+}
+
 //SyncManager Sync Manager is responsible for the business layer information synchronization
 type SyncManager struct {
        sw          *p2p.Switch
        genesisHash bc.Hash
 
        privKey      crypto.PrivKeyEd25519 // local node's p2p key
-       chain        *core.Chain
+       chain        Chain
        txPool       *core.TxPool
        blockFetcher *blockFetcher
        blockKeeper  *blockKeeper
@@ -47,7 +62,7 @@ type SyncManager struct {
 }
 
 //NewSyncManager create a sync manager
-func NewSyncManager(config *cfg.Config, chain *core.Chain, txPool *core.TxPool, newBlockCh chan *bc.Hash) (*SyncManager, error) {
+func NewSyncManager(config *cfg.Config, chain Chain, txPool *core.TxPool, newBlockCh chan *bc.Hash) (*SyncManager, error) {
        genesisHeader, err := chain.GetHeaderByHeight(0)
        if err != nil {
                return nil, err
@@ -288,6 +303,48 @@ func (sm *SyncManager) handleTransactionMsg(peer *peer, msg *TransactionMessage)
        }
 }
 
+func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg BlockchainMessage) {
+       peer := sm.peers.getPeer(basePeer.ID())
+       if peer == nil && msgType != StatusResponseByte && msgType != StatusRequestByte {
+               return
+       }
+
+       switch msg := msg.(type) {
+       case *GetBlockMessage:
+               sm.handleGetBlockMsg(peer, msg)
+
+       case *BlockMessage:
+               sm.handleBlockMsg(peer, msg)
+
+       case *StatusRequestMessage:
+               sm.handleStatusRequestMsg(basePeer)
+
+       case *StatusResponseMessage:
+               sm.handleStatusResponseMsg(basePeer, msg)
+
+       case *TransactionMessage:
+               sm.handleTransactionMsg(peer, msg)
+
+       case *MineBlockMessage:
+               sm.handleMineBlockMsg(peer, msg)
+
+       case *GetHeadersMessage:
+               sm.handleGetHeadersMsg(peer, msg)
+
+       case *HeadersMessage:
+               sm.handleHeadersMsg(peer, msg)
+
+       case *GetBlocksMessage:
+               sm.handleGetBlocksMsg(peer, msg)
+
+       case *BlocksMessage:
+               sm.handleBlocksMsg(peer, msg)
+
+       default:
+               log.Errorf("unknown message type %v", reflect.TypeOf(msg))
+       }
+}
+
 // Defaults to tcp
 func protocolAndAddress(listenAddr string) (string, string) {
        p, address := "tcp", listenAddr
index cc45b88..9220e86 100644 (file)
@@ -1,7 +1,6 @@
 package netsync
 
 import (
-       "reflect"
        "time"
 
        log "github.com/sirupsen/logrus"
@@ -96,43 +95,5 @@ func (pr *ProtocolReactor) Receive(chID byte, src *p2p.Peer, msgBytes []byte) {
                return
        }
 
-       peer := pr.peers.getPeer(src.Key)
-       if peer == nil && msgType != StatusResponseByte && msgType != StatusRequestByte {
-               return
-       }
-
-       switch msg := msg.(type) {
-       case *GetBlockMessage:
-               pr.sm.handleGetBlockMsg(peer, msg)
-
-       case *BlockMessage:
-               pr.sm.handleBlockMsg(peer, msg)
-
-       case *StatusRequestMessage:
-               pr.sm.handleStatusRequestMsg(src)
-
-       case *StatusResponseMessage:
-               pr.sm.handleStatusResponseMsg(src, msg)
-
-       case *TransactionMessage:
-               pr.sm.handleTransactionMsg(peer, msg)
-
-       case *MineBlockMessage:
-               pr.sm.handleMineBlockMsg(peer, msg)
-
-       case *GetHeadersMessage:
-               pr.sm.handleGetHeadersMsg(peer, msg)
-
-       case *HeadersMessage:
-               pr.sm.handleHeadersMsg(peer, msg)
-
-       case *GetBlocksMessage:
-               pr.sm.handleGetBlocksMsg(peer, msg)
-
-       case *BlocksMessage:
-               pr.sm.handleBlocksMsg(peer, msg)
-
-       default:
-               log.Errorf("unknown message type %v", reflect.TypeOf(msg))
-       }
+       pr.sm.processMsg(src, msgType, msg)
 }
diff --git a/netsync/tool_test.go b/netsync/tool_test.go
new file mode 100644 (file)
index 0000000..88e6d37
--- /dev/null
@@ -0,0 +1,159 @@
+package netsync
+
+import (
+       "errors"
+       "math/rand"
+       "net"
+
+       wire "github.com/tendermint/go-wire"
+
+       "github.com/bytom/consensus"
+       "github.com/bytom/protocol/bc/types"
+       "github.com/bytom/test/mock"
+)
+
+type P2PPeer struct {
+       id   string
+       ip   *net.IPAddr
+       flag consensus.ServiceFlag
+
+       srcPeer    *P2PPeer
+       remoteNode *SyncManager
+       msgCh      chan []byte
+       async      bool
+}
+
+func NewP2PPeer(addr, id string, flag consensus.ServiceFlag) *P2PPeer {
+       return &P2PPeer{
+               id:    id,
+               ip:    &net.IPAddr{IP: net.ParseIP(addr)},
+               flag:  flag,
+               msgCh: make(chan []byte),
+               async: false,
+       }
+}
+
+func (p *P2PPeer) Addr() net.Addr {
+       return p.ip
+}
+
+func (p *P2PPeer) ID() string {
+       return p.id
+}
+
+func (p *P2PPeer) ServiceFlag() consensus.ServiceFlag {
+       return p.flag
+}
+
+func (p *P2PPeer) SetConnection(srcPeer *P2PPeer, node *SyncManager) {
+       p.srcPeer = srcPeer
+       p.remoteNode = node
+}
+
+func (p *P2PPeer) TrySend(b byte, msg interface{}) bool {
+       msgBytes := wire.BinaryBytes(msg)
+       if p.async {
+               p.msgCh <- msgBytes
+       } else {
+               msgType, msg, _ := DecodeMessage(msgBytes)
+               p.remoteNode.processMsg(p.srcPeer, msgType, msg)
+       }
+       return true
+}
+
+func (p *P2PPeer) setAsync(b bool) {
+       p.async = b
+}
+
+func (p *P2PPeer) postMan() {
+       for msgBytes := range p.msgCh {
+               msgType, msg, _ := DecodeMessage(msgBytes)
+               p.remoteNode.processMsg(p.srcPeer, msgType, msg)
+       }
+}
+
+type PeerSet struct{}
+
+func NewPeerSet() *PeerSet {
+       return &PeerSet{}
+}
+
+func (ps *PeerSet) AddBannedPeer(string) error { return nil }
+func (ps *PeerSet) StopPeerGracefully(string)  {}
+
+type NetWork struct {
+       nodes map[*SyncManager]P2PPeer
+}
+
+func NewNetWork() *NetWork {
+       return &NetWork{map[*SyncManager]P2PPeer{}}
+}
+
+func (nw *NetWork) Register(node *SyncManager, addr, id string, flag consensus.ServiceFlag) {
+       peer := NewP2PPeer(addr, id, flag)
+       nw.nodes[node] = *peer
+}
+
+func (nw *NetWork) HandsShake(nodeA, nodeB *SyncManager) error {
+       B2A, ok := nw.nodes[nodeA]
+       if !ok {
+               return errors.New("can't find nodeA's p2p peer on network")
+       }
+       A2B, ok := nw.nodes[nodeB]
+       if !ok {
+               return errors.New("can't find nodeB's p2p peer on network")
+       }
+
+       A2B.SetConnection(&B2A, nodeB)
+       B2A.SetConnection(&A2B, nodeA)
+       go A2B.postMan()
+       go B2A.postMan()
+
+       nodeA.handleStatusRequestMsg(&A2B)
+       nodeB.handleStatusRequestMsg(&B2A)
+
+       A2B.setAsync(true)
+       B2A.setAsync(true)
+       return nil
+}
+
+func mockBlocks(startBlock *types.Block, height uint64) []*types.Block {
+       blocks := []*types.Block{}
+       indexBlock := &types.Block{}
+       if startBlock == nil {
+               indexBlock = &types.Block{BlockHeader: types.BlockHeader{Nonce: uint64(rand.Uint32())}}
+               blocks = append(blocks, indexBlock)
+       } else {
+               indexBlock = startBlock
+       }
+
+       for indexBlock.Height < height {
+               block := &types.Block{
+                       BlockHeader: types.BlockHeader{
+                               Height:            indexBlock.Height + 1,
+                               PreviousBlockHash: indexBlock.Hash(),
+                               Nonce:             uint64(rand.Uint32()),
+                       },
+               }
+               blocks = append(blocks, block)
+               indexBlock = block
+       }
+       return blocks
+}
+
+func mockSync(blocks []*types.Block) *SyncManager {
+       chain := mock.NewChain()
+       peers := newPeerSet(NewPeerSet())
+       chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
+       for _, block := range blocks {
+               chain.SetBlockByHeight(block.Height, block)
+       }
+
+       genesis, _ := chain.GetHeaderByHeight(0)
+       return &SyncManager{
+               genesisHash: genesis.Hash(),
+               chain:       chain,
+               blockKeeper: newBlockKeeper(chain, peers),
+               peers:       peers,
+       }
+}
diff --git a/test/mock/chain.go b/test/mock/chain.go
new file mode 100644 (file)
index 0000000..a3c1c64
--- /dev/null
@@ -0,0 +1,120 @@
+package mock
+
+import (
+       "errors"
+
+       "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/bc/types"
+)
+
+type Chain struct {
+       bestBlockHeader *types.BlockHeader
+       heightMap       map[uint64]*types.Block
+       blockMap        map[bc.Hash]*types.Block
+
+       prevOrphans map[bc.Hash]*types.Block
+}
+
+func NewChain() *Chain {
+       return &Chain{
+               heightMap:   map[uint64]*types.Block{},
+               blockMap:    map[bc.Hash]*types.Block{},
+               prevOrphans: make(map[bc.Hash]*types.Block),
+       }
+}
+
+func (c *Chain) BestBlockHeader() *types.BlockHeader {
+       return c.bestBlockHeader
+}
+
+func (c *Chain) BestBlockHeight() uint64 {
+       return c.bestBlockHeader.Height
+}
+
+func (c *Chain) CalcNextSeed(hash *bc.Hash) (*bc.Hash, error) {
+       return &bc.Hash{V0: hash.V1, V1: hash.V2, V2: hash.V3, V3: hash.V0}, nil
+}
+
+func (c *Chain) GetBlockByHash(hash *bc.Hash) (*types.Block, error) {
+       block, ok := c.blockMap[*hash]
+       if !ok {
+               return nil, errors.New("can't find block")
+       }
+       return block, nil
+}
+
+func (c *Chain) GetBlockByHeight(height uint64) (*types.Block, error) {
+       block, ok := c.heightMap[height]
+       if !ok {
+               return nil, errors.New("can't find block")
+       }
+       return block, nil
+}
+
+func (c *Chain) GetHeaderByHash(hash *bc.Hash) (*types.BlockHeader, error) {
+       block, ok := c.blockMap[*hash]
+       if !ok {
+               return nil, errors.New("can't find block")
+       }
+       return &block.BlockHeader, nil
+}
+
+func (c *Chain) GetHeaderByHeight(height uint64) (*types.BlockHeader, error) {
+       block, ok := c.heightMap[height]
+       if !ok {
+               return nil, errors.New("can't find block")
+       }
+       return &block.BlockHeader, nil
+}
+
+func (c *Chain) InMainChain(hash bc.Hash) bool {
+       block, ok := c.blockMap[hash]
+       if !ok {
+               return false
+       }
+       return c.heightMap[block.Height] == block
+}
+
+func (c *Chain) ProcessBlock(block *types.Block) (bool, error) {
+       if c.bestBlockHeader.Hash() == block.PreviousBlockHash {
+               c.heightMap[block.Height] = block
+               c.blockMap[block.Hash()] = block
+               c.bestBlockHeader = &block.BlockHeader
+               return false, nil
+       }
+
+       if _, ok := c.blockMap[block.PreviousBlockHash]; !ok {
+               c.prevOrphans[block.PreviousBlockHash] = block
+               return true, nil
+       }
+
+       c.blockMap[block.Hash()] = block
+       for c.prevOrphans[block.Hash()] != nil {
+               block = c.prevOrphans[block.Hash()]
+               c.blockMap[block.Hash()] = block
+       }
+
+       if block.Height < c.bestBlockHeader.Height {
+               return false, nil
+       }
+
+       c.bestBlockHeader = &block.BlockHeader
+       for !c.InMainChain(block.Hash()) {
+               c.heightMap[block.Height] = block
+               block = c.blockMap[block.PreviousBlockHash]
+       }
+       return false, nil
+}
+
+func (c *Chain) SetBestBlockHeader(header *types.BlockHeader) {
+       c.bestBlockHeader = header
+}
+
+func (c *Chain) SetBlockByHeight(height uint64, block *types.Block) {
+       c.heightMap[height] = block
+       c.blockMap[block.Hash()] = block
+}
+
+func (c *Chain) ValidateTx(*types.Tx) (bool, error) {
+       return false, nil
+}