OSDN Git Service

libgo: Update to weekly.2012-03-13.
[pf3gnuchains/gcc-fork.git] / libgo / go / database / sql / fakedb_test.go
index fc63f03..184e775 100644 (file)
@@ -82,6 +82,7 @@ type fakeConn struct {
        mu          sync.Mutex
        stmtsMade   int
        stmtsClosed int
+       numPrepare  int
 }
 
 func (c *fakeConn) incrStat(v *int) {
@@ -208,10 +209,13 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
 
 func (c *fakeConn) Close() error {
        if c.currTx != nil {
-               return errors.New("can't close; in a Transaction")
+               return errors.New("can't close fakeConn; in a Transaction")
        }
        if c.db == nil {
-               return errors.New("can't close; already closed")
+               return errors.New("can't close fakeConn; already closed")
+       }
+       if c.stmtsMade > c.stmtsClosed {
+               return errors.New("can't close; dangling statement(s)")
        }
        c.db = nil
        return nil
@@ -249,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
 //  just a limitation for fakedb)
 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 3 {
+               stmt.Close()
                return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
        }
        stmt.table = parts[0]
@@ -259,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
                }
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                _, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
                }
                if value != "?" {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
                                stmt.table, column)
                }
@@ -279,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=type,col2=type2
 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameType := strings.Split(colspec, "=")
                if len(nameType) != 2 {
+                       stmt.Close()
                        return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                stmt.colName = append(stmt.colName, nameType[0])
@@ -296,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=?,col2=val
 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                ctype, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
                }
                stmt.colName = append(stmt.colName, column)
@@ -322,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
                        case "int32":
                                i, err := strconv.Atoi(value)
                                if err != nil {
+                                       stmt.Close()
                                        return nil, errf("invalid conversion to int32 from %q", value)
                                }
                                subsetVal = int64(i) // int64 is a subset type, but not int32
                        default:
+                               stmt.Close()
                                return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
                        }
                        stmt.colValue = append(stmt.colValue, subsetVal)
@@ -339,6 +354,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
 }
 
 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+       c.numPrepare++
        if c.db == nil {
                panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
        }
@@ -360,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
        case "INSERT":
                return c.prepareInsert(stmt, parts)
        default:
+               stmt.Close()
                return nil, errf("unsupported command type %q", cmd)
        }
        return stmt, nil