import (
"container/list"
+ "encoding/hex"
+ "encoding/json"
"testing"
"time"
netWork := NewNetWork()
netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
- if err := netWork.HandsShake(a, b); err != nil {
+ if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
t.Errorf("fail on peer hands shake %v", err)
+ } else {
+ go B2A.postMan()
+ go A2B.postMan()
}
a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
netWork := NewNetWork()
netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
- if err := netWork.HandsShake(a, b); err != nil {
+ if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
t.Errorf("fail on peer hands shake %v", err)
+ } else {
+ go B2A.postMan()
+ go A2B.postMan()
}
a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
netWork := NewNetWork()
netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
- if err := netWork.HandsShake(a, b); err != nil {
+ if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
t.Errorf("fail on peer hands shake %v", err)
+ } else {
+ go B2A.postMan()
+ go A2B.postMan()
}
a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
}
}
}
+
+func TestSendMerkleBlock(t *testing.T) {
+ cases := []struct {
+ txCount int
+ relatedTxIndex []int
+ }{
+ {
+ txCount: 10,
+ relatedTxIndex: []int{0, 2, 5},
+ },
+ {
+ txCount: 0,
+ relatedTxIndex: []int{},
+ },
+ {
+ txCount: 10,
+ relatedTxIndex: []int{},
+ },
+ {
+ txCount: 5,
+ relatedTxIndex: []int{0, 1, 2, 3, 4},
+ },
+ {
+ txCount: 20,
+ relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
+ },
+ }
+
+ for _, c := range cases {
+ blocks := mockBlocks(nil, 2)
+ targetBlock := blocks[1]
+ txs, bcTxs := mockTxs(c.txCount)
+ var err error
+
+ targetBlock.Transactions = txs
+ if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
+ t.Fatal(err)
+ }
+
+ spvNode := mockSync(blocks)
+ blockHash := targetBlock.Hash()
+ var statusResult *bc.TransactionStatus
+ if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
+ t.Fatal(err)
+ }
+
+ if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
+ t.Fatal(err)
+ }
+
+ fullNode := mockSync(blocks)
+ netWork := NewNetWork()
+ netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
+ netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
+
+ var F2S *P2PPeer
+ if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
+ t.Errorf("fail on peer hands shake %v", err)
+ }
+
+ completed := make(chan error)
+ go func() {
+ msgBytes := <-F2S.msgCh
+ _, msg, _ := DecodeMessage(msgBytes)
+ switch m := msg.(type) {
+ case *MerkleBlockMessage:
+ var relatedTxIDs []*bc.Hash
+ for _, rawTx := range m.RawTxDatas {
+ tx := &types.Tx{}
+ if err := tx.UnmarshalText(rawTx); err != nil {
+ completed <- err
+ }
+
+ relatedTxIDs = append(relatedTxIDs, &tx.ID)
+ }
+ var txHashes []*bc.Hash
+ for _, hashByte := range m.TxHashes {
+ hash := bc.NewHash(hashByte)
+ txHashes = append(txHashes, &hash)
+ }
+ if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
+ completed <- errors.New("validate tx fail")
+ }
+
+ var statusHashes []*bc.Hash
+ for _, statusByte := range m.StatusHashes {
+ hash := bc.NewHash(statusByte)
+ statusHashes = append(statusHashes, &hash)
+ }
+ var relatedStatuses []*bc.TxVerifyResult
+ for _, statusByte := range m.RawTxStatuses {
+ status := &bc.TxVerifyResult{}
+ err := json.Unmarshal(statusByte, status)
+ if err != nil {
+ completed <- err
+ }
+ relatedStatuses = append(relatedStatuses, status)
+ }
+ if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
+ completed <- errors.New("validate status fail")
+ }
+
+ completed <- nil
+ }
+ }()
+
+ spvPeer := fullNode.peers.getPeer("spv_node")
+ for i := 0; i < len(c.relatedTxIndex); i++ {
+ spvPeer.filterAdds.Add(hex.EncodeToString(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram))
+ }
+ msg := &GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
+ fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
+ if err := <-completed; err != nil {
+ t.Fatal(err)
+ }
+ }
+}