OSDN Git Service

libgo: Update to weekly.2012-03-13.
authorian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>
Fri, 30 Mar 2012 21:26:46 +0000 (21:26 +0000)
committerian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>
Fri, 30 Mar 2012 21:26:46 +0000 (21:26 +0000)
git-svn-id: svn+ssh://gcc.gnu.org/svn/gcc/branches/gcc-4_7-branch@186022 138bc75d-0d04-0410-961f-82ee72b054a4

137 files changed:
libgo/MERGE
libgo/Makefile.am
libgo/Makefile.in
libgo/go/archive/tar/reader.go
libgo/go/archive/tar/writer.go
libgo/go/archive/tar/writer_test.go
libgo/go/archive/zip/reader.go
libgo/go/archive/zip/reader_test.go
libgo/go/archive/zip/struct.go
libgo/go/archive/zip/testdata/crc32-not-streamed.zip [new file with mode: 0644]
libgo/go/archive/zip/testdata/go-no-datadesc-sig.zip [new file with mode: 0644]
libgo/go/archive/zip/testdata/go-with-datadesc-sig.zip [new file with mode: 0644]
libgo/go/archive/zip/writer.go
libgo/go/archive/zip/writer_test.go
libgo/go/crypto/tls/common.go
libgo/go/crypto/tls/handshake_client.go
libgo/go/crypto/tls/handshake_server_test.go
libgo/go/crypto/tls/root_test.go
libgo/go/crypto/tls/root_windows.go [deleted file]
libgo/go/crypto/tls/tls.go
libgo/go/crypto/x509/pkcs1.go
libgo/go/crypto/x509/root.go [new file with mode: 0644]
libgo/go/crypto/x509/root_darwin.go [moved from libgo/go/crypto/tls/root_darwin.go with 90% similarity]
libgo/go/crypto/x509/root_stub.go [moved from libgo/go/crypto/tls/root_stub.go with 51% similarity]
libgo/go/crypto/x509/root_unix.go [moved from libgo/go/crypto/tls/root_unix.go with 76% similarity]
libgo/go/crypto/x509/root_windows.go [new file with mode: 0644]
libgo/go/crypto/x509/verify.go
libgo/go/crypto/x509/verify_test.go
libgo/go/crypto/x509/x509.go
libgo/go/database/sql/driver/driver.go
libgo/go/database/sql/fakedb_test.go
libgo/go/database/sql/sql.go
libgo/go/database/sql/sql_test.go
libgo/go/encoding/asn1/asn1.go
libgo/go/encoding/asn1/asn1_test.go
libgo/go/encoding/asn1/common.go
libgo/go/encoding/asn1/marshal.go
libgo/go/encoding/asn1/marshal_test.go
libgo/go/encoding/binary/binary.go
libgo/go/encoding/csv/reader.go
libgo/go/encoding/gob/decode.go
libgo/go/encoding/gob/encoder_test.go
libgo/go/encoding/gob/gobencdec_test.go
libgo/go/encoding/json/encode.go
libgo/go/exp/norm/maketables.go
libgo/go/exp/wingui/gui.go [deleted file]
libgo/go/exp/wingui/winapi.go [deleted file]
libgo/go/exp/wingui/zwinapi.go [deleted file]
libgo/go/expvar/expvar.go
libgo/go/fmt/doc.go
libgo/go/fmt/export_test.go [new file with mode: 0644]
libgo/go/fmt/fmt_test.go
libgo/go/fmt/format.go
libgo/go/fmt/print.go
libgo/go/fmt/scan.go
libgo/go/go/build/build.go
libgo/go/go/build/deps_test.go [new file with mode: 0644]
libgo/go/go/parser/error_test.go [new file with mode: 0644]
libgo/go/go/parser/parser.go
libgo/go/go/parser/parser_test.go
libgo/go/go/parser/short_test.go [new file with mode: 0644]
libgo/go/go/parser/testdata/commas.src [new file with mode: 0644]
libgo/go/go/parser/testdata/issue3106.src [new file with mode: 0644]
libgo/go/go/printer/nodes.go
libgo/go/go/printer/testdata/statements.golden
libgo/go/go/printer/testdata/statements.input
libgo/go/go/scanner/scanner.go
libgo/go/html/template/doc.go
libgo/go/io/io.go
libgo/go/log/log.go
libgo/go/net/dial_test.go
libgo/go/net/dnsclient.go
libgo/go/net/dnsmsg.go
libgo/go/net/dnsmsg_test.go
libgo/go/net/fd_linux.go
libgo/go/net/file_test.go
libgo/go/net/http/client_test.go
libgo/go/net/http/request.go
libgo/go/net/http/request_test.go
libgo/go/net/http/server.go
libgo/go/net/http/transport.go
libgo/go/net/http/transport_test.go
libgo/go/net/interface.go
libgo/go/net/interface_linux.go
libgo/go/net/iprawsock_posix.go
libgo/go/net/ipsock_posix.go
libgo/go/net/mac.go
libgo/go/net/mac_test.go
libgo/go/net/mail/message.go
libgo/go/net/multicast_test.go
libgo/go/net/net.go
libgo/go/net/net_test.go
libgo/go/net/parse_test.go
libgo/go/net/rpc/client.go
libgo/go/net/server_test.go
libgo/go/net/sock.go
libgo/go/net/sockopt.go
libgo/go/net/sockopt_bsd.go
libgo/go/net/sockopt_linux.go
libgo/go/net/sockopt_windows.go
libgo/go/net/tcpsock_posix.go
libgo/go/net/timeout_test.go
libgo/go/net/udp_test.go
libgo/go/net/udpsock_posix.go
libgo/go/net/unicast_test.go
libgo/go/net/unixsock_posix.go
libgo/go/os/error_posix.go
libgo/go/os/error_test.go [new file with mode: 0644]
libgo/go/os/error_windows.go [new file with mode: 0644]
libgo/go/os/exec/exec.go
libgo/go/os/types.go
libgo/go/path/filepath/path.go
libgo/go/path/filepath/path_test.go
libgo/go/path/filepath/symlink.go [new file with mode: 0644]
libgo/go/path/filepath/symlink_windows.go [new file with mode: 0644]
libgo/go/reflect/type.go
libgo/go/reflect/value.go
libgo/go/runtime/compiler.go [new file with mode: 0644]
libgo/go/runtime/debug/stack_test.go
libgo/go/runtime/pprof/pprof_test.go
libgo/go/strconv/isprint.go [new file with mode: 0644]
libgo/go/strconv/makeisprint.go [new file with mode: 0644]
libgo/go/strconv/quote.go
libgo/go/strconv/quote_test.go
libgo/go/strings/example_test.go
libgo/go/sync/atomic/atomic_test.go
libgo/go/testing/testing.go
libgo/go/time/tick_test.go
libgo/go/time/time.go
libgo/go/unicode/utf16/export_test.go [new file with mode: 0644]
libgo/go/unicode/utf16/utf16.go
libgo/go/unicode/utf16/utf16_test.go
libgo/go/unicode/utf8/utf8.go
libgo/go/unicode/utf8/utf8_test.go
libgo/runtime/malloc.goc
libgo/runtime/proc.c
libgo/runtime/runtime.h

index 17d01ce..13b0438 100644 (file)
@@ -1,4 +1,4 @@
-f4470a54e6db
+3cdba7b0650c
 
 The first line of this file holds the Mercurial revision number of the
 last merge done from the master library sources.
index 99294f1..14f72ec 100644 (file)
@@ -813,6 +813,7 @@ go_net_rpc_files = \
        go/net/rpc/server.go
 
 go_runtime_files = \
+       go/runtime/compiler.go \
        go/runtime/debug.go \
        go/runtime/error.go \
        go/runtime/extern.go \
@@ -843,6 +844,7 @@ go_strconv_files = \
        go/strconv/decimal.go \
        go/strconv/extfloat.go \
        go/strconv/ftoa.go \
+       go/strconv/isprint.go \
        go/strconv/itoa.go \
        go/strconv/quote.go
 
@@ -1000,12 +1002,13 @@ go_crypto_tls_files = \
        go/crypto/tls/handshake_server.go \
        go/crypto/tls/key_agreement.go \
        go/crypto/tls/prf.go \
-       go/crypto/tls/root_unix.go \
        go/crypto/tls/tls.go
 go_crypto_x509_files = \
        go/crypto/x509/cert_pool.go \
        go/crypto/x509/pkcs1.go \
        go/crypto/x509/pkcs8.go \
+       go/crypto/x509/root.go \
+       go/crypto/x509/root_unix.go \
        go/crypto/x509/verify.go \
        go/crypto/x509/x509.go
 
@@ -1320,7 +1323,8 @@ go_os_user_files = \
 go_path_filepath_files = \
        go/path/filepath/match.go \
        go/path/filepath/path.go \
-       go/path/filepath/path_unix.go
+       go/path/filepath/path_unix.go \
+       go/path/filepath/symlink.go
 
 go_regexp_syntax_files = \
        go/regexp/syntax/compile.go \
index b57d929..720d57e 100644 (file)
@@ -1131,6 +1131,7 @@ go_net_rpc_files = \
        go/net/rpc/server.go
 
 go_runtime_files = \
+       go/runtime/compiler.go \
        go/runtime/debug.go \
        go/runtime/error.go \
        go/runtime/extern.go \
@@ -1150,6 +1151,7 @@ go_strconv_files = \
        go/strconv/decimal.go \
        go/strconv/extfloat.go \
        go/strconv/ftoa.go \
+       go/strconv/isprint.go \
        go/strconv/itoa.go \
        go/strconv/quote.go
 
@@ -1315,13 +1317,14 @@ go_crypto_tls_files = \
        go/crypto/tls/handshake_server.go \
        go/crypto/tls/key_agreement.go \
        go/crypto/tls/prf.go \
-       go/crypto/tls/root_unix.go \
        go/crypto/tls/tls.go
 
 go_crypto_x509_files = \
        go/crypto/x509/cert_pool.go \
        go/crypto/x509/pkcs1.go \
        go/crypto/x509/pkcs8.go \
+       go/crypto/x509/root.go \
+       go/crypto/x509/root_unix.go \
        go/crypto/x509/verify.go \
        go/crypto/x509/x509.go
 
@@ -1677,7 +1680,8 @@ go_os_user_files = \
 go_path_filepath_files = \
        go/path/filepath/match.go \
        go/path/filepath/path.go \
-       go/path/filepath/path_unix.go
+       go/path/filepath/path_unix.go \
+       go/path/filepath/symlink.go
 
 go_regexp_syntax_files = \
        go/regexp/syntax/compile.go \
index 755a730..1b40af8 100644 (file)
@@ -18,7 +18,7 @@ import (
 )
 
 var (
-       ErrHeader = errors.New("invalid tar header")
+       ErrHeader = errors.New("archive/tar: invalid tar header")
 )
 
 // A Reader provides sequential access to the contents of a tar archive.
index d35726b..b2b7a58 100644 (file)
@@ -5,18 +5,19 @@
 package tar
 
 // TODO(dsymonds):
-// - catch more errors (no first header, write after close, etc.)
+// - catch more errors (no first header, etc.)
 
 import (
        "errors"
+       "fmt"
        "io"
        "strconv"
 )
 
 var (
-       ErrWriteTooLong    = errors.New("write too long")
-       ErrFieldTooLong    = errors.New("header field too long")
-       ErrWriteAfterClose = errors.New("write after close")
+       ErrWriteTooLong    = errors.New("archive/tar: write too long")
+       ErrFieldTooLong    = errors.New("archive/tar: header field too long")
+       ErrWriteAfterClose = errors.New("archive/tar: write after close")
 )
 
 // A Writer provides sequential writing of a tar archive in POSIX.1 format.
@@ -48,6 +49,11 @@ func NewWriter(w io.Writer) *Writer { return &Writer{w: w} }
 
 // Flush finishes writing the current file (optional).
 func (tw *Writer) Flush() error {
+       if tw.nb > 0 {
+               tw.err = fmt.Errorf("archive/tar: missed writing %d bytes", tw.nb)
+               return tw.err
+       }
+
        n := tw.nb + tw.pad
        for n > 0 && tw.err == nil {
                nr := n
@@ -193,6 +199,9 @@ func (tw *Writer) Close() error {
        }
        tw.Flush()
        tw.closed = true
+       if tw.err != nil {
+               return tw.err
+       }
 
        // trailer: two zero blocks
        for i := 0; i < 2; i++ {
index 0b41372..a214e57 100644 (file)
@@ -9,6 +9,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "strings"
        "testing"
        "testing/iotest"
        "time"
@@ -95,7 +96,8 @@ var writerTests = []*writerTest{
                                        Uname:    "dsymonds",
                                        Gname:    "eng",
                                },
-                               // no contents
+                               // fake contents
+                               contents: strings.Repeat("\x00", 4<<10),
                        },
                },
        },
@@ -150,7 +152,9 @@ testLoop:
 
                buf := new(bytes.Buffer)
                tw := NewWriter(iotest.TruncateWriter(buf, 4<<10)) // only catch the first 4 KB
+               big := false
                for j, entry := range test.entries {
+                       big = big || entry.header.Size > 1<<10
                        if err := tw.WriteHeader(entry.header); err != nil {
                                t.Errorf("test %d, entry %d: Failed writing header: %v", i, j, err)
                                continue testLoop
@@ -160,7 +164,8 @@ testLoop:
                                continue testLoop
                        }
                }
-               if err := tw.Close(); err != nil {
+               // Only interested in Close failures for the small tests.
+               if err := tw.Close(); err != nil && !big {
                        t.Errorf("test %d: Failed closing archive: %v", i, err)
                        continue testLoop
                }
index f3826dc..ddd5075 100644 (file)
@@ -124,10 +124,6 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
                return
        }
        size := int64(f.CompressedSize)
-       if size == 0 && f.hasDataDescriptor() {
-               // permit SectionReader to see the rest of the file
-               size = f.zipsize - (f.headerOffset + bodyOffset)
-       }
        r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
        switch f.Method {
        case Store: // (no compression)
@@ -136,10 +132,13 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
                rc = flate.NewReader(r)
        default:
                err = ErrAlgorithm
+               return
        }
-       if rc != nil {
-               rc = &checksumReader{rc, crc32.NewIEEE(), f, r}
+       var desr io.Reader
+       if f.hasDataDescriptor() {
+               desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
        }
+       rc = &checksumReader{rc, crc32.NewIEEE(), f, desr, nil}
        return
 }
 
