OSDN Git Service

fix rescan wallet (#1108)
[bytom/bytom-spv.git] / wallet / wallet.go
1 package wallet
2
3 import (
4         "encoding/json"
5         "sync"
6
7         log "github.com/sirupsen/logrus"
8         "github.com/tendermint/tmlibs/db"
9
10         "github.com/bytom/account"
11         "github.com/bytom/asset"
12         "github.com/bytom/blockchain/pseudohsm"
13         "github.com/bytom/protocol"
14         "github.com/bytom/protocol/bc"
15         "github.com/bytom/protocol/bc/types"
16 )
17
18 const (
19         //SINGLE single sign
20         SINGLE = 1
21
22         maxTxChanSize = 10000 // txChanSize is the size of channel listening to Txpool newTxCh
23 )
24
25 var walletKey = []byte("walletInfo")
26
27 //StatusInfo is base valid block info to handle orphan block rollback
28 type StatusInfo struct {
29         WorkHeight uint64
30         WorkHash   bc.Hash
31         BestHeight uint64
32         BestHash   bc.Hash
33 }
34
35 //Wallet is related to storing account unspent outputs
36 type Wallet struct {
37         DB         db.DB
38         rw         sync.RWMutex
39         status     StatusInfo
40         AccountMgr *account.Manager
41         AssetReg   *asset.Registry
42         Hsm        *pseudohsm.HSM
43         chain      *protocol.Chain
44         rescanCh   chan struct{}
45         newTxCh    chan *types.Tx
46 }
47
48 //NewWallet return a new wallet instance
49 func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain) (*Wallet, error) {
50         w := &Wallet{
51                 DB:         walletDB,
52                 AccountMgr: account,
53                 AssetReg:   asset,
54                 chain:      chain,
55                 Hsm:        hsm,
56                 rescanCh:   make(chan struct{}, 1),
57                 newTxCh:    make(chan *types.Tx, maxTxChanSize),
58         }
59
60         if err := w.loadWalletInfo(); err != nil {
61                 return nil, err
62         }
63
64         go w.walletUpdater()
65         go w.UnconfirmedTxCollector()
66
67         return w, nil
68 }
69
70 //GetWalletInfo return stored wallet info and nil,if error,
71 //return initial wallet info and err
72 func (w *Wallet) loadWalletInfo() error {
73         if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
74                 return json.Unmarshal(rawWallet, &w.status)
75         }
76
77         block, err := w.chain.GetBlockByHeight(0)
78         if err != nil {
79                 return err
80         }
81         return w.AttachBlock(block)
82 }
83
84 func (w *Wallet) commitWalletInfo(batch db.Batch) error {
85         rawWallet, err := json.Marshal(w.status)
86         if err != nil {
87                 log.WithField("err", err).Error("save wallet info")
88                 return err
89         }
90
91         batch.Set(walletKey, rawWallet)
92         batch.Write()
93         return nil
94 }
95
96 // AttachBlock attach a new block
97 func (w *Wallet) AttachBlock(block *types.Block) error {
98         w.rw.Lock()
99         defer w.rw.Unlock()
100
101         if block.PreviousBlockHash != w.status.WorkHash {
102                 log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
103                 return nil
104         }
105
106         blockHash := block.Hash()
107         txStatus, err := w.chain.GetTransactionStatus(&blockHash)
108         if err != nil {
109                 return err
110         }
111
112         storeBatch := w.DB.NewBatch()
113         w.indexTransactions(storeBatch, block, txStatus)
114         w.buildAccountUTXOs(storeBatch, block, txStatus)
115
116         w.status.WorkHeight = block.Height
117         w.status.WorkHash = block.Hash()
118         if w.status.WorkHeight >= w.status.BestHeight {
119                 w.status.BestHeight = w.status.WorkHeight
120                 w.status.BestHash = w.status.WorkHash
121         }
122         return w.commitWalletInfo(storeBatch)
123 }
124
125 // DetachBlock detach a block and rollback state
126 func (w *Wallet) DetachBlock(block *types.Block) error {
127         w.rw.Lock()
128         defer w.rw.Unlock()
129
130         blockHash := block.Hash()
131         txStatus, err := w.chain.GetTransactionStatus(&blockHash)
132         if err != nil {
133                 return err
134         }
135
136         storeBatch := w.DB.NewBatch()
137         w.reverseAccountUTXOs(storeBatch, block, txStatus)
138         w.deleteTransactions(storeBatch, w.status.BestHeight)
139
140         w.status.BestHeight = block.Height - 1
141         w.status.BestHash = block.PreviousBlockHash
142
143         if w.status.WorkHeight > w.status.BestHeight {
144                 w.status.WorkHeight = w.status.BestHeight
145                 w.status.WorkHash = w.status.BestHash
146         }
147
148         return w.commitWalletInfo(storeBatch)
149 }
150
151 //WalletUpdate process every valid block and reverse every invalid block which need to rollback
152 func (w *Wallet) walletUpdater() {
153         for {
154                 w.getRescanNotification()
155                 for !w.chain.InMainChain(w.status.BestHash) {
156                         block, err := w.chain.GetBlockByHash(&w.status.BestHash)
157                         if err != nil {
158                                 log.WithField("err", err).Error("walletUpdater GetBlockByHash")
159                                 return
160                         }
161
162                         if err := w.DetachBlock(block); err != nil {
163                                 log.WithField("err", err).Error("walletUpdater detachBlock stop")
164                                 return
165                         }
166                 }
167
168                 block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
169                 if block == nil {
170                         w.walletBlockWaiter()
171                         continue
172                 }
173
174                 if err := w.AttachBlock(block); err != nil {
175                         log.WithField("err", err).Error("walletUpdater AttachBlock stop")
176                         return
177                 }
178         }
179 }
180
181 //RescanBlocks provide a trigger to rescan blocks
182 func (w *Wallet) RescanBlocks() {
183         select {
184         case w.rescanCh <- struct{}{}:
185         default:
186                 return
187         }
188 }
189
190 func (w *Wallet) getRescanNotification() {
191         select {
192         case <-w.rescanCh:
193                 w.setRescanStatus()
194         default:
195                 return
196         }
197 }
198
199 func (w *Wallet) setRescanStatus() {
200         block, _ := w.chain.GetBlockByHeight(0)
201         w.status.WorkHash = bc.Hash{}
202         w.AttachBlock(block)
203 }
204
205 func (w *Wallet) walletBlockWaiter() {
206         select {
207         case <-w.chain.BlockWaiter(w.status.WorkHeight + 1):
208         case <-w.rescanCh:
209                 w.setRescanStatus()
210         }
211 }
212
213 // GetNewTxCh return a unconfirmed transaction feed channel
214 func (w *Wallet) GetNewTxCh() chan *types.Tx {
215         return w.newTxCh
216 }
217
218 // UnconfirmedTxCollector collector unconfirmed transaction
219 func (w *Wallet) UnconfirmedTxCollector() {
220         for {
221                 w.SaveUnconfirmedTx(<-w.newTxCh)
222         }
223 }
224
225 // GetWalletStatusInfo return current wallet StatusInfo
226 func (w *Wallet) GetWalletStatusInfo() StatusInfo {
227         w.rw.RLock()
228         defer w.rw.RUnlock()
229
230         return w.status
231 }
232
233 func (w *Wallet) createProgram(account *account.Account, XPub *pseudohsm.XPub, index uint64) error {
234         for i := uint64(0); i < index; i++ {
235                 if _, err := w.AccountMgr.CreateAddress(nil, account.ID, false); err != nil {
236                         return err
237                 }
238         }
239         return nil
240 }