OSDN Git Service

add test case for send merkle block (#1289)
[bytom/bytom.git] / netsync / block_keeper_test.go
index fa3b6ff..4f4ae15 100644 (file)
@@ -2,6 +2,8 @@ package netsync
 
 import (
        "container/list"
+       "encoding/hex"
+       "encoding/json"
        "testing"
        "time"
 
@@ -181,8 +183,11 @@ func TestFastBlockSync(t *testing.T) {
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
-               if err := netWork.HandsShake(a, b); err != nil {
+               if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
                        t.Errorf("fail on peer hands shake %v", err)
+               } else {
+                       go B2A.postMan()
+                       go A2B.postMan()
                }
 
                a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
@@ -439,8 +444,11 @@ func TestRegularBlockSync(t *testing.T) {
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
-               if err := netWork.HandsShake(a, b); err != nil {
+               if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
                        t.Errorf("fail on peer hands shake %v", err)
+               } else {
+                       go B2A.postMan()
+                       go A2B.postMan()
                }
 
                a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
@@ -470,8 +478,11 @@ func TestRequireBlock(t *testing.T) {
        netWork := NewNetWork()
        netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
        netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
-       if err := netWork.HandsShake(a, b); err != nil {
+       if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
                t.Errorf("fail on peer hands shake %v", err)
+       } else {
+               go B2A.postMan()
+               go A2B.postMan()
        }
 
        a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
@@ -510,3 +521,120 @@ func TestRequireBlock(t *testing.T) {
                }
        }
 }
+
+func TestSendMerkleBlock(t *testing.T) {
+       cases := []struct {
+               txCount        int
+               relatedTxIndex []int
+       }{
+               {
+                       txCount:        10,
+                       relatedTxIndex: []int{0, 2, 5},
+               },
+               {
+                       txCount:        0,
+                       relatedTxIndex: []int{},
+               },
+               {
+                       txCount:        10,
+                       relatedTxIndex: []int{},
+               },
+               {
+                       txCount:        5,
+                       relatedTxIndex: []int{0, 1, 2, 3, 4},
+               },
+               {
+                       txCount:        20,
+                       relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
+               },
+       }
+
+       for _, c := range cases {
+               blocks := mockBlocks(nil, 2)
+               targetBlock := blocks[1]
+               txs, bcTxs := mockTxs(c.txCount)
+               var err error
+
+               targetBlock.Transactions = txs
+               if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
+                       t.Fatal(err)
+               }
+
+               spvNode := mockSync(blocks)
+               blockHash := targetBlock.Hash()
+               var statusResult *bc.TransactionStatus
+               if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
+                       t.Fatal(err)
+               }
+               
+               if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
+                       t.Fatal(err)
+               }
+               
+               fullNode := mockSync(blocks)
+               netWork := NewNetWork()
+               netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
+               netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
+
+               var F2S *P2PPeer
+               if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
+                       t.Errorf("fail on peer hands shake %v", err)
+               }
+
+               completed := make(chan error)
+               go func() {
+                       msgBytes := <-F2S.msgCh
+                       _, msg, _ := DecodeMessage(msgBytes)
+                       switch m := msg.(type) {
+                       case *MerkleBlockMessage:
+                               var relatedTxIDs []*bc.Hash
+                               for _, rawTx := range m.RawTxDatas {
+                                       tx := &types.Tx{}
+                                       if err := tx.UnmarshalText(rawTx); err != nil {
+                                               completed <- err
+                                       }
+
+                                       relatedTxIDs = append(relatedTxIDs, &tx.ID)
+                               }
+                               var txHashes []*bc.Hash
+                               for _, hashByte := range m.TxHashes {
+                                       hash := bc.NewHash(hashByte)
+                                       txHashes = append(txHashes, &hash)
+                               }
+                               if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
+                                       completed <- errors.New("validate tx fail")
+                               }
+
+                               var statusHashes []*bc.Hash
+                               for _, statusByte := range m.StatusHashes {
+                                       hash := bc.NewHash(statusByte)
+                                       statusHashes = append(statusHashes, &hash)
+                               }
+                               var relatedStatuses []*bc.TxVerifyResult
+                               for _, statusByte := range m.RawTxStatuses {
+                                       status := &bc.TxVerifyResult{}
+                                       err := json.Unmarshal(statusByte, status)
+                                       if err != nil {
+                                               completed <- err
+                                       }
+                                       relatedStatuses = append(relatedStatuses, status)
+                               }
+                               if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
+                                       completed <- errors.New("validate status fail")
+                               }
+
+                               completed <- nil
+                       }
+               }()
+
+               spvPeer := fullNode.peers.getPeer("spv_node")
+               for i := 0; i < len(c.relatedTxIndex); i++ {
+                       spvPeer.filterAdds.Add(hex.EncodeToString(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram))
+               }
+               msg := &GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
+               fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
+               if err := <-completed; err != nil {
+                       t.Fatal(err)
+               }
+       }
+}