@@ -147,23 +146,36 @@ type checksumReader struct {
        rc   io.ReadCloser
        hash hash.Hash32
        f    *File
-       zipr io.Reader // for reading the data descriptor
+       desr io.Reader // if non-nil, where to read the data descriptor
+       err  error     // sticky error
 }
 
 func (r *checksumReader) Read(b []byte) (n int, err error) {
+       if r.err != nil {
+               return 0, r.err
+       }
        n, err = r.rc.Read(b)
        r.hash.Write(b[:n])
-       if err != io.EOF {
+       if err == nil {
                return
        }
-       if r.f.hasDataDescriptor() {
-               if err = readDataDescriptor(r.zipr, r.f); err != nil {
-                       return
+       if err == io.EOF {
+               if r.desr != nil {
+                       if err1 := readDataDescriptor(r.desr, r.f); err1 != nil {
+                               err = err1
+                       } else if r.hash.Sum32() != r.f.CRC32 {
+                               err = ErrChecksum
+                       }
+               } else {
+                       // If there's not a data descriptor, we still compare
+                       // the CRC32 of what we've read against the file header
+                       // or TOC's CRC32, if it seems like it was set.
+                       if r.f.CRC32 != 0 && r.hash.Sum32() != r.f.CRC32 {
+                               err = ErrChecksum
+                       }
                }
        }
-       if r.hash.Sum32() != r.f.CRC32 {
-               err = ErrChecksum
-       }
+       r.err = err
        return
 }
 
@@ -226,10 +238,31 @@ func readDirectoryHeader(f *File, r io.Reader) error {
 
 func readDataDescriptor(r io.Reader, f *File) error {
        var buf [dataDescriptorLen]byte
-       if _, err := io.ReadFull(r, buf[:]); err != nil {
+
+       // The spec says: "Although not originally assigned a
+       // signature, the value 0x08074b50 has commonly been adopted
+       // as a signature value for the data descriptor record.
+       // Implementers should be aware that ZIP files may be
+       // encountered with or without this signature marking data
+       // descriptors and should account for either case when reading
+       // ZIP files to ensure compatibility."
+       //
+       // dataDescriptorLen includes the size of the signature but
+       // first read just those 4 bytes to see if it exists.
+       if _, err := io.ReadFull(r, buf[:4]); err != nil {
                return err
        }
-       b := readBuf(buf[:])
+       off := 0
+       maybeSig := readBuf(buf[:4])
+       if maybeSig.uint32() != dataDescriptorSignature {
+               // No data descriptor signature. Keep these four
+               // bytes.
+               off += 4
+       }
+       if _, err := io.ReadFull(r, buf[off:12]); err != nil {
+               return err
+       }
+       b := readBuf(buf[:12])
        f.CRC32 = b.uint32()
        f.CompressedSize = b.uint32()
        f.UncompressedSize = b.uint32()
index 066a615..c2db0dc 100644 (file)
@@ -10,23 +10,26 @@ import (
        "io"
        "io/ioutil"
        "os"
+       "path/filepath"
        "testing"
        "time"
 )
 
 type ZipTest struct {
        Name    string
+       Source  func() (r io.ReaderAt, size int64) // if non-nil, used instead of testdata/<Name> file
        Comment string
        File    []ZipTestFile
        Error   error // the error that Opening this file should return
 }
 
 type ZipTestFile struct {
-       Name    string
-       Content []byte // if blank, will attempt to compare against File
-       File    string // name of file to compare to (relative to testdata/)
-       Mtime   string // modified time in format "mm-dd-yy hh:mm:ss"
-       Mode    os.FileMode
+       Name       string
+       Content    []byte // if blank, will attempt to compare against File
+       ContentErr error
+       File       string // name of file to compare to (relative to testdata/)
+       Mtime      string // modified time in format "mm-dd-yy hh:mm:ss"
+       Mode       os.FileMode
 }
 
 // Caution: The Mtime values found for the test files should correspond to
@@ -107,6 +110,99 @@ var tests = []ZipTest{
                Name: "unix.zip",
                File: crossPlatform,
        },
+       {
+               // created by Go, before we wrote the "optional" data
+               // descriptor signatures (which are required by OS X)
+               Name: "go-no-datadesc-sig.zip",
+               File: []ZipTestFile{
+                       {
+                               Name:    "foo.txt",
+                               Content: []byte("foo\n"),
+                               Mtime:   "03-08-12 16:59:10",
+                               Mode:    0644,
+                       },
+                       {
+                               Name:    "bar.txt",
+                               Content: []byte("bar\n"),
+                               Mtime:   "03-08-12 16:59:12",
+                               Mode:    0644,
+                       },
+               },
+       },
+       {
+               // created by Go, after we wrote the "optional" data
+               // descriptor signatures (which are required by OS X)
+               Name: "go-with-datadesc-sig.zip",
+               File: []ZipTestFile{
+                       {
+                               Name:    "foo.txt",
+                               Content: []byte("foo\n"),
+                               Mode:    0666,
+                       },
+                       {
+                               Name:    "bar.txt",
+                               Content: []byte("bar\n"),
+                               Mode:    0666,
+                       },
+               },
+       },
+       {
+               Name:   "Bad-CRC32-in-data-descriptor",
+               Source: returnCorruptCRC32Zip,
+               File: []ZipTestFile{
+                       {
+                               Name:       "foo.txt",
+                               Content:    []byte("foo\n"),
+                               Mode:       0666,
+                               ContentErr: ErrChecksum,
+                       },
+                       {
+                               Name:    "bar.txt",
+                               Content: []byte("bar\n"),
+                               Mode:    0666,
+                       },
+               },
+       },
+       // Tests that we verify (and accept valid) crc32s on files
+       // with crc32s in their file header (not in data descriptors)
+       {
+               Name: "crc32-not-streamed.zip",
+               File: []ZipTestFile{
+                       {
+                               Name:    "foo.txt",
+                               Content: []byte("foo\n"),
+                               Mtime:   "03-08-12 16:59:10",
+                               Mode:    0644,
+                       },
+                       {
+                               Name:    "bar.txt",
+                               Content: []byte("bar\n"),
+                               Mtime:   "03-08-12 16:59:12",
+                               Mode:    0644,
+                       },
+               },
+       },
+       // Tests that we verify (and reject invalid) crc32s on files
+       // with crc32s in their file header (not in data descriptors)
+       {
+               Name:   "crc32-not-streamed.zip",
+               Source: returnCorruptNotStreamedZip,
+               File: []ZipTestFile{
+                       {
+                               Name:       "foo.txt",
+                               Content:    []byte("foo\n"),
+                               Mtime:      "03-08-12 16:59:10",
+                               Mode:       0644,
+                               ContentErr: ErrChecksum,
+                       },
+                       {
+                               Name:    "bar.txt",
+                               Content: []byte("bar\n"),
+                               Mtime:   "03-08-12 16:59:12",
+                               Mode:    0644,
+                       },
+               },
+       },
 }
 
 var crossPlatform = []ZipTestFile{
@@ -139,7 +235,18 @@ func TestReader(t *testing.T) {
 }
 
 func readTestZip(t *testing.T, zt ZipTest) {
-       z, err := OpenReader("testdata/" + zt.Name)
+       var z *Reader
+       var err error
+       if zt.Source != nil {
+               rat, size := zt.Source()
+               z, err = NewReader(rat, size)
+       } else {
+               var rc *ReadCloser
+               rc, err = OpenReader(filepath.Join("testdata", zt.Name))
+               if err == nil {
+                       z = &rc.Reader
+               }
+       }
        if err != zt.Error {
                t.Errorf("error=%v, want %v", err, zt.Error)
                return
@@ -149,11 +256,6 @@ func readTestZip(t *testing.T, zt ZipTest) {
        if err == ErrFormat {
                return
        }
-       defer func() {
-               if err := z.Close(); err != nil {
-                       t.Errorf("error %q when closing zip file", err)
-               }
-       }()
 
        // bail here if no Files expected to be tested
        // (there may actually be files in the zip, but we don't care)
@@ -170,7 +272,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
 
        // test read of each file
        for i, ft := range zt.File {
-               readTestFile(t, ft, z.File[i])
+               readTestFile(t, zt, ft, z.File[i])
        }
 
        // test simultaneous reads
@@ -179,7 +281,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
        for i := 0; i < 5; i++ {
                for j, ft := range zt.File {
                        go func(j int, ft ZipTestFile) {
-                               readTestFile(t, ft, z.File[j])
+                               readTestFile(t, zt, ft, z.File[j])
                                done <- true
                        }(j, ft)
                        n++
@@ -188,26 +290,11 @@ func readTestZip(t *testing.T, zt ZipTest) {
        for ; n > 0; n-- {
                <-done
        }
-
-       // test invalid checksum
-       if !z.File[0].hasDataDescriptor() { // skip test when crc32 in dd
-               z.File[0].CRC32++ // invalidate
-               r, err := z.File[0].Open()
-               if err != nil {
-                       t.Error(err)
-                       return
-               }
-               var b bytes.Buffer
-               _, err = io.Copy(&b, r)
-               if err != ErrChecksum {
-                       t.Errorf("%s: copy error=%v, want %v", z.File[0].Name, err, ErrChecksum)
-               }
-       }
 }
 
-func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
+func readTestFile(t *testing.T, zt ZipTest, ft ZipTestFile, f *File) {
        if f.Name != ft.Name {
-               t.Errorf("name=%q, want %q", f.Name, ft.Name)
+               t.Errorf("%s: name=%q, want %q", zt.Name, f.Name, ft.Name)
        }
 
        if ft.Mtime != "" {
@@ -217,11 +304,11 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
                        return
                }
                if ft := f.ModTime(); !ft.Equal(mtime) {
-                       t.Errorf("%s: mtime=%s, want %s", f.Name, ft, mtime)
+                       t.Errorf("%s: %s: mtime=%s, want %s", zt.Name, f.Name, ft, mtime)
                }
        }
 
-       testFileMode(t, f, ft.Mode)
+       testFileMode(t, zt.Name, f, ft.Mode)
 
        size0 := f.UncompressedSize
 
@@ -237,8 +324,10 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
        }
 
        _, err = io.Copy(&b, r)
+       if err != ft.ContentErr {
+               t.Errorf("%s: copying contents: %v (want %v)", zt.Name, err, ft.ContentErr)
+       }
        if err != nil {
-               t.Error(err)
                return
        }
        r.Close()
@@ -264,12 +353,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
        }
 }
 
-func testFileMode(t *testing.T, f *File, want os.FileMode) {
+func testFileMode(t *testing.T, zipName string, f *File, want os.FileMode) {
        mode := f.Mode()
        if want == 0 {
-               t.Errorf("%s mode: got %v, want none", f.Name, mode)
+               t.Errorf("%s: %s mode: got %v, want none", zipName, f.Name, mode)
        } else if mode != want {
-               t.Errorf("%s mode: want %v, got %v", f.Name, want, mode)
+               t.Errorf("%s: %s mode: want %v, got %v", zipName, f.Name, want, mode)
        }
 }
 
@@ -294,3 +383,35 @@ func TestInvalidFiles(t *testing.T) {
                t.Errorf("sigs: error=%v, want %v", err, ErrFormat)
        }
 }
+
+func messWith(fileName string, corrupter func(b []byte)) (r io.ReaderAt, size int64) {
+       data, err := ioutil.ReadFile(filepath.Join("testdata", fileName))
+       if err != nil {
+               panic("Error reading " + fileName + ": " + err.Error())
+       }
+       corrupter(data)
+       return bytes.NewReader(data), int64(len(data))
+}
+
+func returnCorruptCRC32Zip() (r io.ReaderAt, size int64) {
+       return messWith("go-with-datadesc-sig.zip", func(b []byte) {
+               // Corrupt one of the CRC32s in the data descriptor:
+               b[0x2d]++
+       })
+}
+
+func returnCorruptNotStreamedZip() (r io.ReaderAt, size int64) {
+       return messWith("crc32-not-streamed.zip", func(b []byte) {
+               // Corrupt foo.txt's final crc32 byte, in both
+               // the file header and TOC. (0x7e -> 0x7f)
+               b[0x11]++
+               b[0x9d]++
+
+               // TODO(bradfitz): add a new test that only corrupts
+               // one of these values, and verify that that's also an
+               // error. Currently, the reader code doesn't verify the
+               // fileheader and TOC's crc32 match if they're both
+               // non-zero and only the second line above, the TOC,
+               // is what matters.
+       })
+}
index fdbd16d..55f3dcf 100644 (file)
@@ -27,10 +27,11 @@ const (
        fileHeaderSignature      = 0x04034b50
        directoryHeaderSignature = 0x02014b50
        directoryEndSignature    = 0x06054b50
-       fileHeaderLen            = 30 // + filename + extra
-       directoryHeaderLen       = 46 // + filename + extra + comment
-       directoryEndLen          = 22 // + comment
-       dataDescriptorLen        = 12
+       dataDescriptorSignature  = 0x08074b50 // de-facto standard; required by OS X Finder
+       fileHeaderLen            = 30         // + filename + extra
+       directoryHeaderLen       = 46         // + filename + extra + comment
+       directoryEndLen          = 22         // + comment
+       dataDescriptorLen        = 16         // four uint32: descriptor signature, crc32, compressed size, size
 
        // Constants for the first byte in CreatorVersion
        creatorFAT    = 0
diff --git a/libgo/go/archive/zip/testdata/crc32-not-streamed.zip b/libgo/go/archive/zip/testdata/crc32-not-streamed.zip
new file mode 100644 (file)
index 0000000..f268d88
Binary files /dev/null and b/libgo/go/archive/zip/testdata/crc32-not-streamed.zip differ
diff --git a/libgo/go/archive/zip/testdata/go-no-datadesc-sig.zip b/libgo/go/archive/zip/testdata/go-no-datadesc-sig.zip
new file mode 100644 (file)
index 0000000..c3d593f
Binary files /dev/null and b/libgo/go/archive/zip/testdata/go-no-datadesc-sig.zip differ
diff --git a/libgo/go/archive/zip/testdata/go-with-datadesc-sig.zip b/libgo/go/archive/zip/testdata/go-with-datadesc-sig.zip
new file mode 100644 (file)
index 0000000..bcfe121
Binary files /dev/null and b/libgo/go/archive/zip/testdata/go-with-datadesc-sig.zip differ
index b2cc55b..45eb6bd 100644 (file)
@@ -224,6 +224,7 @@ func (w *fileWriter) close() error {
        // write data descriptor
        var buf [dataDescriptorLen]byte
        b := writeBuf(buf[:])
+       b.uint32(dataDescriptorSignature) // de-facto standard, required by OS X
        b.uint32(fh.CRC32)
        b.uint32(fh.CompressedSize)
        b.uint32(fh.UncompressedSize)
index 88e5211..8b1c4df 100644 (file)
@@ -108,7 +108,7 @@ func testReadFile(t *testing.T, f *File, wt *WriteTest) {
        if f.Name != wt.Name {
                t.Fatalf("File name: got %q, want %q", f.Name, wt.Name)
        }
-       testFileMode(t, f, wt.Mode)
+       testFileMode(t, wt.Name, f, wt.Mode)
        rc, err := f.Open()
        if err != nil {
                t.Fatal("opening:", err)
index 25f7a92..4ba0bf8 100644 (file)
@@ -198,14 +198,6 @@ func (c *Config) time() time.Time {
        return t()
 }
 
-func (c *Config) rootCAs() *x509.CertPool {
-       s := c.RootCAs
-       if s == nil {
-               s = defaultRoots()
-       }
-       return s
-}
-
 func (c *Config) cipherSuites() []uint16 {
        s := c.CipherSuites
        if s == nil {
@@ -311,28 +303,16 @@ func defaultConfig() *Config {
        return &emptyConfig
 }
 
-var once sync.Once
-
-func defaultRoots() *x509.CertPool {
-       once.Do(initDefaults)
-       return varDefaultRoots
-}
+var (
+       once                   sync.Once
+       varDefaultCipherSuites []uint16
+)
 
 func defaultCipherSuites() []uint16 {
-       once.Do(initDefaults)
+       once.Do(initDefaultCipherSuites)
        return varDefaultCipherSuites
 }
 
-func initDefaults() {
-       initDefaultRoots()
-       initDefaultCipherSuites()
-}
-
-var (
-       varDefaultRoots        *x509.CertPool
-       varDefaultCipherSuites []uint16
-)
-
 func initDefaultCipherSuites() {
        varDefaultCipherSuites = make([]uint16, len(cipherSuites))
        for i, suite := range cipherSuites {
index 0d7b806..266eb8f 100644 (file)
@@ -102,7 +102,7 @@ func (c *Conn) clientHandshake() error {
 
        if !c.config.InsecureSkipVerify {
                opts := x509.VerifyOptions{
-                       Roots:         c.config.rootCAs(),
+                       Roots:         c.config.RootCAs,
                        CurrentTime:   c.config.time(),
                        DNSName:       c.config.ServerName,
                        Intermediates: x509.NewCertPool(),
index bd31d31..08a0ccb 100644 (file)
@@ -143,7 +143,7 @@ func testServerScript(t *testing.T, name string, serverScript [][]byte, config *
        if peers != nil {
                gotpeers := <-pchan
                if len(peers) == len(gotpeers) {
-                       for i, _ := range peers {
+                       for i := range peers {
                                if !peers[i].Equal(gotpeers[i]) {
                                        t.Fatalf("%s: mismatch on peer cert %d", name, i)
                                }
index 95a89d8..e61c218 100644 (file)
@@ -5,25 +5,25 @@
 package tls
 
 import (
+       "crypto/x509"
+       "runtime"
        "testing"
 )
 
 var tlsServers = []string{
-       "google.com:443",
-       "github.com:443",
-       "twitter.com:443",
+       "google.com",
+       "github.com",
+       "twitter.com",
 }
 
 func TestOSCertBundles(t *testing.T) {
-       defaultRoots()
-
        if testing.Short() {
                t.Logf("skipping certificate tests in short mode")
                return
        }
 
        for _, addr := range tlsServers {
-               conn, err := Dial("tcp", addr, nil)
+               conn, err := Dial("tcp", addr+":443", &Config{ServerName: addr})
                if err != nil {
                        t.Errorf("unable to verify %v: %v", addr, err)
                        continue
@@ -34,3 +34,28 @@ func TestOSCertBundles(t *testing.T) {
                }
        }
 }
+
+func TestCertHostnameVerifyWindows(t *testing.T) {
+       if runtime.GOOS != "windows" {
+               return
+       }
+
+       if testing.Short() {
+               t.Logf("skipping certificate tests in short mode")
+               return
+       }
+
+       for _, addr := range tlsServers {
+               cfg := &Config{ServerName: "example.com"}
+               conn, err := Dial("tcp", addr+":443", cfg)
+               if err == nil {
+                       conn.Close()
+                       t.Errorf("should fail to verify for example.com: %v", addr)
+                       continue
+               }
+               _, ok := err.(x509.HostnameError)
+               if !ok {
+                       t.Errorf("error type mismatch, got: %v", err)
+               }
+       }
+}
diff --git a/libgo/go/crypto/tls/root_windows.go b/libgo/go/crypto/tls/root_windows.go
deleted file mode 100644 (file)
index 319309a..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-import (
-       "crypto/x509"
-       "syscall"
-       "unsafe"
-)
-
-func loadStore(roots *x509.CertPool, name string) {
-       store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
-       if err != nil {
-               return
-       }
-       defer syscall.CertCloseStore(store, 0)
-
-       var cert *syscall.CertContext
-       for {
-               cert, err = syscall.CertEnumCertificatesInStore(store, cert)
-               if err != nil {
-                       return
-               }
-
-               buf := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]
-               // ParseCertificate requires its own copy of certificate data to keep.
-               buf2 := make([]byte, cert.Length)
-               copy(buf2, buf)
-               if c, err := x509.ParseCertificate(buf2); err == nil {
-                       roots.AddCert(c)
-               }
-       }
-}
-
-func initDefaultRoots() {
-       roots := x509.NewCertPool()
-
-       // Roots
-       loadStore(roots, "ROOT")
-
-       // Intermediates
-       loadStore(roots, "CA")
-
-       varDefaultRoots = roots
-}
index 9184e8e..09df5ad 100644 (file)
@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package tls partially implements the TLS 1.1 protocol, as specified in RFC
-// 4346.
+// Package tls partially implements TLS 1.0, as specified in RFC 2246.
 package tls
 
 import (
@@ -98,7 +97,9 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
        if config == nil {
                config = defaultConfig()
        }
-       if config.ServerName != "" {
+       // If no ServerName is set, infer the ServerName
+       // from the hostname we're connecting to.
+       if config.ServerName == "" {
                // Make a copy to avoid polluting argument or default.
                c := *config
                c.ServerName = hostname
index 3aaa8c5..873d396 100644 (file)
@@ -24,7 +24,7 @@ type pkcs1PrivateKey struct {
        Dq   *big.Int `asn1:"optional"`
        Qinv *big.Int `asn1:"optional"`
 
-       AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional"`
+       AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"`
 }
 
 type pkcs1AdditionalRSAPrime struct {
diff --git a/libgo/go/crypto/x509/root.go b/libgo/go/crypto/x509/root.go
new file mode 100644 (file)
index 0000000..8aae14e
--- /dev/null
@@ -0,0 +1,17 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package x509
+
+import "sync"
+
+var (
+       once        sync.Once
+       systemRoots *CertPool
+)
+
+func systemRootsPool() *CertPool {
+       once.Do(initSystemRoots)
+       return systemRoots
+}
similarity index 90%
rename from libgo/go/crypto/tls/root_darwin.go
rename to libgo/go/crypto/x509/root_darwin.go
index 911a9a6..0f99581 100644 (file)
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package tls
+package x509
 
 /*
 #cgo CFLAGS: -mmacosx-version-min=10.6 -D__MAC_OS_X_VERSION_MAX_ALLOWED=1060
@@ -59,13 +59,14 @@ int FetchPEMRoots(CFDataRef *pemRoots) {
 }
 */
 import "C"
-import (
-       "crypto/x509"
-       "unsafe"
-)
+import "unsafe"
 
-func initDefaultRoots() {
-       roots := x509.NewCertPool()
+func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
+       return nil, nil
+}
+
+func initSystemRoots() {
+       roots := NewCertPool()
 
        var data C.CFDataRef = nil
        err := C.FetchPEMRoots(&data)
@@ -75,5 +76,5 @@ func initDefaultRoots() {
                roots.AppendCertsFromPEM(buf)
        }
 
-       varDefaultRoots = roots
+       systemRoots = roots
 }
similarity index 51%
rename from libgo/go/crypto/tls/root_stub.go
rename to libgo/go/crypto/x509/root_stub.go
index ee2c3e0..5680041 100644 (file)
@@ -4,7 +4,12 @@
 
 // +build plan9 darwin,!cgo
 
-package tls
+package x509
 
-func initDefaultRoots() {
+func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
+       return nil, nil
+}
+
+func initSystemRoots() {
+       systemRoots = NewCertPool()
 }
similarity index 76%
rename from libgo/go/crypto/tls/root_unix.go
rename to libgo/go/crypto/x509/root_unix.go
index acaf3dd..76e79f4 100644 (file)
@@ -4,12 +4,9 @@
 
 // +build freebsd linux openbsd netbsd
 
-package tls
+package x509
 
-import (
-       "crypto/x509"
-       "io/ioutil"
-)
+import "io/ioutil"
 
 // Possible certificate files; stop after finding one.
 var certFiles = []string{
@@ -20,8 +17,12 @@ var certFiles = []string{
        "/usr/local/share/certs/ca-root-nss.crt", // FreeBSD
 }
 
-func initDefaultRoots() {
-       roots := x509.NewCertPool()
+func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
+       return nil, nil
+}
+
+func initSystemRoots() {
+       roots := NewCertPool()
        for _, file := range certFiles {
                data, err := ioutil.ReadFile(file)
                if err == nil {
@@ -29,5 +30,6 @@ func initDefaultRoots() {
                        break
                }
        }
-       varDefaultRoots = roots
+
+       systemRoots = roots
 }
diff --git a/libgo/go/crypto/x509/root_windows.go b/libgo/go/crypto/x509/root_windows.go
new file mode 100644 (file)
index 0000000..7e8f2af
--- /dev/null
@@ -0,0 +1,226 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package x509
+
+import (
+       "errors"
+       "syscall"
+       "unsafe"
+)
+
+// Creates a new *syscall.CertContext representing the leaf certificate in an in-memory
+// certificate store containing itself and all of the intermediate certificates specified
+// in the opts.Intermediates CertPool.
+//
+// A pointer to the in-memory store is available in the returned CertContext's Store field.
+// The store is automatically freed when the CertContext is freed using
+// syscall.CertFreeCertificateContext.
+func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertContext, error) {
+       var storeCtx *syscall.CertContext
+
+       leafCtx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &leaf.Raw[0], uint32(len(leaf.Raw)))
+       if err != nil {
+               return nil, err
+       }
+       defer syscall.CertFreeCertificateContext(leafCtx)
+
+       handle, err := syscall.CertOpenStore(syscall.CERT_STORE_PROV_MEMORY, 0, 0, syscall.CERT_STORE_DEFER_CLOSE_UNTIL_LAST_FREE_FLAG, 0)
+       if err != nil {
+               return nil, err
+       }
+       defer syscall.CertCloseStore(handle, 0)
+
+       err = syscall.CertAddCertificateContextToStore(handle, leafCtx, syscall.CERT_STORE_ADD_ALWAYS, &storeCtx)
+       if err != nil {
+               return nil, err
+       }
+
+       if opts.Intermediates != nil {
+               for _, intermediate := range opts.Intermediates.certs {
+                       ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw)))
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       err = syscall.CertAddCertificateContextToStore(handle, ctx, syscall.CERT_STORE_ADD_ALWAYS, nil)
+                       syscall.CertFreeCertificateContext(ctx)
+                       if err != nil {
+                               return nil, err
+                       }
+               }
+       }
+
+       return storeCtx, nil
+}
+
+// extractSimpleChain extracts the final certificate chain from a CertSimpleChain.
+func extractSimpleChain(simpleChain **syscall.CertSimpleChain, count int) (chain []*Certificate, err error) {
+       if simpleChain == nil || count == 0 {
+               return nil, errors.New("x509: invalid simple chain")
+       }
+
+       simpleChains := (*[1 << 20]*syscall.CertSimpleChain)(unsafe.Pointer(simpleChain))[:]
+       lastChain := simpleChains[count-1]
+       elements := (*[1 << 20]*syscall.CertChainElement)(unsafe.Pointer(lastChain.Elements))[:]
+       for i := 0; i < int(lastChain.NumElements); i++ {
+               // Copy the buf, since ParseCertificate does not create its own copy.
+               cert := elements[i].CertContext
+               encodedCert := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]
+               buf := make([]byte, cert.Length)
+               copy(buf, encodedCert[:])
+               parsedCert, err := ParseCertificate(buf)
+               if err != nil {
+                       return nil, err
+               }
+               chain = append(chain, parsedCert)
+       }
+
+       return chain, nil
+}
+
+// checkChainTrustStatus checks the trust status of the certificate chain, translating
+// any errors it finds into Go errors in the process.
+func checkChainTrustStatus(c *Certificate, chainCtx *syscall.CertChainContext) error {
+       if chainCtx.TrustStatus.ErrorStatus != syscall.CERT_TRUST_NO_ERROR {
+               status := chainCtx.TrustStatus.ErrorStatus
+               switch status {
+               case syscall.CERT_TRUST_IS_NOT_TIME_VALID:
+                       return CertificateInvalidError{c, Expired}
+               default:
+                       return UnknownAuthorityError{c}
+               }
+       }
+       return nil
+}
+
+// checkChainSSLServerPolicy checks that the certificate chain in chainCtx is valid for
+// use as a certificate chain for a SSL/TLS server.
+func checkChainSSLServerPolicy(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) error {
+       sslPara := &syscall.SSLExtraCertChainPolicyPara{
+               AuthType:   syscall.AUTHTYPE_SERVER,
+               ServerName: syscall.StringToUTF16Ptr(opts.DNSName),
+       }
+       sslPara.Size = uint32(unsafe.Sizeof(*sslPara))
+
+       para := &syscall.CertChainPolicyPara{
+               ExtraPolicyPara: uintptr(unsafe.Pointer(sslPara)),
+       }
+       para.Size = uint32(unsafe.Sizeof(*para))
+
+       status := syscall.CertChainPolicyStatus{}
+       err := syscall.CertVerifyCertificateChainPolicy(syscall.CERT_CHAIN_POLICY_SSL, chainCtx, para, &status)
+       if err != nil {
+               return err
+       }
+
+       // TODO(mkrautz): use the lChainIndex and lElementIndex fields
+       // of the CertChainPolicyStatus to provide proper context, instead
+       // using c.
+       if status.Error != 0 {
+               switch status.Error {
+               case syscall.CERT_E_EXPIRED:
+                       return CertificateInvalidError{c, Expired}
+               case syscall.CERT_E_CN_NO_MATCH:
+                       return HostnameError{c, opts.DNSName}
+               case syscall.CERT_E_UNTRUSTEDROOT:
+                       return UnknownAuthorityError{c}
+               default:
+                       return UnknownAuthorityError{c}
+               }
+       }
+
+       return nil
+}
+
+// systemVerify is like Verify, except that it uses CryptoAPI calls
+// to build certificate chains and verify them.
+func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
+       hasDNSName := opts != nil && len(opts.DNSName) > 0
+
+       storeCtx, err := createStoreContext(c, opts)
+       if err != nil {
+               return nil, err
+       }
+       defer syscall.CertFreeCertificateContext(storeCtx)
+
+       para := new(syscall.CertChainPara)
+       para.Size = uint32(unsafe.Sizeof(*para))
+
+       // If there's a DNSName set in opts, assume we're verifying
+       // a certificate from a TLS server.
+       if hasDNSName {
+               oids := []*byte{
+                       &syscall.OID_PKIX_KP_SERVER_AUTH[0],
+                       // Both IE and Chrome allow certificates with
+                       // Server Gated Crypto as well. Some certificates
+                       // in the wild require them.
+                       &syscall.OID_SERVER_GATED_CRYPTO[0],
+                       &syscall.OID_SGC_NETSCAPE[0],
+               }
+               para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_OR
+               para.RequestedUsage.Usage.Length = uint32(len(oids))
+               para.RequestedUsage.Usage.UsageIdentifiers = &oids[0]
+       } else {
+               para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_AND
+               para.RequestedUsage.Usage.Length = 0
+               para.RequestedUsage.Usage.UsageIdentifiers = nil
+       }
+
+       var verifyTime *syscall.Filetime
+       if opts != nil && !opts.CurrentTime.IsZero() {
+               ft := syscall.NsecToFiletime(opts.CurrentTime.UnixNano())
+               verifyTime = &ft
+       }
+
+       // CertGetCertificateChain will traverse Windows's root stores
+       // in an attempt to build a verified certificate chain.  Once
+       // it has found a verified chain, it stops. MSDN docs on
+       // CERT_CHAIN_CONTEXT:
+       //
+       //   When a CERT_CHAIN_CONTEXT is built, the first simple chain
+       //   begins with an end certificate and ends with a self-signed
+       //   certificate. If that self-signed certificate is not a root
+       //   or otherwise trusted certificate, an attempt is made to
+       //   build a new chain. CTLs are used to create the new chain
+       //   beginning with the self-signed certificate from the original
+       //   chain as the end certificate of the new chain. This process
+       //   continues building additional simple chains until the first
+       //   self-signed certificate is a trusted certificate or until
+       //   an additional simple chain cannot be built.
+       //
+       // The result is that we'll only get a single trusted chain to
+       // return to our caller.
+       var chainCtx *syscall.CertChainContext
+       err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, 0, 0, &chainCtx)
+       if err != nil {
+               return nil, err
+       }
+       defer syscall.CertFreeCertificateChain(chainCtx)
+
+       err = checkChainTrustStatus(c, chainCtx)
+       if err != nil {
+               return nil, err
+       }
+
+       if hasDNSName {
+               err = checkChainSSLServerPolicy(c, chainCtx, opts)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       chain, err := extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
+       if err != nil {
+               return nil, err
+       }
+
+       chains = append(chains, chain)
+
+       return chains, nil
+}
+
+func initSystemRoots() {
+       systemRoots = NewCertPool()
+}
index 3859dd8..307c5ef 100644 (file)
@@ -5,6 +5,7 @@
 package x509
 
 import (
+       "runtime"
        "strings"
        "time"
        "unicode/utf8"
@@ -23,6 +24,9 @@ const (
        // certificate has a name constraint which doesn't include the name
        // being checked.
        CANotAuthorizedForThisName
+       // TooManyIntermediates results when a path length constraint is
+       // violated.
+       TooManyIntermediates
 )
 
 // CertificateInvalidError results when an odd error occurs. Users of this
@@ -40,6 +44,8 @@ func (e CertificateInvalidError) Error() string {
                return "x509: certificate has expired or is not yet valid"
        case CANotAuthorizedForThisName:
                return "x509: a root or intermediate certificate is not authorized to sign in this domain"
+       case TooManyIntermediates:
+               return "x509: too many intermediates for path length constraint"
        }
        return "x509: unknown error"
 }
@@ -76,7 +82,7 @@ func (e UnknownAuthorityError) Error() string {
 type VerifyOptions struct {
        DNSName       string
        Intermediates *CertPool
-       Roots         *CertPool
+       Roots         *CertPool // if nil, the system roots are used
        CurrentTime   time.Time // if zero, the current time is used
 }
 
@@ -87,7 +93,7 @@ const (
 )
 
 // isValid performs validity checks on the c.
-func (c *Certificate) isValid(certType int, opts *VerifyOptions) error {
+func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error {
        now := opts.CurrentTime
        if now.IsZero() {
                now = time.Now()
@@ -130,26 +136,44 @@ func (c *Certificate) isValid(certType int, opts *VerifyOptions) error {
                return CertificateInvalidError{c, NotAuthorizedToSign}
        }
 
+       if c.BasicConstraintsValid && c.MaxPathLen >= 0 {
+               numIntermediates := len(currentChain) - 1
+               if numIntermediates > c.MaxPathLen {
+                       return CertificateInvalidError{c, TooManyIntermediates}
+               }
+       }
+
        return nil
 }
 
 // Verify attempts to verify c by building one or more chains from c to a
-// certificate in opts.roots, using certificates in opts.Intermediates if
+// certificate in opts.Roots, using certificates in opts.Intermediates if
 // needed. If successful, it returns one or more chains where the first
 // element of the chain is c and the last element is from opts.Roots.
 //
 // WARNING: this doesn't do any revocation checking.
 func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
-       err = c.isValid(leafCertificate, &opts)
+       // Use Windows's own verification and chain building.
+       if opts.Roots == nil && runtime.GOOS == "windows" {
+               return c.systemVerify(&opts)
+       }
+
+       if opts.Roots == nil {
+               opts.Roots = systemRootsPool()
+       }
+
+       err = c.isValid(leafCertificate, nil, &opts)
        if err != nil {
                return
        }
+
        if len(opts.DNSName) > 0 {
                err = c.VerifyHostname(opts.DNSName)
                if err != nil {
                        return
                }
        }
+
        return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts)
 }
 
@@ -163,7 +187,7 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
 func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) {
        for _, rootNum := range opts.Roots.findVerifiedParents(c) {
                root := opts.Roots.certs[rootNum]
-               err = root.isValid(rootCertificate, opts)
+               err = root.isValid(rootCertificate, currentChain, opts)
                if err != nil {
                        continue
                }
@@ -178,7 +202,7 @@ nextIntermediate:
                                continue nextIntermediate
                        }
                }
-               err = intermediate.isValid(intermediateCertificate, opts)
+               err = intermediate.isValid(intermediateCertificate, currentChain, opts)
                if err != nil {
                        continue
                }
index 2cdd66a..7b171b2 100644 (file)
@@ -8,6 +8,7 @@ import (
        "crypto/x509/pkix"
        "encoding/pem"
        "errors"
+       "runtime"
        "strings"
        "testing"
        "time"
@@ -19,7 +20,7 @@ type verifyTest struct {
        roots         []string
        currentTime   int64
        dnsName       string
-       nilRoots      bool
+       systemSkip    bool
 
        errorCallback  func(*testing.T, int, error) bool
        expectedChains [][]string
@@ -60,14 +61,6 @@ var verifyTests = []verifyTest{
        {
                leaf:          googleLeaf,
                intermediates: []string{thawteIntermediate},
-               nilRoots:      true, // verifies that we don't crash
-               currentTime:   1302726541,
-               dnsName:       "www.google.com",
-               errorCallback: expectAuthorityUnknown,
-       },
-       {
-               leaf:          googleLeaf,
-               intermediates: []string{thawteIntermediate},
                roots:         []string{verisignRoot},
                currentTime:   1,
                dnsName:       "www.example.com",
@@ -80,6 +73,9 @@ var verifyTests = []verifyTest{
                currentTime: 1302726541,
                dnsName:     "www.google.com",
 
+               // Skip when using systemVerify, since Windows
+               // *will* find the missing intermediate cert.
+               systemSkip:    true,
                errorCallback: expectAuthorityUnknown,
        },
        {
@@ -109,6 +105,9 @@ var verifyTests = []verifyTest{
                roots:         []string{startComRoot},
                currentTime:   1302726541,
 
+               // Skip when using systemVerify, since Windows
+               // can only return a single chain to us (for now).
+               systemSkip: true,
                expectedChains: [][]string{
                        {"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"},
                        {"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority", "StartCom Certification Authority"},
@@ -148,23 +147,26 @@ func certificateFromPEM(pemBytes string) (*Certificate, error) {
        return ParseCertificate(block.Bytes)
 }
 
-func TestVerify(t *testing.T) {
+func testVerify(t *testing.T, useSystemRoots bool) {
        for i, test := range verifyTests {
+               if useSystemRoots && test.systemSkip {
+                       continue
+               }
+
                opts := VerifyOptions{
-                       Roots:         NewCertPool(),
                        Intermediates: NewCertPool(),
                        DNSName:       test.dnsName,
                        CurrentTime:   time.Unix(test.currentTime, 0),
                }
-               if test.nilRoots {
-                       opts.Roots = nil
-               }
 
-               for j, root := range test.roots {
-                       ok := opts.Roots.AppendCertsFromPEM([]byte(root))
-                       if !ok {
-                               t.Errorf("#%d: failed to parse root #%d", i, j)
-                               return
+               if !useSystemRoots {
+                       opts.Roots = NewCertPool()
+                       for j, root := range test.roots {
+                               ok := opts.Roots.AppendCertsFromPEM([]byte(root))
+                               if !ok {
+                                       t.Errorf("#%d: failed to parse root #%d", i, j)
+                                       return
+                               }
                        }
                }
 
@@ -225,6 +227,19 @@ func TestVerify(t *testing.T) {
        }
 }
 
+func TestGoVerify(t *testing.T) {
+       testVerify(t, false)
+}
+
+func TestSystemVerify(t *testing.T) {
+       if runtime.GOOS != "windows" {
+               t.Logf("skipping verify test using system APIs on %q", runtime.GOOS)
+               return
+       }
+
+       testVerify(t, true)
+}
+
 func chainToDebugString(chain []*Certificate) string {
        var chainStr string
        for _, cert := range chain {
index f5da86b..8dae7e7 100644 (file)
@@ -429,7 +429,7 @@ func (h UnhandledCriticalExtension) Error() string {
 
 type basicConstraints struct {
        IsCA       bool `asn1:"optional"`
-       MaxPathLen int  `asn1:"optional"`
+       MaxPathLen int  `asn1:"optional,default:-1"`
 }
 
 // RFC 5280 4.2.1.4
index 7f986b8..2f5280d 100644 (file)
@@ -43,6 +43,17 @@ type Driver interface {
 // documented.
 var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
 
+// ErrBadConn should be returned by a driver to signal to the sql
+// package that a driver.Conn is in a bad state (such as the server
+// having earlier closed the connection) and the sql package should
+// retry on a new connection.
+//
+// To prevent duplicate operations, ErrBadConn should NOT be returned
+// if there's a possibility that the database server might have
+// performed the operation. Even if the server sends back an error,
+// you shouldn't return ErrBadConn.
+var ErrBadConn = errors.New("driver: bad connection")
+
 // Execer is an optional interface that may be implemented by a Conn.
 //
 // If a Conn does not implement Execer, the db package's DB.Exec will
index fc63f03..184e775 100644 (file)
@@ -82,6 +82,7 @@ type fakeConn struct {
        mu          sync.Mutex
        stmtsMade   int
        stmtsClosed int
+       numPrepare  int
 }
 
 func (c *fakeConn) incrStat(v *int) {
@@ -208,10 +209,13 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
 
 func (c *fakeConn) Close() error {
        if c.currTx != nil {
-               return errors.New("can't close; in a Transaction")
+               return errors.New("can't close fakeConn; in a Transaction")
        }
        if c.db == nil {
-               return errors.New("can't close; already closed")
+               return errors.New("can't close fakeConn; already closed")
+       }
+       if c.stmtsMade > c.stmtsClosed {
+               return errors.New("can't close; dangling statement(s)")
        }
        c.db = nil
        return nil
@@ -249,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
 //  just a limitation for fakedb)
 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 3 {
+               stmt.Close()
                return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
        }
        stmt.table = parts[0]
@@ -259,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
                }
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                _, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
                }
                if value != "?" {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
                                stmt.table, column)
                }
@@ -279,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=type,col2=type2
 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameType := strings.Split(colspec, "=")
                if len(nameType) != 2 {
+                       stmt.Close()
                        return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                stmt.colName = append(stmt.colName, nameType[0])
@@ -296,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=?,col2=val
 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                ctype, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
                }
                stmt.colName = append(stmt.colName, column)
@@ -322,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
                        case "int32":
                                i, err := strconv.Atoi(value)
                                if err != nil {
+                                       stmt.Close()
                                        return nil, errf("invalid conversion to int32 from %q", value)
                                }
                                subsetVal = int64(i) // int64 is a subset type, but not int32
                        default:
+                               stmt.Close()
                                return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
                        }
                        stmt.colValue = append(stmt.colValue, subsetVal)
@@ -339,6 +354,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
 }
 
 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+       c.numPrepare++
        if c.db == nil {
                panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
        }
@@ -360,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
        case "INSERT":
                return c.prepareInsert(stmt, parts)
        default:
+               stmt.Close()
                return nil, errf("unsupported command type %q", cmd)
        }
        return stmt, nil
index 62b551d..51a357b 100644 (file)
@@ -175,6 +175,16 @@ var ErrNoRows = errors.New("sql: no rows in result set")
 
 // DB is a database handle. It's safe for concurrent use by multiple
 // goroutines.
+//
+// If the underlying database driver has the concept of a connection
+// and per-connection session state, the sql package manages creating
+// and freeing connections automatically, including maintaining a free
+// pool of idle connections. If observing session state is required,
+// either do not share a *DB between multiple concurrent goroutines or
+// create and observe all state only within a transaction. Once
+// DB.Open is called, the returned Tx is bound to a single isolated
+// connection. Once Tx.Commit or Tx.Rollback is called, that
+// connection is returned to DB's idle connection pool.
 type DB struct {
        driver driver.Driver
        dsn    string
@@ -241,34 +251,56 @@ func (db *DB) conn() (driver.Conn, error) {
 func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
        db.mu.Lock()
        defer db.mu.Unlock()
-       for n, conn := range db.freeConn {
-               if conn == wanted {
-                       db.freeConn[n] = db.freeConn[len(db.freeConn)-1]
-                       db.freeConn = db.freeConn[:len(db.freeConn)-1]
-                       return wanted, true
+       for i, conn := range db.freeConn {
+               if conn != wanted {
+                       continue
                }
+               db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
+               db.freeConn = db.freeConn[:len(db.freeConn)-1]
+               return wanted, true
        }
        return nil, false
 }
 
-func (db *DB) putConn(c driver.Conn) {
+// putConnHook is a hook for testing.
+var putConnHook func(*DB, driver.Conn)
+
+// putConn adds a connection to the db's free pool.
+// err is optionally the last error that occured on this connection.
+func (db *DB) putConn(c driver.Conn, err error) {
+       if err == driver.ErrBadConn {
+               // Don't reuse bad connections.
+               return
+       }
        db.mu.Lock()
-       defer db.mu.Unlock()
+       if putConnHook != nil {
+               putConnHook(db, c)
+       }
        if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
                db.freeConn = append(db.freeConn, c)
+               db.mu.Unlock()
                return
        }
-       db.closeConn(c) // TODO(bradfitz): release lock before calling this?
-}
-
-func (db *DB) closeConn(c driver.Conn) {
-       // TODO: check to see if we need this Conn for any prepared statements
-       // that are active.
+       // TODO: check to see if we need this Conn for any prepared
+       // statements which are still active?
+       db.mu.Unlock()
        c.Close()
 }
 
 // Prepare creates a prepared statement for later execution.
 func (db *DB) Prepare(query string) (*Stmt, error) {
+       var stmt *Stmt
+       var err error
+       for i := 0; i < 10; i++ {
+               stmt, err = db.prepare(query)
+               if err != driver.ErrBadConn {
+                       break
+               }
+       }
+       return stmt, err
+}
+
+func (db *DB) prepare(query string) (stmt *Stmt, err error) {
        // TODO: check if db.driver supports an optional
        // driver.Preparer interface and call that instead, if so,
        // otherwise we make a prepared statement that's bound
@@ -279,12 +311,12 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
        if err != nil {
                return nil, err
        }
-       defer db.putConn(ci)
+       defer db.putConn(ci, err)
        si, err := ci.Prepare(query)
        if err != nil {
                return nil, err
        }
-       stmt := &Stmt{
+       stmt = &Stmt{
                db:    db,
                query: query,
                css:   []connStmt{{ci, si}},
@@ -295,15 +327,22 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
 // Exec executes a query without returning any rows.
 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
        sargs, err := subsetTypeArgs(args)
-       if err != nil {
-               return nil, err
+       var res Result
+       for i := 0; i < 10; i++ {
+               res, err = db.exec(query, sargs)
+               if err != driver.ErrBadConn {
+                       break
+               }
        }
+       return res, err
+}
 
+func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
        ci, err := db.conn()
        if err != nil {
                return nil, err
        }
-       defer db.putConn(ci)
+       defer db.putConn(ci, err)
 
        if execer, ok := ci.(driver.Execer); ok {
                resi, err := execer.Exec(query, sargs)
@@ -354,13 +393,25 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
 // Begin starts a transaction. The isolation level is dependent on
 // the driver.
 func (db *DB) Begin() (*Tx, error) {
+       var tx *Tx
+       var err error
+       for i := 0; i < 10; i++ {
+               tx, err = db.begin()
+               if err != driver.ErrBadConn {
+                       break
+               }
+       }
+       return tx, err
+}
+
+func (db *DB) begin() (tx *Tx, err error) {
        ci, err := db.conn()
        if err != nil {
                return nil, err
        }
        txi, err := ci.Begin()
        if err != nil {
-               db.putConn(ci)
+               db.putConn(ci, err)
                return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err)
        }
        return &Tx{
@@ -406,7 +457,7 @@ func (tx *Tx) close() {
                panic("double close") // internal error
        }
        tx.done = true
-       tx.db.putConn(tx.ci)
+       tx.db.putConn(tx.ci, nil)
        tx.ci = nil
        tx.txi = nil
 }
@@ -561,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
                return nil, err
        }
        rows, err := stmt.Query(args...)
-       if err == nil {
-               rows.closeStmt = stmt
+       if err != nil {
+               stmt.Close()
+               return nil, err
        }
+       rows.closeStmt = stmt
        return rows, err
 }
 
@@ -609,7 +662,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
        if err != nil {
                return nil, err
        }
-       defer releaseConn()
+       defer releaseConn(nil)
 
        // -1 means the driver doesn't know how to count the number of
        // placeholders, so we won't sanity check input here and instead let the
@@ -672,7 +725,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
 // connStmt returns a free driver connection on which to execute the
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
-func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
+func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) {
        if err = s.stickyErr; err != nil {
                return
        }
@@ -691,7 +744,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
                if err != nil {
                        return
                }
-               releaseConn = func() { s.tx.releaseConn() }
+               releaseConn = func(error) { s.tx.releaseConn() }
                return ci, releaseConn, s.txsi, nil
        }
 
@@ -700,7 +753,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
        for _, v := range s.css {
                // TODO(bradfitz): lazily clean up entries in this
                // list with dead conns while enumerating
-               if _, match = s.db.connIfFree(cs.ci); match {
+               if _, match = s.db.connIfFree(v.ci); match {
                        cs = v
                        break
                }
@@ -710,22 +763,28 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
        // Make a new conn if all are busy.
        // TODO(bradfitz): or wait for one? make configurable later?
        if !match {
-               ci, err := s.db.conn()
-               if err != nil {
-                       return nil, nil, nil, err
-               }
-               si, err := ci.Prepare(s.query)
-               if err != nil {
-                       return nil, nil, nil, err
+               for i := 0; ; i++ {
+                       ci, err := s.db.conn()
+                       if err != nil {
+                               return nil, nil, nil, err
+                       }
+                       si, err := ci.Prepare(s.query)
+                       if err == driver.ErrBadConn && i < 10 {
+                               continue
+                       }
+                       if err != nil {
+                               return nil, nil, nil, err
+                       }
+                       s.mu.Lock()
+                       cs = connStmt{ci, si}
+                       s.css = append(s.css, cs)
+                       s.mu.Unlock()
+                       break
                }
-               s.mu.Lock()
-               cs = connStmt{ci, si}
-               s.css = append(s.css, cs)
-               s.mu.Unlock()
        }
 
        conn := cs.ci
-       releaseConn = func() { s.db.putConn(conn) }
+       releaseConn = func(err error) { s.db.putConn(conn, err) }
        return conn, releaseConn, cs.si, nil
 }
 
@@ -749,7 +808,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        }
        rowsi, err := si.Query(sargs)
        if err != nil {
-               s.db.putConn(ci)
+               releaseConn(err)
                return nil, err
        }
        // Note: ownership of ci passes to the *Rows, to be freed
@@ -800,7 +859,7 @@ func (s *Stmt) Close() error {
                for _, v := range s.css {
                        if ci, match := s.db.connIfFree(v.ci); match {
                                v.si.Close()
-                               s.db.putConn(ci)
+                               s.db.putConn(ci, nil)
                        } else {
                                // TODO(bradfitz): care that we can't close
                                // this statement because the statement's
@@ -827,7 +886,7 @@ func (s *Stmt) Close() error {
 type Rows struct {
        db          *DB
        ci          driver.Conn // owned; must call putconn when closed to release
-       releaseConn func()
+       releaseConn func(error)
        rowsi       driver.Rows
 
        closed    bool
@@ -939,7 +998,7 @@ func (rs *Rows) Close() error {
        }
        rs.closed = true
        err := rs.rowsi.Close()
-       rs.releaseConn()
+       rs.releaseConn(err)
        if rs.closeStmt != nil {
                rs.closeStmt.Close()
        }
@@ -963,7 +1022,7 @@ func (r *Row) Scan(dest ...interface{}) error {
        }
 
        // TODO(bradfitz): for now we need to defensively clone all
-       // []byte that the driver returned (not permitting 
+       // []byte that the driver returned (not permitting
        // *RawBytes in Rows.Scan), since we're about to close
        // the Rows in our defer, when we return from this function.
        // the contract with the driver.Next(...) interface is that it
index c985a10..b296705 100644 (file)
@@ -5,13 +5,35 @@
 package sql
 
 import (
+       "database/sql/driver"
        "fmt"
        "reflect"
+       "runtime"
        "strings"
        "testing"
        "time"
 )
 
+func init() {
+       type dbConn struct {
+               db *DB
+               c  driver.Conn
+       }
+       freedFrom := make(map[dbConn]string)
+       putConnHook = func(db *DB, c driver.Conn) {
+               for _, oc := range db.freeConn {
+                       if oc == c {
+                               // print before panic, as panic may get lost due to conflicting panic
+                               // (all goroutines asleep) elsewhere, since we might not unlock
+                               // the mutex in freeConn here.
+                               println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
+                               panic("double free of conn.")
+                       }
+               }
+               freedFrom[dbConn{db, c}] = stack()
+       }
+}
+
 const fakeDBName = "foo"
 
 var chrisBirthday = time.Unix(123456789, 0)
@@ -47,9 +69,19 @@ func closeDB(t *testing.T, db *DB) {
        }
 }
 
+// numPrepares assumes that db has exactly 1 idle conn and returns
+// its count of calls to Prepare
+func numPrepares(t *testing.T, db *DB) int {
+       if n := len(db.freeConn); n != 1 {
+               t.Fatalf("free conns = %d; want 1", n)
+       }
+       return db.freeConn[0].(*fakeConn).numPrepare
+}
+
 func TestQuery(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
+       prepares0 := numPrepares(t, db)
        rows, err := db.Query("SELECT|people|age,name|")
        if err != nil {
                t.Fatalf("Query: %v", err)
@@ -83,7 +115,10 @@ func TestQuery(t *testing.T) {
        // And verify that the final rows.Next() call, which hit EOF,
        // also closed the rows connection.
        if n := len(db.freeConn); n != 1 {
-               t.Errorf("free conns after query hitting EOF = %d; want 1", n)
+               t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+       }
+       if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+               t.Errorf("executed %d Prepare statements; want 1", prepares)
        }
 }
 
@@ -216,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) {
        if err != nil {
                t.Fatalf("Prepare: %v", err)
        }
+       defer stmt.Close()
        var age int
        for n, tt := range []struct {
                name string
@@ -256,6 +292,7 @@ func TestExec(t *testing.T) {
        if err != nil {
                t.Errorf("Stmt, err = %v, %v", stmt, err)
        }
+       defer stmt.Close()
 
        type execTest struct {
                args    []interface{}
@@ -297,11 +334,14 @@ func TestTxStmt(t *testing.T) {
        if err != nil {
                t.Fatalf("Stmt, err = %v, %v", stmt, err)
        }
+       defer stmt.Close()
        tx, err := db.Begin()
        if err != nil {
                t.Fatalf("Begin = %v", err)
        }
-       _, err = tx.Stmt(stmt).Exec("Bobby", 7)
+       txs := tx.Stmt(stmt)
+       defer txs.Close()
+       _, err = txs.Exec("Bobby", 7)
        if err != nil {
                t.Fatalf("Exec = %v", err)
        }
@@ -330,6 +370,7 @@ func TestTxQuery(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       defer r.Close()
 
        if !r.Next() {
                if r.Err() != nil {
@@ -345,6 +386,22 @@ func TestTxQuery(t *testing.T) {
        }
 }
 
+func TestTxQueryInvalid(t *testing.T) {
+       db := newTestDB(t, "")
+       defer closeDB(t, db)
+
+       tx, err := db.Begin()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer tx.Rollback()
+
+       _, err = tx.Query("SELECT|t1|name|")
+       if err == nil {
+               t.Fatal("Error expected")
+       }
+}
+
 // Tests fix for issue 2542, that we release a lock when querying on
 // a closed connection.
 func TestIssue2542Deadlock(t *testing.T) {
@@ -450,48 +507,48 @@ type nullTestSpec struct {
 
 func TestNullStringParam(t *testing.T) {
        spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
-               nullTestRow{NullString{"aqua", true}, "", NullString{"aqua", true}},
-               nullTestRow{NullString{"brown", false}, "", NullString{"", false}},
-               nullTestRow{"chartreuse", "", NullString{"chartreuse", true}},
-               nullTestRow{NullString{"darkred", true}, "", NullString{"darkred", true}},
-               nullTestRow{NullString{"eel", false}, "", NullString{"", false}},
-               nullTestRow{"foo", NullString{"black", false}, nil},
+               {NullString{"aqua", true}, "", NullString{"aqua", true}},
+               {NullString{"brown", false}, "", NullString{"", false}},
+               {"chartreuse", "", NullString{"chartreuse", true}},
+               {NullString{"darkred", true}, "", NullString{"darkred", true}},
+               {NullString{"eel", false}, "", NullString{"", false}},
+               {"foo", NullString{"black", false}, nil},
        }}
        nullTestRun(t, spec)
 }
 
 func TestNullInt64Param(t *testing.T) {
        spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
-               nullTestRow{NullInt64{31, true}, 1, NullInt64{31, true}},
-               nullTestRow{NullInt64{-22, false}, 1, NullInt64{0, false}},
-               nullTestRow{22, 1, NullInt64{22, true}},
-               nullTestRow{NullInt64{33, true}, 1, NullInt64{33, true}},
-               nullTestRow{NullInt64{222, false}, 1, NullInt64{0, false}},
-               nullTestRow{0, NullInt64{31, false}, nil},
+               {NullInt64{31, true}, 1, NullInt64{31, true}},
+               {NullInt64{-22, false}, 1, NullInt64{0, false}},
+               {22, 1, NullInt64{22, true}},
+               {NullInt64{33, true}, 1, NullInt64{33, true}},
+               {NullInt64{222, false}, 1, NullInt64{0, false}},
+               {0, NullInt64{31, false}, nil},
        }}
        nullTestRun(t, spec)
 }
 
 func TestNullFloat64Param(t *testing.T) {
        spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
-               nullTestRow{NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
-               nullTestRow{NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
-               nullTestRow{-22.9, 1, NullFloat64{-22.9, true}},
-               nullTestRow{NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
-               nullTestRow{NullFloat64{222, false}, 1, NullFloat64{0, false}},
-               nullTestRow{10, NullFloat64{31.2, false}, nil},
+               {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
+               {NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
+               {-22.9, 1, NullFloat64{-22.9, true}},
+               {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
+               {NullFloat64{222, false}, 1, NullFloat64{0, false}},
+               {10, NullFloat64{31.2, false}, nil},
        }}
        nullTestRun(t, spec)
 }
 
 func TestNullBoolParam(t *testing.T) {
        spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
-               nullTestRow{NullBool{false, true}, true, NullBool{false, true}},
-               nullTestRow{NullBool{true, false}, false, NullBool{false, false}},
-               nullTestRow{true, true, NullBool{true, true}},
-               nullTestRow{NullBool{true, true}, false, NullBool{true, true}},
-               nullTestRow{NullBool{true, false}, true, NullBool{false, false}},
-               nullTestRow{true, NullBool{true, false}, nil},
+               {NullBool{false, true}, true, NullBool{false, true}},
+               {NullBool{true, false}, false, NullBool{false, false}},
+               {true, true, NullBool{true, true}},
+               {NullBool{true, true}, false, NullBool{true, true}},
+               {NullBool{true, false}, true, NullBool{false, false}},
+               {true, NullBool{true, false}, nil},
        }}
        nullTestRun(t, spec)
 }
@@ -510,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
        if err != nil {
                t.Fatalf("prepare: %v", err)
        }
+       defer stmt.Close()
        if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
                t.Errorf("exec insert chris: %v", err)
        }
@@ -549,3 +607,8 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
                }
        }
 }
+
+func stack() string {
+       buf := make([]byte, 1024)
+       return string(buf[:runtime.Stack(buf, false)])
+}
index 4d1ae38..3bf81a6 100644 (file)
@@ -250,10 +250,14 @@ func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error)
 func parseUTCTime(bytes []byte) (ret time.Time, err error) {
        s := string(bytes)
        ret, err = time.Parse("0601021504Z0700", s)
-       if err == nil {
-               return
+       if err != nil {
+               ret, err = time.Parse("060102150405Z0700", s)
        }
-       ret, err = time.Parse("060102150405Z0700", s)
+       if err == nil && ret.Year() >= 2050 {
+               // UTCTime only encodes times prior to 2050. See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1
+               ret = ret.AddDate(-100, 0, 0)
+       }
+
        return
 }
 
index 92c9eb6..93803f4 100644 (file)
@@ -321,7 +321,7 @@ var parseFieldParametersTestData []parseFieldParametersTest = []parseFieldParame
        {"default:42", fieldParameters{defaultValue: newInt64(42)}},
        {"tag:17", fieldParameters{tag: newInt(17)}},
        {"optional,explicit,default:42,tag:17", fieldParameters{optional: true, explicit: true, defaultValue: newInt64(42), tag: newInt(17)}},
-       {"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, false, newInt64(42), newInt(17), 0, false}},
+       {"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, false, newInt64(42), newInt(17), 0, false, false}},
        {"set", fieldParameters{set: true}},
 }
 
index f7cb3ac..03856bc 100644 (file)
@@ -75,6 +75,7 @@ type fieldParameters struct {
        tag          *int   // the EXPLICIT or IMPLICIT tag (maybe nil).
        stringType   int    // the string tag to use when marshaling.
        set          bool   // true iff this should be encoded as a SET
+       omitEmpty    bool   // true iff this should be omitted if empty when marshaling.
 
        // Invariants:
        //   if explicit is set, tag is non-nil.
@@ -116,6 +117,8 @@ func parseFieldParameters(str string) (ret fieldParameters) {
                        if ret.tag == nil {
                                ret.tag = new(int)
                        }
+               case part == "omitempty":
+                       ret.omitEmpty = true
                }
        }
        return
index 774bee7..163bca5 100644 (file)
@@ -463,6 +463,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
                return marshalField(out, v.Elem(), params)
        }
 
+       if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
+               return
+       }
+
        if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
                return
        }
index a7447f9..f43bcae 100644 (file)
@@ -54,6 +54,10 @@ type optionalRawValueTest struct {
        A RawValue `asn1:"optional"`
 }
 
+type omitEmptyTest struct {
+       A []string `asn1:"omitempty"`
+}
+
 type testSET []int
 
 var PST = time.FixedZone("PST", -8*60*60)
@@ -116,6 +120,8 @@ var marshalTests = []marshalTest{
        {rawContentsStruct{[]byte{0x30, 3, 1, 2, 3}, 64}, "3003010203"},
        {RawValue{Tag: 1, Class: 2, IsCompound: false, Bytes: []byte{1, 2, 3}}, "8103010203"},
        {testSET([]int{10}), "310302010a"},
+       {omitEmptyTest{[]string{}}, "3000"},
+       {omitEmptyTest{[]string{"1"}}, "30053003130131"},
 }
 
 func TestMarshal(t *testing.T) {
index 02f090d..712e490 100644 (file)
@@ -2,12 +2,17 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package binary implements translation between
-// unsigned integer values and byte sequences
-// and the reading and writing of fixed-size values.
+// Package binary implements translation between numbers and byte sequences
+// and encoding and decoding of varints.
+//
+// Numbers are translated by reading and writing fixed-size values.
 // A fixed-size value is either a fixed-size arithmetic
 // type (int8, uint8, int16, float32, complex64, ...)
 // or an array or struct containing only fixed-size values.
+//
+// Varints are a method of encoding integers using one or more bytes;
+// numbers with smaller absolute value take a smaller number of bytes.
+// For a specification, see http://code.google.com/apis/protocolbuffers/docs/encoding.html.
 package binary
 
 import (
index 9aa398e..db4d988 100644 (file)
@@ -92,7 +92,8 @@ var (
 // If FieldsPerRecord is positive, Read requires each record to
 // have the given number of fields.  If FieldsPerRecord is 0, Read sets it to
 // the number of fields in the first record, so that future records must
-// have the same field count.
+// have the same field count.  If FieldsPerRecord is negative, no check is
+// made and records may have a variable number of fields.
 //
 // If LazyQuotes is true, a quote may appear in an unquoted field and a
 // non-doubled quote may appear in a quoted field.
index 0708a83..e32a178 100644 (file)
@@ -707,6 +707,9 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui
        if name == "" {
                // Copy the representation of the nil interface value to the target.
                // This is horribly unsafe and special.
+               if indir > 0 {
+                       p = allocate(ityp, p, 1) // All but the last level has been allocated by dec.Indirect
+               }
                *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.InterfaceData()
                return
        }
index 050786d..c4947cb 100644 (file)
@@ -694,8 +694,8 @@ type Bug3 struct {
 
 func TestGobPtrSlices(t *testing.T) {
        in := []*Bug3{
-               &Bug3{1, nil},
-               &Bug3{2, nil},
+               {1, nil},
+               {2, nil},
        }
        b := new(bytes.Buffer)
        err := NewEncoder(b).Encode(&in)
index 83644c0..45240d7 100644 (file)
@@ -573,3 +573,22 @@ func TestGobEncodeIsZero(t *testing.T) {
                t.Fatalf("%v != %v", x, y)
        }
 }
+
+func TestGobEncodePtrError(t *testing.T) {
+       var err error
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err = enc.Encode(&err)
+       if err != nil {
+               t.Fatal("encode:", err)
+       }
+       dec := NewDecoder(b)
+       err2 := fmt.Errorf("foo")
+       err = dec.Decode(&err2)
+       if err != nil {
+               t.Fatal("decode:", err)
+       }
+       if err2 != nil {
+               t.Fatalf("expected nil, got %v", err2)
+       }
+}
index 5425a3a..edbafcf 100644 (file)
@@ -43,7 +43,8 @@ import (
 // to keep some browsers from misinterpreting JSON output as HTML.
 //
 // Array and slice values encode as JSON arrays, except that
-// []byte encodes as a base64-encoded string.
+// []byte encodes as a base64-encoded string, and a nil slice
+// encodes as the null JSON object.
 //
 // Struct values encode as JSON objects. Each exported struct field
 // becomes a member of the object unless
index bb21bb5..1deedc9 100644 (file)
@@ -577,7 +577,7 @@ type decompSet [4]map[string]bool
 
 func makeDecompSet() decompSet {
        m := decompSet{}
-       for i, _ := range m {
+       for i := range m {
                m[i] = make(map[string]bool)
        }
        return m
@@ -646,7 +646,7 @@ func printCharInfoTables() int {
        fmt.Println("const (")
        for i, m := range decompSet {
                sa := []string{}
-               for s, _ := range m {
+               for s := range m {
                        sa = append(sa, s)
                }
                sort.Strings(sa)
diff --git a/libgo/go/exp/wingui/gui.go b/libgo/go/exp/wingui/gui.go
deleted file mode 100644 (file)
index 3b79873..0000000
+++ /dev/null
@@ -1,155 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build windows
-
-package main
-
-import (
-       "fmt"
-       "os"
-       "syscall"
-       "unsafe"
-)
-
-// some help functions
-
-func abortf(format string, a ...interface{}) {
-       fmt.Fprintf(os.Stdout, format, a...)
-       os.Exit(1)
-}
-
-func abortErrNo(funcname string, err error) {
-       errno, _ := err.(syscall.Errno)
-       abortf("%s failed: %d %s\n", funcname, uint32(errno), err)
-}
-
-// global vars
-
-var (
-       mh syscall.Handle
-       bh syscall.Handle
-)
-
-// WinProc called by windows to notify us of all windows events we might be interested in.
-func WndProc(hwnd syscall.Handle, msg uint32, wparam, lparam uintptr) (rc uintptr) {
-       switch msg {
-       case WM_CREATE:
-               var e error
-               // CreateWindowEx
-               bh, e = CreateWindowEx(
-                       0,
-                       syscall.StringToUTF16Ptr("button"),
-                       syscall.StringToUTF16Ptr("Quit"),
-                       WS_CHILD|WS_VISIBLE|BS_DEFPUSHBUTTON,
-                       75, 70, 140, 25,
-                       hwnd, 1, mh, 0)
-               if e != nil {
-                       abortErrNo("CreateWindowEx", e)
-               }
-               fmt.Printf("button handle is %x\n", bh)
-               rc = DefWindowProc(hwnd, msg, wparam, lparam)
-       case WM_COMMAND:
-               switch syscall.Handle(lparam) {
-               case bh:
-                       e := PostMessage(hwnd, WM_CLOSE, 0, 0)
-                       if e != nil {
-                               abortErrNo("PostMessage", e)
-                       }
-               default:
-                       rc = DefWindowProc(hwnd, msg, wparam, lparam)
-               }
-       case WM_CLOSE:
-               DestroyWindow(hwnd)
-       case WM_DESTROY:
-               PostQuitMessage(0)
-       default:
-               rc = DefWindowProc(hwnd, msg, wparam, lparam)
-       }
-       //fmt.Printf("WndProc(0x%08x, %d, 0x%08x, 0x%08x) (%d)\n", hwnd, msg, wparam, lparam, rc)
-       return
-}
-
-func rungui() int {
-       var e error
-
-       // GetModuleHandle
-       mh, e = GetModuleHandle(nil)
-       if e != nil {
-               abortErrNo("GetModuleHandle", e)
-       }
-
-       // Get icon we're going to use.
-       myicon, e := LoadIcon(0, IDI_APPLICATION)
-       if e != nil {
-               abortErrNo("LoadIcon", e)
-       }
-
-       // Get cursor we're going to use.
-       mycursor, e := LoadCursor(0, IDC_ARROW)
-       if e != nil {
-               abortErrNo("LoadCursor", e)
-       }
-
-       // Create callback
-       wproc := syscall.NewCallback(WndProc)
-
-       // RegisterClassEx
-       wcname := syscall.StringToUTF16Ptr("myWindowClass")
-       var wc Wndclassex
-       wc.Size = uint32(unsafe.Sizeof(wc))
-       wc.WndProc = wproc
-       wc.Instance = mh
-       wc.Icon = myicon
-       wc.Cursor = mycursor
-       wc.Background = COLOR_BTNFACE + 1
-       wc.MenuName = nil
-       wc.ClassName = wcname
-       wc.IconSm = myicon
-       if _, e := RegisterClassEx(&wc); e != nil {
-               abortErrNo("RegisterClassEx", e)
-       }
-
-       // CreateWindowEx
-       wh, e := CreateWindowEx(
-               WS_EX_CLIENTEDGE,
-               wcname,
-               syscall.StringToUTF16Ptr("My window"),
-               WS_OVERLAPPEDWINDOW,
-               CW_USEDEFAULT, CW_USEDEFAULT, 300, 200,
-               0, 0, mh, 0)
-       if e != nil {
-               abortErrNo("CreateWindowEx", e)
-       }
-       fmt.Printf("main window handle is %x\n", wh)
-
-       // ShowWindow
-       ShowWindow(wh, SW_SHOWDEFAULT)
-
-       // UpdateWindow
-       if e := UpdateWindow(wh); e != nil {
-               abortErrNo("UpdateWindow", e)
-       }
-
-       // Process all windows messages until WM_QUIT.
-       var m Msg
-       for {
-               r, e := GetMessage(&m, 0, 0, 0)
-               if e != nil {
-                       abortErrNo("GetMessage", e)
-               }
-               if r == 0 {
-                       // WM_QUIT received -> get out
-                       break
-               }
-               TranslateMessage(&m)
-               DispatchMessage(&m)
-       }
-       return int(m.Wparam)
-}
-
-func main() {
-       rc := rungui()
-       os.Exit(rc)
-}
diff --git a/libgo/go/exp/wingui/winapi.go b/libgo/go/exp/wingui/winapi.go
deleted file mode 100644 (file)
index f876088..0000000
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build windows
-
-package main
-
-import (
-       "syscall"
-       "unsafe"
-)
-
-type Wndclassex struct {
-       Size       uint32
-       Style      uint32
-       WndProc    uintptr
-       ClsExtra   int32
-       WndExtra   int32
-       Instance   syscall.Handle
-       Icon       syscall.Handle
-       Cursor     syscall.Handle
-       Background syscall.Handle
-       MenuName   *uint16
-       ClassName  *uint16
-       IconSm     syscall.Handle
-}
-
-type Point struct {
-       X uintptr
-       Y uintptr
-}
-
-type Msg struct {
-       Hwnd    syscall.Handle
-       Message uint32
-       Wparam  uintptr
-       Lparam  uintptr
-       Time    uint32
-       Pt      Point
-}
-
-const (
-       // Window styles
-       WS_OVERLAPPED   = 0
-       WS_POPUP        = 0x80000000
-       WS_CHILD        = 0x40000000
-       WS_MINIMIZE     = 0x20000000
-       WS_VISIBLE      = 0x10000000
-       WS_DISABLED     = 0x8000000
-       WS_CLIPSIBLINGS = 0x4000000
-       WS_CLIPCHILDREN = 0x2000000
-       WS_MAXIMIZE     = 0x1000000
-       WS_CAPTION      = WS_BORDER | WS_DLGFRAME
-       WS_BORDER       = 0x800000
-       WS_DLGFRAME     = 0x400000
-       WS_VSCROLL      = 0x200000
-       WS_HSCROLL      = 0x100000
-       WS_SYSMENU      = 0x80000
-       WS_THICKFRAME   = 0x40000
-       WS_GROUP        = 0x20000
-       WS_TABSTOP      = 0x10000
-       WS_MINIMIZEBOX  = 0x20000
-       WS_MAXIMIZEBOX  = 0x10000
-       WS_TILED        = WS_OVERLAPPED
-       WS_ICONIC       = WS_MINIMIZE
-       WS_SIZEBOX      = WS_THICKFRAME
-       // Common Window Styles
-       WS_OVERLAPPEDWINDOW = WS_OVERLAPPED | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME | WS_MINIMIZEBOX | WS_MAXIMIZEBOX
-       WS_TILEDWINDOW      = WS_OVERLAPPEDWINDOW
-       WS_POPUPWINDOW      = WS_POPUP | WS_BORDER | WS_SYSMENU
-       WS_CHILDWINDOW      = WS_CHILD
-
-       WS_EX_CLIENTEDGE = 0x200
-
-       // Some windows messages
-       WM_CREATE  = 1
-       WM_DESTROY = 2
-       WM_CLOSE   = 16
-       WM_COMMAND = 273
-
-       // Some button control styles
-       BS_DEFPUSHBUTTON = 1
-
-       // Some color constants
-       COLOR_WINDOW  = 5
-       COLOR_BTNFACE = 15
-
-       // Default window position
-       CW_USEDEFAULT = 0x80000000 - 0x100000000
-
-       // Show window default style
-       SW_SHOWDEFAULT = 10
-)
-
-var (
-       // Some globally known cursors
-       IDC_ARROW = MakeIntResource(32512)
-       IDC_IBEAM = MakeIntResource(32513)
-       IDC_WAIT  = MakeIntResource(32514)
-       IDC_CROSS = MakeIntResource(32515)
-
-       // Some globally known icons
-       IDI_APPLICATION = MakeIntResource(32512)
-       IDI_HAND        = MakeIntResource(32513)
-       IDI_QUESTION    = MakeIntResource(32514)
-       IDI_EXCLAMATION = MakeIntResource(32515)
-       IDI_ASTERISK    = MakeIntResource(32516)
-       IDI_WINLOGO     = MakeIntResource(32517)
-       IDI_WARNING     = IDI_EXCLAMATION
-       IDI_ERROR       = IDI_HAND
-       IDI_INFORMATION = IDI_ASTERISK
-)
-
-//sys  GetModuleHandle(modname *uint16) (handle syscall.Handle, err error) = GetModuleHandleW
-//sys  RegisterClassEx(wndclass *Wndclassex) (atom uint16, err error) = user32.RegisterClassExW
-//sys  CreateWindowEx(exstyle uint32, classname *uint16, windowname *uint16, style uint32, x int32, y int32, width int32, height int32, wndparent syscall.Handle, menu syscall.Handle, instance syscall.Handle, param uintptr) (hwnd syscall.Handle, err error) = user32.CreateWindowExW
-//sys  DefWindowProc(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) = user32.DefWindowProcW
-//sys  DestroyWindow(hwnd syscall.Handle) (err error) = user32.DestroyWindow
-//sys  PostQuitMessage(exitcode int32) = user32.PostQuitMessage
-//sys  ShowWindow(hwnd syscall.Handle, cmdshow int32) (wasvisible bool) = user32.ShowWindow
-//sys  UpdateWindow(hwnd syscall.Handle) (err error) = user32.UpdateWindow
-//sys  GetMessage(msg *Msg, hwnd syscall.Handle, MsgFilterMin uint32, MsgFilterMax uint32) (ret int32, err error) [failretval==-1] = user32.GetMessageW
-//sys  TranslateMessage(msg *Msg) (done bool) = user32.TranslateMessage
-//sys  DispatchMessage(msg *Msg) (ret int32) = user32.DispatchMessageW
-//sys  LoadIcon(instance syscall.Handle, iconname *uint16) (icon syscall.Handle, err error) = user32.LoadIconW
-//sys  LoadCursor(instance syscall.Handle, cursorname *uint16) (cursor syscall.Handle, err error) = user32.LoadCursorW
-//sys  SetCursor(cursor syscall.Handle) (precursor syscall.Handle, err error) = user32.SetCursor
-//sys  SendMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) = user32.SendMessageW
-//sys  PostMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (err error) = user32.PostMessageW
-
-func MakeIntResource(id uint16) *uint16 {
-       return (*uint16)(unsafe.Pointer(uintptr(id)))
-}
diff --git a/libgo/go/exp/wingui/zwinapi.go b/libgo/go/exp/wingui/zwinapi.go
deleted file mode 100644 (file)
index 5666c6d..0000000
+++ /dev/null
@@ -1,192 +0,0 @@
-// +build windows
-// mksyscall_windows.pl winapi.go
-// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
-
-package main
-
-import "unsafe"
-import "syscall"
-
-var (
-       modkernel32 = syscall.NewLazyDLL("kernel32.dll")
-       moduser32   = syscall.NewLazyDLL("user32.dll")
-
-       procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW")
-       procRegisterClassExW = moduser32.NewProc("RegisterClassExW")
-       procCreateWindowExW  = moduser32.NewProc("CreateWindowExW")
-       procDefWindowProcW   = moduser32.NewProc("DefWindowProcW")
-       procDestroyWindow    = moduser32.NewProc("DestroyWindow")
-       procPostQuitMessage  = moduser32.NewProc("PostQuitMessage")
-       procShowWindow       = moduser32.NewProc("ShowWindow")
-       procUpdateWindow     = moduser32.NewProc("UpdateWindow")
-       procGetMessageW      = moduser32.NewProc("GetMessageW")
-       procTranslateMessage = moduser32.NewProc("TranslateMessage")
-       procDispatchMessageW = moduser32.NewProc("DispatchMessageW")
-       procLoadIconW        = moduser32.NewProc("LoadIconW")
-       procLoadCursorW      = moduser32.NewProc("LoadCursorW")
-       procSetCursor        = moduser32.NewProc("SetCursor")
-       procSendMessageW     = moduser32.NewProc("SendMessageW")
-       procPostMessageW     = moduser32.NewProc("PostMessageW")
-)
-
-func GetModuleHandle(modname *uint16) (handle syscall.Handle, err error) {
-       r0, _, e1 := syscall.Syscall(procGetModuleHandleW.Addr(), 1, uintptr(unsafe.Pointer(modname)), 0, 0)
-       handle = syscall.Handle(r0)
-       if handle == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func RegisterClassEx(wndclass *Wndclassex) (atom uint16, err error) {
-       r0, _, e1 := syscall.Syscall(procRegisterClassExW.Addr(), 1, uintptr(unsafe.Pointer(wndclass)), 0, 0)
-       atom = uint16(r0)
-       if atom == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func CreateWindowEx(exstyle uint32, classname *uint16, windowname *uint16, style uint32, x int32, y int32, width int32, height int32, wndparent syscall.Handle, menu syscall.Handle, instance syscall.Handle, param uintptr) (hwnd syscall.Handle, err error) {
-       r0, _, e1 := syscall.Syscall12(procCreateWindowExW.Addr(), 12, uintptr(exstyle), uintptr(unsafe.Pointer(classname)), uintptr(unsafe.Pointer(windowname)), uintptr(style), uintptr(x), uintptr(y), uintptr(width), uintptr(height), uintptr(wndparent), uintptr(menu), uintptr(instance), uintptr(param))
-       hwnd = syscall.Handle(r0)
-       if hwnd == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func DefWindowProc(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) {
-       r0, _, _ := syscall.Syscall6(procDefWindowProcW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
-       lresult = uintptr(r0)
-       return
-}
-
-func DestroyWindow(hwnd syscall.Handle) (err error) {
-       r1, _, e1 := syscall.Syscall(procDestroyWindow.Addr(), 1, uintptr(hwnd), 0, 0)
-       if int(r1) == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func PostQuitMessage(exitcode int32) {
-       syscall.Syscall(procPostQuitMessage.Addr(), 1, uintptr(exitcode), 0, 0)
-       return
-}
-
-func ShowWindow(hwnd syscall.Handle, cmdshow int32) (wasvisible bool) {
-       r0, _, _ := syscall.Syscall(procShowWindow.Addr(), 2, uintptr(hwnd), uintptr(cmdshow), 0)
-       wasvisible = bool(r0 != 0)
-       return
-}
-
-func UpdateWindow(hwnd syscall.Handle) (err error) {
-       r1, _, e1 := syscall.Syscall(procUpdateWindow.Addr(), 1, uintptr(hwnd), 0, 0)
-       if int(r1) == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func GetMessage(msg *Msg, hwnd syscall.Handle, MsgFilterMin uint32, MsgFilterMax uint32) (ret int32, err error) {
-       r0, _, e1 := syscall.Syscall6(procGetMessageW.Addr(), 4, uintptr(unsafe.Pointer(msg)), uintptr(hwnd), uintptr(MsgFilterMin), uintptr(MsgFilterMax), 0, 0)
-       ret = int32(r0)
-       if ret == -1 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func TranslateMessage(msg *Msg) (done bool) {
-       r0, _, _ := syscall.Syscall(procTranslateMessage.Addr(), 1, uintptr(unsafe.Pointer(msg)), 0, 0)
-       done = bool(r0 != 0)
-       return
-}
-
-func DispatchMessage(msg *Msg) (ret int32) {
-       r0, _, _ := syscall.Syscall(procDispatchMessageW.Addr(), 1, uintptr(unsafe.Pointer(msg)), 0, 0)
-       ret = int32(r0)
-       return
-}
-
-func LoadIcon(instance syscall.Handle, iconname *uint16) (icon syscall.Handle, err error) {
-       r0, _, e1 := syscall.Syscall(procLoadIconW.Addr(), 2, uintptr(instance), uintptr(unsafe.Pointer(iconname)), 0)
-       icon = syscall.Handle(r0)
-       if icon == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func LoadCursor(instance syscall.Handle, cursorname *uint16) (cursor syscall.Handle, err error) {
-       r0, _, e1 := syscall.Syscall(procLoadCursorW.Addr(), 2, uintptr(instance), uintptr(unsafe.Pointer(cursorname)), 0)
-       cursor = syscall.Handle(r0)
-       if cursor == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func SetCursor(cursor syscall.Handle) (precursor syscall.Handle, err error) {
-       r0, _, e1 := syscall.Syscall(procSetCursor.Addr(), 1, uintptr(cursor), 0, 0)
-       precursor = syscall.Handle(r0)
-       if precursor == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
-
-func SendMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) {
-       r0, _, _ := syscall.Syscall6(procSendMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
-       lresult = uintptr(r0)
-       return
-}
-
-func PostMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (err error) {
-       r1, _, e1 := syscall.Syscall6(procPostMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
-       if int(r1) == 0 {
-               if e1 != 0 {
-                       err = error(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}
index 1919296..b065995 100644 (file)
@@ -41,10 +41,14 @@ type Var interface {
 // Int is a 64-bit integer variable that satisfies the Var interface.
 type Int struct {
        i  int64
-       mu sync.Mutex
+       mu sync.RWMutex
 }
 
-func (v *Int) String() string { return strconv.FormatInt(v.i, 10) }
+func (v *Int) String() string {
+       v.mu.RLock()
+       defer v.mu.RUnlock()
+       return strconv.FormatInt(v.i, 10)
+}
 
 func (v *Int) Add(delta int64) {
        v.mu.Lock()
@@ -61,10 +65,14 @@ func (v *Int) Set(value int64) {
 // Float is a 64-bit float variable that satisfies the Var interface.
 type Float struct {
        f  float64
-       mu sync.Mutex
+       mu sync.RWMutex
 }
 
-func (v *Float) String() string { return strconv.FormatFloat(v.f, 'g', -1, 64) }
+func (v *Float) String() string {
+       v.mu.RLock()
+       defer v.mu.RUnlock()
+       return strconv.FormatFloat(v.f, 'g', -1, 64)
+}
 
 // Add adds delta to v.
 func (v *Float) Add(delta float64) {
@@ -95,17 +103,17 @@ type KeyValue struct {
 func (v *Map) String() string {
        v.mu.RLock()
        defer v.mu.RUnlock()
-       b := new(bytes.Buffer)
-       fmt.Fprintf(b, "{")
+       var b bytes.Buffer
+       fmt.Fprintf(&b, "{")
        first := true
        for key, val := range v.m {
                if !first {
-                       fmt.Fprintf(b, ", ")
+                       fmt.Fprintf(&b, ", ")
                }
-               fmt.Fprintf(b, "\"%s\": %v", key, val)
+               fmt.Fprintf(&b, "\"%s\": %v", key, val)
                first = false
        }
-       fmt.Fprintf(b, "}")
+       fmt.Fprintf(&b, "}")
        return b.String()
 }
 
@@ -180,12 +188,21 @@ func (v *Map) Do(f func(KeyValue)) {
 
 // String is a string variable, and satisfies the Var interface.
 type String struct {
-       s string
+       s  string
+       mu sync.RWMutex
 }
 
-func (v *String) String() string { return strconv.Quote(v.s) }
+func (v *String) String() string {
+       v.mu.RLock()
+       defer v.mu.RUnlock()
+       return strconv.Quote(v.s)
+}
 
-func (v *String) Set(value string) { v.s = value }
+func (v *String) Set(value string) {
+       v.mu.Lock()
+       defer v.mu.Unlock()
+       v.s = value
+}
 
 // Func implements Var by calling the function
 // and formatting the returned value using JSON.
index 7d4178d..9660370 100644 (file)
@@ -7,7 +7,8 @@
        to C's printf and scanf.  The format 'verbs' are derived from C's but
        are simpler.
 
-       Printing:
+
+       Printing
 
        The verbs:
 
        by a single character (the verb) and end with a parenthesized
        description.
 
-       Scanning:
+
+       Scanning
 
        An analogous set of functions scans formatted text to yield
        values.  Scan, Scanf and Scanln read from os.Stdin; Fscan,
diff --git a/libgo/go/fmt/export_test.go b/libgo/go/fmt/export_test.go
new file mode 100644 (file)
index 0000000..89d57ee
--- /dev/null
@@ -0,0 +1,7 @@
+// Copyright 2012 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+var IsSpace = isSpace
index f34df59..e0c587a 100644 (file)
@@ -13,6 +13,7 @@ import (
        "strings"
        "testing"
        "time"
+       "unicode"
 )
 
 type (
@@ -830,3 +831,13 @@ func TestBadVerbRecursion(t *testing.T) {
                t.Error("fail with value")
        }
 }
+
+func TestIsSpace(t *testing.T) {
+       // This tests the internal isSpace function.
+       // IsSpace = isSpace is defined in export_test.go.
+       for i := rune(0); i <= unicode.MaxRune; i++ {
+               if IsSpace(i) != unicode.IsSpace(i) {
+                       t.Errorf("isSpace(%U) = %v, want %v", i, IsSpace(i), unicode.IsSpace(i))
+               }
+       }
+}
index 78d9e99..2186f33 100644 (file)
@@ -5,9 +5,7 @@
 package fmt
 
 import (
-       "bytes"
        "strconv"
-       "unicode"
        "unicode/utf8"
 )
 
@@ -36,10 +34,10 @@ func init() {
 }
 
 // A fmt is the raw formatter used by Printf etc.
-// It prints into a bytes.Buffer that must be set up externally.
+// It prints into a buffer that must be set up separately.
 type fmt struct {
        intbuf [nByte]byte
-       buf    *bytes.Buffer
+       buf    *buffer
        // width, precision
        wid  int
        prec int
@@ -69,7 +67,7 @@ func (f *fmt) clearflags() {
        f.zero = false
 }
 
-func (f *fmt) init(buf *bytes.Buffer) {
+func (f *fmt) init(buf *buffer) {
        f.buf = buf
        f.clearflags()
 }
@@ -247,7 +245,7 @@ func (f *fmt) integer(a int64, base uint64, signedness bool, digits string) {
        }
 
        // If we want a quoted char for %#U, move the data up to make room.
-       if f.unicode && f.uniQuote && a >= 0 && a <= unicode.MaxRune && unicode.IsPrint(rune(a)) {
+       if f.unicode && f.uniQuote && a >= 0 && a <= utf8.MaxRune && strconv.IsPrint(rune(a)) {
                runeWidth := utf8.RuneLen(rune(a))
                width := 1 + 1 + runeWidth + 1 // space, quote, rune, quote
                copy(buf[i-width:], buf[i:])   // guaranteed to have enough room.
@@ -290,16 +288,15 @@ func (f *fmt) fmt_s(s string) {
 // fmt_sx formats a string as a hexadecimal encoding of its bytes.
 func (f *fmt) fmt_sx(s, digits string) {
        // TODO: Avoid buffer by pre-padding.
-       var b bytes.Buffer
+       var b []byte
        for i := 0; i < len(s); i++ {
                if i > 0 && f.space {
-                       b.WriteByte(' ')
+                       b = append(b, ' ')
                }
                v := s[i]
-               b.WriteByte(digits[v>>4])
-               b.WriteByte(digits[v&0xF])
+               b = append(b, digits[v>>4], digits[v&0xF])
        }
-       f.pad(b.Bytes())
+       f.pad(b)
 }
 
 // fmt_q formats a string as a double-quoted, escaped Go string constant.
index c3ba2f3..1343824 100644 (file)
@@ -5,13 +5,11 @@
 package fmt
 
 import (
-       "bytes"
        "errors"
        "io"
        "os"
        "reflect"
        "sync"
-       "unicode"
        "unicode/utf8"
 )
 
@@ -71,11 +69,45 @@ type GoStringer interface {
        GoString() string
 }
 
+// Use simple []byte instead of bytes.Buffer to avoid large dependency.
+type buffer []byte
+
+func (b *buffer) Write(p []byte) (n int, err error) {
+       *b = append(*b, p...)
+       return len(p), nil
+}
+
+func (b *buffer) WriteString(s string) (n int, err error) {
+       *b = append(*b, s...)
+       return len(s), nil
+}
+
+func (b *buffer) WriteByte(c byte) error {
+       *b = append(*b, c)
+       return nil
+}
+
+func (bp *buffer) WriteRune(r rune) error {
+       if r < utf8.RuneSelf {
+               *bp = append(*bp, byte(r))
+               return nil
+       }
+
+       b := *bp
+       n := len(b)
+       for n+utf8.UTFMax > cap(b) {
+               b = append(b, 0)
+       }
+       w := utf8.EncodeRune(b[n:n+utf8.UTFMax], r)
+       *bp = b[:n+w]
+       return nil
+}
+
 type pp struct {
        n         int
        panicking bool
        erroring  bool // printing an error condition
-       buf       bytes.Buffer
+       buf       buffer
        // field holds the current item, as an interface{}.
        field interface{}
        // value holds the current item, as a reflect.Value, and will be
@@ -133,10 +165,10 @@ func newPrinter() *pp {
 // Save used pp structs in ppFree; avoids an allocation per invocation.
 func (p *pp) free() {
        // Don't hold on to pp structs with large buffers.
-       if cap(p.buf.Bytes()) > 1024 {
+       if cap(p.buf) > 1024 {
                return
        }
-       p.buf.Reset()
+       p.buf = p.buf[:0]
        p.field = nil
        p.value = reflect.Value{}
        ppFree.put(p)
@@ -179,7 +211,7 @@ func (p *pp) Write(b []byte) (ret int, err error) {
 func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
        p := newPrinter()
        p.doPrintf(format, a)
-       n64, err := p.buf.WriteTo(w)
+       n64, err := w.Write(p.buf)
        p.free()
        return int(n64), err
 }
@@ -194,7 +226,7 @@ func Printf(format string, a ...interface{}) (n int, err error) {
 func Sprintf(format string, a ...interface{}) string {
        p := newPrinter()
        p.doPrintf(format, a)
-       s := p.buf.String()
+       s := string(p.buf)
        p.free()
        return s
 }
@@ -213,7 +245,7 @@ func Errorf(format string, a ...interface{}) error {
 func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
        p := newPrinter()
        p.doPrint(a, false, false)
-       n64, err := p.buf.WriteTo(w)
+       n64, err := w.Write(p.buf)
        p.free()
        return int(n64), err
 }
@@ -230,7 +262,7 @@ func Print(a ...interface{}) (n int, err error) {
 func Sprint(a ...interface{}) string {
        p := newPrinter()
        p.doPrint(a, false, false)
-       s := p.buf.String()
+       s := string(p.buf)
        p.free()
        return s
 }
@@ -245,7 +277,7 @@ func Sprint(a ...interface{}) string {
 func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
        p := newPrinter()
        p.doPrint(a, true, true)
-       n64, err := p.buf.WriteTo(w)
+       n64, err := w.Write(p.buf)
        p.free()
        return int(n64), err
 }
@@ -262,7 +294,7 @@ func Println(a ...interface{}) (n int, err error) {
 func Sprintln(a ...interface{}) string {
        p := newPrinter()
        p.doPrint(a, true, true)
-       s := p.buf.String()
+       s := string(p.buf)
        p.free()
        return s
 }
@@ -352,7 +384,7 @@ func (p *pp) fmtInt64(v int64, verb rune) {
        case 'o':
                p.fmt.integer(v, 8, signed, ldigits)
        case 'q':
-               if 0 <= v && v <= unicode.MaxRune {
+               if 0 <= v && v <= utf8.MaxRune {
                        p.fmt.fmt_qc(v)
                } else {
                        p.badVerb(verb)
@@ -416,7 +448,7 @@ func (p *pp) fmtUint64(v uint64, verb rune, goSyntax bool) {
        case 'o':
                p.fmt.integer(int64(v), 8, unsigned, ldigits)
        case 'q':
-               if 0 <= v && v <= unicode.MaxRune {
+               if 0 <= v && v <= utf8.MaxRune {
                        p.fmt.fmt_qc(int64(v))
                } else {
                        p.badVerb(verb)
index fa9a558..0b3e040 100644 (file)
@@ -5,15 +5,12 @@
 package fmt
 
 import (
-       "bytes"
        "errors"
        "io"
        "math"
        "os"
        "reflect"
        "strconv"
-       "strings"
-       "unicode"
        "unicode/utf8"
 )
 
@@ -87,25 +84,36 @@ func Scanf(format string, a ...interface{}) (n int, err error) {
        return Fscanf(os.Stdin, format, a...)
 }
 
+type stringReader string
+
+func (r *stringReader) Read(b []byte) (n int, err error) {
+       n = copy(b, *r)
+       *r = (*r)[n:]
+       if n == 0 {
+               err = io.EOF
+       }
+       return
+}
+
 // Sscan scans the argument string, storing successive space-separated
 // values into successive arguments.  Newlines count as space.  It
 // returns the number of items successfully scanned.  If that is less
 // than the number of arguments, err will report why.
 func Sscan(str string, a ...interface{}) (n int, err error) {
-       return Fscan(strings.NewReader(str), a...)
+       return Fscan((*stringReader)(&str), a...)
 }
 
 // Sscanln is similar to Sscan, but stops scanning at a newline and
 // after the final item there must be a newline or EOF.
 func Sscanln(str string, a ...interface{}) (n int, err error) {
-       return Fscanln(strings.NewReader(str), a...)
+       return Fscanln((*stringReader)(&str), a...)
 }
 
 // Sscanf scans the argument string, storing successive space-separated
 // values into successive arguments as determined by the format.  It
 // returns the number of items successfully parsed.
 func Sscanf(str string, format string, a ...interface{}) (n int, err error) {
-       return Fscanf(strings.NewReader(str), format, a...)
+       return Fscanf((*stringReader)(&str), format, a...)
 }
 
 // Fscan scans text read from r, storing successive space-separated
@@ -149,7 +157,7 @@ const eof = -1
 // ss is the internal implementation of ScanState.
 type ss struct {
        rr       io.RuneReader // where to read input
-       buf      bytes.Buffer  // token accumulator
+       buf      buffer        // token accumulator
        peekRune rune          // one-rune lookahead
        prevRune rune          // last rune returned by ReadRune
        count    int           // runes consumed so far.
@@ -262,14 +270,46 @@ func (s *ss) Token(skipSpace bool, f func(rune) bool) (tok []byte, err error) {
        if f == nil {
                f = notSpace
        }
-       s.buf.Reset()
+       s.buf = s.buf[:0]
        tok = s.token(skipSpace, f)
        return
 }
 
+// space is a copy of the unicode.White_Space ranges,
+// to avoid depending on package unicode.
+var space = [][2]uint16{
+       {0x0009, 0x000d},
+       {0x0020, 0x0020},
+       {0x0085, 0x0085},
+       {0x00a0, 0x00a0},
+       {0x1680, 0x1680},
+       {0x180e, 0x180e},
+       {0x2000, 0x200a},
+       {0x2028, 0x2029},
+       {0x202f, 0x202f},
+       {0x205f, 0x205f},
+       {0x3000, 0x3000},
+}
+
+func isSpace(r rune) bool {
+       if r >= 1<<16 {
+               return false
+       }
+       rx := uint16(r)
+       for _, rng := range space {
+               if rx < rng[0] {
+                       return false
+               }
+               if rx <= rng[1] {
+                       return true
+               }
+       }
+       return false
+}
+
 // notSpace is the default scanning function used in Token.
 func notSpace(r rune) bool {
-       return !unicode.IsSpace(r)
+       return !isSpace(r)
 }
 
 // skipSpace provides Scan() methods the ability to skip space and newline characters 
@@ -378,10 +418,10 @@ func (s *ss) free(old ssave) {
                return
        }
        // Don't hold on to ss structs with large buffers.
-       if cap(s.buf.Bytes()) > 1024 {
+       if cap(s.buf) > 1024 {
                return
        }
-       s.buf.Reset()
+       s.buf = s.buf[:0]
        s.rr = nil
        ssFree.put(s)
 }
@@ -403,7 +443,7 @@ func (s *ss) skipSpace(stopAtNewline bool) {
                        s.errorString("unexpected newline")
                        return
                }
-               if !unicode.IsSpace(r) {
+               if !isSpace(r) {
                        s.UnreadRune()
                        break
                }
@@ -429,7 +469,7 @@ func (s *ss) token(skipSpace bool, f func(rune) bool) []byte {
                }
                s.buf.WriteRune(r)
        }
-       return s.buf.Bytes()
+       return s.buf
 }
 
 // typeError indicates that the type of the operand did not match the format
@@ -440,6 +480,15 @@ func (s *ss) typeError(field interface{}, expected string) {
 var complexError = errors.New("syntax error scanning complex number")
 var boolError = errors.New("syntax error scanning boolean")
 
+func indexRune(s string, r rune) int {
+       for i, c := range s {
+               if c == r {
+                       return i
+               }
+       }
+       return -1
+}
+
 // consume reads the next rune in the input and reports whether it is in the ok string.
 // If accept is true, it puts the character into the input token.
 func (s *ss) consume(ok string, accept bool) bool {
@@ -447,7 +496,7 @@ func (s *ss) consume(ok string, accept bool) bool {
        if r == eof {
                return false
        }
-       if strings.IndexRune(ok, r) >= 0 {
+       if indexRune(ok, r) >= 0 {
                if accept {
                        s.buf.WriteRune(r)
                }
@@ -465,7 +514,7 @@ func (s *ss) peek(ok string) bool {
        if r != eof {
                s.UnreadRune()
        }
-       return strings.IndexRune(ok, r) >= 0
+       return indexRune(ok, r) >= 0
 }
 
 func (s *ss) notEOF() {
@@ -560,7 +609,7 @@ func (s *ss) scanNumber(digits string, haveDigits bool) string {
        }
        for s.accept(digits) {
        }
-       return s.buf.String()
+       return string(s.buf)
 }
 
 // scanRune returns the next rune value in the input.
@@ -660,16 +709,16 @@ func (s *ss) scanUint(verb rune, bitSize int) uint64 {
 // if the width is specified. It's not rigorous about syntax because it doesn't check that
 // we have at least some digits, but Atof will do that.
 func (s *ss) floatToken() string {
-       s.buf.Reset()
+       s.buf = s.buf[:0]
        // NaN?
        if s.accept("nN") && s.accept("aA") && s.accept("nN") {
-               return s.buf.String()
+               return string(s.buf)
        }
        // leading sign?
        s.accept(sign)
        // Inf?
        if s.accept("iI") && s.accept("nN") && s.accept("fF") {
-               return s.buf.String()
+               return string(s.buf)
        }
        // digits?
        for s.accept(decimalDigits) {
@@ -688,7 +737,7 @@ func (s *ss) floatToken() string {
                for s.accept(decimalDigits) {
                }
        }
-       return s.buf.String()
+       return string(s.buf)
 }
 
 // complexTokens returns the real and imaginary parts of the complex number starting here.
@@ -698,13 +747,13 @@ func (s *ss) complexTokens() (real, imag string) {
        // TODO: accept N and Ni independently?
        parens := s.accept("(")
        real = s.floatToken()
-       s.buf.Reset()
+       s.buf = s.buf[:0]
        // Must now have a sign.
        if !s.accept("+-") {
                s.error(complexError)
        }
        // Sign is now in buffer
-       imagSign := s.buf.String()
+       imagSign := string(s.buf)
        imag = s.floatToken()
        if !s.accept("i") {
                s.error(complexError)
@@ -717,7 +766,7 @@ func (s *ss) complexTokens() (real, imag string) {
 
 // convertFloat converts the string to a float64value.
 func (s *ss) convertFloat(str string, n int) float64 {
-       if p := strings.Index(str, "p"); p >= 0 {
+       if p := indexRune(str, 'p'); p >= 0 {
                // Atof doesn't handle power-of-2 exponents,
                // but they're easy to evaluate.
                f, err := strconv.ParseFloat(str[:p], n)
@@ -794,7 +843,7 @@ func (s *ss) quotedString() string {
                        }
                        s.buf.WriteRune(r)
                }
-               return s.buf.String()
+               return string(s.buf)
        case '"':
                // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
                s.buf.WriteRune(quote)
@@ -811,7 +860,7 @@ func (s *ss) quotedString() string {
                                break
                        }
                }
-               result, err := strconv.Unquote(s.buf.String())
+               result, err := strconv.Unquote(string(s.buf))
                if err != nil {
                        s.error(err)
                }
@@ -844,7 +893,7 @@ func (s *ss) hexByte() (b byte, ok bool) {
        if rune1 == eof {
                return
        }
-       if unicode.IsSpace(rune1) {
+       if isSpace(rune1) {
                s.UnreadRune()
                return
        }
@@ -862,11 +911,11 @@ func (s *ss) hexString() string {
                }
                s.buf.WriteByte(b)
        }
-       if s.buf.Len() == 0 {
+       if len(s.buf) == 0 {
                s.errorString("Scan: no hex data for %x string")
                return ""
        }
-       return s.buf.String()
+       return string(s.buf)
 }
 
 const floatVerbs = "beEfFgGv"
@@ -875,7 +924,7 @@ const hugeWid = 1 << 30
 
 // scanOne scans a single value, deriving the scanner from the type of the argument.
 func (s *ss) scanOne(verb rune, field interface{}) {
-       s.buf.Reset()
+       s.buf = s.buf[:0]
        var err error
        // If the parameter has its own Scan method, use that.
        if v, ok := field.(Scanner); ok {
@@ -1004,7 +1053,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err error) {
                        if r == '\n' || r == eof {
                                break
                        }
-                       if !unicode.IsSpace(r) {
+                       if !isSpace(r) {
                                s.errorString("Scan: expected newline")
                                break
                        }
@@ -1032,7 +1081,7 @@ func (s *ss) advance(format string) (i int) {
                        i += w // skip the first %
                }
                sawSpace := false
-               for unicode.IsSpace(fmtc) && i < len(format) {
+               for isSpace(fmtc) && i < len(format) {
                        sawSpace = true
                        i += w
                        fmtc, w = utf8.DecodeRuneInString(format[i:])
@@ -1044,7 +1093,7 @@ func (s *ss) advance(format string) (i int) {
                        if inputc == eof {
                                return
                        }
-                       if !unicode.IsSpace(inputc) {
+                       if !isSpace(inputc) {
                                // Space in format but not in input: error
                                s.errorString("expected space in input to match format")
                        }
index eece761..dc9dcd1 100644 (file)
@@ -34,7 +34,7 @@ type Context struct {
        CgoEnabled  bool     // whether cgo can be used
        BuildTags   []string // additional tags to recognize in +build lines
        UseAllFiles bool     // use files regardless of +build lines, file names
-       Gccgo       bool     // assume use of gccgo when computing object paths
+       Compiler    string   // compiler to assume when computing target paths
 
        // By default, Import uses the operating system's file system calls
        // to read directories and files.  To read from other sources,
@@ -210,6 +210,7 @@ func (ctxt *Context) SrcDirs() []string {
 // if set, or else the compiled code's GOARCH, GOOS, and GOROOT.
 var Default Context = defaultContext()
 
+// This list is also known to ../../../cmd/dist/build.c.
 var cgoEnabled = map[string]bool{
        "darwin/386":    true,
        "darwin/amd64":  true,
@@ -228,6 +229,7 @@ func defaultContext() Context {
        c.GOOS = envOr("GOOS", runtime.GOOS)
        c.GOROOT = runtime.GOROOT()
        c.GOPATH = envOr("GOPATH", "")
+       c.Compiler = runtime.Compiler
 
        switch os.Getenv("CGO_ENABLED") {
        case "1":
@@ -277,11 +279,12 @@ type Package struct {
        PkgObj     string // installed .a file
 
        // Source files
-       GoFiles  []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles)
-       CgoFiles []string // .go source files that import "C"
-       CFiles   []string // .c source files
-       HFiles   []string // .h source files
-       SFiles   []string // .s source files
+       GoFiles   []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles)
+       CgoFiles  []string // .go source files that import "C"
+       CFiles    []string // .c source files
+       HFiles    []string // .h source files
+       SFiles    []string // .s source files
+       SysoFiles []string // .syso system object files to add to archive
 
        // Cgo directives
        CgoPkgConfig []string // Cgo pkg-config directives
@@ -314,6 +317,16 @@ func (ctxt *Context) ImportDir(dir string, mode ImportMode) (*Package, error) {
        return ctxt.Import(".", dir, mode)
 }
 
+// NoGoError is the error used by Import to describe a directory
+// containing no Go source files.
+type NoGoError struct {
+       Dir string
+}
+
+func (e *NoGoError) Error() string {
+       return "no Go source files in " + e.Dir
+}
+
 // Import returns details about the Go package named by the import path,
 // interpreting local import paths relative to the src directory.  If the path
 // is a local import path naming a package that can be imported using a
@@ -336,11 +349,16 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
        }
 
        var pkga string
-       if ctxt.Gccgo {
+       var pkgerr error
+       switch ctxt.Compiler {
+       case "gccgo":
                dir, elem := pathpkg.Split(p.ImportPath)
                pkga = "pkg/gccgo/" + dir + "lib" + elem + ".a"
-       } else {
+       case "gc":
                pkga = "pkg/" + ctxt.GOOS + "_" + ctxt.GOARCH + "/" + p.ImportPath + ".a"
+       default:
+               // Save error for end of function.
+               pkgerr = fmt.Errorf("import %q: unknown compiler %q", path, ctxt.Compiler)
        }
 
        binaryOnly := false
@@ -396,7 +414,7 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
                if ctxt.GOROOT != "" {
                        dir := ctxt.joinPath(ctxt.GOROOT, "src", "pkg", path)
                        isDir := ctxt.isDir(dir)
-                       binaryOnly = !isDir && mode&AllowBinary != 0 && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
+                       binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
                        if isDir || binaryOnly {
                                p.Dir = dir
                                p.Goroot = true
@@ -407,7 +425,7 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
                for _, root := range ctxt.gopath() {
                        dir := ctxt.joinPath(root, "src", path)
                        isDir := ctxt.isDir(dir)
-                       binaryOnly = !isDir && mode&AllowBinary != 0 && ctxt.isFile(ctxt.joinPath(root, pkga))
+                       binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(root, pkga))
                        if isDir || binaryOnly {
                                p.Dir = dir
                                p.Root = root
@@ -426,14 +444,16 @@ Found:
                }
                p.PkgRoot = ctxt.joinPath(p.Root, "pkg")
                p.BinDir = ctxt.joinPath(p.Root, "bin")
-               p.PkgObj = ctxt.joinPath(p.Root, pkga)
+               if pkga != "" {
+                       p.PkgObj = ctxt.joinPath(p.Root, pkga)
+               }
        }
 
        if mode&FindOnly != 0 {
-               return p, nil
+               return p, pkgerr
        }
        if binaryOnly && (mode&AllowBinary) != 0 {
-               return p, nil
+               return p, pkgerr
        }
 
        dirs, err := ctxt.readDir(p.Dir)
@@ -467,7 +487,13 @@ Found:
                ext := name[i:]
                switch ext {
                case ".go", ".c", ".s", ".h", ".S":
-                       // tentatively okay
+                       // tentatively okay - read to make sure
+               case ".syso":
+                       // binary objects to add to package archive
+                       // Likely of the form foo_windows.syso, but
+                       // the name was vetted above with goodOSArchFile.
+                       p.SysoFiles = append(p.SysoFiles, name)
+                       continue
                default:
                        // skip
                        continue
@@ -586,7 +612,7 @@ Found:
                }
        }
        if p.Name == "" {
-               return p, fmt.Errorf("no Go source files in %s", p.Dir)
+               return p, &NoGoError{p.Dir}
        }
 
        p.Imports, p.ImportPos = cleanImports(imported)
@@ -601,7 +627,7 @@ Found:
                sort.Strings(p.SFiles)
        }
 
-       return p, nil
+       return p, pkgerr
 }
 
 func cleanImports(m map[string][]token.Position) ([]string, map[string][]token.Position) {
diff --git a/libgo/go/go/build/deps_test.go b/libgo/go/go/build/deps_test.go
new file mode 100644 (file)
index 0000000..4e9f32a
--- /dev/null
@@ -0,0 +1,424 @@
+// Copyright 2012 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file exercises the import parser but also checks that
+// some low-level packages do not have new dependencies added.
+
+package build_test
+
+import (
+       "go/build"
+       "sort"
+       "testing"
+)
+
+// pkgDeps defines the expected dependencies between packages in
+// the Go source tree.  It is a statement of policy.
+// Changes should not be made to this map without prior discussion.
+//
+// The map contains two kinds of entries:
+// 1) Lower-case keys are standard import paths and list the
+// allowed imports in that package.
+// 2) Upper-case keys define aliases for package sets, which can then
+// be used as dependencies by other rules.
+//
+// DO NOT CHANGE THIS DATA TO FIX BUILDS.
+// 
+var pkgDeps = map[string][]string{
+       // L0 is the lowest level, core, nearly unavoidable packages.
+       "errors":      {},
+       "io":          {"errors", "sync"},
+       "runtime":     {"unsafe"},
+       "sync":        {"sync/atomic"},
+       "sync/atomic": {"unsafe"},
+       "unsafe":      {},
+
+       "L0": {
+               "errors",
+               "io",
+               "runtime",
+               "sync",
+               "sync/atomic",
+               "unsafe",
+       },
+
+       // L1 adds simple functions and strings processing,
+       // but not Unicode tables.
+       "math":          {"unsafe"},
+       "math/cmplx":    {"math"},
+       "math/rand":     {"L0", "math"},
+       "sort":          {"math"},
+       "strconv":       {"L0", "unicode/utf8", "math"},
+       "unicode/utf16": {},
+       "unicode/utf8":  {},
+
+       "L1": {
+               "L0",
+               "math",
+               "math/cmplx",
+               "math/rand",
+               "sort",
+               "strconv",
+               "unicode/utf16",
+               "unicode/utf8",
+       },
+
+       // L2 adds Unicode and strings processing.
+       "bufio":   {"L0", "unicode/utf8", "bytes"},
+       "bytes":   {"L0", "unicode", "unicode/utf8"},
+       "path":    {"L0", "unicode/utf8", "strings"},
+       "strings": {"L0", "unicode", "unicode/utf8"},
+       "unicode": {},
+
+       "L2": {
+               "L1",
+               "bufio",
+               "bytes",
+               "path",
+               "strings",
+               "unicode",
+       },
+
+       // L3 adds reflection and some basic utility packages
+       // and interface definitions, but nothing that makes
+       // system calls.
+       "crypto":          {"L2", "hash"}, // interfaces
+       "crypto/cipher":   {"L2"},         // interfaces
+       "encoding/base32": {"L2"},
+       "encoding/base64": {"L2"},
+       "encoding/binary": {"L2", "reflect"},
+       "hash":            {"L2"}, // interfaces
+       "hash/adler32":    {"L2", "hash"},
+       "hash/crc32":      {"L2", "hash"},
+       "hash/crc64":      {"L2", "hash"},
+       "hash/fnv":        {"L2", "hash"},
+       "image":           {"L2", "image/color"}, // interfaces
+       "image/color":     {"L2"},                // interfaces
+       "reflect":         {"L2"},
+
+       "L3": {
+               "L2",
+               "crypto",
+               "crypto/cipher",
+               "encoding/base32",
+               "encoding/base64",
+               "encoding/binary",
+               "hash",
+               "hash/adler32",
+               "hash/crc32",
+               "hash/crc64",
+               "hash/fnv",
+               "image",
+               "image/color",
+               "reflect",
+       },
+
+       // End of linear dependency definitions.
+
+       // Operating system access.
+       "syscall":       {"L0", "unicode/utf16"},
+       "time":          {"L0", "syscall"},
+       "os":            {"L1", "os", "syscall", "time"},
+       "path/filepath": {"L2", "os", "syscall"},
+       "io/ioutil":     {"L2", "os", "path/filepath", "time"},
+       "os/exec":       {"L2", "os", "syscall"},
+       "os/signal":     {"L2", "os", "syscall"},
+
+       // OS enables basic operating system functionality,
+       // but not direct use of package syscall, nor os/signal.
+       "OS": {
+               "io/ioutil",
+               "os",
+               "os/exec",
+               "path/filepath",
+               "time",
+       },
+
+       // Formatted I/O: few dependencies (L1) but we must add reflect.
+       "fmt": {"L1", "os", "reflect"},
+       "log": {"L1", "os", "fmt", "time"},
+
+       // Packages used by testing must be low-level (L2+fmt).
+       "regexp":         {"L2", "regexp/syntax"},
+       "regexp/syntax":  {"L2"},
+       "runtime/debug":  {"L2", "fmt", "io/ioutil", "os"},
+       "runtime/pprof":  {"L2", "fmt", "text/tabwriter"},
+       "text/tabwriter": {"L2"},
+
+       "testing":        {"L2", "flag", "fmt", "os", "runtime/pprof", "time"},
+       "testing/iotest": {"L2", "log"},
+       "testing/quick":  {"L2", "flag", "fmt", "reflect"},
+
+       // L4 is defined as L3+fmt+log+time, because in general once
+       // you're using L3 packages, use of fmt, log, or time is not a big deal.
+       "L4": {
+               "L3",
+               "fmt",
+               "log",
+               "time",
+       },
+
+       // Go parser.
+       "go/ast":     {"L4", "OS", "go/scanner", "go/token"},
+       "go/doc":     {"L4", "go/ast", "go/token", "regexp", "text/template"},
+       "go/parser":  {"L4", "OS", "go/ast", "go/scanner", "go/token"},
+       "go/printer": {"L4", "OS", "go/ast", "go/scanner", "go/token", "text/tabwriter"},
+       "go/scanner": {"L4", "OS", "go/token"},
+       "go/token":   {"L4"},
+
+       "GOPARSER": {
+               "go/ast",
+               "go/doc",
+               "go/parser",
+               "go/printer",
+               "go/scanner",
+               "go/token",
+       },
+
+       // One of a kind.
+       "archive/tar":         {"L4", "OS"},
+       "archive/zip":         {"L4", "OS", "compress/flate"},
+       "compress/bzip2":      {"L4"},
+       "compress/flate":      {"L4"},
+       "compress/gzip":       {"L4", "compress/flate"},
+       "compress/lzw":        {"L4"},
+       "compress/zlib":       {"L4", "compress/flate"},
+       "database/sql":        {"L4", "database/sql/driver"},
+       "database/sql/driver": {"L4", "time"},
+       "debug/dwarf":         {"L4"},
+       "debug/elf":           {"L4", "OS", "debug/dwarf"},
+       "debug/gosym":         {"L4"},
+       "debug/macho":         {"L4", "OS", "debug/dwarf"},
+       "debug/pe":            {"L4", "OS", "debug/dwarf"},
+       "encoding/ascii85":    {"L4"},
+       "encoding/asn1":       {"L4", "math/big"},
+       "encoding/csv":        {"L4"},
+       "encoding/gob":        {"L4", "OS"},
+       "encoding/hex":        {"L4"},
+       "encoding/json":       {"L4"},
+       "encoding/pem":        {"L4"},
+       "encoding/xml":        {"L4"},
+       "flag":                {"L4", "OS"},
+       "go/build":            {"L4", "OS", "GOPARSER"},
+       "html":                {"L4"},
+       "image/draw":          {"L4"},
+       "image/gif":           {"L4", "compress/lzw"},
+       "image/jpeg":          {"L4"},
+       "image/png":           {"L4", "compress/zlib"},
+       "index/suffixarray":   {"L4", "regexp"},
+       "math/big":            {"L4"},
+       "mime":                {"L4", "OS", "syscall"},
+       "net/url":             {"L4"},
+       "text/scanner":        {"L4", "OS"},
+       "text/template/parse": {"L4"},
+
+       "html/template": {
+               "L4", "OS", "encoding/json", "html", "text/template",
+               "text/template/parse",
+       },
+       "text/template": {
+               "L4", "OS", "net/url", "text/template/parse",
+       },
+
+       // Cgo.
+       "runtime/cgo": {"L0", "C"},
+       "CGO":         {"C", "runtime/cgo"},
+
+       // Fake entry to satisfy the pseudo-import "C"
+       // that shows up in programs that use cgo.
+       "C": {},
+
+       "os/user": {"L4", "CGO", "syscall"},
+
+       // Basic networking.
+       // Because net must be used by any package that wants to
+       // do networking portably, it must have a small dependency set: just L1+basic os.
+       "net": {"L1", "CGO", "os", "syscall", "time"},
+
+       // NET enables use of basic network-related packages.
+       "NET": {
+               "net",
+               "mime",
+               "net/textproto",
+               "net/url",
+       },
+
+       // Uses of networking.
+       "log/syslog":    {"L4", "OS", "net"},
+       "net/mail":      {"L4", "NET", "OS"},
+       "net/textproto": {"L4", "OS", "net"},
+
+       // Core crypto.
+       "crypto/aes":    {"L3"},
+       "crypto/des":    {"L3"},
+       "crypto/hmac":   {"L3"},
+       "crypto/md5":    {"L3"},
+       "crypto/rc4":    {"L3"},
+       "crypto/sha1":   {"L3"},
+       "crypto/sha256": {"L3"},
+       "crypto/sha512": {"L3"},
+       "crypto/subtle": {"L3"},
+
+       "CRYPTO": {
+               "crypto/aes",
+               "crypto/des",
+               "crypto/hmac",
+               "crypto/md5",
+               "crypto/rc4",
+               "crypto/sha1",
+               "crypto/sha256",
+               "crypto/sha512",
+               "crypto/subtle",
+       },
+
+       // Random byte, number generation.
+       // This would be part of core crypto except that it imports
+       // math/big, which imports fmt.
+       "crypto/rand": {"L4", "CRYPTO", "OS", "math/big", "syscall"},
+
+       // Mathematical crypto: dependencies on fmt (L4) and math/big.
+       // We could avoid some of the fmt, but math/big imports fmt anyway.
+       "crypto/dsa":      {"L4", "CRYPTO", "math/big"},
+       "crypto/ecdsa":    {"L4", "CRYPTO", "crypto/elliptic", "math/big"},
+       "crypto/elliptic": {"L4", "CRYPTO", "math/big"},
+       "crypto/rsa":      {"L4", "CRYPTO", "crypto/rand", "math/big"},
+
+       "CRYPTO-MATH": {
+               "CRYPTO",
+               "crypto/dsa",
+               "crypto/ecdsa",
+               "crypto/elliptic",
+               "crypto/rand",
+               "crypto/rsa",
+               "encoding/asn1",
+               "math/big",
+       },
+
+       // SSL/TLS.
+       "crypto/tls": {
+               "L4", "CRYPTO-MATH", "CGO", "OS",
+               "crypto/x509", "encoding/pem", "net", "syscall",
+       },
+       "crypto/x509":      {"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/x509/pkix", "encoding/pem", "syscall"},
+       "crypto/x509/pkix": {"L4", "CRYPTO-MATH"},
+
+       // Simple net+crypto-aware packages.
+       "mime/multipart": {"L4", "OS", "mime", "crypto/rand", "net/textproto"},
+       "net/smtp":       {"L4", "CRYPTO", "NET", "crypto/tls"},
+
+       // HTTP, kingpin of dependencies.
+       "net/http": {
+               "L4", "NET", "OS",
+               "compress/gzip", "crypto/tls", "mime/multipart", "runtime/debug",
+       },
+
+       // HTTP-using packages.
+       "expvar":            {"L4", "OS", "encoding/json", "net/http"},
+       "net/http/cgi":      {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"},
+       "net/http/fcgi":     {"L4", "NET", "OS", "net/http", "net/http/cgi"},
+       "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http"},
+       "net/http/httputil": {"L4", "NET", "OS", "net/http"},
+       "net/http/pprof":    {"L4", "OS", "html/template", "net/http", "runtime/pprof"},
+       "net/rpc":           {"L4", "NET", "encoding/gob", "net/http", "text/template"},
+       "net/rpc/jsonrpc":   {"L4", "NET", "encoding/json", "net/rpc"},
+}
+
+// isMacro reports whether p is a package dependency macro
+// (uppercase name).
+func isMacro(p string) bool {
+       return 'A' <= p[0] && p[0] <= 'Z'
+}
+
+func allowed(pkg string) map[string]bool {
+       m := map[string]bool{}
+       var allow func(string)
+       allow = func(p string) {
+               if m[p] {
+                       return
+               }
+               m[p] = true // set even for macros, to avoid loop on cycle
+
+               // Upper-case names are macro-expanded.
+               if isMacro(p) {
+                       for _, pp := range pkgDeps[p] {
+                               allow(pp)
+                       }
+               }
+       }
+       for _, pp := range pkgDeps[pkg] {
+               allow(pp)
+       }
+       return m
+}
+
+var bools = []bool{false, true}
+var geese = []string{"darwin", "freebsd", "linux", "netbsd", "openbsd", "plan9", "windows"}
+var goarches = []string{"386", "amd64", "arm"}
+
+type osPkg struct {
+       goos, pkg string
+}
+
+// allowedErrors are the operating systems and packages known to contain errors
+// (currently just "no Go source files")
+var allowedErrors = map[osPkg]bool{
+       osPkg{"windows", "log/syslog"}: true,
+       osPkg{"plan9", "log/syslog"}:   true,
+}
+
+func TestDependencies(t *testing.T) {
+       var all []string
+
+       for k := range pkgDeps {
+               all = append(all, k)
+       }
+       sort.Strings(all)
+
+       ctxt := build.Default
+       test := func(mustImport bool) {
+               for _, pkg := range all {
+                       if isMacro(pkg) {
+                               continue
+                       }
+                       p, err := ctxt.Import(pkg, "", 0)
+                       if err != nil {
+                               if allowedErrors[osPkg{ctxt.GOOS, pkg}] {
+                                       continue
+                               }
+                               // Some of the combinations we try might not
+                               // be reasonable (like arm,plan9,cgo), so ignore
+                               // errors for the auto-generated combinations.
+                               if !mustImport {
+                                       continue
+                               }
+                               t.Errorf("%s/%s/cgo=%v %v", ctxt.GOOS, ctxt.GOARCH, ctxt.CgoEnabled, err)
+                               continue
+                       }
+                       ok := allowed(pkg)
+                       var bad []string
+                       for _, imp := range p.Imports {
+                               if !ok[imp] {
+                                       bad = append(bad, imp)
+                               }
+                       }
+                       if bad != nil {
+                               t.Errorf("%s/%s/cgo=%v unexpected dependency: %s imports %v", ctxt.GOOS, ctxt.GOARCH, ctxt.CgoEnabled, pkg, bad)
+                       }
+               }
+       }
+       test(true)
+
+       if testing.Short() {
+               t.Logf("skipping other systems")
+               return
+       }
+
+       for _, ctxt.GOOS = range geese {
+               for _, ctxt.GOARCH = range goarches {
+                       for _, ctxt.CgoEnabled = range bools {
+                               test(false)
+                       }
+               }
+       }
+}
diff --git a/libgo/go/go/parser/error_test.go b/libgo/go/go/parser/error_test.go
new file mode 100644 (file)
index 0000000..377c8b8
--- /dev/null
@@ -0,0 +1,166 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements a parser test harness. The files in the testdata
+// directory are parsed and the errors reported are compared against the
+// error messages expected in the test files. The test files must end in
+// .src rather than .go so that they are not disturbed by gofmt runs.
+//
+// Expected errors are indicated in the test files by putting a comment
+// of the form /* ERROR "rx" */ immediately following an offending token.
+// The harness will verify that an error matching the regular expression
+// rx is reported at that source position.
+//
+// For instance, the following test file indicates that a "not declared"
+// error should be reported for the undeclared variable x:
+//
+//     package p
+//     func f() {
+//             _ = x /* ERROR "not declared" */ + 1
+//     }
+
+package parser
+
+import (
+       "go/scanner"
+       "go/token"
+       "io/ioutil"
+       "path/filepath"
+       "regexp"
+       "strings"
+       "testing"
+)
+
+const testdata = "testdata"
+
+// getFile assumes that each filename occurs at most once
+func getFile(filename string) (file *token.File) {
+       fset.Iterate(func(f *token.File) bool {
+               if f.Name() == filename {
+                       if file != nil {
+                               panic(filename + " used multiple times")
+                       }
+                       file = f
+               }
+               return true
+       })
+       return file
+}
+
+func getPos(filename string, offset int) token.Pos {
+       if f := getFile(filename); f != nil {
+               return f.Pos(offset)
+       }
+       return token.NoPos
+}
+
+// ERROR comments must be of the form /* ERROR "rx" */ and rx is
+// a regular expression that matches the expected error message.
+//
+var errRx = regexp.MustCompile(`^/\* *ERROR *"([^"]*)" *\*/$`)
+
+// expectedErrors collects the regular expressions of ERROR comments found
+// in files and returns them as a map of error positions to error messages.
+//
+func expectedErrors(t *testing.T, filename string, src []byte) map[token.Pos]string {
+       errors := make(map[token.Pos]string)
+
+       var s scanner.Scanner
+       // file was parsed already - do not add it again to the file
+       // set otherwise the position information returned here will
+       // not match the position information collected by the parser
+       s.Init(getFile(filename), src, nil, scanner.ScanComments)
+       var prev token.Pos // position of last non-comment, non-semicolon token
+
+       for {
+               pos, tok, lit := s.Scan()
+               switch tok {
+               case token.EOF:
+                       return errors
+               case token.COMMENT:
+                       s := errRx.FindStringSubmatch(lit)
+                       if len(s) == 2 {
+                               errors[prev] = string(s[1])
+                       }
+               default:
+                       prev = pos
+               }
+       }
+
+       panic("unreachable")
+}
+
+// compareErrors compares the map of expected error messages with the list
+// of found errors and reports discrepancies.
+//
+func compareErrors(t *testing.T, expected map[token.Pos]string, found scanner.ErrorList) {
+       for _, error := range found {
+               // error.Pos is a token.Position, but we want
+               // a token.Pos so we can do a map lookup
+               pos := getPos(error.Pos.Filename, error.Pos.Offset)
+               if msg, found := expected[pos]; found {
+                       // we expect a message at pos; check if it matches
+                       rx, err := regexp.Compile(msg)
+                       if err != nil {
+                               t.Errorf("%s: %v", error.Pos, err)
+                               continue
+                       }
+                       if match := rx.MatchString(error.Msg); !match {
+                               t.Errorf("%s: %q does not match %q", error.Pos, error.Msg, msg)
+                               continue
+                       }
+                       // we have a match - eliminate this error
+                       delete(expected, pos)
+               } else {
+                       // To keep in mind when analyzing failed test output:
+                       // If the same error position occurs multiple times in errors,
+                       // this message will be triggered (because the first error at
+                       // the position removes this position from the expected errors).
+                       t.Errorf("%s: unexpected error: %s", error.Pos, error.Msg)
+               }
+       }
+
+       // there should be no expected errors left
+       if len(expected) > 0 {
+               t.Errorf("%d errors not reported:", len(expected))
+               for pos, msg := range expected {
+                       t.Errorf("%s: %s\n", fset.Position(pos), msg)
+               }
+       }
+}
+
+func checkErrors(t *testing.T, filename string, input interface{}) {
+       src, err := readSource(filename, input)
+       if err != nil {
+               t.Error(err)
+               return
+       }
+
+       _, err = ParseFile(fset, filename, src, DeclarationErrors)
+       found, ok := err.(scanner.ErrorList)
+       if err != nil && !ok {
+               t.Error(err)
+               return
+       }
+
+       // we are expecting the following errors
+       // (collect these after parsing a file so that it is found in the file set)
+       expected := expectedErrors(t, filename, src)
+
+       // verify errors returned by the parser
+       compareErrors(t, expected, found)
+}
+
+func TestErrors(t *testing.T) {
+       list, err := ioutil.ReadDir(testdata)
+       if err != nil {
+               t.Fatal(err)
+       }
+       for _, fi := range list {
+               name := fi.Name()
+               if !fi.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".src") {
+                       checkErrors(t, filepath.Join(testdata, name), nil)
+               }
+       }
+}
index a122baf..e362e13 100644 (file)
@@ -40,6 +40,13 @@ type parser struct {
        tok token.Token // one token look-ahead
        lit string      // token literal
 
+       // Error recovery
+       // (used to limit the number of calls to syncXXX functions
+       // w/o making scanning progress - avoids potential endless
+       // loops across multiple parser functions during error recovery)
+       syncPos token.Pos // last synchronization position
+       syncCnt int       // number of calls to syncXXX without progress
+
        // Non-syntactic parser control
        exprLev int // < 0: in control clause, >= 0: in expression
 
@@ -362,18 +369,36 @@ func (p *parser) expect(tok token.Token) token.Pos {
 // expectClosing is like expect but provides a better error message
 // for the common case of a missing comma before a newline.
 //
-func (p *parser) expectClosing(tok token.Token, construct string) token.Pos {
+func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
        if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
-               p.error(p.pos, "missing ',' before newline in "+construct)
+               p.error(p.pos, "missing ',' before newline in "+context)
                p.next()
        }
        return p.expect(tok)
 }
 
 func (p *parser) expectSemi() {
+       // semicolon is optional before a closing ')' or '}'
        if p.tok != token.RPAREN && p.tok != token.RBRACE {
-               p.expect(token.SEMICOLON)
+               if p.tok == token.SEMICOLON {
+                       p.next()
+               } else {
+                       p.errorExpected(p.pos, "';'")
+                       syncStmt(p)
+               }
+       }
+}
+
+func (p *parser) atComma(context string) bool {
+       if p.tok == token.COMMA {
+               return true
+       }
+       if p.tok == token.SEMICOLON && p.lit == "\n" {
+               p.error(p.pos, "missing ',' before newline in "+context)
+               return true // "insert" the comma and continue
+
        }
+       return false
 }
 
 func assert(cond bool, msg string) {
@@ -382,6 +407,68 @@ func assert(cond bool, msg string) {
        }
 }
 
+// syncStmt advances to the next statement.
+// Used for synchronization after an error.
+//
+func syncStmt(p *parser) {
+       for {
+               switch p.tok {
+               case token.BREAK, token.CONST, token.CONTINUE, token.DEFER,
+                       token.FALLTHROUGH, token.FOR, token.GO, token.GOTO,
+                       token.IF, token.RETURN, token.SELECT, token.SWITCH,
+                       token.TYPE, token.VAR:
+                       // Return only if parser made some progress since last
+                       // sync or if it has not reached 10 sync calls without
+                       // progress. Otherwise consume at least one token to
+                       // avoid an endless parser loop (it is possible that
+                       // both parseOperand and parseStmt call syncStmt and
+                       // correctly do not advance, thus the need for the
+                       // invocation limit p.syncCnt).
+                       if p.pos == p.syncPos && p.syncCnt < 10 {
+                               p.syncCnt++
+                               return
+                       }
+                       if p.pos > p.syncPos {
+                               p.syncPos = p.pos
+                               p.syncCnt = 0
+                               return
+                       }
+                       // Reaching here indicates a parser bug, likely an
+                       // incorrect token list in this function, but it only
+                       // leads to skipping of possibly correct code if a
+                       // previous error is present, and thus is preferred
+                       // over a non-terminating parse.
+               case token.EOF:
+                       return
+               }
+               p.next()
+       }
+}
+
+// syncDecl advances to the next declaration.
+// Used for synchronization after an error.
+//
+func syncDecl(p *parser) {
+       for {
+               switch p.tok {
+               case token.CONST, token.TYPE, token.VAR:
+                       // see comments in syncStmt
+                       if p.pos == p.syncPos && p.syncCnt < 10 {
+                               p.syncCnt++
+                               return
+                       }
+                       if p.pos > p.syncPos {
+                               p.syncPos = p.pos
+                               p.syncCnt = 0
+                               return
+                       }
+               case token.EOF:
+                       return
+               }
+               p.next()
+       }
+}
+
 // ----------------------------------------------------------------------------
 // Identifiers
 
@@ -522,9 +609,11 @@ func (p *parser) makeIdentList(list []ast.Expr) []*ast.Ident {
        for i, x := range list {
                ident, isIdent := x.(*ast.Ident)
                if !isIdent {
-                       pos := x.Pos()
-                       p.errorExpected(pos, "identifier")
-                       ident = &ast.Ident{NamePos: pos, Name: "_"}
+                       if _, isBad := x.(*ast.BadExpr); !isBad {
+                               // only report error if it's a new one
+                               p.errorExpected(x.Pos(), "identifier")
+                       }
+                       ident = &ast.Ident{NamePos: x.Pos(), Name: "_"}
                }
                idents[i] = ident
        }
@@ -688,7 +777,7 @@ func (p *parser) parseParameterList(scope *ast.Scope, ellipsisOk bool) (params [
                        // Go spec: The scope of an identifier denoting a function
                        // parameter or result variable is the function body.
                        p.declare(field, nil, scope, ast.Var, idents...)
-                       if p.tok != token.COMMA {
+                       if !p.atComma("parameter list") {
                                break
                        }
                        p.next()
@@ -991,19 +1080,19 @@ func (p *parser) parseOperand(lhs bool) ast.Expr {
 
        case token.FUNC:
                return p.parseFuncTypeOrLit()
+       }
 
-       default:
-               if typ := p.tryIdentOrType(true); typ != nil {
-                       // could be type for composite literal or conversion
-                       _, isIdent := typ.(*ast.Ident)
-                       assert(!isIdent, "type cannot be identifier")
-                       return typ
-               }
+       if typ := p.tryIdentOrType(true); typ != nil {
+               // could be type for composite literal or conversion
+               _, isIdent := typ.(*ast.Ident)
+               assert(!isIdent, "type cannot be identifier")
+               return typ
        }
 
+       // we have an error
        pos := p.pos
        p.errorExpected(pos, "operand")
-       p.next() // make progress
+       syncStmt(p)
        return &ast.BadExpr{From: pos, To: p.pos}
 }
 
@@ -1078,7 +1167,7 @@ func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
                        ellipsis = p.pos
                        p.next()
                }
-               if p.tok != token.COMMA {
+               if !p.atComma("argument list") {
                        break
                }
                p.next()
@@ -1118,7 +1207,7 @@ func (p *parser) parseElementList() (list []ast.Expr) {
 
        for p.tok != token.RBRACE && p.tok != token.EOF {
                list = append(list, p.parseElement(true))
-               if p.tok != token.COMMA {
+               if !p.atComma("composite literal") {
                        break
                }
                p.next()
@@ -1262,8 +1351,8 @@ L:
                                x = p.parseTypeAssertion(p.checkExpr(x))
                        default:
                                pos := p.pos
-                               p.next() // make progress
                                p.errorExpected(pos, "selector or type assertion")
+                               p.next() // make progress
                                x = &ast.BadExpr{From: pos, To: p.pos}
                        }
                case token.LBRACK:
@@ -1471,7 +1560,10 @@ func (p *parser) parseCallExpr() *ast.CallExpr {
        if call, isCall := x.(*ast.CallExpr); isCall {
                return call
        }
-       p.errorExpected(x.Pos(), "function/method call")
+       if _, isBad := x.(*ast.BadExpr); !isBad {
+               // only report error if it's a new one
+               p.errorExpected(x.Pos(), "function/method call")
+       }
        return nil
 }
 
@@ -1862,7 +1954,7 @@ func (p *parser) parseStmt() (s ast.Stmt) {
 
        switch p.tok {
        case token.CONST, token.TYPE, token.VAR:
-               s = &ast.DeclStmt{Decl: p.parseDecl()}
+               s = &ast.DeclStmt{Decl: p.parseDecl(syncStmt)}
        case
                // tokens that may start an expression
                token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
@@ -1904,7 +1996,7 @@ func (p *parser) parseStmt() (s ast.Stmt) {
                // no statement found
                pos := p.pos
                p.errorExpected(pos, "statement")
-               p.next() // make progress
+               syncStmt(p)
                s = &ast.BadStmt{From: pos, To: p.pos}
        }
 
@@ -2095,8 +2187,13 @@ func (p *parser) parseReceiver(scope *ast.Scope) *ast.FieldList {
        recv := par.List[0]
        base := deref(recv.Type)
        if _, isIdent := base.(*ast.Ident); !isIdent {
-               p.errorExpected(base.Pos(), "(unqualified) identifier")
-               par.List = []*ast.Field{{Type: &ast.BadExpr{From: recv.Pos(), To: recv.End()}}}
+               if _, isBad := base.(*ast.BadExpr); !isBad {
+                       // only report error if it's a new one
+                       p.errorExpected(base.Pos(), "(unqualified) identifier")
+               }
+               par.List = []*ast.Field{
+                       {Type: &ast.BadExpr{From: recv.Pos(), To: recv.End()}},
+               }
        }
 
        return par
@@ -2152,7 +2249,7 @@ func (p *parser) parseFuncDecl() *ast.FuncDecl {
        return decl
 }
 
-func (p *parser) parseDecl() ast.Decl {
+func (p *parser) parseDecl(sync func(*parser)) ast.Decl {
        if p.trace {
                defer un(trace(p, "Declaration"))
        }
@@ -2174,9 +2271,8 @@ func (p *parser) parseDecl() ast.Decl {
        default:
                pos := p.pos
                p.errorExpected(pos, "declaration")
-               p.next() // make progress
-               decl := &ast.BadDecl{From: pos, To: p.pos}
-               return decl
+               sync(p)
+               return &ast.BadDecl{From: pos, To: p.pos}
        }
 
        return p.parseGenDecl(p.tok, f)
@@ -2215,7 +2311,7 @@ func (p *parser) parseFile() *ast.File {
                if p.mode&ImportsOnly == 0 {
                        // rest of package body
                        for p.tok != token.EOF {
-                               decls = append(decls, p.parseDecl())
+                               decls = append(decls, p.parseDecl(syncDecl))
                        }
                }
        }
index 93ca3d6..5e45acd 100644 (file)
@@ -14,87 +14,14 @@ import (
 
 var fset = token.NewFileSet()
 
-var illegalInputs = []interface{}{
-       nil,
-       3.14,
-       []byte(nil),
-       "foo!",
-       `package p; func f() { if /* should have condition */ {} };`,
-       `package p; func f() { if ; /* should have condition */ {} };`,
-       `package p; func f() { if f(); /* should have condition */ {} };`,
-       `package p; const c; /* should have constant value */`,
-       `package p; func f() { if _ = range x; true {} };`,
-       `package p; func f() { switch _ = range x; true {} };`,
-       `package p; func f() { for _ = range x ; ; {} };`,
-       `package p; func f() { for ; ; _ = range x {} };`,
-       `package p; func f() { for ; _ = range x ; {} };`,
-       `package p; func f() { switch t = t.(type) {} };`,
-       `package p; func f() { switch t, t = t.(type) {} };`,
-       `package p; func f() { switch t = t.(type), t {} };`,
-       `package p; var a = [1]int; /* illegal expression */`,
-       `package p; var a = [...]int; /* illegal expression */`,
-       `package p; var a = struct{} /* illegal expression */`,
-       `package p; var a = func(); /* illegal expression */`,
-       `package p; var a = interface{} /* illegal expression */`,
-       `package p; var a = []int /* illegal expression */`,
-       `package p; var a = map[int]int /* illegal expression */`,
-       `package p; var a = chan int; /* illegal expression */`,
-       `package p; var a = []int{[]int}; /* illegal expression */`,
-       `package p; var a = ([]int); /* illegal expression */`,
-       `package p; var a = a[[]int:[]int]; /* illegal expression */`,
-       `package p; var a = <- chan int; /* illegal expression */`,
-       `package p; func f() { select { case _ <- chan int: } };`,
-}
-
-func TestParseIllegalInputs(t *testing.T) {
-       for _, src := range illegalInputs {
-               _, err := ParseFile(fset, "", src, 0)
-               if err == nil {
-                       t.Errorf("ParseFile(%v) should have failed", src)
-               }
-       }
-}
-
-var validPrograms = []string{
-       "package p\n",
-       `package p;`,
-       `package p; import "fmt"; func f() { fmt.Println("Hello, World!") };`,
-       `package p; func f() { if f(T{}) {} };`,
-       `package p; func f() { _ = (<-chan int)(x) };`,
-       `package p; func f() { _ = (<-chan <-chan int)(x) };`,
-       `package p; func f(func() func() func());`,
-       `package p; func f(...T);`,
-       `package p; func f(float, ...int);`,
-       `package p; func f(x int, a ...int) { f(0, a...); f(1, a...,) };`,
-       `package p; func f(int,) {};`,
-       `package p; func f(...int,) {};`,
-       `package p; func f(x ...int,) {};`,
-       `package p; type T []int; var a []bool; func f() { if a[T{42}[0]] {} };`,
-       `package p; type T []int; func g(int) bool { return true }; func f() { if g(T{42}[0]) {} };`,
-       `package p; type T []int; func f() { for _ = range []int{T{42}[0]} {} };`,
-       `package p; var a = T{{1, 2}, {3, 4}}`,
-       `package p; func f() { select { case <- c: case c <- d: case c <- <- d: case <-c <- d: } };`,
-       `package p; func f() { select { case x := (<-c): } };`,
-       `package p; func f() { if ; true {} };`,
-       `package p; func f() { switch ; {} };`,
-       `package p; func f() { for _ = range "foo" + "bar" {} };`,
-}
-
-func TestParseValidPrograms(t *testing.T) {
-       for _, src := range validPrograms {
-               _, err := ParseFile(fset, "", src, SpuriousErrors)
-               if err != nil {
-                       t.Errorf("ParseFile(%q): %v", src, err)
-               }
-       }
-}
-
 var validFiles = []string{
        "parser.go",
        "parser_test.go",
+       "error_test.go",
+       "short_test.go",
 }
 
-func TestParse3(t *testing.T) {
+func TestParse(t *testing.T) {
        for _, filename := range validFiles {
                _, err := ParseFile(fset, filename, nil, DeclarationErrors)
                if err != nil {
@@ -116,7 +43,7 @@ func nameFilter(filename string) bool {
 
 func dirFilter(f os.FileInfo) bool { return nameFilter(f.Name()) }
 
-func TestParse4(t *testing.T) {
+func TestParseDir(t *testing.T) {
        path := "."
        pkgs, err := ParseDir(fset, path, dirFilter, 0)
        if err != nil {
@@ -158,7 +85,7 @@ func TestParseExpr(t *testing.T) {
        }
 
        // it must not crash
-       for _, src := range validPrograms {
+       for _, src := range valids {
                ParseExpr(src)
        }
 }
diff --git a/libgo/go/go/parser/short_test.go b/libgo/go/go/parser/short_test.go
new file mode 100644 (file)
index 0000000..238492b
--- /dev/null
@@ -0,0 +1,75 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains test cases for short valid and invalid programs.
+
+package parser
+
+import "testing"
+
+var valids = []string{
+       "package p\n",
+       `package p;`,
+       `package p; import "fmt"; func f() { fmt.Println("Hello, World!") };`,
+       `package p; func f() { if f(T{}) {} };`,
+       `package p; func f() { _ = (<-chan int)(x) };`,
+       `package p; func f() { _ = (<-chan <-chan int)(x) };`,
+       `package p; func f(func() func() func());`,
+       `package p; func f(...T);`,
+       `package p; func f(float, ...int);`,
+       `package p; func f(x int, a ...int) { f(0, a...); f(1, a...,) };`,
+       `package p; func f(int,) {};`,
+       `package p; func f(...int,) {};`,
+       `package p; func f(x ...int,) {};`,
+       `package p; type T []int; var a []bool; func f() { if a[T{42}[0]] {} };`,
+       `package p; type T []int; func g(int) bool { return true }; func f() { if g(T{42}[0]) {} };`,
+       `package p; type T []int; func f() { for _ = range []int{T{42}[0]} {} };`,
+       `package p; var a = T{{1, 2}, {3, 4}}`,
+       `package p; func f() { select { case <- c: case c <- d: case c <- <- d: case <-c <- d: } };`,
+       `package p; func f() { select { case x := (<-c): } };`,
+       `package p; func f() { if ; true {} };`,
+       `package p; func f() { switch ; {} };`,
+       `package p; func f() { for _ = range "foo" + "bar" {} };`,
+}
+
+func TestValid(t *testing.T) {
+       for _, src := range valids {
+               checkErrors(t, src, src)
+       }
+}
+
+var invalids = []string{
+       `foo /* ERROR "expected 'package'" */ !`,
+       `package p; func f() { if { /* ERROR "expected operand" */ } };`,
+       `package p; func f() { if ; { /* ERROR "expected operand" */ } };`,
+       `package p; func f() { if f(); { /* ERROR "expected operand" */ } };`,
+       `package p; const c; /* ERROR "expected '='" */`,
+       `package p; func f() { if _ /* ERROR "expected condition" */ = range x; true {} };`,
+       `package p; func f() { switch _ /* ERROR "expected condition" */ = range x; true {} };`,
+       `package p; func f() { for _ = range x ; /* ERROR "expected '{'" */ ; {} };`,
+       `package p; func f() { for ; ; _ = range /* ERROR "expected operand" */ x {} };`,
+       `package p; func f() { for ; _ /* ERROR "expected condition" */ = range x ; {} };`,
+       `package p; func f() { switch t /* ERROR "expected condition" */ = t.(type) {} };`,
+       `package p; func f() { switch t /* ERROR "expected condition" */ , t = t.(type) {} };`,
+       `package p; func f() { switch t /* ERROR "expected condition" */ = t.(type), t {} };`,
+       `package p; var a = [ /* ERROR "expected expression" */ 1]int;`,
+       `package p; var a = [ /* ERROR "expected expression" */ ...]int;`,
+       `package p; var a = struct /* ERROR "expected expression" */ {}`,
+       `package p; var a = func /* ERROR "expected expression" */ ();`,
+       `package p; var a = interface /* ERROR "expected expression" */ {}`,
+       `package p; var a = [ /* ERROR "expected expression" */ ]int`,
+       `package p; var a = map /* ERROR "expected expression" */ [int]int`,
+       `package p; var a = chan /* ERROR "expected expression" */ int;`,
+       `package p; var a = []int{[ /* ERROR "expected expression" */ ]int};`,
+       `package p; var a = ( /* ERROR "expected expression" */ []int);`,
+       `package p; var a = a[[ /* ERROR "expected expression" */ ]int:[]int];`,
+       `package p; var a = <-  /* ERROR "expected expression" */ chan int;`,
+       `package p; func f() { select { case _ <- chan  /* ERROR "expected expression" */ int: } };`,
+}
+
+func TestInvalid(t *testing.T) {
+       for _, src := range invalids {
+               checkErrors(t, src, src)
+       }
+}
diff --git a/libgo/go/go/parser/testdata/commas.src b/libgo/go/go/parser/testdata/commas.src
new file mode 100644 (file)
index 0000000..af6e706
--- /dev/null
@@ -0,0 +1,19 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Test case for error messages/parser synchronization
+// after missing commas.
+
+package p
+
+var _ = []int{
+       0 /* ERROR "missing ','" */
+}
+
+var _ = []int{
+       0,
+       1,
+       2,
+       3 /* ERROR "missing ','" */
+}
diff --git a/libgo/go/go/parser/testdata/issue3106.src b/libgo/go/go/parser/testdata/issue3106.src
new file mode 100644 (file)
index 0000000..82796c8
--- /dev/null
@@ -0,0 +1,46 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Test case for issue 3106: Better synchronization of
+// parser after certain syntax errors.
+
+package main
+
+func f() {
+       var m Mutex
+       c := MakeCond(&m)
+       percent := 0
+       const step = 10
+       for i := 0; i < 5; i++ {
+               go func() {
+                       for {
+                               // Emulates some useful work.
+                               time.Sleep(1e8)
+                               m.Lock()
+                               defer
+                               if /* ERROR "expected operand, found 'if'" */ percent == 100 {
+                                       m.Unlock()
+                                       break
+                               }
+                               percent++
+                               if percent % step == 0 {
+                                       //c.Signal()
+                               }
+                               m.Unlock()
+                       }
+               }()
+       }
+       for {
+               m.Lock()
+               if percent == 0 || percent % step != 0 {
+                       c.Wait()
+               }
+               fmt.Print(",")
+               if percent == 100 {
+                       m.Unlock()
+                       break
+               }
+               m.Unlock()
+       }
+}
index 05b4ef5..6be3c09 100644 (file)
@@ -15,7 +15,7 @@ import (
        "unicode/utf8"
 )
 
-// Other formatting issues:
+// Formatting issues:
 // - better comment formatting for /*-style comments at the end of a line (e.g. a declaration)
 //   when the comment spans multiple lines; if such a comment is just two lines, formatting is
 //   not idempotent
@@ -964,6 +964,41 @@ func (p *printer) controlClause(isForStmt bool, init ast.Stmt, expr ast.Expr, po
        }
 }
 
+// indentList reports whether an expression list would look better if it
+// were indented wholesale (starting with the very first element, rather
+// than starting at the first line break).
+//
+func (p *printer) indentList(list []ast.Expr) bool {
+       // Heuristic: indentList returns true if there are more than one multi-
+       // line element in the list, or if there is any element that is not
+       // starting on the same line as the previous one ends.
+       if len(list) >= 2 {
+               var b = p.lineFor(list[0].Pos())
+               var e = p.lineFor(list[len(list)-1].End())
+               if 0 < b && b < e {
+                       // list spans multiple lines
+                       n := 0 // multi-line element count
+                       line := b
+                       for _, x := range list {
+                               xb := p.lineFor(x.Pos())
+                               xe := p.lineFor(x.End())
+                               if line < xb {
+                                       // x is not starting on the same
+                                       // line as the previous one ended
+                                       return true
+                               }
+                               if xb < xe {
+                                       // x is a multi-line element
+                                       n++
+                               }
+                               line = xe
+                       }
+                       return n > 1
+               }
+       }
+       return false
+}
+
 func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
        p.print(stmt.Pos())
 
@@ -1030,7 +1065,18 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                p.print(token.RETURN)
                if s.Results != nil {
                        p.print(blank)
-                       p.exprList(s.Pos(), s.Results, 1, 0, token.NoPos)
+                       // Use indentList heuristic to make corner cases look
+                       // better (issue 1207). A more systematic approach would
+                       // always indent, but this would cause significant
+                       // reformatting of the code base and not necessarily
+                       // lead to more nicely formatted code in general.
+                       if p.indentList(s.Results) {
+                               p.print(indent)
+                               p.exprList(s.Pos(), s.Results, 1, noIndent, token.NoPos)
+                               p.print(unindent)
+                       } else {
+                               p.exprList(s.Pos(), s.Results, 1, 0, token.NoPos)
+                       }
                }
 
        case *ast.BranchStmt:
@@ -1200,9 +1246,9 @@ func keepTypeColumn(specs []ast.Spec) []bool {
        return m
 }
 
-func (p *printer) valueSpec(s *ast.ValueSpec, keepType, doIndent bool) {
+func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool) {
        p.setComment(s.Doc)
-       p.identList(s.Names, doIndent) // always present
+       p.identList(s.Names, false) // always present
        extraTabs := 3
        if s.Type != nil || keepType {
                p.print(vtab)
@@ -1290,7 +1336,7 @@ func (p *printer) genDecl(d *ast.GenDecl) {
                                        if i > 0 {
                                                p.linebreak(p.lineFor(s.Pos()), 1, ignore, newSection)
                                        }
-                                       p.valueSpec(s.(*ast.ValueSpec), keepType[i], false)
+                                       p.valueSpec(s.(*ast.ValueSpec), keepType[i])
                                        newSection = p.isMultiLine(s)
                                }
                        } else {
index ffca21e..4d70617 100644 (file)
@@ -55,12 +55,24 @@ func _f() {
        return T{
                1,
                2,
-       },
+       }, nil
+       return T{
+                       1,
+                       2,
+               },
+               T{
+                       x:      3,
+                       y:      4,
+               }, nil
+       return T{
+                       1,
+                       2,
+               },
                nil
        return T{
-               1,
-               2,
-       },
+                       1,
+                       2,
+               },
                T{
                        x:      3,
                        y:      4,
@@ -70,10 +82,10 @@ func _f() {
                z
        return func() {}
        return func() {
-               _ = 0
-       }, T{
-               1, 2,
-       }
+                       _ = 0
+               }, T{
+                       1, 2,
+               }
        return func() {
                _ = 0
        }
@@ -84,6 +96,37 @@ func _f() {
        }
 }
 
+// Formatting of multi-line returns: test cases from issue 1207.
+func F() (*T, os.Error) {
+       return &T{
+                       X:      1,
+                       Y:      2,
+               },
+               nil
+}
+
+func G() (*T, *T, os.Error) {
+       return &T{
+                       X:      1,
+                       Y:      2,
+               },
+               &T{
+                       X:      3,
+                       Y:      4,
+               },
+               nil
+}
+
+func _() interface{} {
+       return &fileStat{
+               name:           basename(file.name),
+               size:           mkSize(d.FileSizeHigh, d.FileSizeLow),
+               modTime:        mkModTime(d.LastWriteTime),
+               mode:           mkMode(d.FileAttributes),
+               sys:            mkSysFromFI(&d),
+       }, nil
+}
+
 // Formatting of if-statement headers.
 func _() {
        if true {
index 99945e9..bd03bc9 100644 (file)
@@ -55,6 +55,18 @@ func _f() {
        return T{
                        1,
                        2,
+               }, nil
+       return T{
+                       1,
+                       2,
+               },
+               T{
+                       x: 3,
+                       y: 4,
+               }, nil
+       return T{
+                       1,
+                       2,
                },
                nil
        return T{
@@ -84,6 +96,37 @@ func _f() {
        }
 }
 
+// Formatting of multi-line returns: test cases from issue 1207.
+func F() (*T, os.Error) {
+       return &T{
+               X: 1,
+               Y: 2,
+       },
+               nil
+}
+
+func G() (*T, *T, os.Error) {
+       return &T{
+               X: 1,
+               Y: 2,
+       },
+               &T{
+                       X: 3,
+                       Y: 4,
+               },
+               nil
+}
+
+func _() interface{} {
+       return &fileStat{
+                       name:    basename(file.name),
+                       size:    mkSize(d.FileSizeHigh, d.FileSizeLow),
+                       modTime: mkModTime(d.LastWriteTime),
+                       mode:    mkMode(d.FileAttributes),
+                       sys:     mkSysFromFI(&d),
+               }, nil
+}
+
 // Formatting of if-statement headers.
 func _() {
        if true {}
index 2395363..da50874 100644 (file)
@@ -109,7 +109,7 @@ const (
 func (s *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode Mode) {
        // Explicitly initialize all fields since a scanner may be reused.
        if file.Size() != len(src) {
-               panic("file size does not match src len")
+               panic(fmt.Sprintf("file size (%d) does not match src len (%d)", file.Size(), len(src)))
        }
        s.file = file
        s.dir, _ = filepath.Split(file.Name())
index 3699ea1..f470fac 100644 (file)
@@ -29,7 +29,7 @@ can be safely embedded in an HTML document. The escaping is contextual, so
 actions can appear within JavaScript, CSS, and URI contexts.
 
 The security model used by this package assumes that template authors are
-trusted, while text/template Execute's data parameter is not. More details are
+trusted, while Execute's data parameter is not. More details are
 provided below.
 
 Example
index 7074834..54bf159 100644 (file)
@@ -173,6 +173,13 @@ type ReaderAt interface {
 // at offset off.  It returns the number of bytes written from p (0 <= n <= len(p))
 // and any error encountered that caused the write to stop early.
 // WriteAt must return a non-nil error if it returns n < len(p).
+//
+// If WriteAt is writing to a destination with a seek offset,
+// WriteAt should not affect nor be affected by the underlying
+// seek offset.
+//
+// Clients of WriteAt can execute parallel WriteAt calls on the same
+// destination if the ranges do not overlap.
 type WriterAt interface {
        WriteAt(p []byte, off int64) (n int, err error)
 }
index 02a407e..1d7f209 100644 (file)
@@ -13,8 +13,6 @@
 package log
 
 import (
-       "bytes"
-       _ "debug/elf"
        "fmt"
        "io"
        "os"
@@ -29,7 +27,7 @@ const (
        // order they appear (the order listed here) or the format they present (as
        // described in the comments).  A colon appears after these items:
        //      2009/0123 01:23:23.123123 /a/b/c/d.go:23: message
-       Ldate         = 1 << iota     // the date: 2009/0123
+       Ldate         = 1 << iota     // the date: 2009/01/23
        Ltime                         // the time: 01:23:23
        Lmicroseconds                 // microsecond resolution: 01:23:23.123123.  assumes Ltime.
        Llongfile                     // full file name and line number: /a/b/c/d.go:23
@@ -42,11 +40,11 @@ const (
 // the Writer's Write method.  A Logger can be used simultaneously from
 // multiple goroutines; it guarantees to serialize access to the Writer.
 type Logger struct {
-       mu     sync.Mutex   // ensures atomic writes; protects the following fields
-       prefix string       // prefix to write at beginning of each line
-       flag   int          // properties
-       out    io.Writer    // destination for output
-       buf    bytes.Buffer // for accumulating text to write
+       mu     sync.Mutex // ensures atomic writes; protects the following fields
+       prefix string     // prefix to write at beginning of each line
+       flag   int        // properties
+       out    io.Writer  // destination for output
+       buf    []byte     // for accumulating text to write
 }
 
 // New creates a new Logger.   The out variable sets the
@@ -61,10 +59,10 @@ var std = New(os.Stderr, "", LstdFlags)
 
 // Cheap integer to fixed-width decimal ASCII.  Give a negative width to avoid zero-padding.
 // Knows the buffer has capacity.
-func itoa(buf *bytes.Buffer, i int, wid int) {
+func itoa(buf *[]byte, i int, wid int) {
        var u uint = uint(i)
        if u == 0 && wid <= 1 {
-               buf.WriteByte('0')
+               *buf = append(*buf, '0')
                return
        }
 
@@ -76,38 +74,33 @@ func itoa(buf *bytes.Buffer, i int, wid int) {
                wid--
                b[bp] = byte(u%10) + '0'
        }
-
-       // avoid slicing b to avoid an allocation.
-       for bp < len(b) {
-               buf.WriteByte(b[bp])
-               bp++
-       }
+       *buf = append(*buf, b[bp:]...)
 }
 
-func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, file string, line int) {
-       buf.WriteString(l.prefix)
+func (l *Logger) formatHeader(buf *[]byte, t time.Time, file string, line int) {
+       *buf = append(*buf, l.prefix...)
        if l.flag&(Ldate|Ltime|Lmicroseconds) != 0 {
                if l.flag&Ldate != 0 {
                        year, month, day := t.Date()
                        itoa(buf, year, 4)
-                       buf.WriteByte('/')
+                       *buf = append(*buf, '/')
                        itoa(buf, int(month), 2)
-                       buf.WriteByte('/')
+                       *buf = append(*buf, '/')
                        itoa(buf, day, 2)
-                       buf.WriteByte(' ')
+                       *buf = append(*buf, ' ')
                }
                if l.flag&(Ltime|Lmicroseconds) != 0 {
                        hour, min, sec := t.Clock()
                        itoa(buf, hour, 2)
-                       buf.WriteByte(':')
+                       *buf = append(*buf, ':')
                        itoa(buf, min, 2)
-                       buf.WriteByte(':')
+                       *buf = append(*buf, ':')
                        itoa(buf, sec, 2)
                        if l.flag&Lmicroseconds != 0 {
-                               buf.WriteByte('.')
+                               *buf = append(*buf, '.')
                                itoa(buf, t.Nanosecond()/1e3, 6)
                        }
-                       buf.WriteByte(' ')
+                       *buf = append(*buf, ' ')
                }
        }
        if l.flag&(Lshortfile|Llongfile) != 0 {
@@ -121,10 +114,10 @@ func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, file string, line
                        }
                        file = short
                }
-               buf.WriteString(file)
-               buf.WriteByte(':')
+               *buf = append(*buf, file...)
+               *buf = append(*buf, ':')
                itoa(buf, line, -1)
-               buf.WriteString(": ")
+               *buf = append(*buf, ": "...)
        }
 }
 
@@ -151,13 +144,13 @@ func (l *Logger) Output(calldepth int, s string) error {
                }
                l.mu.Lock()
        }
-       l.buf.Reset()
+       l.buf = l.buf[:0]
        l.formatHeader(&l.buf, now, file, line)
-       l.buf.WriteString(s)
+       l.buf = append(l.buf, s...)
        if len(s) > 0 && s[len(s)-1] != '\n' {
-               l.buf.WriteByte('\n')
+               l.buf = append(l.buf, '\n')
        }
-       _, err := l.out.Write(l.buf.Bytes())
+       _, err := l.out.Write(l.buf)
        return err
 }
 
index 5f5aea1..7212087 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "flag"
+       "fmt"
        "regexp"
        "runtime"
        "testing"
@@ -32,7 +33,7 @@ func TestDialTimeout(t *testing.T) {
        numConns := listenerBacklog + 10
 
        // TODO(bradfitz): It's hard to test this in a portable
-       // way. This is unforunate, but works for now.
+       // way. This is unfortunate, but works for now.
        switch runtime.GOOS {
        case "linux":
                // The kernel will start accepting TCP connections before userspace
@@ -44,13 +45,25 @@ func TestDialTimeout(t *testing.T) {
                                errc <- err
                        }()
                }
-       case "darwin":
+       case "darwin", "windows":
                // At least OS X 10.7 seems to accept any number of
                // connections, ignoring listen's backlog, so resort
                // to connecting to a hopefully-dead 127/8 address.
                // Same for windows.
+               //
+               // Use an IANA reserved port (49151) instead of 80, because
+               // on our 386 builder, this Dial succeeds, connecting
+               // to an IIS web server somewhere.  The data center
+               // or VM or firewall must be stealing the TCP connection.
+               // 
+               // IANA Service Name and Transport Protocol Port Number Registry
+               // <http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xml>
                go func() {
-                       _, err := DialTimeout("tcp", "127.0.71.111:80", 200*time.Millisecond)
+                       c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond)
+                       if err == nil {
+                               err = fmt.Errorf("unexpected: connected to %s!", c.RemoteAddr())
+                               c.Close()
+                       }
                        errc <- err
                }()
        default:
index f4ed8b8..e69cb31 100644 (file)
@@ -5,8 +5,6 @@
 package net
 
 import (
-       "bytes"
-       "fmt"
        "math/rand"
        "sort"
 )
@@ -45,20 +43,22 @@ func reverseaddr(addr string) (arpa string, err error) {
                return "", &DNSError{Err: "unrecognized address", Name: addr}
        }
        if ip.To4() != nil {
-               return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil
+               return itoa(int(ip[15])) + "." + itoa(int(ip[14])) + "." + itoa(int(ip[13])) + "." +
+                       itoa(int(ip[12])) + ".in-addr.arpa.", nil
        }
        // Must be IPv6
-       var buf bytes.Buffer
+       buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
        // Add it, in reverse, to the buffer
        for i := len(ip) - 1; i >= 0; i-- {
-               s := fmt.Sprintf("%02x", ip[i])
-               buf.WriteByte(s[1])
-               buf.WriteByte('.')
-               buf.WriteByte(s[0])
-               buf.WriteByte('.')
+               v := ip[i]
+               buf = append(buf, hexDigit[v&0xF])
+               buf = append(buf, '.')
+               buf = append(buf, hexDigit[v>>4])
+               buf = append(buf, '.')
        }
        // Append "ip6.arpa." and return (buf already has the final .)
-       return buf.String() + "ip6.arpa.", nil
+       buf = append(buf, "ip6.arpa."...)
+       return string(buf), nil
 }
 
 // Find answer for name in dns message.
index 97c5062..b6ebe11 100644 (file)
@@ -7,11 +7,10 @@
 // This is intended to support name resolution during Dial.
 // It doesn't have to be blazing fast.
 //
-// Rather than write the usual handful of routines to pack and
-// unpack every message that can appear on the wire, we use
-// reflection to write a generic pack/unpack for structs and then
-// use it.  Thus, if in the future we need to define new message
-// structs, no new pack/unpack/printing code needs to be written.
+// Each message structure has a Walk method that is used by
+// a generic pack/unpack routine. Thus, if in the future we need
+// to define new message structs, no new pack/unpack/printing code
+// needs to be written.
 //
 // The first half of this file defines the DNS message formats.
 // The second half implements the conversion to and from wire format.
 
 package net
 
-import (
-       "fmt"
-       "os"
-       "reflect"
-)
-
 // Packet formats
 
 // Wire constants.
@@ -75,6 +68,20 @@ const (
        dnsRcodeRefused        = 5
 )
 
+// A dnsStruct describes how to iterate over its fields to emulate
+// reflective marshalling.
+type dnsStruct interface {
+       // Walk iterates over fields of a structure and calls f
+       // with a reference to that field, the name of the field
+       // and a tag ("", "domain", "ipv4", "ipv6") specifying
+       // particular encodings. Possible concrete types
+       // for v are *uint16, *uint32, *string, or []byte, and
+       // *int, *bool in the case of dnsMsgHdr.
+       // Whenever f returns false, Walk must stop and return
+       // false, and otherwise return true.
+       Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
+}
+
 // The wire format for the DNS packet header.
 type dnsHeader struct {
        Id                                 uint16
@@ -82,6 +89,15 @@ type dnsHeader struct {
        Qdcount, Ancount, Nscount, Arcount uint16
 }
 
+func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&h.Id, "Id", "") &&
+               f(&h.Bits, "Bits", "") &&
+               f(&h.Qdcount, "Qdcount", "") &&
+               f(&h.Ancount, "Ancount", "") &&
+               f(&h.Nscount, "Nscount", "") &&
+               f(&h.Arcount, "Arcount", "")
+}
+
 const (
        // dnsHeader.Bits
        _QR = 1 << 15 // query/response (response=1)
@@ -98,6 +114,12 @@ type dnsQuestion struct {
        Qclass uint16
 }
 
+func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&q.Name, "Name", "domain") &&
+               f(&q.Qtype, "Qtype", "") &&
+               f(&q.Qclass, "Qclass", "")
+}
+
 // DNS responses (resource records).
 // There are many types of messages,
 // but they all share the same header.
@@ -113,7 +135,16 @@ func (h *dnsRR_Header) Header() *dnsRR_Header {
        return h
 }
 
+func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&h.Name, "Name", "domain") &&
+               f(&h.Rrtype, "Rrtype", "") &&
+               f(&h.Class, "Class", "") &&
+               f(&h.Ttl, "Ttl", "") &&
+               f(&h.Rdlength, "Rdlength", "")
+}
+
 type dnsRR interface {
+       dnsStruct
        Header() *dnsRR_Header
 }
 
@@ -128,6 +159,10 @@ func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
+}
+
 type dnsRR_HINFO struct {
        Hdr dnsRR_Header
        Cpu string
@@ -138,6 +173,10 @@ func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_HINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Cpu, "Cpu", "") && f(&rr.Os, "Os", "")
+}
+
 type dnsRR_MB struct {
        Hdr dnsRR_Header
        Mb  string `net:"domain-name"`
@@ -147,6 +186,10 @@ func (rr *dnsRR_MB) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MB) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mb, "Mb", "domain")
+}
+
 type dnsRR_MG struct {
        Hdr dnsRR_Header
        Mg  string `net:"domain-name"`
@@ -156,6 +199,10 @@ func (rr *dnsRR_MG) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MG) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mg, "Mg", "domain")
+}
+
 type dnsRR_MINFO struct {
        Hdr   dnsRR_Header
        Rmail string `net:"domain-name"`
@@ -166,6 +213,10 @@ func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Rmail, "Rmail", "domain") && f(&rr.Email, "Email", "domain")
+}
+
 type dnsRR_MR struct {
        Hdr dnsRR_Header
        Mr  string `net:"domain-name"`
@@ -175,6 +226,10 @@ func (rr *dnsRR_MR) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MR) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mr, "Mr", "domain")
+}
+
 type dnsRR_MX struct {
        Hdr  dnsRR_Header
        Pref uint16
@@ -185,6 +240,10 @@ func (rr *dnsRR_MX) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
+}
+
 type dnsRR_NS struct {
        Hdr dnsRR_Header
        Ns  string `net:"domain-name"`
@@ -194,6 +253,10 @@ func (rr *dnsRR_NS) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
+}
+
 type dnsRR_PTR struct {
        Hdr dnsRR_Header
        Ptr string `net:"domain-name"`
@@ -203,6 +266,10 @@ func (rr *dnsRR_PTR) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
+}
+
 type dnsRR_SOA struct {
        Hdr     dnsRR_Header
        Ns      string `net:"domain-name"`
@@ -218,6 +285,17 @@ func (rr *dnsRR_SOA) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) &&
+               f(&rr.Ns, "Ns", "domain") &&
+               f(&rr.Mbox, "Mbox", "domain") &&
+               f(&rr.Serial, "Serial", "") &&
+               f(&rr.Refresh, "Refresh", "") &&
+               f(&rr.Retry, "Retry", "") &&
+               f(&rr.Expire, "Expire", "") &&
+               f(&rr.Minttl, "Minttl", "")
+}
+
 type dnsRR_TXT struct {
        Hdr dnsRR_Header
        Txt string // not domain name
@@ -227,6 +305,10 @@ func (rr *dnsRR_TXT) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Txt, "Txt", "")
+}
+
 type dnsRR_SRV struct {
        Hdr      dnsRR_Header
        Priority uint16
@@ -239,6 +321,14 @@ func (rr *dnsRR_SRV) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) &&
+               f(&rr.Priority, "Priority", "") &&
+               f(&rr.Weight, "Weight", "") &&
+               f(&rr.Port, "Port", "") &&
+               f(&rr.Target, "Target", "domain")
+}
+
 type dnsRR_A struct {
        Hdr dnsRR_Header
        A   uint32 `net:"ipv4"`
@@ -248,6 +338,10 @@ func (rr *dnsRR_A) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
+}
+
 type dnsRR_AAAA struct {
        Hdr  dnsRR_Header
        AAAA [16]byte `net:"ipv6"`
@@ -257,6 +351,10 @@ func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
+}
+
 // Packing and unpacking.
 //
 // All the packers and unpackers take a (msg []byte, off int)
@@ -386,134 +484,107 @@ Loop:
        return s, off1, true
 }
 
-// TODO(rsc): Move into generic library?
-// Pack a reflect.StructValue into msg.  Struct members can only be uint16, uint32, string,
-// [n]byte, and other (often anonymous) structs.
-func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
-       for i := 0; i < val.NumField(); i++ {
-               f := val.Type().Field(i)
-               switch fv := val.Field(i); fv.Kind() {
+// packStruct packs a structure into msg at specified offset off, and
+// returns off1 such that msg[off:off1] is the encoded data.
+func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
+       ok = any.Walk(func(field interface{}, name, tag string) bool {
+               switch fv := field.(type) {
                default:
-                       fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
-                       return len(msg), false
-               case reflect.Struct:
-                       off, ok = packStructValue(fv, msg, off)
-               case reflect.Uint16:
+                       println("net: dns: unknown packing type")
+                       return false
+               case *uint16:
+                       i := *fv
                        if off+2 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := fv.Uint()
                        msg[off] = byte(i >> 8)
                        msg[off+1] = byte(i)
                        off += 2
-               case reflect.Uint32:
-                       if off+4 > len(msg) {
-                               return len(msg), false
-                       }
-                       i := fv.Uint()
+               case *uint32:
+                       i := *fv
                        msg[off] = byte(i >> 24)
                        msg[off+1] = byte(i >> 16)
                        msg[off+2] = byte(i >> 8)
                        msg[off+3] = byte(i)
                        off += 4
-               case reflect.Array:
-                       if fv.Type().Elem().Kind() != reflect.Uint8 {
-                               fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
-                               return len(msg), false
-                       }
-                       n := fv.Len()
+               case []byte:
+                       n := len(fv)
                        if off+n > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv)
+                       copy(msg[off:off+n], fv)
                        off += n
-               case reflect.String:
-                       // There are multiple string encodings.
-                       // The tag distinguishes ordinary strings from domain names.
-                       s := fv.String()
-                       switch f.Tag {
+               case *string:
+                       s := *fv
+                       switch tag {
                        default:
-                               fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
-                               return len(msg), false
-                       case `net:"domain-name"`:
+                               println("net: dns: unknown string tag", tag)
+                               return false
+                       case "domain":
                                off, ok = packDomainName(s, msg, off)
                                if !ok {
-                                       return len(msg), false
+                                       return false
                                }
                        case "":
                                // Counted string: 1 byte length.
                                if len(s) > 255 || off+1+len(s) > len(msg) {
-                                       return len(msg), false
+                                       return false
                                }
                                msg[off] = byte(len(s))
                                off++
                                off += copy(msg[off:], s)
                        }
                }
+               return true
+       })
+       if !ok {
+               return len(msg), false
        }
        return off, true
 }
 
-func structValue(any interface{}) reflect.Value {
-       return reflect.ValueOf(any).Elem()
-}
-
-func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
-       off, ok = packStructValue(structValue(any), msg, off)
-       return off, ok
-}
-
-// TODO(rsc): Move into generic library?
-// Unpack a reflect.StructValue from msg.
-// Same restrictions as packStructValue.
-func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
-       for i := 0; i < val.NumField(); i++ {
-               f := val.Type().Field(i)
-               switch fv := val.Field(i); fv.Kind() {
+// unpackStruct decodes msg[off:] into the given structure, and
+// returns off1 such that msg[off:off1] is the encoded data.
+func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
+       ok = any.Walk(func(field interface{}, name, tag string) bool {
+               switch fv := field.(type) {
                default:
-                       fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
-                       return len(msg), false
-               case reflect.Struct:
-                       off, ok = unpackStructValue(fv, msg, off)
-               case reflect.Uint16:
+                       println("net: dns: unknown packing type")
+                       return false
+               case *uint16:
                        if off+2 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := uint16(msg[off])<<8 | uint16(msg[off+1])
-                       fv.SetUint(uint64(i))
+                       *fv = uint16(msg[off])<<8 | uint16(msg[off+1])
                        off += 2
-               case reflect.Uint32:
+               case *uint32:
                        if off+4 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
-                       fv.SetUint(uint64(i))
+                       *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
+                               uint32(msg[off+2])<<8 | uint32(msg[off+3])
                        off += 4
-               case reflect.Array:
-                       if fv.Type().Elem().Kind() != reflect.Uint8 {
-                               fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
-                               return len(msg), false
-                       }
-                       n := fv.Len()
+               case []byte:
+                       n := len(fv)
                        if off+n > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       reflect.Copy(fv, reflect.ValueOf(msg[off:off+n]))
+                       copy(fv, msg[off:off+n])
                        off += n
-               case reflect.String:
+               case *string:
                        var s string
-                       switch f.Tag {
+                       switch tag {
                        default:
-                               fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
-                               return len(msg), false
-                       case `net:"domain-name"`:
+                               println("net: dns: unknown string tag", tag)
+                               return false
+                       case "domain":
                                s, off, ok = unpackDomainName(msg, off)
                                if !ok {
-                                       return len(msg), false
+                                       return false
                                }
                        case "":
                                if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
-                                       return len(msg), false
+                                       return false
                                }
                                n := int(msg[off])
                                off++
@@ -524,51 +595,77 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
                                off += n
                                s = string(b)
                        }
-                       fv.SetString(s)
+                       *fv = s
                }
+               return true
+       })
+       if !ok {
+               return len(msg), false
        }
        return off, true
 }
 
-func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
-       off, ok = unpackStructValue(structValue(any), msg, off)
-       return off, ok
-}
-
-// Generic struct printer.
-// Doesn't care about the string tag `net:"domain-name"`,
-// but does look for an `net:"ipv4"` tag on uint32 variables
-// and the `net:"ipv6"` tag on array variables,
-// printing them as IP addresses.
-func printStructValue(val reflect.Value) string {
+// Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
+// as IP addresses.
+func printStruct(any dnsStruct) string {
        s := "{"
-       for i := 0; i < val.NumField(); i++ {
-               if i > 0 {
+       i := 0
+       any.Walk(func(val interface{}, name, tag string) bool {
+               i++
+               if i > 1 {
                        s += ", "
                }
-               f := val.Type().Field(i)
-               if !f.Anonymous {
-                       s += f.Name + "="
-               }
-               fval := val.Field(i)
-               if fv := fval; fv.Kind() == reflect.Struct {
-                       s += printStructValue(fv)
-               } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == `net:"ipv4"` {
-                       i := fv.Uint()
+               s += name + "="
+               switch tag {
+               case "ipv4":
+                       i := val.(uint32)
                        s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
-               } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == `net:"ipv6"` {
-                       i := fv.Interface().([]byte)
+               case "ipv6":
+                       i := val.([]byte)
                        s += IP(i).String()
-               } else {
-                       s += fmt.Sprint(fval.Interface())
+               default:
+                       var i int64
+                       switch v := val.(type) {
+                       default:
+                               // can't really happen.
+                               s += "<unknown type>"
+                               return true
+                       case *string:
+                               s += *v
+                               return true
+                       case []byte:
+                               s += string(v)
+                               return true
+                       case *bool:
+                               if *v {
+                                       s += "true"
+                               } else {
+                                       s += "false"
+                               }
+                               return true
+                       case *int:
+                               i = int64(*v)
+                       case *uint:
+                               i = int64(*v)
+                       case *uint8:
+                               i = int64(*v)
+                       case *uint16:
+                               i = int64(*v)
+                       case *uint32:
+                               i = int64(*v)
+                       case *uint64:
+                               i = int64(*v)
+                       case *uintptr:
+                               i = int64(*v)
+                       }
+                       s += itoa(int(i))
                }
-       }
+               return true
+       })
        s += "}"
        return s
 }
 
-func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
-
 // Resource record packer.
 func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
        var off1 int
@@ -627,6 +724,17 @@ type dnsMsgHdr struct {
        rcode               int
 }
 
+func