OSDN Git Service

Merge pull request #375 from Bytom/dev
[bytom/bytom-spv.git] / blockchain / wallet / wallet.go
1 package wallet
2
3 import (
4         "encoding/json"
5         "fmt"
6
7         log "github.com/sirupsen/logrus"
8         "github.com/tendermint/go-wire/data/base58"
9         "github.com/tendermint/tmlibs/db"
10
11         "github.com/bytom/blockchain/account"
12         "github.com/bytom/blockchain/asset"
13         "github.com/bytom/blockchain/pseudohsm"
14         "github.com/bytom/crypto/ed25519/chainkd"
15         "github.com/bytom/crypto/sha3pool"
16         "github.com/bytom/protocol"
17         "github.com/bytom/protocol/bc"
18         "github.com/bytom/protocol/bc/legacy"
19 )
20
21 //SINGLE single sign
22 const SINGLE = 1
23
24 //RecoveryIndex walletdb recovery cp number
25 const RecoveryIndex = 5000
26
27 var walletKey = []byte("walletInfo")
28 var privKeyKey = []byte("keysInfo")
29
30 //StatusInfo is base valid block info to handle orphan block rollback
31 type StatusInfo struct {
32         WorkHeight uint64
33         WorkHash   bc.Hash
34         BestHeight uint64
35         BestHash   bc.Hash
36 }
37
38 //KeyInfo is key import status
39 type KeyInfo struct {
40         Alias    string       `json:"alias"`
41         XPub     chainkd.XPub `json:"xpub"`
42         Percent  uint8        `json:"percent"`
43         Complete bool         `json:"complete"`
44 }
45
46 //Wallet is related to storing account unspent outputs
47 type Wallet struct {
48         DB             db.DB
49         status         StatusInfo
50         AccountMgr     *account.Manager
51         AssetReg       *asset.Registry
52         chain          *protocol.Chain
53         rescanProgress chan struct{}
54         ImportPrivKey  bool
55         keysInfo       []KeyInfo
56 }
57
58 //NewWallet return a new wallet instance
59 func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, chain *protocol.Chain, xpubs []pseudohsm.XPub) (*Wallet, error) {
60         w := &Wallet{
61                 DB:             walletDB,
62                 AccountMgr:     account,
63                 AssetReg:       asset,
64                 chain:          chain,
65                 rescanProgress: make(chan struct{}, 1),
66                 keysInfo:       make([]KeyInfo, 0),
67         }
68
69         if err := w.loadWalletInfo(xpubs); err != nil {
70                 return nil, err
71         }
72
73         if err := w.loadKeysInfo(); err != nil {
74                 return nil, err
75         }
76
77         w.ImportPrivKey = w.getImportKeyFlag()
78
79         go w.walletUpdater()
80
81         return w, nil
82 }
83
84 //GetWalletInfo return stored wallet info and nil,if error,
85 //return initial wallet info and err
86 func (w *Wallet) loadWalletInfo(xpubs []pseudohsm.XPub) error {
87         if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
88                 return json.Unmarshal(rawWallet, &w.status)
89         }
90
91         for i, v := range xpubs {
92                 if err := w.ImportAccountXpubKey(i, v, RecoveryIndex); err != nil {
93                         return err
94                 }
95         }
96
97         block, err := w.chain.GetBlockByHeight(0)
98         if err != nil {
99                 return err
100         }
101         return w.attachBlock(block)
102 }
103
104 func (w *Wallet) commitWalletInfo(batch db.Batch) error {
105         rawWallet, err := json.Marshal(w.status)
106         if err != nil {
107                 log.WithField("err", err).Error("save wallet info")
108                 return err
109         }
110
111         batch.Set(walletKey, rawWallet)
112         batch.Write()
113         return nil
114 }
115
116 //GetWalletInfo return stored wallet info and nil,if error,
117 //return initial wallet info and err
118 func (w *Wallet) loadKeysInfo() error {
119         if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
120                 json.Unmarshal(rawKeyInfo, &w.keysInfo)
121                 return nil
122         }
123         return nil
124 }
125
126 func (w *Wallet) commitkeysInfo() error {
127         rawKeysInfo, err := json.Marshal(w.keysInfo)
128         if err != nil {
129                 log.WithField("err", err).Error("save wallet info")
130                 return err
131         }
132         w.DB.Set(privKeyKey, rawKeysInfo)
133         return nil
134 }
135
136 func (w *Wallet) getImportKeyFlag() bool {
137         for _, v := range w.keysInfo {
138                 if v.Complete == false {
139                         return true
140                 }
141         }
142         return false
143 }
144
145 func (w *Wallet) attachBlock(block *legacy.Block) error {
146         if block.PreviousBlockHash != w.status.WorkHash {
147                 log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
148                 return nil
149         }
150
151         storeBatch := w.DB.NewBatch()
152         w.indexTransactions(storeBatch, block)
153         w.buildAccountUTXOs(storeBatch, block)
154
155         w.status.WorkHeight = block.Height
156         w.status.WorkHash = block.Hash()
157         if w.status.WorkHeight >= w.status.BestHeight {
158                 w.status.BestHeight = w.status.WorkHeight
159                 w.status.BestHash = w.status.WorkHash
160         }
161         return w.commitWalletInfo(storeBatch)
162 }
163
164 func (w *Wallet) detachBlock(block *legacy.Block) error {
165         storeBatch := w.DB.NewBatch()
166         w.reverseAccountUTXOs(storeBatch, block)
167         w.deleteTransactions(storeBatch, w.status.BestHeight)
168
169         w.status.BestHeight = block.Height - 1
170         w.status.BestHash = block.PreviousBlockHash
171
172         if w.status.WorkHeight > w.status.BestHeight {
173                 w.status.WorkHeight = w.status.BestHeight
174                 w.status.WorkHash = w.status.BestHash
175         }
176
177         return w.commitWalletInfo(storeBatch)
178 }
179
180 //WalletUpdate process every valid block and reverse every invalid block which need to rollback
181 func (w *Wallet) walletUpdater() {
182         for {
183                 getRescanNotification(w)
184                 checkRescanStatus(w)
185                 for !w.chain.InMainChain(w.status.BestHeight, w.status.BestHash) {
186                         block, err := w.chain.GetBlockByHash(&w.status.BestHash)
187                         if err != nil {
188                                 log.WithField("err", err).Error("walletUpdater GetBlockByHash")
189                                 return
190                         }
191
192                         if err := w.detachBlock(block); err != nil {
193                                 log.WithField("err", err).Error("walletUpdater detachBlock")
194                                 return
195                         }
196                 }
197
198                 block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
199                 if block == nil {
200                         <-w.chain.BlockWaiter(w.status.WorkHeight + 1)
201                         continue
202                 }
203
204                 if err := w.attachBlock(block); err != nil {
205                         log.WithField("err", err).Error("walletUpdater stop")
206                         return
207                 }
208         }
209 }
210
211 func getRescanNotification(w *Wallet) {
212         select {
213         case <-w.rescanProgress:
214                 w.status.WorkHeight = 0
215                 block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight)
216                 w.status.WorkHash = block.Hash()
217         default:
218                 return
219         }
220 }
221
222 // ExportAccountPrivKey exports the account private key as a WIF for encoding as a string
223 // in the Wallet Import Formt.
224 func (w *Wallet) ExportAccountPrivKey(hsm *pseudohsm.HSM, xpub chainkd.XPub, auth string) (*string, error) {
225         xprv, err := hsm.LoadChainKDKey(xpub, auth)
226         if err != nil {
227                 return nil, err
228         }
229         var hashed [32]byte
230         sha3pool.Sum256(hashed[:], xprv[:])
231
232         tmp := append(xprv[:], hashed[:4]...)
233         res := base58.Encode(tmp)
234         return &res, nil
235 }
236
237 // ImportAccountPrivKey imports the account key in the Wallet Import Formt.
238 func (w *Wallet) ImportAccountPrivKey(hsm *pseudohsm.HSM, xprv chainkd.XPrv, keyAlias, auth string, index uint64, accountAlias string) (*pseudohsm.XPub, error) {
239         if hsm.HasAlias(keyAlias) {
240                 return nil, pseudohsm.ErrDuplicateKeyAlias
241         }
242         if hsm.HasKey(xprv) {
243                 return nil, pseudohsm.ErrDuplicateKey
244         }
245
246         if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
247                 return nil, account.ErrDuplicateAlias
248         }
249
250         xpub, _, err := hsm.ImportXPrvKey(auth, keyAlias, xprv)
251         if err != nil {
252                 return nil, err
253         }
254
255         newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias, nil)
256         if err != nil {
257                 return nil, err
258         }
259         if err := w.recoveryAccountWalletDB(newAccount, xpub, index, keyAlias); err != nil {
260                 return nil, err
261         }
262         return xpub, nil
263 }
264
265 // ImportAccountXpubKey imports the account key in the Wallet Import Formt.
266 func (w *Wallet) ImportAccountXpubKey(xpubIndex int, xpub pseudohsm.XPub, cpIndex uint64) error {
267         accountAlias := fmt.Sprintf("recovery_%d", xpubIndex)
268
269         if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
270                 return account.ErrDuplicateAlias
271         }
272
273         newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias, nil)
274         if err != nil {
275                 return err
276         }
277
278         return w.recoveryAccountWalletDB(newAccount, &xpub, cpIndex, xpub.Alias)
279 }
280
281 func (w *Wallet) recoveryAccountWalletDB(account *account.Account, XPub *pseudohsm.XPub, index uint64, keyAlias string) error {
282         if err := w.createProgram(account, XPub, index); err != nil {
283                 return err
284         }
285         w.ImportPrivKey = true
286         tmp := KeyInfo{
287                 Alias:    keyAlias,
288                 XPub:     XPub.XPub,
289                 Complete: false,
290         }
291         w.keysInfo = append(w.keysInfo, tmp)
292         w.commitkeysInfo()
293         w.rescanBlocks()
294
295         return nil
296 }
297
298 func (w *Wallet) createProgram(account *account.Account, XPub *pseudohsm.XPub, index uint64) error {
299         for i := uint64(0); i < index; i++ {
300                 if _, err := w.AccountMgr.CreateAddress(nil, account.ID, false); err != nil {
301                         return err
302                 }
303         }
304         return nil
305 }
306
307 func (w *Wallet) rescanBlocks() {
308         select {
309         case <-w.rescanProgress:
310                 w.rescanProgress <- struct{}{}
311         default:
312                 return
313         }
314 }
315
316 //GetRescanStatus return key import rescan status
317 func (w *Wallet) GetRescanStatus() ([]KeyInfo, error) {
318         keysInfo := make([]KeyInfo, len(w.keysInfo))
319
320         if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
321                 if err := json.Unmarshal(rawKeyInfo, &keysInfo); err != nil {
322                         return nil, err
323                 }
324         }
325
326         var status StatusInfo
327         if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
328                 if err := json.Unmarshal(rawWallet, &status); err != nil {
329                         return nil, err
330                 }
331         }
332
333         for i, v := range keysInfo {
334                 if v.Complete == true || status.BestHeight == 0 {
335                         keysInfo[i].Percent = 100
336                         continue
337                 }
338
339                 keysInfo[i].Percent = uint8(status.WorkHeight * 100 / status.BestHeight)
340                 if v.Percent == 100 {
341                         keysInfo[i].Complete = true
342                 }
343         }
344         return keysInfo, nil
345 }
346
347 func checkRescanStatus(w *Wallet) {
348         if !w.ImportPrivKey {
349                 return
350         }
351         if w.status.WorkHeight >= w.status.BestHeight {
352                 w.ImportPrivKey = false
353                 for i := range w.keysInfo {
354                         w.keysInfo[i].Complete = true
355                 }
356         }
357
358         w.commitkeysInfo()
359 }