+
+func TestSendMerkleBlock(t *testing.T) {
+ cases := []struct {
+ txCount int
+ relatedTxIndex []int
+ }{
+ {
+ txCount: 10,
+ relatedTxIndex: []int{0, 2, 5},
+ },
+ {
+ txCount: 0,
+ relatedTxIndex: []int{},
+ },
+ {
+ txCount: 10,
+ relatedTxIndex: []int{},
+ },
+ {
+ txCount: 5,
+ relatedTxIndex: []int{0, 1, 2, 3, 4},
+ },
+ {
+ txCount: 20,
+ relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
+ },
+ }
+
+ for _, c := range cases {
+ blocks := mockBlocks(nil, 2)
+ targetBlock := blocks[1]
+ txs, bcTxs := mockTxs(c.txCount)
+ var err error
+
+ targetBlock.Transactions = txs
+ if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
+ t.Fatal(err)
+ }
+
+ spvNode := mockSync(blocks)
+ blockHash := targetBlock.Hash()
+ var statusResult *bc.TransactionStatus
+ if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
+ t.Fatal(err)
+ }
+
+ if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
+ t.Fatal(err)
+ }
+
+ fullNode := mockSync(blocks)
+ netWork := NewNetWork()
+ netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
+ netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
+
+ var F2S *P2PPeer
+ if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
+ t.Errorf("fail on peer hands shake %v", err)
+ }
+
+ completed := make(chan error)
+ go func() {
+ msgBytes := <-F2S.msgCh
+ _, msg, _ := DecodeMessage(msgBytes)
+ switch m := msg.(type) {
+ case *MerkleBlockMessage:
+ var relatedTxIDs []*bc.Hash
+ for _, rawTx := range m.RawTxDatas {
+ tx := &types.Tx{}
+ if err := tx.UnmarshalText(rawTx); err != nil {
+ completed <- err
+ }
+
+ relatedTxIDs = append(relatedTxIDs, &tx.ID)
+ }
+ var txHashes []*bc.Hash
+ for _, hashByte := range m.TxHashes {
+ hash := bc.NewHash(hashByte)
+ txHashes = append(txHashes, &hash)
+ }
+ if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
+ completed <- errors.New("validate tx fail")
+ }
+
+ var statusHashes []*bc.Hash
+ for _, statusByte := range m.StatusHashes {
+ hash := bc.NewHash(statusByte)
+ statusHashes = append(statusHashes, &hash)
+ }
+ var relatedStatuses []*bc.TxVerifyResult
+ for _, statusByte := range m.RawTxStatuses {
+ status := &bc.TxVerifyResult{}
+ err := json.Unmarshal(statusByte, status)
+ if err != nil {
+ completed <- err
+ }
+ relatedStatuses = append(relatedStatuses, status)
+ }
+ if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
+ completed <- errors.New("validate status fail")
+ }
+
+ completed <- nil
+ }
+ }()
+
+ spvPeer := fullNode.peers.getPeer("spv_node")
+ for i := 0; i < len(c.relatedTxIndex); i++ {
+ spvPeer.filterAdds.Add(hex.EncodeToString(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram))
+ }
+ msg := &GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
+ fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
+ if err := <-completed; err != nil {
+ t.Fatal(err)
+ }
+ }
+}