OSDN Git Service

c8a19974d641df1317c832829ba7929a7a71df9c
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / sql / fakedb_test.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package sql
6
7 import (
8         "errors"
9         "fmt"
10         "io"
11         "log"
12         "strconv"
13         "strings"
14         "sync"
15
16         "exp/sql/driver"
17 )
18
19 var _ = log.Printf
20
21 // fakeDriver is a fake database that implements Go's driver.Driver
22 // interface, just for testing.
23 //
24 // It speaks a query language that's semantically similar to but
25 // syntantically different and simpler than SQL.  The syntax is as
26 // follows:
27 //
28 //   WIPE
29 //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
30 //     where types are: "string", [u]int{8,16,32,64}, "bool"
31 //   INSERT|<tablename>|col=val,col2=val2,col3=?
32 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
33 //
34 // When opening a a fakeDriver's database, it starts empty with no
35 // tables.  All tables and data are stored in memory only.
36 type fakeDriver struct {
37         mu        sync.Mutex
38         openCount int
39         dbs       map[string]*fakeDB
40 }
41
42 type fakeDB struct {
43         name string
44
45         mu     sync.Mutex
46         free   []*fakeConn
47         tables map[string]*table
48 }
49
50 type table struct {
51         mu      sync.Mutex
52         colname []string
53         coltype []string
54         rows    []*row
55 }
56
57 func (t *table) columnIndex(name string) int {
58         for n, nname := range t.colname {
59                 if name == nname {
60                         return n
61                 }
62         }
63         return -1
64 }
65
66 type row struct {
67         cols []interface{} // must be same size as its table colname + coltype
68 }
69
70 func (r *row) clone() *row {
71         nrow := &row{cols: make([]interface{}, len(r.cols))}
72         copy(nrow.cols, r.cols)
73         return nrow
74 }
75
76 type fakeConn struct {
77         db *fakeDB // where to return ourselves to
78
79         currTx *fakeTx
80 }
81
82 type fakeTx struct {
83         c *fakeConn
84 }
85
86 type fakeStmt struct {
87         c *fakeConn
88         q string // just for debugging
89
90         cmd   string
91         table string
92
93         colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
94         colType      []string      // used by CREATE
95         colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
96         placeholders int           // used by INSERT/SELECT: number of ? params
97
98         whereCol []string // used by SELECT (all placeholders)
99
100         placeholderConverter []driver.ValueConverter // used by INSERT
101 }
102
103 var fdriver driver.Driver = &fakeDriver{}
104
105 func init() {
106         Register("test", fdriver)
107 }
108
109 // Supports dsn forms:
110 //    <dbname>
111 //    <dbname>;wipe
112 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
113         d.mu.Lock()
114         defer d.mu.Unlock()
115         d.openCount++
116         if d.dbs == nil {
117                 d.dbs = make(map[string]*fakeDB)
118         }
119         parts := strings.Split(dsn, ";")
120         if len(parts) < 1 {
121                 return nil, errors.New("fakedb: no database name")
122         }
123         name := parts[0]
124         db, ok := d.dbs[name]
125         if !ok {
126                 db = &fakeDB{name: name}
127                 d.dbs[name] = db
128         }
129         return &fakeConn{db: db}, nil
130 }
131
132 func (db *fakeDB) wipe() {
133         db.mu.Lock()
134         defer db.mu.Unlock()
135         db.tables = nil
136 }
137
138 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
139         db.mu.Lock()
140         defer db.mu.Unlock()
141         if db.tables == nil {
142                 db.tables = make(map[string]*table)
143         }
144         if _, exist := db.tables[name]; exist {
145                 return fmt.Errorf("table %q already exists", name)
146         }
147         if len(columnNames) != len(columnTypes) {
148                 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
149                         name, len(columnNames), len(columnTypes))
150         }
151         db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
152         return nil
153 }
154
155 // must be called with db.mu lock held
156 func (db *fakeDB) table(table string) (*table, bool) {
157         if db.tables == nil {
158                 return nil, false
159         }
160         t, ok := db.tables[table]
161         return t, ok
162 }
163
164 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
165         db.mu.Lock()
166         defer db.mu.Unlock()
167         t, ok := db.table(table)
168         if !ok {
169                 return
170         }
171         for n, cname := range t.colname {
172                 if cname == column {
173                         return t.coltype[n], true
174                 }
175         }
176         return "", false
177 }
178
179 func (c *fakeConn) Begin() (driver.Tx, error) {
180         if c.currTx != nil {
181                 return nil, errors.New("already in a transaction")
182         }
183         c.currTx = &fakeTx{c: c}
184         return c.currTx, nil
185 }
186
187 func (c *fakeConn) Close() error {
188         if c.currTx != nil {
189                 return errors.New("can't close; in a Transaction")
190         }
191         if c.db == nil {
192                 return errors.New("can't close; already closed")
193         }
194         c.db = nil
195         return nil
196 }
197
198 func errf(msg string, args ...interface{}) error {
199         return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
200 }
201
202 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
203 // (note that where where columns must always contain ? marks,
204 //  just a limitation for fakedb)
205 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
206         if len(parts) != 3 {
207                 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
208         }
209         stmt.table = parts[0]
210         stmt.colName = strings.Split(parts[1], ",")
211         for n, colspec := range strings.Split(parts[2], ",") {
212                 nameVal := strings.Split(colspec, "=")
213                 if len(nameVal) != 2 {
214                         return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
215                 }
216                 column, value := nameVal[0], nameVal[1]
217                 _, ok := c.db.columnType(stmt.table, column)
218                 if !ok {
219                         return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
220                 }
221                 if value != "?" {
222                         return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
223                                 stmt.table, column)
224                 }
225                 stmt.whereCol = append(stmt.whereCol, column)
226                 stmt.placeholders++
227         }
228         return stmt, nil
229 }
230
231 // parts are table|col=type,col2=type2
232 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
233         if len(parts) != 2 {
234                 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
235         }
236         stmt.table = parts[0]
237         for n, colspec := range strings.Split(parts[1], ",") {
238                 nameType := strings.Split(colspec, "=")
239                 if len(nameType) != 2 {
240                         return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
241                 }
242                 stmt.colName = append(stmt.colName, nameType[0])
243                 stmt.colType = append(stmt.colType, nameType[1])
244         }
245         return stmt, nil
246 }
247
248 // parts are table|col=?,col2=val
249 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
250         if len(parts) != 2 {
251                 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
252         }
253         stmt.table = parts[0]
254         for n, colspec := range strings.Split(parts[1], ",") {
255                 nameVal := strings.Split(colspec, "=")
256                 if len(nameVal) != 2 {
257                         return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
258                 }
259                 column, value := nameVal[0], nameVal[1]
260                 ctype, ok := c.db.columnType(stmt.table, column)
261                 if !ok {
262                         return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
263                 }
264                 stmt.colName = append(stmt.colName, column)
265
266                 if value != "?" {
267                         var subsetVal interface{}
268                         // Convert to driver subset type
269                         switch ctype {
270                         case "string":
271                                 subsetVal = []byte(value)
272                         case "int32":
273                                 i, err := strconv.Atoi(value)
274                                 if err != nil {
275                                         return nil, errf("invalid conversion to int32 from %q", value)
276                                 }
277                                 subsetVal = int64(i) // int64 is a subset type, but not int32
278                         default:
279                                 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
280                         }
281                         stmt.colValue = append(stmt.colValue, subsetVal)
282                 } else {
283                         stmt.placeholders++
284                         stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
285                         stmt.colValue = append(stmt.colValue, "?")
286                 }
287         }
288         return stmt, nil
289 }
290
291 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
292         if c.db == nil {
293                 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
294         }
295         parts := strings.Split(query, "|")
296         if len(parts) < 1 {
297                 return nil, errf("empty query")
298         }
299         cmd := parts[0]
300         parts = parts[1:]
301         stmt := &fakeStmt{q: query, c: c, cmd: cmd}
302         switch cmd {
303         case "WIPE":
304                 // Nothing
305         case "SELECT":
306                 return c.prepareSelect(stmt, parts)
307         case "CREATE":
308                 return c.prepareCreate(stmt, parts)
309         case "INSERT":
310                 return c.prepareInsert(stmt, parts)
311         default:
312                 return nil, errf("unsupported command type %q", cmd)
313         }
314         return stmt, nil
315 }
316
317 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
318         return s.placeholderConverter[idx]
319 }
320
321 func (s *fakeStmt) Close() error {
322         return nil
323 }
324
325 func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
326         db := s.c.db
327         switch s.cmd {
328         case "WIPE":
329                 db.wipe()
330                 return driver.DDLSuccess, nil
331         case "CREATE":
332                 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
333                         return nil, err
334                 }
335                 return driver.DDLSuccess, nil
336         case "INSERT":
337                 return s.execInsert(args)
338         }
339         fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
340         return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
341 }
342
343 func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
344         db := s.c.db
345         if len(args) != s.placeholders {
346                 panic("error in pkg db; should only get here if size is correct")
347         }
348         db.mu.Lock()
349         t, ok := db.table(s.table)
350         db.mu.Unlock()
351         if !ok {
352                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
353         }
354
355         t.mu.Lock()
356         defer t.mu.Unlock()
357
358         cols := make([]interface{}, len(t.colname))
359         argPos := 0
360         for n, colname := range s.colName {
361                 colidx := t.columnIndex(colname)
362                 if colidx == -1 {
363                         return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
364                 }
365                 var val interface{}
366                 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
367                         val = args[argPos]
368                         argPos++
369                 } else {
370                         val = s.colValue[n]
371                 }
372                 cols[colidx] = val
373         }
374
375         t.rows = append(t.rows, &row{cols: cols})
376         return driver.RowsAffected(1), nil
377 }
378
379 func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
380         db := s.c.db
381         if len(args) != s.placeholders {
382                 panic("error in pkg db; should only get here if size is correct")
383         }
384
385         db.mu.Lock()
386         t, ok := db.table(s.table)
387         db.mu.Unlock()
388         if !ok {
389                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
390         }
391         t.mu.Lock()
392         defer t.mu.Unlock()
393
394         colIdx := make(map[string]int) // select column name -> column index in table
395         for _, name := range s.colName {
396                 idx := t.columnIndex(name)
397                 if idx == -1 {
398                         return nil, fmt.Errorf("fakedb: unknown column name %q", name)
399                 }
400                 colIdx[name] = idx
401         }
402
403         mrows := []*row{}
404 rows:
405         for _, trow := range t.rows {
406                 // Process the where clause, skipping non-match rows. This is lazy
407                 // and just uses fmt.Sprintf("%v") to test equality.  Good enough
408                 // for test code.
409                 for widx, wcol := range s.whereCol {
410                         idx := t.columnIndex(wcol)
411                         if idx == -1 {
412                                 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
413                         }
414                         tcol := trow.cols[idx]
415                         if bs, ok := tcol.([]byte); ok {
416                                 // lazy hack to avoid sprintf %v on a []byte
417                                 tcol = string(bs)
418                         }
419                         if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
420                                 continue rows
421                         }
422                 }
423                 mrow := &row{cols: make([]interface{}, len(s.colName))}
424                 for seli, name := range s.colName {
425                         mrow.cols[seli] = trow.cols[colIdx[name]]
426                 }
427                 mrows = append(mrows, mrow)
428         }
429
430         cursor := &rowsCursor{
431                 pos:  -1,
432                 rows: mrows,
433                 cols: s.colName,
434         }
435         return cursor, nil
436 }
437
438 func (s *fakeStmt) NumInput() int {
439         return s.placeholders
440 }
441
442 func (tx *fakeTx) Commit() error {
443         tx.c.currTx = nil
444         return nil
445 }
446
447 func (tx *fakeTx) Rollback() error {
448         tx.c.currTx = nil
449         return nil
450 }
451
452 type rowsCursor struct {
453         cols   []string
454         pos    int
455         rows   []*row
456         closed bool
457 }
458
459 func (rc *rowsCursor) Close() error {
460         rc.closed = true
461         return nil
462 }
463
464 func (rc *rowsCursor) Columns() []string {
465         return rc.cols
466 }
467
468 func (rc *rowsCursor) Next(dest []interface{}) error {
469         if rc.closed {
470                 return errors.New("fakedb: cursor is closed")
471         }
472         rc.pos++
473         if rc.pos >= len(rc.rows) {
474                 return io.EOF // per interface spec
475         }
476         for i, v := range rc.rows[rc.pos].cols {
477                 // TODO(bradfitz): convert to subset types? naah, I
478                 // think the subset types should only be input to
479                 // driver, but the sql package should be able to handle
480                 // a wider range of types coming out of drivers. all
481                 // for ease of drivers, and to prevent drivers from
482                 // messing up conversions or doing them differently.
483                 dest[i] = v
484         }
485         return nil
486 }
487
488 func converterForType(typ string) driver.ValueConverter {
489         switch typ {
490         case "bool":
491                 return driver.Bool
492         case "int32":
493                 return driver.Int32
494         case "string":
495                 return driver.String
496         }
497         panic("invalid fakedb column type of " + typ)
498 }