"github.com/bytom/protocol/bc/types"
)
-func genesisTx() *types.Tx {
+func GenesisTx() *types.Tx {
contract, err := hex.DecodeString("00148c9d063ff74ee6d9ffa88d83aeb038068366c4c4")
if err != nil {
log.Panicf("fail on decode genesis tx output control program")
}
func mainNetGenesisBlock() *types.Block {
- tx := genesisTx()
+ tx := GenesisTx()
txStatus := bc.NewTransactionStatus()
if err := txStatus.SetStatus(0, false); err != nil {
log.Panicf(err.Error())
}
func testNetGenesisBlock() *types.Block {
- tx := genesisTx()
+ tx := GenesisTx()
txStatus := bc.NewTransactionStatus()
if err := txStatus.SetStatus(0, false); err != nil {
log.Panicf(err.Error())
}
func soloNetGenesisBlock() *types.Block {
- tx := genesisTx()
+ tx := GenesisTx()
txStatus := bc.NewTransactionStatus()
if err := txStatus.SetStatus(0, false); err != nil {
log.Panicf(err.Error())
TxPrefix = "TXS:"
//TxIndexPrefix is wallet database tx index prefix
TxIndexPrefix = "TID:"
+ //TxIndexPrefix is wallet database global tx index prefix
+ GlobalTxIndexPrefix = "GTID:"
)
func formatKey(blockHeight uint64, position uint32) string {
return []byte(TxIndexPrefix + txID)
}
+func calcGlobalTxIndexKey(txID string) []byte {
+ return []byte(GlobalTxIndexPrefix + txID)
+}
+
+func calcGlobalTxIndex(blockHash *bc.Hash, position int) []byte {
+ return []byte(fmt.Sprintf("%064x%08x", blockHash.String(), position))
+}
+
// deleteTransaction delete transactions when orphan block rollback
func (w *Wallet) deleteTransactions(batch db.Batch, height uint64) {
tmpTx := query.AnnotatedTx{}
// delete unconfirmed transaction
batch.Delete(calcUnconfirmedTxKey(tx.ID.String()))
}
+
+ for position, globalTx := range b.Transactions {
+ blockHash := b.BlockHeader.Hash()
+ batch.Set(calcGlobalTxIndexKey(globalTx.ID.String()), calcGlobalTxIndex(&blockHash, position))
+ }
+
return nil
}
"github.com/bytom/account"
"github.com/bytom/asset"
"github.com/bytom/blockchain/pseudohsm"
+ "github.com/bytom/errors"
"github.com/bytom/event"
"github.com/bytom/protocol"
"github.com/bytom/protocol/bc"
logModule = "wallet"
)
-var walletKey = []byte("walletInfo")
+var (
+ currentVersion = uint(1)
+ walletKey = []byte("walletInfo")
+
+ errBestBlockNotFoundInCore = errors.New("best block not found in core")
+ errWalletVersionMismatch = errors.New("wallet version mismatch")
+)
//StatusInfo is base valid block info to handle orphan block rollback
type StatusInfo struct {
+ Version uint
WorkHeight uint64
WorkHash bc.Hash
BestHeight uint64
}
}
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
+func (w *Wallet) checkWalletInfo() error {
+ if w.status.Version != currentVersion {
+ return errWalletVersionMismatch
+ } else if !w.chain.BlockExist(&w.status.BestHash) {
+ return errBestBlockNotFoundInCore
+ }
+
+ return nil
+}
+
+//loadWalletInfo return stored wallet info and nil,
+//if error, return initial wallet info and err
func (w *Wallet) loadWalletInfo() error {
if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
if err := json.Unmarshal(rawWallet, &w.status); err != nil {
return err
}
- //handle the case than use replace the coreDB during status in fork chain
- if w.chain.BlockExist(&w.status.BestHash) {
+ err := w.checkWalletInfo()
+ if err == nil {
return nil
}
- log.WithFields(log.Fields{"module": logModule}).Warn("reset the wallet status due to core doesn't have wallet best block")
+ log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
w.deleteAccountTxs()
w.deleteUtxos()
- w.status = StatusInfo{}
}
+ w.status.Version = currentVersion
block, err := w.chain.GetBlockByHeight(0)
if err != nil {
return err
package wallet
import (
+ "encoding/json"
"io/ioutil"
"os"
+ "reflect"
"testing"
"time"
"github.com/bytom/blockchain/pseudohsm"
"github.com/bytom/blockchain/signers"
"github.com/bytom/blockchain/txbuilder"
+ "github.com/bytom/config"
"github.com/bytom/consensus"
"github.com/bytom/crypto/ed25519/chainkd"
"github.com/bytom/database/leveldb"
"github.com/bytom/protocol/bc/types"
)
+func TestWalletVersion(t *testing.T) {
+ // prepare wallet
+ dirPath, err := ioutil.TempDir(".", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dirPath)
+
+ testDB := dbm.NewDB("testdb", "leveldb", "temp")
+ defer os.RemoveAll("temp")
+
+ dispatcher := event.NewDispatcher()
+ w := mockWallet(testDB, nil, nil, nil, dispatcher)
+
+ // legacy status test case
+ type legacyStatusInfo struct {
+ WorkHeight uint64
+ WorkHash bc.Hash
+ BestHeight uint64
+ BestHash bc.Hash
+ }
+ rawWallet, err := json.Marshal(legacyStatusInfo{})
+ if err != nil {
+ t.Fatal("Marshal legacyStatusInfo")
+ }
+
+ w.DB.Set(walletKey, rawWallet)
+ rawWallet = w.DB.Get(walletKey)
+ if rawWallet == nil {
+ t.Fatal("fail to load wallet StatusInfo")
+ }
+
+ if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
+ t.Fatal("fail to detect legacy wallet version")
+ }
+
+ // lower wallet version test case
+ lowerVersion := StatusInfo{Version: currentVersion - 1}
+ rawWallet, err = json.Marshal(lowerVersion)
+ if err != nil {
+ t.Fatal("save wallet info")
+ }
+
+ w.DB.Set(walletKey, rawWallet)
+ rawWallet = w.DB.Get(walletKey)
+ if rawWallet == nil {
+ t.Fatal("fail to load wallet StatusInfo")
+ }
+
+ if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
+ t.Fatal("fail to detect expired wallet version")
+ }
+}
+
func TestWalletUpdate(t *testing.T) {
dirPath, err := ioutil.TempDir(".", "")
if err != nil {
block := mockSingleBlock(tx)
txStatus := bc.NewTransactionStatus()
txStatus.SetStatus(0, false)
+ txStatus.SetStatus(1, false)
store.SaveBlock(block, txStatus)
w := mockWallet(testDB, accountManager, reg, chain, dispatcher)
if len(wants) != 1 {
t.Fatal(err)
}
+
+ if wants[0].ID != tx.ID {
+ t.Fatal("account txID mismatch")
+ }
+
+ for position, tx := range block.Transactions {
+ get := w.DB.Get(calcGlobalTxIndexKey(tx.ID.String()))
+ bh := block.BlockHeader.Hash()
+ expect := calcGlobalTxIndex(&bh, position)
+ if !reflect.DeepEqual(get, expect) {
+ t.Fatalf("position#%d: compare retrieved globalTxIdx err", position)
+ }
+ }
}
func TestMemPoolTxQueryLoop(t *testing.T) {
Height: 1,
Bits: 2305843009230471167,
},
- Transactions: []*types.Tx{tx},
+ Transactions: []*types.Tx{config.GenesisTx(), tx},
}
}