OSDN Git Service

libgo: Update to Go 1.0.3.
authorian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>
Wed, 3 Oct 2012 05:27:22 +0000 (05:27 +0000)
committerian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>
Wed, 3 Oct 2012 05:27:22 +0000 (05:27 +0000)
git-svn-id: svn+ssh://gcc.gnu.org/svn/gcc/branches/gcc-4_7-branch@192024 138bc75d-0d04-0410-961f-82ee72b054a4

124 files changed:
libgo/MERGE
libgo/Makefile.am
libgo/Makefile.in
libgo/go/bufio/bufio.go
libgo/go/builtin/builtin.go
libgo/go/bytes/bytes.go
libgo/go/compress/flate/inflate.go
libgo/go/compress/flate/reader_test.go [new file with mode: 0644]
libgo/go/crypto/elliptic/elliptic.go
libgo/go/crypto/rand/rand_test.go
libgo/go/crypto/rand/rand_windows.go
libgo/go/crypto/rsa/pkcs1v15.go
libgo/go/crypto/tls/conn.go
libgo/go/crypto/x509/verify.go
libgo/go/crypto/x509/x509.go
libgo/go/database/sql/fakedb_test.go
libgo/go/database/sql/sql.go
libgo/go/encoding/binary/varint.go
libgo/go/encoding/gob/decode.go
libgo/go/encoding/gob/doc.go
libgo/go/encoding/gob/encoder_test.go
libgo/go/encoding/gob/type.go
libgo/go/encoding/json/encode.go
libgo/go/flag/flag.go
libgo/go/fmt/fmt_test.go
libgo/go/fmt/print.go
libgo/go/go/ast/print.go
libgo/go/go/ast/print_test.go
libgo/go/go/ast/resolve.go
libgo/go/go/ast/walk.go
libgo/go/go/build/build.go
libgo/go/go/build/build_test.go
libgo/go/go/build/doc.go
libgo/go/go/doc/reader.go
libgo/go/go/doc/testdata/error2.1.golden
libgo/go/go/doc/testdata/error2.go
libgo/go/go/printer/nodes.go
libgo/go/go/printer/printer_test.go
libgo/go/go/scanner/errors.go
libgo/go/go/scanner/scanner.go
libgo/go/html/template/content.go
libgo/go/html/template/url.go
libgo/go/image/jpeg/reader.go
libgo/go/image/jpeg/writer.go
libgo/go/image/jpeg/writer_test.go
libgo/go/image/names.go
libgo/go/io/io.go
libgo/go/log/syslog/syslog.go
libgo/go/log/syslog/syslog_test.go
libgo/go/math/all_test.go
libgo/go/math/big/nat.go
libgo/go/math/bits.go
libgo/go/math/remainder.go
libgo/go/mime/grammar.go
libgo/go/mime/multipart/multipart.go
libgo/go/net/dial.go
libgo/go/net/fd.go
libgo/go/net/file.go
libgo/go/net/http/client.go
libgo/go/net/http/client_test.go
libgo/go/net/http/example_test.go
libgo/go/net/http/export_test.go
libgo/go/net/http/fs.go
libgo/go/net/http/fs_test.go
libgo/go/net/http/header.go
libgo/go/net/http/httptest/server.go
libgo/go/net/http/httputil/dump.go
libgo/go/net/http/pprof/pprof.go
libgo/go/net/http/range_test.go
libgo/go/net/http/serve_test.go
libgo/go/net/http/server.go
libgo/go/net/http/transport.go
libgo/go/net/http/transport_test.go
libgo/go/net/iprawsock.go
libgo/go/net/iprawsock_plan9.go
libgo/go/net/iprawsock_posix.go
libgo/go/net/mail/message.go
libgo/go/net/net_posix.go [new file with mode: 0644]
libgo/go/net/rpc/jsonrpc/all_test.go
libgo/go/net/rpc/server.go
libgo/go/net/sockopt.go
libgo/go/os/error_plan9.go
libgo/go/os/error_posix.go
libgo/go/os/error_test.go
libgo/go/os/error_windows.go
libgo/go/os/exec.go
libgo/go/os/exec/exec.go
libgo/go/os/exec/exec_test.go
libgo/go/os/exec_plan9.go
libgo/go/os/exec_posix.go
libgo/go/os/exec_unix.go
libgo/go/os/exec_windows.go
libgo/go/os/file_posix.go
libgo/go/os/file_unix.go
libgo/go/os/os_test.go
libgo/go/os/types.go
libgo/go/path/path.go
libgo/go/path/path_test.go
libgo/go/reflect/all_test.go
libgo/go/reflect/value.go
libgo/go/regexp/regexp.go
libgo/go/runtime/pprof/pprof.go
libgo/go/runtime/pprof/pprof_test.go
libgo/go/strconv/atoi.go
libgo/go/sync/waitgroup.go
libgo/go/sync/waitgroup_test.go
libgo/go/syscall/env_windows.go
libgo/go/syscall/exec_unix.go
libgo/go/syscall/exec_windows.go
libgo/go/syscall/security_windows.go
libgo/go/syscall/syscall.go
libgo/go/syscall/syscall_linux_386.go
libgo/go/testing/testing.go
libgo/go/text/tabwriter/tabwriter.go
libgo/go/text/template/doc.go
libgo/go/text/template/exec_test.go
libgo/go/text/template/funcs.go
libgo/go/text/template/parse/lex.go
libgo/go/text/template/parse/lex_test.go
libgo/go/time/time.go
libgo/runtime/chan.c
libgo/runtime/cpuprof.c
libgo/runtime/print.c
libgo/runtime/runtime.c

index e3e47d3..89116d1 100644 (file)
@@ -1,4 +1,4 @@
-5e806355a9e1
+2d8bc3c94ecb
 
 The first line of this file holds the Mercurial revision number of the
 last merge done from the master library sources.
index f59b004..82587ca 100644 (file)
@@ -715,6 +715,7 @@ go_net_files = \
        go/net/lookup_unix.go \
        go/net/mac.go \
        go/net/net.go \
+       go/net/net_posix.go \
        go/net/parse.go \
        go/net/pipe.go \
        go/net/port.go \
index 0931ed1..30f9274 100644 (file)
@@ -981,6 +981,7 @@ go_net_files = \
        go/net/lookup_unix.go \
        go/net/mac.go \
        go/net/net.go \
+       go/net/net_posix.go \
        go/net/parse.go \
        go/net/pipe.go \
        go/net/port.go \
index b44d0e7..0e28482 100644 (file)
@@ -272,6 +272,9 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
        panic("not reached")
 }
 
+// ReadLine is a low-level line-reading primitive. Most callers should use
+// ReadBytes('\n') or ReadString('\n') instead.
+//
 // ReadLine tries to return a single line, not including the end-of-line bytes.
 // If the line was too long for the buffer then isPrefix is set and the
 // beginning of the line is returned. The rest of the line will be returned
index e81616c..a30943b 100644 (file)
@@ -81,9 +81,8 @@ type uintptr uintptr
 // integer values.
 type byte byte
 
-// rune is an alias for int and is equivalent to int in all ways. It is
+// rune is an alias for int32 and is equivalent to int32 in all ways. It is
 // used, by convention, to distinguish character values from integer values.
-// In a future version of Go, it will change to an alias of int32.
 type rune rune
 
 // Type is here for the purposes of documentation only. It is a stand-in
index 7d1426f..09b3c1a 100644 (file)
@@ -415,7 +415,7 @@ func Repeat(b []byte, count int) []byte {
 // ToUpper returns a copy of the byte array s with all Unicode letters mapped to their upper case.
 func ToUpper(s []byte) []byte { return Map(unicode.ToUpper, s) }
 
-// ToUpper returns a copy of the byte array s with all Unicode letters mapped to their lower case.
+// ToLower returns a copy of the byte array s with all Unicode letters mapped to their lower case.
 func ToLower(s []byte) []byte { return Map(unicode.ToLower, s) }
 
 // ToTitle returns a copy of the byte array s with all Unicode letters mapped to their title case.
index 3f2042b..394c32f 100644 (file)
@@ -16,9 +16,10 @@ import (
 const (
        maxCodeLen = 16    // max length of Huffman code
        maxHist    = 32768 // max history required
-       maxLit     = 286
-       maxDist    = 32
-       numCodes   = 19 // number of codes in Huffman meta-code
+       // The next three numbers come from the RFC, section 3.2.7.
+       maxLit   = 286
+       maxDist  = 32
+       numCodes = 19 // number of codes in Huffman meta-code
 )
 
 // A CorruptInputError reports the presence of corrupt input at a given offset.
@@ -306,10 +307,15 @@ func (f *decompressor) readHuffman() error {
                }
        }
        nlit := int(f.b&0x1F) + 257
+       if nlit > maxLit {
+               return CorruptInputError(f.roffset)
+       }
        f.b >>= 5
        ndist := int(f.b&0x1F) + 1
+       // maxDist is 32, so ndist is always valid.
        f.b >>= 5
        nclen := int(f.b&0xF) + 4
+       // numCodes is 19, so nclen is always valid.
        f.b >>= 4
        f.nb -= 5 + 5 + 4
 
diff --git a/libgo/go/compress/flate/reader_test.go b/libgo/go/compress/flate/reader_test.go
new file mode 100644 (file)
index 0000000..54ed788
--- /dev/null
@@ -0,0 +1,95 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package flate
+
+import (
+       "bytes"
+       "io"
+       "io/ioutil"
+       "runtime"
+       "strings"
+       "testing"
+)
+
+func TestNlitOutOfRange(t *testing.T) {
+       // Trying to decode this bogus flate data, which has a Huffman table
+       // with nlit=288, should not panic.
+       io.Copy(ioutil.Discard, NewReader(strings.NewReader(
+               "\xfc\xfe\x36\xe7\x5e\x1c\xef\xb3\x55\x58\x77\xb6\x56\xb5\x43\xf4"+
+                       "\x6f\xf2\xd2\xe6\x3d\x99\xa0\x85\x8c\x48\xeb\xf8\xda\x83\x04\x2a"+
+                       "\x75\xc4\xf8\x0f\x12\x11\xb9\xb4\x4b\x09\xa0\xbe\x8b\x91\x4c")))
+}
+
+const (
+       digits = iota
+       twain
+)
+
+var testfiles = []string{
+       // Digits is the digits of the irrational number e. Its decimal representation
+       // does not repeat, but there are only 10 posible digits, so it should be
+       // reasonably compressible.
+       digits: "../testdata/e.txt",
+       // Twain is Project Gutenberg's edition of Mark Twain's classic English novel.
+       twain: "../testdata/Mark.Twain-Tom.Sawyer.txt",
+}
+
+func benchmarkDecode(b *testing.B, testfile, level, n int) {
+       b.StopTimer()
+       b.SetBytes(int64(n))
+       buf0, err := ioutil.ReadFile(testfiles[testfile])
+       if err != nil {
+               b.Fatal(err)
+       }
+       if len(buf0) == 0 {
+               b.Fatalf("test file %q has no data", testfiles[testfile])
+       }
+       compressed := new(bytes.Buffer)
+       w, err := NewWriter(compressed, level)
+       if err != nil {
+               b.Fatal(err)
+       }
+       for i := 0; i < n; i += len(buf0) {
+               if len(buf0) > n-i {
+                       buf0 = buf0[:n-i]
+               }
+               io.Copy(w, bytes.NewBuffer(buf0))
+       }
+       w.Close()
+       buf1 := compressed.Bytes()
+       buf0, compressed, w = nil, nil, nil
+       runtime.GC()
+       b.StartTimer()
+       for i := 0; i < b.N; i++ {
+               io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1)))
+       }
+}
+
+// These short names are so that gofmt doesn't break the BenchmarkXxx function
+// bodies below over multiple lines.
+const (
+       speed    = BestSpeed
+       default_ = DefaultCompression
+       compress = BestCompression
+)
+
+func BenchmarkDecodeDigitsSpeed1e4(b *testing.B)    { benchmarkDecode(b, digits, speed, 1e4) }
+func BenchmarkDecodeDigitsSpeed1e5(b *testing.B)    { benchmarkDecode(b, digits, speed, 1e5) }
+func BenchmarkDecodeDigitsSpeed1e6(b *testing.B)    { benchmarkDecode(b, digits, speed, 1e6) }
+func BenchmarkDecodeDigitsDefault1e4(b *testing.B)  { benchmarkDecode(b, digits, default_, 1e4) }
+func BenchmarkDecodeDigitsDefault1e5(b *testing.B)  { benchmarkDecode(b, digits, default_, 1e5) }
+func BenchmarkDecodeDigitsDefault1e6(b *testing.B)  { benchmarkDecode(b, digits, default_, 1e6) }
+func BenchmarkDecodeDigitsCompress1e4(b *testing.B) { benchmarkDecode(b, digits, compress, 1e4) }
+func BenchmarkDecodeDigitsCompress1e5(b *testing.B) { benchmarkDecode(b, digits, compress, 1e5) }
+func BenchmarkDecodeDigitsCompress1e6(b *testing.B) { benchmarkDecode(b, digits, compress, 1e6) }
+func BenchmarkDecodeTwainSpeed1e4(b *testing.B)     { benchmarkDecode(b, twain, speed, 1e4) }
+func BenchmarkDecodeTwainSpeed1e5(b *testing.B)     { benchmarkDecode(b, twain, speed, 1e5) }
+func BenchmarkDecodeTwainSpeed1e6(b *testing.B)     { benchmarkDecode(b, twain, speed, 1e6) }
+func BenchmarkDecodeTwainDefault1e4(b *testing.B)   { benchmarkDecode(b, twain, default_, 1e4) }
+func BenchmarkDecodeTwainDefault1e5(b *testing.B)   { benchmarkDecode(b, twain, default_, 1e5) }
+func BenchmarkDecodeTwainDefault1e6(b *testing.B)   { benchmarkDecode(b, twain, default_, 1e6) }
+func BenchmarkDecodeTwainCompress1e4(b *testing.B)  { benchmarkDecode(b, twain, compress, 1e4) }
+func BenchmarkDecodeTwainCompress1e5(b *testing.B)  { benchmarkDecode(b, twain, compress, 1e5) }
+func BenchmarkDecodeTwainCompress1e6(b *testing.B)  { benchmarkDecode(b, twain, compress, 1e6) }
index 30835a9..a399089 100644 (file)
@@ -370,7 +370,7 @@ func P384() Curve {
        return p384
 }
 
-// P256 returns a Curve which implements P-521 (see FIPS 186-3, section D.2.5)
+// P521 returns a Curve which implements P-521 (see FIPS 186-3, section D.2.5)
 func P521() Curve {
        initonce.Do(initAll)
        return p521
index da091ba..e46e61d 100644 (file)
@@ -30,3 +30,14 @@ func TestRead(t *testing.T) {
                t.Fatalf("Compressed %d -> %d", len(b), z.Len())
        }
 }
+
+func TestReadEmpty(t *testing.T) {
+       n, err := Reader.Read(make([]byte, 0))
+       if n != 0 || err != nil {
+               t.Fatalf("Read(make([]byte, 0)) = %d, %v", n, err)
+       }
+       n, err = Reader.Read(nil)
+       if n != 0 || err != nil {
+               t.Fatalf("Read(nil) = %d, %v", n, err)
+       }
+}
index 2b2bd4b..82b39b6 100644 (file)
@@ -35,6 +35,10 @@ func (r *rngReader) Read(b []byte) (n int, err error) {
                }
        }
        r.mu.Unlock()
+
+       if len(b) == 0 {
+               return 0, nil
+       }
        err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
        if err != nil {
                return 0, os.NewSyscallError("CryptGenRandom", err)
index a32236e..f39a48a 100644 (file)
@@ -25,10 +25,10 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, er
                return
        }
 
-       // EM = 0x02 || PS || 0x00 || M
-       em := make([]byte, k-1)
-       em[0] = 2
-       ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):]
+       // EM = 0x00 || 0x02 || PS || 0x00 || M
+       em := make([]byte, k)
+       em[1] = 2
+       ps, mm := em[2:len(em)-len(msg)-1], em[len(em)-len(msg):]
        err = nonZeroRandomBytes(ps, rand)
        if err != nil {
                return
@@ -38,7 +38,9 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, er
 
        m := new(big.Int).SetBytes(em)
        c := encrypt(new(big.Int), pub, m)
-       out = c.Bytes()
+
+       copyWithLeftPad(em, c.Bytes())
+       out = em
        return
 }
 
@@ -185,9 +187,12 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
 
        m := new(big.Int).SetBytes(em)
        c, err := decrypt(rand, priv, m)
-       if err == nil {
-               s = c.Bytes()
+       if err != nil {
+               return
        }
+
+       copyWithLeftPad(em, c.Bytes())
+       s = em
        return
 }
 
@@ -241,3 +246,13 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte,
        }
        return
 }
+
+// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
+// needed.
+func copyWithLeftPad(dest, src []byte) {
+       numPaddingBytes := len(dest) - len(src)
+       for i := 0; i < numPaddingBytes; i++ {
+               dest[i] = 0
+       }
+       copy(dest[numPaddingBytes:], src)
+}
index 2a5115d..455910a 100644 (file)
@@ -487,6 +487,16 @@ Again:
                return err
        }
        typ := recordType(b.data[0])
+
+       // No valid TLS record has a type of 0x80, however SSLv2 handshakes
+       // start with a uint16 length where the MSB is set and the first record
+       // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
+       // an SSLv2 client.
+       if want == recordTypeHandshake && typ == 0x80 {
+               c.sendAlert(alertProtocolVersion)
+               return errors.New("tls: unsupported SSLv2 handshake received")
+       }
+
        vers := uint16(b.data[1])<<8 | uint16(b.data[2])
        n := int(b.data[3])<<8 | int(b.data[4])
        if c.haveVers && vers != c.vers {
index 307c5ef..2881453 100644 (file)
@@ -39,7 +39,7 @@ type CertificateInvalidError struct {
 func (e CertificateInvalidError) Error() string {
        switch e.Reason {
        case NotAuthorizedToSign:
-               return "x509: certificate is not authorized to sign other other certificates"
+               return "x509: certificate is not authorized to sign other certificates"
        case Expired:
                return "x509: certificate has expired or is not yet valid"
        case CANotAuthorizedForThisName:
index c4d85e6..e6b0c58 100644 (file)
@@ -344,6 +344,55 @@ func (c *Certificate) Equal(other *Certificate) bool {
        return bytes.Equal(c.Raw, other.Raw)
 }
 
+// Entrust have a broken root certificate (CN=Entrust.net Certification
+// Authority (2048)) which isn't marked as a CA certificate and is thus invalid
+// according to PKIX.
+// We recognise this certificate by its SubjectPublicKeyInfo and exempt it
+// from the Basic Constraints requirement.
+// See http://www.entrust.net/knowledge-base/technote.cfm?tn=7869
+//
+// TODO(agl): remove this hack once their reissued root is sufficiently
+// widespread.
+var entrustBrokenSPKI = []byte{
+       0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09,
+       0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01,
+       0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00,
+       0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01,
+       0x00, 0x97, 0xa3, 0x2d, 0x3c, 0x9e, 0xde, 0x05,
+       0xda, 0x13, 0xc2, 0x11, 0x8d, 0x9d, 0x8e, 0xe3,
+       0x7f, 0xc7, 0x4b, 0x7e, 0x5a, 0x9f, 0xb3, 0xff,
+       0x62, 0xab, 0x73, 0xc8, 0x28, 0x6b, 0xba, 0x10,
+       0x64, 0x82, 0x87, 0x13, 0xcd, 0x57, 0x18, 0xff,
+       0x28, 0xce, 0xc0, 0xe6, 0x0e, 0x06, 0x91, 0x50,
+       0x29, 0x83, 0xd1, 0xf2, 0xc3, 0x2a, 0xdb, 0xd8,
+       0xdb, 0x4e, 0x04, 0xcc, 0x00, 0xeb, 0x8b, 0xb6,
+       0x96, 0xdc, 0xbc, 0xaa, 0xfa, 0x52, 0x77, 0x04,
+       0xc1, 0xdb, 0x19, 0xe4, 0xae, 0x9c, 0xfd, 0x3c,
+       0x8b, 0x03, 0xef, 0x4d, 0xbc, 0x1a, 0x03, 0x65,
+       0xf9, 0xc1, 0xb1, 0x3f, 0x72, 0x86, 0xf2, 0x38,
+       0xaa, 0x19, 0xae, 0x10, 0x88, 0x78, 0x28, 0xda,
+       0x75, 0xc3, 0x3d, 0x02, 0x82, 0x02, 0x9c, 0xb9,
+       0xc1, 0x65, 0x77, 0x76, 0x24, 0x4c, 0x98, 0xf7,
+       0x6d, 0x31, 0x38, 0xfb, 0xdb, 0xfe, 0xdb, 0x37,
+       0x02, 0x76, 0xa1, 0x18, 0x97, 0xa6, 0xcc, 0xde,
+       0x20, 0x09, 0x49, 0x36, 0x24, 0x69, 0x42, 0xf6,
+       0xe4, 0x37, 0x62, 0xf1, 0x59, 0x6d, 0xa9, 0x3c,
+       0xed, 0x34, 0x9c, 0xa3, 0x8e, 0xdb, 0xdc, 0x3a,
+       0xd7, 0xf7, 0x0a, 0x6f, 0xef, 0x2e, 0xd8, 0xd5,
+       0x93, 0x5a, 0x7a, 0xed, 0x08, 0x49, 0x68, 0xe2,
+       0x41, 0xe3, 0x5a, 0x90, 0xc1, 0x86, 0x55, 0xfc,
+       0x51, 0x43, 0x9d, 0xe0, 0xb2, 0xc4, 0x67, 0xb4,
+       0xcb, 0x32, 0x31, 0x25, 0xf0, 0x54, 0x9f, 0x4b,
+       0xd1, 0x6f, 0xdb, 0xd4, 0xdd, 0xfc, 0xaf, 0x5e,
+       0x6c, 0x78, 0x90, 0x95, 0xde, 0xca, 0x3a, 0x48,
+       0xb9, 0x79, 0x3c, 0x9b, 0x19, 0xd6, 0x75, 0x05,
+       0xa0, 0xf9, 0x88, 0xd7, 0xc1, 0xe8, 0xa5, 0x09,
+       0xe4, 0x1a, 0x15, 0xdc, 0x87, 0x23, 0xaa, 0xb2,
+       0x75, 0x8c, 0x63, 0x25, 0x87, 0xd8, 0xf8, 0x3d,
+       0xa6, 0xc2, 0xcc, 0x66, 0xff, 0xa5, 0x66, 0x68,
+       0x55, 0x02, 0x03, 0x01, 0x00, 0x01,
+}
+
 // CheckSignatureFrom verifies that the signature on c is a valid signature
 // from parent.
 func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err error) {
@@ -352,8 +401,10 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err error) {
        // certificate, or the extension is present but the cA boolean is not
        // asserted, then the certified public key MUST NOT be used to verify
        // certificate signatures."
-       if parent.Version == 3 && !parent.BasicConstraintsValid ||
-               parent.BasicConstraintsValid && !parent.IsCA {
+       // (except for Entrust, see comment above entrustBrokenSPKI)
+       if (parent.Version == 3 && !parent.BasicConstraintsValid ||
+               parent.BasicConstraintsValid && !parent.IsCA) &&
+               !bytes.Equal(c.RawSubjectPublicKeyInfo, entrustBrokenSPKI) {
                return ConstraintViolationError{}
        }
 
index 184e775..a11fb78 100644 (file)
@@ -31,7 +31,7 @@ var _ = log.Printf
 //   INSERT|<tablename>|col=val,col2=val2,col3=?
 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
 //
-// When opening a fakeDriver's database, it starts empty with no
+// When opening a fakeDriver's database, it starts empty with no
 // tables.  All tables and data are stored in memory only.
 type fakeDriver struct {
        mu        sync.Mutex
@@ -234,7 +234,7 @@ func checkSubsetTypes(args []driver.Value) error {
 
 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
        // This is an optional interface, but it's implemented here
-       // just to check that all the args of of the proper types.
+       // just to check that all the args are of the proper types.
        // ErrSkip is returned so the caller acts as if we didn't
        // implement this at all.
        err := checkSubsetTypes(args)
@@ -249,7 +249,7 @@ func errf(msg string, args ...interface{}) error {
 }
 
 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
-// (note that where where columns must always contain ? marks,
+// (note that where columns must always contain ? marks,
 //  just a limitation for fakedb)
 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 3 {
index 89136ef..d557fc8 100644 (file)
@@ -311,7 +311,10 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) {
        if err != nil {
                return nil, err
        }
-       defer db.putConn(ci, err)
+       defer func() {
+               db.putConn(ci, err)
+       }()
+
        si, err := ci.Prepare(query)
        if err != nil {
                return nil, err
@@ -345,7 +348,9 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
        if err != nil {
                return nil, err
        }
-       defer db.putConn(ci, err)
+       defer func() {
+               db.putConn(ci, err)
+       }()
 
        if execer, ok := ci.(driver.Execer); ok {
                resi, err := execer.Exec(query, sargs)
index b756afd..719018b 100644 (file)
@@ -123,7 +123,7 @@ func ReadUvarint(r io.ByteReader) (uint64, error) {
        panic("unreachable")
 }
 
-// ReadVarint reads an encoded unsigned integer from r and returns it as a uint64.
+// ReadVarint reads an encoded unsigned integer from r and returns it as aint64.
 func ReadVarint(r io.ByteReader) (int64, error) {
        ux, err := ReadUvarint(r) // ok to continue in presence of error
        x := int64(ux >> 1)
index e32a178..8690b35 100644 (file)
@@ -562,6 +562,9 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) {
 func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl error) {
        instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
        for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding array or slice: length exceeds input size (%d elements)", length)
+               }
                up := unsafe.Pointer(p)
                if elemIndir > 1 {
                        up = decIndirect(up, elemIndir)
@@ -652,9 +655,6 @@ func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
 // Slices are encoded as an unsigned length followed by the elements.
 func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl error) {
        nr := state.decodeUint()
-       if nr > uint64(state.b.Len()) {
-               errorf("length of slice exceeds input size (%d elements)", nr)
-       }
        n := int(nr)
        if indir > 0 {
                up := unsafe.Pointer(p)
index 821d9a3..6d77c17 100644 (file)
@@ -118,7 +118,7 @@ elements using the standard gob encoding for their type, recursively.
 
 Maps are sent as an unsigned count followed by that many key, element
 pairs. Empty but non-nil maps are sent, so if the sender has allocated
-a map, the receiver will allocate a map even no elements are
+a map, the receiver will allocate a map even if no elements are
 transmitted.
 
 Structs are sent as a sequence of (field number, field value) pairs.  The field
index c4947cb..db824d9 100644 (file)
@@ -736,3 +736,32 @@ func TestPtrToMapOfMap(t *testing.T) {
                t.Fatalf("expected %v got %v", data, newData)
        }
 }
+
+// There was an error check comparing the length of the input with the
+// length of the slice being decoded. It was wrong because the next
+// thing in the input might be a type definition, which would lead to
+// an incorrect length check.  This test reproduces the corner case.
+
+type Z struct {
+}
+
+func Test29ElementSlice(t *testing.T) {
+       Register(Z{})
+       src := make([]interface{}, 100) // Size needs to be bigger than size of type definition.
+       for i := range src {
+               src[i] = Z{}
+       }
+       buf := new(bytes.Buffer)
+       err := NewEncoder(buf).Encode(src)
+       if err != nil {
+               t.Fatalf("encode: %v", err)
+               return
+       }
+
+       var dst []interface{}
+       err = NewDecoder(buf).Decode(&dst)
+       if err != nil {
+               t.Errorf("decode: %v", err)
+               return
+       }
+}
index 0dd7a0a..a8ee2fa 100644 (file)
@@ -749,12 +749,28 @@ func Register(value interface{}) {
        rt := reflect.TypeOf(value)
        name := rt.String()
 
-       // But for named types (or pointers to them), qualify with import path.
+       // But for named types (or pointers to them), qualify with import path (but see inner comment).
        // Dereference one pointer looking for a named type.
        star := ""
        if rt.Name() == "" {
                if pt := rt; pt.Kind() == reflect.Ptr {
                        star = "*"
+                       // NOTE: The following line should be rt = pt.Elem() to implement
+                       // what the comment above claims, but fixing it would break compatibility
+                       // with existing gobs.
+                       //
+                       // Given package p imported as "full/p" with these definitions:
+                       //     package p
+                       //     type T1 struct { ... }
+                       // this table shows the intended and actual strings used by gob to
+                       // name the types:
+                       //
+                       // Type      Correct string     Actual string
+                       //
+                       // T1        full/p.T1          full/p.T1
+                       // *T1       *full/p.T1         *p.T1
+                       //
+                       // The missing full path cannot be fixed without breaking existing gob decoders.
                        rt = pt
                }
        }
index b6e1cb1..d2c1c44 100644 (file)
@@ -55,7 +55,7 @@ import (
 // nil pointer or interface value, and any array, slice, map, or string of
 // length zero. The object's default key string is the struct field name
 // but can be specified in the struct field's tag value. The "json" key in
-// struct field's tag value is the key name, followed by an optional comma
+// the struct field's tag value is the key name, followed by an optional comma
 // and options. Examples:
 //
 //   // Field is ignored by this package.
index 5444ad1..bbabd88 100644 (file)
@@ -33,7 +33,7 @@
 
        After parsing, the arguments after the flag are available as the
        slice flag.Args() or individually as flag.Arg(i).
-       The arguments are indexed from 0 up to flag.NArg().
+       The arguments are indexed from 0 through flag.NArg()-1.
 
        Command line flag syntax:
                -flag
@@ -707,7 +707,7 @@ func (f *FlagSet) parseOne() (bool, error) {
        if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
                if has_value {
                        if err := fv.Set(value); err != nil {
-                               f.failf("invalid boolean value %q for  -%s: %v", value, name, err)
+                               return false, f.failf("invalid boolean value %q for  -%s: %v", value, name, err)
                        }
                } else {
                        fv.Set("true")
index 500a459..98ebfb7 100644 (file)
@@ -844,3 +844,15 @@ func TestIsSpace(t *testing.T) {
                }
        }
 }
+
+func TestNilDoesNotBecomeTyped(t *testing.T) {
+       type A struct{}
+       type B struct{}
+       var a *A = nil
+       var b B = B{}
+       got := Sprintf("%s %s %s %s %s", nil, a, nil, b, nil)
+       const expect = "%!s(<nil>) %!s(*fmt_test.A=<nil>) %!s(<nil>) {} %!s(<nil>)"
+       if got != expect {
+               t.Errorf("expected:\n\t%q\ngot:\n\t%q", expect, got)
+       }
+}
index 1343824..f29e8c8 100644 (file)
@@ -712,6 +712,9 @@ func (p *pp) handleMethods(verb rune, plus, goSyntax bool, depth int) (wasString
 }
 
 func (p *pp) printField(field interface{}, verb rune, plus, goSyntax bool, depth int) (wasString bool) {
+       p.field = field
+       p.value = reflect.Value{}
+
        if field == nil {
                if verb == 'T' || verb == 'v' {
                        p.buf.Write(nilAngleBytes)
@@ -721,8 +724,6 @@ func (p *pp) printField(field interface{}, verb rune, plus, goSyntax bool, depth
                return false
        }
 
-       p.field = field
-       p.value = reflect.Value{}
        // Special processing considerations.
        // %T (the value's type) and %p (its address) are special; we always do them first.
        switch verb {
index 02cf9e0..2de9af2 100644 (file)
@@ -34,7 +34,8 @@ func NotNilFilter(_ string, v reflect.Value) bool {
 //
 // A non-nil FieldFilter f may be provided to control the output:
 // struct fields for which f(fieldname, fieldvalue) is true are
-// are printed; all others are filtered from the output.
+// are printed; all others are filtered from the output. Unexported
+// struct fields are never printed.
 //
 func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (err error) {
        // setup printer
@@ -145,15 +146,18 @@ func (p *printer) print(x reflect.Value) {
                p.print(x.Elem())
 
        case reflect.Map:
-               p.printf("%s (len = %d) {\n", x.Type(), x.Len())
-               p.indent++
-               for _, key := range x.MapKeys() {
-                       p.print(key)
-                       p.printf(": ")
-                       p.print(x.MapIndex(key))
+               p.printf("%s (len = %d) {", x.Type(), x.Len())
+               if x.Len() > 0 {
+                       p.indent++
                        p.printf("\n")
+                       for _, key := range x.MapKeys() {
+                               p.print(key)
+                               p.printf(": ")
+                               p.print(x.MapIndex(key))
+                               p.printf("\n")
+                       }
+                       p.indent--
                }
-               p.indent--
                p.printf("}")
 
        case reflect.Ptr:
@@ -169,32 +173,57 @@ func (p *printer) print(x reflect.Value) {
                        p.print(x.Elem())
                }
 
+       case reflect.Array:
+               p.printf("%s {", x.Type())
+               if x.Len() > 0 {
+                       p.indent++
+                       p.printf("\n")
+                       for i, n := 0, x.Len(); i < n; i++ {
+                               p.printf("%d: ", i)
+                               p.print(x.Index(i))
+                               p.printf("\n")
+                       }
+                       p.indent--
+               }
+               p.printf("}")
+
        case reflect.Slice:
                if s, ok := x.Interface().([]byte); ok {
                        p.printf("%#q", s)
                        return
                }
-               p.printf("%s (len = %d) {\n", x.Type(), x.Len())
-               p.indent++
-               for i, n := 0, x.Len(); i < n; i++ {
-                       p.printf("%d: ", i)
-                       p.print(x.Index(i))
+               p.printf("%s (len = %d) {", x.Type(), x.Len())
+               if x.Len() > 0 {
+                       p.indent++
                        p.printf("\n")
+                       for i, n := 0, x.Len(); i < n; i++ {
+                               p.printf("%d: ", i)
+                               p.print(x.Index(i))
+                               p.printf("\n")
+                       }
+                       p.indent--
                }
-               p.indent--
                p.printf("}")
 
        case reflect.Struct:
-               p.printf("%s {\n", x.Type())
-               p.indent++
                t := x.Type()
+               p.printf("%s {", t)
+               p.indent++
+               first := true
                for i, n := 0, t.NumField(); i < n; i++ {
-                       name := t.Field(i).Name
-                       value := x.Field(i)
-                       if p.filter == nil || p.filter(name, value) {
-                               p.printf("%s: ", name)
-                               p.print(value)
-                               p.printf("\n")
+                       // exclude non-exported fields because their
+                       // values cannot be accessed via reflection
+                       if name := t.Field(i).Name; IsExported(name) {
+                               value := x.Field(i)
+                               if p.filter == nil || p.filter(name, value) {
+                                       if first {
+                                               p.printf("\n")
+                                               first = false
+                                       }
+                                       p.printf("%s: ", name)
+                                       p.print(value)
+                                       p.printf("\n")
+                               }
                        }
                }
                p.indent--
index 71c028e..210f164 100644 (file)
@@ -23,6 +23,7 @@ var tests = []struct {
        {"foobar", "0  \"foobar\""},
 
        // maps
+       {map[Expr]string{}, `0  map[ast.Expr]string (len = 0) {}`},
        {map[string]int{"a": 1},
                `0  map[string]int (len = 1) {
                1  .  "a": 1
@@ -31,7 +32,21 @@ var tests = []struct {
        // pointers
        {new(int), "0  *0"},
 
+       // arrays
+       {[0]int{}, `0  [0]int {}`},
+       {[3]int{1, 2, 3},
+               `0  [3]int {
+               1  .  0: 1
+               2  .  1: 2
+               3  .  2: 3
+               4  }`},
+       {[...]int{42},
+               `0  [1]int {
+               1  .  0: 42
+               2  }`},
+
        // slices
+       {[]int{}, `0  []int (len = 0) {}`},
        {[]int{1, 2, 3},
                `0  []int (len = 3) {
                1  .  0: 1
@@ -40,6 +55,12 @@ var tests = []struct {
                4  }`},
 
        // structs
+       {struct{}{}, `0  struct {} {}`},
+       {struct{ x int }{007}, `0  struct { x int } {}`},
+       {struct{ X, y int }{42, 991},
+               `0  struct { X int; y int } {
+               1  .  X: 42
+               2  }`},
        {struct{ X, Y int }{42, 991},
                `0  struct { X int; Y int } {
                1  .  X: 42
index 908e61c..54b5d73 100644 (file)
@@ -136,7 +136,7 @@ func NewPackage(fset *token.FileSet, files map[string]*File, importer Importer,
                                for _, obj := range pkg.Data.(*Scope).Objects {
                                        p.declare(fileScope, pkgScope, obj)
                                }
-                       } else {
+                       } else if name != "_" {
                                // declare imported package object in file scope
                                // (do not re-use pkg in the file scope but create
                                // a new object instead; the Decl field is different
index 181cfd1..66b1dc2 100644 (file)
@@ -344,9 +344,6 @@ func Walk(v Visitor, node Node) {
                }
                Walk(v, n.Name)
                walkDeclList(v, n.Decls)
-               for _, g := range n.Comments {
-                       Walk(v, g)
-               }
                // don't walk n.Comments - they have been
                // visited already through the individual
                // nodes
index 7a81d50..67e73c5 100644 (file)
@@ -536,7 +536,7 @@ Found:
                        return p, err
                }
 
-               pkg := string(pf.Name.Name)
+               pkg := pf.Name.Name
                if pkg == "documentation" {
                        continue
                }
@@ -570,7 +570,7 @@ Found:
                                if !ok {
                                        continue
                                }
-                               quoted := string(spec.Path.Value)
+                               quoted := spec.Path.Value
                                path, err := strconv.Unquote(quoted)
                                if err != nil {
                                        log.Panicf("%s: parser returned invalid quoted string: <%s>", filename, quoted)
@@ -678,7 +678,7 @@ func (ctxt *Context) shouldBuild(content []byte) bool {
                }
                line = bytes.TrimSpace(line)
                if len(line) == 0 { // Blank line
-                       end = cap(content) - cap(line) // &line[0] - &content[0]
+                       end = len(content) - len(p)
                        continue
                }
                if !bytes.HasPrefix(line, slashslash) { // Not comment line
index 560ebad..caa4f26 100644 (file)
@@ -75,3 +75,32 @@ func TestLocalDirectory(t *testing.T) {
                t.Fatalf("ImportPath=%q, want %q", p.ImportPath, "go/build")
        }
 }
+
+func TestShouldBuild(t *testing.T) {
+       const file1 = "// +build tag1\n\n" +
+               "package main\n"
+
+       const file2 = "// +build cgo\n\n" +
+               "// This package implements parsing of tags like\n" +
+               "// +build tag1\n" +
+               "package build"
+
+       const file3 = "// Copyright The Go Authors.\n\n" +
+               "package build\n\n" +
+               "// shouldBuild checks tags given by lines of the form\n" +
+               "// +build tag\n" +
+               "func shouldBuild(content []byte)\n"
+
+       ctx := &Context{BuildTags: []string{"tag1"}}
+       if !ctx.shouldBuild([]byte(file1)) {
+               t.Errorf("should not build file1, expected the contrary")
+       }
+       if ctx.shouldBuild([]byte(file2)) {
+               t.Errorf("should build file2, expected the contrary")
+       }
+
+       ctx = &Context{BuildTags: nil}
+       if !ctx.shouldBuild([]byte(file3)) {
+               t.Errorf("should not build file3, expected the contrary")
+       }
+}
index 67c26ac..9b7a946 100644 (file)
@@ -60,7 +60,7 @@
 // A build constraint is a line comment beginning with the directive +build
 // that lists the conditions under which a file should be included in the package.
 // Constraints may appear in any kind of source file (not just Go), but
-// they must be appear near the top of the file, preceded
+// they must appear near the top of the file, preceded
 // only by blank lines and other line comments.
 //
 // A build constraint is evaluated as the OR of space-separated options;
index 5eaae37..60b174f 100644 (file)
@@ -494,7 +494,7 @@ func (r *reader) readPackage(pkg *ast.Package, mode Mode) {
        r.funcs = make(methodSet)
 
        // sort package files before reading them so that the
-       // result result does not depend on map iteration order
+       // result does not depend on map iteration order
        i := 0
        for filename := range pkg.Files {
                r.filenames[i] = filename
index 776bd1b..dbcc1b0 100644 (file)
@@ -10,7 +10,7 @@ FILENAMES
 TYPES
        // 
        type I0 interface {
-               // When embedded, the the locally declared error interface
+               // When embedded, the locally-declared error interface
                // is only visible if all declarations are shown.
                error
        }
index 6cc36fe..6ee96c2 100644 (file)
@@ -5,7 +5,7 @@
 package error2
 
 type I0 interface {
-       // When embedded, the the locally declared error interface
+       // When embedded, the locally-declared error interface
        // is only visible if all declarations are shown.
        error
 }
index f13f9a5..e346b93 100644 (file)
@@ -325,9 +325,14 @@ func (p *printer) parameters(fields *ast.FieldList) {
 }
 
 func (p *printer) signature(params, result *ast.FieldList) {
-       p.parameters(params)
+       if params != nil {
+               p.parameters(params)
+       } else {
+               p.print(token.LPAREN, token.RPAREN)
+       }
        n := result.NumFields()
        if n > 0 {
+               // result != nil
                p.print(blank)
                if n == 1 && result.List[0].Names == nil {
                        // single anonymous result; no ()'s
index 497d671..ab9e9b2 100644 (file)
@@ -385,6 +385,35 @@ func (t *t) foo(a, b, c int) int {
        }
 }
 
+// TestFuncType tests that an ast.FuncType with a nil Params field
+// can be printed (per go/ast specification). Test case for issue 3870.
+func TestFuncType(t *testing.T) {
+       src := &ast.File{
+               Name: &ast.Ident{Name: "p"},
+               Decls: []ast.Decl{
+                       &ast.FuncDecl{
+                               Name: &ast.Ident{Name: "f"},
+                               Type: &ast.FuncType{},
+                       },
+               },
+       }
+
+       var buf bytes.Buffer
+       if err := Fprint(&buf, fset, src); err != nil {
+               t.Fatal(err)
+       }
+       got := buf.String()
+
+       const want = `package p
+
+func f()
+`
+
+       if got != want {
+               t.Fatalf("got:\n%s\nwant:\n%s\n", got, want)
+       }
+}
+
 // TextX is a skeleton test that can be filled in for debugging one-off cases.
 // Do not remove.
 func TestX(t *testing.T) {
index 8a75a96..22de69c 100644 (file)
@@ -120,7 +120,7 @@ func PrintError(w io.Writer, err error) {
                for _, e := range list {
                        fmt.Fprintf(w, "%s\n", e)
                }
-       } else {
+       } else if err != nil {
                fmt.Fprintf(w, "%s\n", err)
        }
 }
index da50874..6ef3e14 100644 (file)
@@ -81,7 +81,7 @@ func (s *Scanner) next() {
        }
 }
 
-// A mode value is set of flags (or 0).
+// A mode value is set of flags (or 0).
 // They control scanner behavior.
 //
 type Mode uint
index c1bd2e4..42ea793 100644 (file)
@@ -47,7 +47,7 @@ type (
        // JSStr("foo\\nbar") is fine, but JSStr("foo\\\nbar") is not.
        JSStr string
 
-       // URL encapsulates a known safe URL as defined in RFC 3896.
+       // URL encapsulates a known safe URL or URL substring (see RFC 3986).
        // A URL like `javascript:checkThatFormNotEditedBeforeLeavingPage()`
        // from a trusted source should go in the page, but by default dynamic
        // `javascript:` URLs are filtered out since they are a frequently
index 454c791..2ca76bf 100644 (file)
@@ -60,7 +60,7 @@ func urlProcessor(norm bool, args ...interface{}) string {
                c := s[i]
                switch c {
                // Single quote and parens are sub-delims in RFC 3986, but we
-               // escape them so the output can be embedded in in single
+               // escape them so the output can be embedded in single
                // quoted attributes and unquoted CSS url(...) constructs.
                // Single quotes are reserved in URLs, but are only used in
                // the obsolete "mark" rule in an appendix in RFC 3986
index d9adf6e..8da3611 100644 (file)
@@ -74,7 +74,9 @@ const (
        comMarker   = 0xfe // COMment.
 )
 
-// Maps from the zig-zag ordering to the natural ordering.
+// unzig maps from the zig-zag ordering to the natural ordering. For example,
+// unzig[3] is the column and row of the fourth element in zig-zag order. The
+// value is 16, which means first column (16%8 == 0) and third row (16/8 == 2).
 var unzig = [blockSize]int{
        0, 1, 8, 16, 9, 2, 3, 10,
        17, 24, 32, 25, 18, 11, 4, 5,
@@ -101,7 +103,7 @@ type decoder struct {
        nComp         int
        comp          [nColorComponent]component
        huff          [maxTc + 1][maxTh + 1]huffman
-       quant         [maxTq + 1]block
+       quant         [maxTq + 1]block // Quantization tables, in zig-zag order.
        b             bits
        tmp           [1024]byte
 }
@@ -264,6 +266,7 @@ func (d *decoder) processSOS(n int) error {
                                for j := 0; j < d.comp[i].h*d.comp[i].v; j++ {
                                        // TODO(nigeltao): make this a "var b block" once the compiler's escape
                                        // analysis is good enough to allocate it on the stack, not the heap.
+                                       // b is in natural (not zig-zag) order.
                                        b = block{}
 
                                        // Decode the DC coefficient, as specified in section F.2.2.1.
@@ -282,7 +285,7 @@ func (d *decoder) processSOS(n int) error {
                                        b[0] = dc[i] * qt[0]
 
                                        // Decode the AC coefficients, as specified in section F.2.2.2.
-                                       for k := 1; k < blockSize; k++ {
+                                       for zig := 1; zig < blockSize; zig++ {
                                                value, err := d.decodeHuffman(&d.huff[acTable][scan[i].ta])
                                                if err != nil {
                                                        return err
@@ -290,20 +293,20 @@ func (d *decoder) processSOS(n int) error {
                                                val0 := value >> 4
                                                val1 := value & 0x0f
                                                if val1 != 0 {
-                                                       k += int(val0)
-                                                       if k > blockSize {
+                                                       zig += int(val0)
+                                                       if zig > blockSize {
                                                                return FormatError("bad DCT index")
                                                        }
                                                        ac, err := d.receiveExtend(val1)
                                                        if err != nil {
                                                                return err
                                                        }
-                                                       b[unzig[k]] = ac * qt[k]
+                                                       b[unzig[zig]] = ac * qt[zig]
                                                } else {
                                                        if val0 != 0x0f {
                                                                break
                                                        }
-                                                       k += 0x0f
+                                                       zig += 0x0f
                                                }
                                        }
 
@@ -393,6 +396,15 @@ func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, error) {
                if marker == eoiMarker { // End Of Image.
                        break
                }
+               if rst0Marker <= marker && marker <= rst7Marker {
+                       // Figures B.2 and B.16 of the specification suggest that restart markers should
+                       // only occur between Entropy Coded Segments and not after the final ECS.
+                       // However, some encoders may generate incorrect JPEGs with a final restart
+                       // marker. That restart marker will be seen here instead of inside the processSOS
+                       // method, and is ignored as a harmless error. Restart markers have no extra data,
+                       // so we check for this before we read the 16-bit length of the segment.
+                       continue
+               }
 
                // Read the 16-bit length of the segment. The value includes the 2 bytes for the
                // length itself, so we subtract 2 to get the number of remaining bytes.
@@ -421,7 +433,7 @@ func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, error) {
                        err = d.processSOS(n)
                case marker == driMarker: // Define Restart Interval.
                        err = d.processDRI(n)
-               case marker >= app0Marker && marker <= app15Marker || marker == comMarker: // APPlication specific, or COMment.
+               case app0Marker <= marker && marker <= app15Marker || marker == comMarker: // APPlication specific, or COMment.
                        err = d.ignore(n)
                default:
                        err = UnsupportedError("unknown marker")
index 3322c09..099298e 100644 (file)
@@ -56,26 +56,28 @@ const (
        nQuantIndex
 )
 
-// unscaledQuant are the unscaled quantization tables. Each encoder copies and
-// scales the tables according to its quality parameter.
+// unscaledQuant are the unscaled quantization tables in zig-zag order. Each
+// encoder copies and scales the tables according to its quality parameter.
+// The values are derived from section K.1 after converting from natural to
+// zig-zag order.
 var unscaledQuant = [nQuantIndex][blockSize]byte{
        // Luminance.
        {
-               16, 11, 10, 16, 24, 40, 51, 61,
-               12, 12, 14, 19, 26, 58, 60, 55,
-               14, 13, 16, 24, 40, 57, 69, 56,
-               14, 17, 22, 29, 51, 87, 80, 62,
-               18, 22, 37, 56, 68, 109, 103, 77,
-               24, 35, 55, 64, 81, 104, 113, 92,
-               49, 64, 78, 87, 103, 121, 120, 101,
-               72, 92, 95, 98, 112, 100, 103, 99,
+               16, 11, 12, 14, 12, 10, 16, 14,
+               13, 14, 18, 17, 16, 19, 24, 40,
+               26, 24, 22, 22, 24, 49, 35, 37,
+               29, 40, 58, 51, 61, 60, 57, 51,
+               56, 55, 64, 72, 92, 78, 64, 68,
+               87, 69, 55, 56, 80, 109, 81, 87,
+               95, 98, 103, 104, 103, 62, 77, 113,
+               121, 112, 100, 120, 92, 101, 103, 99,
        },
        // Chrominance.
        {
-               17, 18, 24, 47, 99, 99, 99, 99,
-               18, 21, 26, 66, 99, 99, 99, 99,
-               24, 26, 56, 99, 99, 99, 99, 99,
-               47, 66, 99, 99, 99, 99, 99, 99,
+               17, 18, 18, 24, 21, 24, 47, 26,
+               26, 47, 99, 66, 56, 66, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
                99, 99, 99, 99, 99, 99, 99, 99,
                99, 99, 99, 99, 99, 99, 99, 99,
                99, 99, 99, 99, 99, 99, 99, 99,
@@ -222,7 +224,7 @@ type encoder struct {
        buf [16]byte
        // bits and nBits are accumulated bits to write to w.
        bits, nBits uint32
-       // quant is the scaled quantization tables.
+       // quant is the scaled quantization tables, in zig-zag order.
        quant [nQuantIndex][blockSize]byte
 }
 
@@ -301,7 +303,7 @@ func (e *encoder) writeMarkerHeader(marker uint8, markerlen int) {
 
 // writeDQT writes the Define Quantization Table marker.
 func (e *encoder) writeDQT() {
-       markerlen := 2 + int(nQuantIndex)*(1+blockSize)
+       const markerlen = 2 + int(nQuantIndex)*(1+blockSize)
        e.writeMarkerHeader(dqtMarker, markerlen)
        for i := range e.quant {
                e.writeByte(uint8(i))
@@ -311,7 +313,7 @@ func (e *encoder) writeDQT() {
 
 // writeSOF0 writes the Start Of Frame (Baseline) marker.
 func (e *encoder) writeSOF0(size image.Point) {
-       markerlen := 8 + 3*nColorComponent
+       const markerlen = 8 + 3*nColorComponent
        e.writeMarkerHeader(sof0Marker, markerlen)
        e.buf[0] = 8 // 8-bit color.
        e.buf[1] = uint8(size.Y >> 8)
@@ -344,6 +346,7 @@ func (e *encoder) writeDHT() {
 
 // writeBlock writes a block of pixel data using the given quantization table,
 // returning the post-quantized DC value of the DCT-transformed block.
+// b is in natural (not zig-zag) order.
 func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int) int {
        fdct(b)
        // Emit the DC delta.
@@ -351,8 +354,8 @@ func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int) int {
        e.emitHuffRLE(huffIndex(2*q+0), 0, dc-prevDC)
        // Emit the AC components.
        h, runLength := huffIndex(2*q+1), 0
-       for k := 1; k < blockSize; k++ {
-               ac := div(b[unzig[k]], (8 * int(e.quant[q][k])))
+       for zig := 1; zig < blockSize; zig++ {
+               ac := div(b[unzig[zig]], (8 * int(e.quant[q][zig])))
                if ac == 0 {
                        runLength++
                } else {
@@ -433,10 +436,12 @@ func scale(dst *block, src *[4]block) {
 //     - component 1 uses DC table 0 and AC table 0 "\x01\x00",
 //     - component 2 uses DC table 1 and AC table 1 "\x02\x11",
 //     - component 3 uses DC table 1 and AC table 1 "\x03\x11",
-//     - padding "\x00\x00\x00".
+//     - the bytes "\x00\x3f\x00". Section B.2.3 of the spec says that for
+//       sequential DCTs, those bytes (8-bit Ss, 8-bit Se, 4-bit Ah, 4-bit Al)
+//       should be 0x00, 0x3f, 0x00<<4 | 0x00.
 var sosHeader = []byte{
        0xff, 0xda, 0x00, 0x0c, 0x03, 0x01, 0x00, 0x02,
-       0x11, 0x03, 0x11, 0x00, 0x00, 0x00,
+       0x11, 0x03, 0x11, 0x00, 0x3f, 0x00,
 }
 
 // writeSOS writes the StartOfScan marker.
@@ -444,6 +449,7 @@ func (e *encoder) writeSOS(m image.Image) {
        e.write(sosHeader)
        var (
                // Scratch buffers to hold the YCbCr values.
+               // The blocks are in natural (not zig-zag) order.
                yBlock  block
                cbBlock [4]block
                crBlock [4]block
index b8e8fa3..8732df8 100644 (file)
@@ -6,6 +6,7 @@ package jpeg
 
 import (
        "bytes"
+       "fmt"
        "image"
        "image/color"
        "image/png"
@@ -15,6 +16,87 @@ import (
        "testing"
 )
 
+// zigzag maps from the natural ordering to the zig-zag ordering. For example,
+// zigzag[0*8 + 3] is the zig-zag sequence number of the element in the fourth
+// column and first row.
+var zigzag = [blockSize]int{
+       0, 1, 5, 6, 14, 15, 27, 28,
+       2, 4, 7, 13, 16, 26, 29, 42,
+       3, 8, 12, 17, 25, 30, 41, 43,
+       9, 11, 18, 24, 31, 40, 44, 53,
+       10, 19, 23, 32, 39, 45, 52, 54,
+       20, 22, 33, 38, 46, 51, 55, 60,
+       21, 34, 37, 47, 50, 56, 59, 61,
+       35, 36, 48, 49, 57, 58, 62, 63,
+}
+
+func TestZigUnzig(t *testing.T) {
+       for i := 0; i < blockSize; i++ {
+               if unzig[zigzag[i]] != i {
+                       t.Errorf("unzig[zigzag[%d]] == %d", i, unzig[zigzag[i]])
+               }
+               if zigzag[unzig[i]] != i {
+                       t.Errorf("zigzag[unzig[%d]] == %d", i, zigzag[unzig[i]])
+               }
+       }
+}
+
+// unscaledQuantInNaturalOrder are the unscaled quantization tables in
+// natural (not zig-zag) order, as specified in section K.1.
+var unscaledQuantInNaturalOrder = [nQuantIndex][blockSize]byte{
+       // Luminance.
+       {
+               16, 11, 10, 16, 24, 40, 51, 61,
+               12, 12, 14, 19, 26, 58, 60, 55,
+               14, 13, 16, 24, 40, 57, 69, 56,
+               14, 17, 22, 29, 51, 87, 80, 62,
+               18, 22, 37, 56, 68, 109, 103, 77,
+               24, 35, 55, 64, 81, 104, 113, 92,
+               49, 64, 78, 87, 103, 121, 120, 101,
+               72, 92, 95, 98, 112, 100, 103, 99,
+       },
+       // Chrominance.
+       {
+               17, 18, 24, 47, 99, 99, 99, 99,
+               18, 21, 26, 66, 99, 99, 99, 99,
+               24, 26, 56, 99, 99, 99, 99, 99,
+               47, 66, 99, 99, 99, 99, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
+               99, 99, 99, 99, 99, 99, 99, 99,
+       },
+}
+
+func TestUnscaledQuant(t *testing.T) {
+       bad := false
+       for i := quantIndex(0); i < nQuantIndex; i++ {
+               for zig := 0; zig < blockSize; zig++ {
+                       got := unscaledQuant[i][zig]
+                       want := unscaledQuantInNaturalOrder[i][unzig[zig]]
+                       if got != want {
+                               t.Errorf("i=%d, zig=%d: got %d, want %d", i, zig, got, want)
+                               bad = true
+                       }
+               }
+       }
+       if bad {
+               names := [nQuantIndex]string{"Luminance", "Chrominance"}
+               buf := &bytes.Buffer{}
+               for i, name := range names {
+                       fmt.Fprintf(buf, "// %s.\n{\n", name)
+                       for zig := 0; zig < blockSize; zig++ {
+                               fmt.Fprintf(buf, "%d, ", unscaledQuantInNaturalOrder[i][unzig[zig]])
+                               if zig%8 == 7 {
+                                       buf.WriteString("\n")
+                               }
+                       }
+                       buf.WriteString("},\n")
+               }
+               t.Logf("expected unscaledQuant values:\n%s", buf.String())
+       }
+}
+
 var testCase = []struct {
        filename  string
        quality   int
index 55f634c..04ee2cf 100644 (file)
@@ -20,7 +20,7 @@ var (
 )
 
 // Uniform is an infinite-sized Image of uniform color.
-// It implements the color.Color, color.ColorModel, and Image interfaces.
+// It implements the color.Color, color.Model, and Image interfaces.
 type Uniform struct {
        C color.Color
 }
index 54bf159..5187eff 100644 (file)
@@ -130,11 +130,23 @@ type ReadWriteSeeker interface {
 }
 
 // ReaderFrom is the interface that wraps the ReadFrom method.
+//
+// ReadFrom reads data from r until EOF or error.
+// The return value n is the number of bytes read.
+// Any error except io.EOF encountered during the read is also returned.
+//
+// The Copy function uses ReaderFrom if available.
 type ReaderFrom interface {
        ReadFrom(r Reader) (n int64, err error)
 }
 
 // WriterTo is the interface that wraps the WriteTo method.
+//
+// WriteTo writes data to w until there's no more data to write or
+// when an error occurs. The return value n is the number of bytes
+// written. Any error encountered during the write is also returned.
+//
+// The Copy function uses WriterTo if available.
 type WriterTo interface {
        WriteTo(w Writer) (n int64, err error)
 }
index f53310c..e5620e1 100644 (file)
@@ -138,7 +138,11 @@ func (w *Writer) Debug(m string) (err error) {
 }
 
 func (n netConn) writeBytes(p Priority, prefix string, b []byte) (int, error) {
-       _, err := fmt.Fprintf(n.conn, "<%d>%s: %s\n", p, prefix, b)
+       nl := ""
+       if len(b) == 0 || b[len(b)-1] != '\n' {
+               nl = "\n"
+       }
+       _, err := fmt.Fprintf(n.conn, "<%d>%s: %s%s", p, prefix, b, nl)
        if err != nil {
                return 0, err
        }
@@ -146,7 +150,11 @@ func (n netConn) writeBytes(p Priority, prefix string, b []byte) (int, error) {
 }
 
 func (n netConn) writeString(p Priority, prefix string, s string) (int, error) {
-       _, err := fmt.Fprintf(n.conn, "<%d>%s: %s\n", p, prefix, s)
+       nl := ""
+       if len(s) == 0 || s[len(s)-1] != '\n' {
+               nl = "\n"
+       }
+       _, err := fmt.Fprintf(n.conn, "<%d>%s: %s%s", p, prefix, s, nl)
        if err != nil {
                return 0, err
        }
index 0fd6239..b7579c3 100644 (file)
@@ -98,20 +98,32 @@ func TestUDPDial(t *testing.T) {
 }
 
 func TestWrite(t *testing.T) {
-       done := make(chan string)
-       startServer(done)
-       l, err := Dial("udp", serverAddr, LOG_ERR, "syslog_test")
-       if err != nil {
-               t.Fatalf("syslog.Dial() failed: %s", err)
+       tests := []struct {
+               pri Priority
+               pre string
+               msg string
+               exp string
+       }{
+               {LOG_ERR, "syslog_test", "", "<3>syslog_test: \n"},
+               {LOG_ERR, "syslog_test", "write test", "<3>syslog_test: write test\n"},
+               // Write should not add \n if there already is one
+               {LOG_ERR, "syslog_test", "write test 2\n", "<3>syslog_test: write test 2\n"},
        }
-       msg := "write test"
-       _, err = io.WriteString(l, msg)
-       if err != nil {
-               t.Fatalf("WriteString() failed: %s", err)
-       }
-       expected := "<3>syslog_test: write test\n"
-       rcvd := <-done
-       if rcvd != expected {
-               t.Fatalf("s.Info() = '%q', but wanted '%q'", rcvd, expected)
+
+       for _, test := range tests {
+               done := make(chan string)
+               startServer(done)
+               l, err := Dial("udp", serverAddr, test.pri, test.pre)
+               if err != nil {
+                       t.Fatalf("syslog.Dial() failed: %s", err)
+               }
+               _, err = io.WriteString(l, test.msg)
+               if err != nil {
+                       t.Fatalf("WriteString() failed: %s", err)
+               }
+               rcvd := <-done
+               if rcvd != test.exp {
+                       t.Fatalf("s.Info() = '%q', but wanted '%q'", rcvd, test.exp)
+               }
        }
 }
index ed66a42..35c33ce 100644 (file)
@@ -1693,6 +1693,17 @@ func alike(a, b float64) bool {
        return false
 }
 
+func TestNaN(t *testing.T) {
+       f64 := NaN()
+       if f64 == f64 {
+               t.Fatalf("NaN() returns %g, expected NaN", f64)
+       }
+       f32 := float32(f64)
+       if f32 == f32 {
+               t.Fatalf("float32(NaN()) is %g, expected NaN", f32)
+       }
+}
+
 func TestAcos(t *testing.T) {
        for i := 0; i < len(vf); i++ {
                a := vf[i] / 10
index eaa6ff0..6d81823 100644 (file)
@@ -396,7 +396,7 @@ func (z nat) mul(x, y nat) nat {
        }
 
        // use basic multiplication if the numbers are small
-       if n < karatsubaThreshold || n < 2 {
+       if n < karatsubaThreshold {
                z = z.make(m + n)
                basicMul(z, x, y)
                return z.norm()
index 1cf60ce..0df0b1c 100644 (file)
@@ -5,7 +5,7 @@
 package math
 
 const (
-       uvnan    = 0x7FF0000000000001
+       uvnan    = 0x7FF8000000000001
        uvinf    = 0x7FF0000000000000
        uvneginf = 0xFFF0000000000000
        mask     = 0x7FF
index a233e8e..98bb04d 100644 (file)
@@ -4,7 +4,7 @@
 
 package math
 
-// The original C code and the the comment below are from
+// The original C code and the comment below are from
 // FreeBSD's /usr/src/lib/msun/src/e_remainder.c and came
 // with this notice.  The go code is a simplified version of
 // the original C.
index 83cc411..09e941e 100644 (file)
@@ -22,7 +22,7 @@ func isTokenChar(r rune) bool {
        return r > 0x20 && r < 0x7f && !isTSpecial(r)
 }
 
-// isToken returns true if s is a 'token' as as defined by RFC 1521
+// isToken returns true if s is a 'token' as defined by RFC 1521
 // and RFC 2045.
 func isToken(s string) bool {
        if s == "" {
index e9e337b..fb07e1a 100644 (file)
@@ -71,7 +71,7 @@ func (p *Part) parseContentDisposition() {
        }
 }
 
-// NewReader creates a new multipart Reader reading from r using the
+// NewReader creates a new multipart Reader reading from reader using the
 // given MIME boundary.
 func NewReader(reader io.Reader, boundary string) *Reader {
        b := []byte("\r\n--" + boundary + "--")
index 10ca5fa..5191239 100644 (file)
@@ -173,7 +173,7 @@ func (a stringAddr) String() string  { return a.addr }
 
 // Listen announces on the local network address laddr.
 // The network string net must be a stream-oriented network:
-// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket".
+// "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
 func Listen(net, laddr string) (Listener, error) {
        afnet, a, err := resolveNetAddr("listen", net, laddr)
        if err != nil {
index 76c953b..ff4f4f8 100644 (file)
@@ -645,10 +645,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
 }
 
 func (fd *netFD) dup() (f *os.File, err error) {
+       syscall.ForkLock.RLock()
        ns, err := syscall.Dup(fd.sysfd)
        if err != nil {
+               syscall.ForkLock.RUnlock()
                return nil, &OpError{"dup", fd.net, fd.laddr, err}
        }
+       syscall.CloseOnExec(ns)
+       syscall.ForkLock.RUnlock()
 
        // We want blocking mode for the new fd, hence the double negative.
        if err = syscall.SetNonblock(ns, false); err != nil {
index fc6c6fa..837326e 100644 (file)
@@ -12,13 +12,18 @@ import (
 )
 
 func newFileFD(f *os.File) (*netFD, error) {
+       syscall.ForkLock.RLock()
        fd, err := syscall.Dup(int(f.Fd()))
        if err != nil {
+               syscall.ForkLock.RUnlock()
                return nil, os.NewSyscallError("dup", err)
        }
+       syscall.CloseOnExec(fd)
+       syscall.ForkLock.RUnlock()
 
-       proto, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+       sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
        if err != nil {
+               closesocket(fd)
                return nil, os.NewSyscallError("getsockopt", err)
        }
 
@@ -31,24 +36,24 @@ func newFileFD(f *os.File) (*netFD, error) {
                return nil, syscall.EINVAL
        case *syscall.SockaddrInet4:
                family = syscall.AF_INET
-               if proto == syscall.SOCK_DGRAM {
+               if sotype == syscall.SOCK_DGRAM {
                        toAddr = sockaddrToUDP
-               } else if proto == syscall.SOCK_RAW {
+               } else if sotype == syscall.SOCK_RAW {
                        toAddr = sockaddrToIP
                }
        case *syscall.SockaddrInet6:
                family = syscall.AF_INET6
-               if proto == syscall.SOCK_DGRAM {
+               if sotype == syscall.SOCK_DGRAM {
                        toAddr = sockaddrToUDP
-               } else if proto == syscall.SOCK_RAW {
+               } else if sotype == syscall.SOCK_RAW {
                        toAddr = sockaddrToIP
                }
        case *syscall.SockaddrUnix:
                family = syscall.AF_UNIX
                toAddr = sockaddrToUnix
-               if proto == syscall.SOCK_DGRAM {
+               if sotype == syscall.SOCK_DGRAM {
                        toAddr = sockaddrToUnixgram
-               } else if proto == syscall.SOCK_SEQPACKET {
+               } else if sotype == syscall.SOCK_SEQPACKET {
                        toAddr = sockaddrToUnixpacket
                }
        }
@@ -56,8 +61,9 @@ func newFileFD(f *os.File) (*netFD, error) {
        sa, _ = syscall.Getpeername(fd)
        raddr := toAddr(sa)
 
-       netfd, err := newFD(fd, family, proto, laddr.Network())
+       netfd, err := newFD(fd, family, sotype, laddr.Network())
        if err != nil {
+               closesocket(fd)
                return nil, err
        }
        netfd.setAddr(laddr, raddr)
index 54564e0..8944142 100644 (file)
@@ -14,6 +14,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "log"
        "net/url"
        "strings"
 )
@@ -35,7 +36,8 @@ type Client struct {
        // following an HTTP redirect. The arguments req and via
        // are the upcoming request and the requests made already,
        // oldest first. If CheckRedirect returns an error, the client
-       // returns that error instead of issue the Request req.
+       // returns that error (wrapped in a url.Error) instead of
+       // issuing the Request req.
        //
        // If CheckRedirect is nil, the Client uses its default policy,
        // which is to stop after 10 consecutive requests.
@@ -87,9 +89,13 @@ type readClose struct {
 // Do sends an HTTP request and returns an HTTP response, following
 // policy (e.g. redirects, cookies, auth) as configured on the client.
 //
-// A non-nil response always contains a non-nil resp.Body.
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or if there was an HTTP protocol error.
+// A non-2xx response doesn't cause an error.
 //
-// Callers should close resp.Body when done reading from it. If
+// When err is nil, resp always contains a non-nil resp.Body.
+//
+// Callers should close res.Body when done reading from it. If
 // resp.Body is not closed, the Client's underlying RoundTripper
 // (typically Transport) may not be able to re-use a persistent TCP
 // connection to the server for a subsequent "keep-alive" request.
@@ -102,7 +108,8 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
        return send(req, c.Transport)
 }
 
-// send issues an HTTP request.  Caller should close resp.Body when done reading from it.
+// send issues an HTTP request.
+// Caller should close resp.Body when done reading from it.
 func send(req *Request, t RoundTripper) (resp *Response, err error) {
        if t == nil {
                t = DefaultTransport
@@ -130,7 +137,14 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
        if u := req.URL.User; u != nil {
                req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String())))
        }
-       return t.RoundTrip(req)
+       resp, err = t.RoundTrip(req)
+       if err != nil {
+               if resp != nil {
+                       log.Printf("RoundTripper returned a response & error; ignoring response")
+               }
+               return nil, err
+       }
+       return resp, nil
 }
 
 // True if the specified HTTP status code is one for which the Get utility should
@@ -151,10 +165,15 @@ func shouldRedirect(statusCode int) bool {
 //    303 (See Other)
 //    307 (Temporary Redirect)
 //
-// Caller should close r.Body when done reading from it.
+// An error is returned if there were too many redirects or if there
+// was an HTTP protocol error. A non-2xx response doesn't cause an
+// error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
 //
 // Get is a wrapper around DefaultClient.Get.
-func Get(url string) (r *Response, err error) {
+func Get(url string) (resp *Response, err error) {
        return DefaultClient.Get(url)
 }
 
@@ -167,8 +186,13 @@ func Get(url string) (r *Response, err error) {
 //    303 (See Other)
 //    307 (Temporary Redirect)
 //
-// Caller should close r.Body when done reading from it.
-func (c *Client) Get(url string) (r *Response, err error) {
+// An error is returned if the Client's CheckRedirect function fails
+// or if there was an HTTP protocol error. A non-2xx response doesn't
+// cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Get(url string) (resp *Response, err error) {
        req, err := NewRequest("GET", url, nil)
        if err != nil {
                return nil, err
@@ -176,7 +200,7 @@ func (c *Client) Get(url string) (r *Response, err error) {
        return c.doFollowingRedirects(req)
 }
 
-func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
+func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) {
        // TODO: if/when we add cookie support, the redirected request shouldn't
        // necessarily supply the same cookies as the original.
        var base *url.URL
@@ -224,17 +248,17 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
                        req.AddCookie(cookie)
                }
                urlStr = req.URL.String()
-               if r, err = send(req, c.Transport); err != nil {
+               if resp, err = send(req, c.Transport); err != nil {
                        break
                }
-               if c := r.Cookies(); len(c) > 0 {
+               if c := resp.Cookies(); len(c) > 0 {
                        jar.SetCookies(req.URL, c)
                }
 
-               if shouldRedirect(r.StatusCode) {
-                       r.Body.Close()
-                       if urlStr = r.Header.Get("Location"); urlStr == "" {
-                               err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode))
+               if shouldRedirect(resp.StatusCode) {
+                       resp.Body.Close()
+                       if urlStr = resp.Header.Get("Location"); urlStr == "" {
+                               err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
                                break
                        }
                        base = req.URL
@@ -244,13 +268,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
                return
        }
 
+       if resp != nil {
+               resp.Body.Close()
+       }
+
        method := ireq.Method
-       err = &url.Error{
+       return nil, &url.Error{
                Op:  method[0:1] + strings.ToLower(method[1:]),
                URL: urlStr,
                Err: err,
        }
-       return
 }
 
 func defaultCheckRedirect(req *Request, via []*Request) error {
@@ -262,17 +289,17 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
 
 // Post issues a POST to the specified URL.
 //
-// Caller should close r.Body when done reading from it.
+// Caller should close resp.Body when done reading from it.
 //
 // Post is a wrapper around DefaultClient.Post
-func Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
        return DefaultClient.Post(url, bodyType, body)
 }
 
 // Post issues a POST to the specified URL.
 //
-// Caller should close r.Body when done reading from it.
-func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
        req, err := NewRequest("POST", url, body)
        if err != nil {
                return nil, err
@@ -283,28 +310,30 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response,
                        req.AddCookie(cookie)
                }
        }
-       r, err = send(req, c.Transport)
+       resp, err = send(req, c.Transport)
        if err == nil && c.Jar != nil {
-               c.Jar.SetCookies(req.URL, r.Cookies())
+               c.Jar.SetCookies(req.URL, resp.Cookies())
        }
-       return r, err
+       return
 }
 
-// PostForm issues a POST to the specified URL, 
-// with data's keys and values urlencoded as the request body.
+// PostForm issues a POST to the specified URL, with data's keys and
+// values URL-encoded as the request body.
 //
-// Caller should close r.Body when done reading from it.
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
 //
 // PostForm is a wrapper around DefaultClient.PostForm
-func PostForm(url string, data url.Values) (r *Response, err error) {
+func PostForm(url string, data url.Values) (resp *Response, err error) {
        return DefaultClient.PostForm(url, data)
 }
 
 // PostForm issues a POST to the specified URL, 
 // with data's keys and values urlencoded as the request body.
 //
-// Caller should close r.Body when done reading from it.
-func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) {
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
        return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
 }
 
@@ -318,7 +347,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error)
 //    307 (Temporary Redirect)
 //
 // Head is a wrapper around DefaultClient.Head
-func Head(url string) (r *Response, err error) {
+func Head(url string) (resp *Response, err error) {
        return DefaultClient.Head(url)
 }
 
@@ -330,7 +359,7 @@ func Head(url string) (r *Response, err error) {
 //    302 (Found)
 //    303 (See Other)
 //    307 (Temporary Redirect)
-func (c *Client) Head(url string) (r *Response, err error) {
+func (c *Client) Head(url string) (resp *Response, err error) {
        req, err := NewRequest("HEAD", url, nil)
        if err != nil {
                return nil, err
index 9b4261b..09fcc1c 100644 (file)
@@ -8,6 +8,7 @@ package http_test
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "errors"
        "fmt"
        "io"
@@ -231,9 +232,8 @@ func TestRedirects(t *testing.T) {
 
        checkErr = errors.New("no redirects allowed")
        res, err = c.Get(ts.URL)
-       finalUrl = res.Request.URL.String()
-       if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
-               t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+       if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
+               t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
        }
 }
 
@@ -465,3 +465,49 @@ func TestClientErrorWithRequestURI(t *testing.T) {
                t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
        }
 }
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+       certs := x509.NewCertPool()
+       for _, c := range ts.TLS.Certificates {
+               roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+               if err != nil {
+                       t.Fatalf("error parsing server's root cert: %v", err)
+               }
+               for _, root := range roots {
+                       certs.AddCert(root)
+               }
+       }
+       return &Transport{
+               TLSClientConfig: &tls.Config{RootCAs: certs},
+       }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.TLS.ServerName != "127.0.0.1" {
+                       t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+               }
+       }))
+       defer ts.Close()
+
+       c := &Client{Transport: newTLSTransport(t, ts)}
+       if _, err := c.Get(ts.URL); err != nil {
+               t.Fatalf("expected successful TLS connection, got error: %v", err)
+       }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       defer ts.Close()
+
+       trans := newTLSTransport(t, ts)
+       trans.TLSClientConfig.ServerName = "badserver"
+       c := &Client{Transport: trans}
+       _, err := c.Get(ts.URL)
+       if err == nil {
+               t.Fatalf("expected an error")
+       }
+       if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+               t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+       }
+}
index ec81440..22073ea 100644 (file)
@@ -43,10 +43,10 @@ func ExampleGet() {
                log.Fatal(err)
        }
        robots, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
        if err != nil {
                log.Fatal(err)
        }
-       res.Body.Close()
        fmt.Printf("%s", robots)
 }
 
index 13640ca..313c6af 100644 (file)
@@ -11,8 +11,8 @@ import "time"
 
 func (t *Transport) IdleConnKeysForTesting() (keys []string) {
        keys = make([]string, 0)
-       t.lk.Lock()
-       defer t.lk.Unlock()
+       t.idleLk.Lock()
+       defer t.idleLk.Unlock()
        if t.idleConn == nil {
                return
        }
@@ -23,8 +23,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
 }
 
 func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
-       t.lk.Lock()
-       defer t.lk.Unlock()
+       t.idleLk.Lock()
+       defer t.idleLk.Unlock()
        if t.idleConn == nil {
                return 0
        }
index f35dd32..208d6ca 100644 (file)
@@ -11,6 +11,8 @@ import (
        "fmt"
        "io"
        "mime"
+       "mime/multipart"
+       "net/textproto"
        "os"
        "path"
        "path/filepath"
@@ -26,7 +28,8 @@ import (
 type Dir string
 
 func (d Dir) Open(name string) (File, error) {
-       if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
+       if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 ||
+               strings.Contains(name, "\x00") {
                return nil, errors.New("http: invalid character in file path")
        }
        dir := string(d)
@@ -123,8 +126,9 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
        code := StatusOK
 
        // If Content-Type isn't set, use the file's extension to find it.
-       if w.Header().Get("Content-Type") == "" {
-               ctype := mime.TypeByExtension(filepath.Ext(name))
+       ctype := w.Header().Get("Content-Type")
+       if ctype == "" {
+               ctype = mime.TypeByExtension(filepath.Ext(name))
                if ctype == "" {
                        // read a chunk to decide between utf-8 text and binary
                        var buf [1024]byte
@@ -141,18 +145,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
        }
 
        // handle Content-Range header.
-       // TODO(adg): handle multiple ranges
        sendSize := size
+       var sendContent io.Reader = content
        if size >= 0 {
                ranges, err := parseRange(r.Header.Get("Range"), size)
-               if err == nil && len(ranges) > 1 {
-                       err = errors.New("multiple ranges not supported")
-               }
                if err != nil {
                        Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
                        return
                }
-               if len(ranges) == 1 {
+               if sumRangesSize(ranges) >= size {
+                       // The total number of bytes in all the ranges
+                       // is larger than the size of the file by
+                       // itself, so this is probably an attack, or a
+                       // dumb client.  Ignore the range request.
+                       ranges = nil
+               }
+               switch {
+               case len(ranges) == 1:
+                       // RFC 2616, Section 14.16:
+                       // "When an HTTP message includes the content of a single
+                       // range (for example, a response to a request for a
+                       // single range, or to a request for a set of ranges
+                       // that overlap without any holes), this content is
+                       // transmitted with a Content-Range header, and a
+                       // Content-Length header showing the number of bytes
+                       // actually transferred.
+                       // ...
+                       // A response to a request for a single range MUST NOT
+                       // be sent using the multipart/byteranges media type."
                        ra := ranges[0]
                        if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
                                Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
@@ -160,7 +180,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
                        }
                        sendSize = ra.length
                        code = StatusPartialContent
-                       w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size))
+                       w.Header().Set("Content-Range", ra.contentRange(size))
+               case len(ranges) > 1:
+                       for _, ra := range ranges {
+                               if ra.start > size {
+                                       Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+                                       return
+                               }
+                       }
+                       sendSize = rangesMIMESize(ranges, ctype, size)
+                       code = StatusPartialContent
+
+                       pr, pw := io.Pipe()
+                       mw := multipart.NewWriter(pw)
+                       w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+                       sendContent = pr
+                       defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+                       go func() {
+                               for _, ra := range ranges {
+                                       part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+                                       if err != nil {
+                                               pw.CloseWithError(err)
+                                               return
+                                       }
+                                       if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+                                               pw.CloseWithError(err)
+                                               return
+                                       }
+                                       if _, err := io.CopyN(part, content, ra.length); err != nil {
+                                               pw.CloseWithError(err)
+                                               return
+                                       }
+                               }
+                               mw.Close()
+                               pw.Close()
+                       }()
                }
 
                w.Header().Set("Accept-Ranges", "bytes")
@@ -172,11 +226,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
        w.WriteHeader(code)
 
        if r.Method != "HEAD" {
-               if sendSize == -1 {
-                       io.Copy(w, content)
-               } else {
-                       io.CopyN(w, content, sendSize)
-               }
+               io.CopyN(w, sendContent, sendSize)
        }
 }
 
@@ -243,9 +293,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
 
        // use contents of index.html for directory, if present
        if d.IsDir() {
-               if checkLastModified(w, r, d.ModTime()) {
-                       return
-               }
                index := name + indexPage
                ff, err := fs.Open(index)
                if err == nil {
@@ -259,11 +306,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
                }
        }
 
+       // Still a directory? (we didn't find an index.html file)
        if d.IsDir() {
+               if checkLastModified(w, r, d.ModTime()) {
+                       return
+               }
                dirList(w, f)
                return
        }
 
+       // serverContent will check modification time
        serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f)
 }
 
@@ -312,6 +364,17 @@ type httpRange struct {
        start, length int64
 }
 
+func (r httpRange) contentRange(size int64) string {
+       return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+       return textproto.MIMEHeader{
+               "Content-Range": {r.contentRange(size)},
+               "Content-Type":  {contentType},
+       }
+}
+
 // parseRange parses a Range header string as per RFC 2616.
 func parseRange(s string, size int64) ([]httpRange, error) {
        if s == "" {
@@ -323,11 +386,15 @@ func parseRange(s string, size int64) ([]httpRange, error) {
        }
        var ranges []httpRange
        for _, ra := range strings.Split(s[len(b):], ",") {
+               ra = strings.TrimSpace(ra)
+               if ra == "" {
+                       continue
+               }
                i := strings.Index(ra, "-")
                if i < 0 {
                        return nil, errors.New("invalid range")
                }
-               start, end := ra[:i], ra[i+1:]
+               start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
                var r httpRange
                if start == "" {
                        // If no start is specified, end specifies the
@@ -365,3 +432,32 @@ func parseRange(s string, size int64) ([]httpRange, error) {
        }
        return ranges, nil
 }
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+       *w += countingWriter(len(p))
+       return len(p), nil
+}
+
+// rangesMIMESize returns the nunber of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+       var w countingWriter
+       mw := multipart.NewWriter(&w)
+       for _, ra := range ranges {
+               mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+               encSize += ra.length
+       }
+       mw.Close()
+       encSize += int64(w)
+       return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+       for _, ra := range ranges {
+               size += ra.length
+       }
+       return
+}
index ffba6a7..17329fb 100644 (file)
@@ -10,12 +10,15 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "mime"
+       "mime/multipart"
        "net"
        . "net/http"
        "net/http/httptest"
        "net/url"
        "os"
        "os/exec"
+       "path"
        "path/filepath"
        "regexp"
        "runtime"
@@ -25,21 +28,29 @@ import (
 )
 
 const (
-       testFile       = "testdata/file"
-       testFileLength = 11
+       testFile    = "testdata/file"
+       testFileLen = 11
 )
 
+type wantRange struct {
+       start, end int64 // range [start,end)
+}
+
 var ServeFileRangeTests = []struct {
-       start, end int
-       r          string
-       code       int
+       r      string
+       code   int
+       ranges []wantRange
 }{
-       {0, testFileLength, "", StatusOK},
-       {0, 5, "0-4", StatusPartialContent},
-       {2, testFileLength, "2-", StatusPartialContent},
-       {testFileLength - 5, testFileLength, "-5", StatusPartialContent},
-       {3, 8, "3-7", StatusPartialContent},
-       {0, 0, "20-", StatusRequestedRangeNotSatisfiable},
+       {r: "", code: StatusOK},
+       {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+       {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+       {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+       {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+       {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable},
+       {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+       {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+       {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+       {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
 }
 
 func TestServeFile(t *testing.T) {
@@ -65,33 +76,81 @@ func TestServeFile(t *testing.T) {
 
        // straight GET
        _, body := getBody(t, "straight get", req)
-       if !equal(body, file) {
+       if !bytes.Equal(body, file) {
                t.Fatalf("body mismatch: got %q, want %q", body, file)
        }
 
        // Range tests
-       for i, rt := range ServeFileRangeTests {
-               req.Header.Set("Range", "bytes="+rt.r)
-               if rt.r == "" {
-                       req.Header["Range"] = nil
+       for _, rt := range ServeFileRangeTests {
+               if rt.r != "" {
+                       req.Header.Set("Range", rt.r)
                }
-               r, body := getBody(t, fmt.Sprintf("test %d", i), req)
-               if r.StatusCode != rt.code {
-                       t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code)
+               resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+               if resp.StatusCode != rt.code {
+                       t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
                }
                if rt.code == StatusRequestedRangeNotSatisfiable {
                        continue
                }
-               h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength)
-               if rt.r == "" {
-                       h = ""
+               wantContentRange := ""
+               if len(rt.ranges) == 1 {
+                       rng := rt.ranges[0]
+                       wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+               }
+               cr := resp.Header.Get("Content-Range")
+               if cr != wantContentRange {
+                       t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
                }
-               cr := r.Header.Get("Content-Range")
-               if cr != h {
-                       t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
+               ct := resp.Header.Get("Content-Type")
+               if len(rt.ranges) == 1 {
+                       rng := rt.ranges[0]
+                       wantBody := file[rng.start:rng.end]
+                       if !bytes.Equal(body, wantBody) {
+                               t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+                       }
+                       if strings.HasPrefix(ct, "multipart/byteranges") {
+                               t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r)
+                       }
                }
-               if !equal(body, file[rt.start:rt.end]) {
-                       t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
+               if len(rt.ranges) > 1 {
+                       typ, params, err := mime.ParseMediaType(ct)
+                       if err != nil {
+                               t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+                               continue
+                       }
+                       if typ != "multipart/byteranges" {
+                               t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r)
+                               continue
+                       }
+                       if params["boundary"] == "" {
+                               t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+                       }
+                       if g, w := resp.ContentLength, int64(len(body)); g != w {
+                               t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+                       }
+                       mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+                       for ri, rng := range rt.ranges {
+                               part, err := mr.NextPart()
+                               if err != nil {
+                                       t.Fatalf("range=%q, reading part index %d: %v", rt.r, ri, err)
+                               }
+                               body, err := ioutil.ReadAll(part)
+                               if err != nil {
+                                       t.Fatalf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+                               }
+                               wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+                               wantBody := file[rng.start:rng.end]
+                               if !bytes.Equal(body, wantBody) {
+                                       t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+                               }
+                               if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+                                       t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+                               }
+                       }
+                       _, err = mr.NextPart()
+                       if err != io.EOF {
+                               t.Errorf("range=%q; expected final error io.EOF; got %v", err)
+                       }
                }
        }
 }
@@ -276,6 +335,11 @@ func TestServeFileMimeType(t *testing.T) {
 }
 
 func TestServeFileFromCWD(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               // TODO(brainman): find out why this test is broken
+               t.Logf("Temporarily skipping test on Windows; see http://golang.org/issue/3917")
+               return
+       }
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
                ServeFile(w, r, "fs_test.go")
        }))
@@ -325,6 +389,139 @@ func TestServeIndexHtml(t *testing.T) {
        }
 }
 
+func TestFileServerZeroByte(t *testing.T) {
+       ts := httptest.NewServer(FileServer(Dir(".")))
+       defer ts.Close()
+
+       res, err := Get(ts.URL + "/..\x00")
+       if err != nil {
+               t.Fatal(err)
+       }
+       b, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal("reading Body:", err)
+       }
+       if res.StatusCode == 200 {
+               t.Errorf("got status 200; want an error. Body is:\n%s", string(b))
+       }
+}
+
+type fakeFileInfo struct {
+       dir      bool
+       basename string
+       modtime  time.Time
+       ents     []*fakeFileInfo
+       contents string
+}
+
+func (f *fakeFileInfo) Name() string       { return f.basename }
+func (f *fakeFileInfo) Sys() interface{}   { return nil }
+func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
+func (f *fakeFileInfo) IsDir() bool        { return f.dir }
+func (f *fakeFileInfo) Size() int64        { return int64(len(f.contents)) }
+func (f *fakeFileInfo) Mode() os.FileMode {
+       if f.dir {
+               return 0755 | os.ModeDir
+       }
+       return 0644
+}
+
+type fakeFile struct {
+       io.ReadSeeker
+       fi   *fakeFileInfo
+       path string // as opened
+}
+
+func (f *fakeFile) Close() error               { return nil }
+func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil }
+func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
+       if !f.fi.dir {
+               return nil, os.ErrInvalid
+       }
+       var fis []os.FileInfo
+       for _, fi := range f.fi.ents {
+               fis = append(fis, fi)
+       }
+       return fis, nil
+}
+
+type fakeFS map[string]*fakeFileInfo
+
+func (fs fakeFS) Open(name string) (File, error) {
+       name = path.Clean(name)
+       f, ok := fs[name]
+       if !ok {
+               println("fake filesystem didn't find file", name)
+               return nil, os.ErrNotExist
+       }
+       return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
+}
+
+func TestDirectoryIfNotModified(t *testing.T) {
+       const indexContents = "I am a fake index.html file"
+       fileMod := time.Unix(1000000000, 0).UTC()
+       fileModStr := fileMod.Format(TimeFormat)
+       dirMod := time.Unix(123, 0).UTC()
+       indexFile := &fakeFileInfo{
+               basename: "index.html",
+               modtime:  fileMod,
+               contents: indexContents,
+       }
+       fs := fakeFS{
+               "/": &fakeFileInfo{
+                       dir:     true,
+                       modtime: dirMod,
+                       ents:    []*fakeFileInfo{indexFile},
+               },
+               "/index.html": indexFile,
+       }
+
+       ts := httptest.NewServer(FileServer(fs))
+       defer ts.Close()
+
+       res, err := Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       b, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(b) != indexContents {
+               t.Fatalf("Got body %q; want %q", b, indexContents)
+       }
+       res.Body.Close()
+
+       lastMod := res.Header.Get("Last-Modified")
+       if lastMod != fileModStr {
+               t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
+       }
+
+       req, _ := NewRequest("GET", ts.URL, nil)
+       req.Header.Set("If-Modified-Since", lastMod)
+
+       res, err = DefaultClient.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if res.StatusCode != 304 {
+               t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
+       }
+       res.Body.Close()
+
+       // Advance the index.html file's modtime, but not the directory's.
+       indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
+
+       res, err = DefaultClient.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if res.StatusCode != 200 {
+               t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
+       }
+       res.Body.Close()
+}
+
 func TestServeContent(t *testing.T) {
        type req struct {
                name    string
@@ -464,15 +661,3 @@ func TestLinuxSendfileChild(*testing.T) {
                panic(err)
        }
 }
-
-func equal(a, b []byte) bool {
-       if len(a) != len(b) {
-               return false
-       }
-       for i := range a {
-               if a[i] != b[i] {
-                       return false
-               }
-       }
-       return true
-}
index b107c31..6be94f9 100644 (file)
@@ -76,3 +76,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
 // the rest are converted to lowercase.  For example, the
 // canonical key for "accept-encoding" is "Accept-Encoding".
 func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
+
+// hasToken returns whether token appears with v, ASCII
+// case-insensitive, with space or comma boundaries.
+// token must be all lowercase.
+// v may contain mixed cased.
+func hasToken(v, token string) bool {
+       if len(token) > len(v) || token == "" {
+               return false
+       }
+       if v == token {
+               return true
+       }
+       for sp := 0; sp <= len(v)-len(token); sp++ {
+               // Check that first character is good.
+               // The token is ASCII, so checking only a single byte
+               // is sufficient.  We skip this potential starting
+               // position if both the first byte and its potential
+               // ASCII uppercase equivalent (b|0x20) don't match.
+               // False positives ('^' => '~') are caught by EqualFold.
+               if b := v[sp]; b != token[0] && b|0x20 != token[0] {
+                       continue
+               }
+               // Check that start pos is on a valid token boundary.
+               if sp > 0 && !isTokenBoundary(v[sp-1]) {
+                       continue
+               }
+               // Check that end pos is on a valid token boundary.
+               if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
+                       continue
+               }
+               if strings.EqualFold(v[sp:sp+len(token)], token) {
+                       return true
+               }
+       }
+       return false
+}
+
+func isTokenBoundary(b byte) bool {
+       return b == ' ' || b == ',' || b == '\t'
+}
index 57cf0c9..165600e 100644 (file)
@@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
 // of ASN.1 time).
 var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
+MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
 DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
 qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
-8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
-DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
-CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
-+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
------END CERTIFICATE-----
-`)
+8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB
+/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC
+CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB
+nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw
+Pg==
+-----END CERTIFICATE-----`)
 
 // localhostKey is the private key for localhostCert.
 var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
index 892ef4e..0fb2eeb 100644 (file)
@@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
 
        t := &http.Transport{
                Dial: func(net, addr string) (net.Conn, error) {
-                       return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil
+                       return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
                },
        }
 
index b8874f3..f578725 100644 (file)
 // To use pprof, link this package into your program:
 //     import _ "net/http/pprof"
 //
+// If your application is not already running an http server, you
+// need to start one.  Add "net/http" and "log" to your imports and
+// the following code to your main function:
+//
+//     go func() {
+//             log.Println(http.ListenAndServe("localhost:6060", nil))
+//     }()
+//
 // Then use the pprof tool to look at the heap profile:
 //
 //     go tool pprof http://localhost:6060/debug/pprof/heap
index 5274a81..ef911af 100644 (file)
@@ -14,15 +14,34 @@ var ParseRangeTests = []struct {
        r      []httpRange
 }{
        {"", 0, nil},
+       {"", 1000, nil},
        {"foo", 0, nil},
        {"bytes=", 0, nil},
+       {"bytes=7", 10, nil},
+       {"bytes= 7 ", 10, nil},
+       {"bytes=1-", 0, nil},
        {"bytes=5-4", 10, nil},
        {"bytes=0-2,5-4", 10, nil},
+       {"bytes=2-5,4-3", 10, nil},
+       {"bytes=--5,4--3", 10, nil},
+       {"bytes=A-", 10, nil},
+       {"bytes=A- ", 10, nil},
+       {"bytes=A-Z", 10, nil},
+       {"bytes= -Z", 10, nil},
+       {"bytes=5-Z", 10, nil},
+       {"bytes=Ran-dom, garbage", 10, nil},
+       {"bytes=0x01-0x02", 10, nil},
+       {"bytes=         ", 10, nil},
+       {"bytes= , , ,   ", 10, nil},
+
        {"bytes=0-9", 10, []httpRange{{0, 10}}},
        {"bytes=0-", 10, []httpRange{{0, 10}}},
        {"bytes=5-", 10, []httpRange{{5, 5}}},
        {"bytes=0-20", 10, []httpRange{{0, 10}}},
        {"bytes=15-,0-5", 10, nil},
+       {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+       {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+       {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
        {"bytes=-5", 10, []httpRange{{5, 5}}},
        {"bytes=-15", 10, []httpRange{{0, 10}}},
        {"bytes=0-499", 10000, []httpRange{{0, 500}}},
@@ -32,6 +51,9 @@ var ParseRangeTests = []struct {
        {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
        {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
        {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+       // Match Apache laxity:
+       {"bytes=   1 -2   ,  4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
 }
 
 func TestParseRange(t *testing.T) {
index b6a6b4c..c9d7393 100644 (file)
@@ -386,17 +386,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
        }
 
        r := bufio.NewReader(conn)
-       _, err = ReadResponse(r, &Request{Method: "GET"})
+       res, err := ReadResponse(r, &Request{Method: "GET"})
        if err != nil {
                t.Fatal("ReadResponse error:", err)
        }
 
-       success := make(chan bool)
+       didReadAll := make(chan bool, 1)
        go func() {
                select {
                case <-time.After(5 * time.Second):
-                       t.Fatal("body not closed after 5s")
-               case <-success:
+                       t.Error("body not closed after 5s")
+                       return
+               case <-didReadAll:
                }
        }()
 
@@ -404,8 +405,11 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
        if err != nil {
                t.Fatal("read error:", err)
        }
+       didReadAll <- true
 
-       success <- true
+       if !res.Close {
+               t.Errorf("Response.Close = false; want true")
+       }
 }
 
 // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
@@ -1108,6 +1112,38 @@ func TestServerBufferedChunking(t *testing.T) {
        }
 }
 
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
+       defer ts.Close()
+
+       for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+               conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       t.Fatalf("error dialing: %v", err)
+               }
+               _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+               if err != nil {
+                       t.Fatalf("error writing: %v", err)
+               }
+               req, _ := NewRequest("GET", "/", nil)
+               res, err := ReadResponse(bufio.NewReader(conn), req)
+               if err != nil {
+                       t.Fatalf("error reading response: %v", err)
+               }
+               if te := res.TransferEncoding; len(te) > 0 {
+                       t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+               }
+               if cl := res.ContentLength; cl != 0 {
+                       t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+               }
+               conn.Close()
+       }
+}
+
 // goTimeout runs f, failing t if f takes more than ns to complete.
 func goTimeout(t *testing.T, d time.Duration, f func()) {
        ch := make(chan bool, 2)
index 0572b4a..b74b762 100644 (file)
@@ -390,6 +390,11 @@ func (w *response) WriteHeader(code int) {
        if !w.req.ProtoAtLeast(1, 0) {
                return
        }
+
+       if w.closeAfterReply && !hasToken(w.header.Get("Connection"), "close") {
+               w.header.Set("Connection", "close")
+       }
+
        proto := "HTTP/1.0"
        if w.req.ProtoAtLeast(1, 1) {
                proto = "HTTP/1.1"
@@ -508,8 +513,16 @@ func (w *response) Write(data []byte) (n int, err error) {
 }
 
 func (w *response) finishRequest() {
-       // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
-       // back, we can make this a keep-alive response ...
+       // If the handler never wrote any bytes and never sent a Content-Length
+       // response header, set the length explicitly to zero. This helps
+       // HTTP/1.0 clients keep their "keep-alive" connections alive, and for
+       // HTTP/1.1 clients is just as good as the alternative: sending a
+       // chunked response and immediately sending the zero-length EOF chunk.
+       if w.written == 0 && w.header.Get("Content-Length") == "" {
+               w.header.Set("Content-Length", "0")
+       }
+       // If this was an HTTP/1.0 request with keep-alive and we sent a
+       // Content-Length back, we can make this a keep-alive response ...
        if w.req.wantsHttp10KeepAlive() {
                sentLength := w.header.Get("Content-Length") != ""
                if sentLength && w.header.Get("Connection") == "keep-alive" {
@@ -817,13 +830,13 @@ func RedirectHandler(url string, code int) Handler {
 // patterns and calls the handler for the pattern that
 // most closely matches the URL.
 //
-// Patterns named fixed, rooted paths, like "/favicon.ico",
+// Patterns name fixed, rooted paths, like "/favicon.ico",
 // or rooted subtrees, like "/images/" (note the trailing slash).
 // Longer patterns take precedence over shorter ones, so that
 // if there are handlers registered for both "/images/"
 // and "/images/thumbnails/", the latter handler will be
 // called for paths beginning "/images/thumbnails/" and the
-// former will receiver requests for any other paths in the
+// former will receive requests for any other paths in the
 // "/images/" subtree.
 //
 // Patterns may optionally begin with a host name, restricting matches to
@@ -917,11 +930,13 @@ func (mux *ServeMux) handler(r *Request) Handler {
 // ServeHTTP dispatches the request to the handler whose
 // pattern most closely matches the request URL.
 func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
-       // Clean path to canonical form and redirect.
-       if p := cleanPath(r.URL.Path); p != r.URL.Path {
-               w.Header().Set("Location", p)
-               w.WriteHeader(StatusMovedPermanently)
-               return
+       if r.Method != "CONNECT" {
+               // Clean path to canonical form and redirect.
+               if p := cleanPath(r.URL.Path); p != r.URL.Path {
+                       w.Header().Set("Location", p)
+                       w.WriteHeader(StatusMovedPermanently)
+                       return
+               }
        }
        mux.handler(r).ServeHTTP(w, r)
 }
index 6efe191..6131d0d 100644 (file)
@@ -41,8 +41,9 @@ const DefaultMaxIdleConnsPerHost = 2
 // https, and http proxies (for either http or https with CONNECT).
 // Transport can also cache connections for future re-use.
 type Transport struct {
-       lk       sync.Mutex
+       idleLk   sync.Mutex
        idleConn map[string][]*persistConn
+       altLk    sync.RWMutex
        altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
 
        // TODO: tunable on global max cached connections
@@ -131,12 +132,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
                return nil, errors.New("http: nil Request.Header")
        }
        if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
-               t.lk.Lock()
+               t.altLk.RLock()
                var rt RoundTripper
                if t.altProto != nil {
                        rt = t.altProto[req.URL.Scheme]
                }
-               t.lk.Unlock()
+               t.altLk.RUnlock()
                if rt == nil {
                        return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
                }
@@ -170,8 +171,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
        if scheme == "http" || scheme == "https" {
                panic("protocol " + scheme + " already registered")
        }
-       t.lk.Lock()
-       defer t.lk.Unlock()
+       t.altLk.Lock()
+       defer t.altLk.Unlock()
        if t.altProto == nil {
                t.altProto = make(map[string]RoundTripper)
        }
@@ -186,17 +187,18 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
 // a "keep-alive" state. It does not interrupt any connections currently
 // in use.
 func (t *Transport) CloseIdleConnections() {
-       t.lk.Lock()
-       defer t.lk.Unlock()
-       if t.idleConn == nil {
+       t.idleLk.Lock()
+       m := t.idleConn
+       t.idleConn = nil
+       t.idleLk.Unlock()
+       if m == nil {
                return
        }
-       for _, conns := range t.idleConn {
+       for _, conns := range m {
                for _, pconn := range conns {
                        pconn.close()
                }
        }
-       t.idleConn = make(map[string][]*persistConn)
 }
 
 //
@@ -242,8 +244,6 @@ func (cm *connectMethod) proxyAuth() string {
 // If pconn is no longer needed or not in a good state, putIdleConn
 // returns false.
 func (t *Transport) putIdleConn(pconn *persistConn) bool {
-       t.lk.Lock()
-       defer t.lk.Unlock()
        if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
                pconn.close()
                return false
@@ -256,21 +256,27 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
        if max == 0 {
                max = DefaultMaxIdleConnsPerHost
        }
+       t.idleLk.Lock()
+       if t.idleConn == nil {
+               t.idleConn = make(map[string][]*persistConn)
+       }
        if len(t.idleConn[key]) >= max {
+               t.idleLk.Unlock()
                pconn.close()
                return false
        }
        t.idleConn[key] = append(t.idleConn[key], pconn)
+       t.idleLk.Unlock()
        return true
 }
 
 func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
-       t.lk.Lock()
-       defer t.lk.Unlock()
+       key := cm.String()
+       t.idleLk.Lock()
+       defer t.idleLk.Unlock()
        if t.idleConn == nil {
-               t.idleConn = make(map[string][]*persistConn)
+               return nil
        }
-       key := cm.String()
        for {
                pconns, ok := t.idleConn[key]
                if !ok {
@@ -365,7 +371,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
 
        if cm.targetScheme == "https" {
                // Initiate TLS and check remote host name against certificate.
-               conn = tls.Client(conn, t.TLSClientConfig)
+               cfg := t.TLSClientConfig
+               if cfg == nil || cfg.ServerName == "" {
+                       host, _, _ := net.SplitHostPort(cm.addr())
+                       if cfg == nil {
+                               cfg = &tls.Config{ServerName: host}
+                       } else {
+                               clone := *cfg // shallow clone
+                               clone.ServerName = host
+                               cfg = &clone
+                       }
+               }
+               conn = tls.Client(conn, cfg)
                if err = conn.(*tls.Conn).Handshake(); err != nil {
                        return nil, err
                }
@@ -484,6 +501,7 @@ type persistConn struct {
        t        *Transport
        cacheKey string // its connectMethod.String()
        conn     net.Conn
+       closed   bool                // whether conn has been closed
        br       *bufio.Reader       // from conn
        bw       *bufio.Writer       // to conn
        reqch    chan requestAndChan // written by roundTrip(); read by readLoop()
@@ -501,8 +519,9 @@ type persistConn struct {
 
 func (pc *persistConn) isBroken() bool {
        pc.lk.Lock()
-       defer pc.lk.Unlock()
-       return pc.broken
+       b := pc.broken
+       pc.lk.Unlock()
+       return b
 }
 
 var remoteSideClosedFunc func(error) bool // or nil to use default
@@ -571,29 +590,32 @@ func (pc *persistConn) readLoop() {
 
                hasBody := resp != nil && resp.ContentLength != 0
                var waitForBodyRead chan bool
-               if alive {
-                       if hasBody {
-                               lastbody = resp.Body
-                               waitForBodyRead = make(chan bool)
-                               resp.Body.(*bodyEOFSignal).fn = func() {
-                                       if !pc.t.putIdleConn(pc) {
-                                               alive = false
-                                       }
-                                       waitForBodyRead <- true
-                               }
-                       } else {
-                               // When there's no response body, we immediately
-                               // reuse the TCP connection (putIdleConn), but
-                               // we need to prevent ClientConn.Read from
-                               // closing the Response.Body on the next
-                               // loop, otherwise it might close the body
-                               // before the client code has had a chance to
-                               // read it (even though it'll just be 0, EOF).
-                               lastbody = nil
-
-                               if !pc.t.putIdleConn(pc) {
+               if hasBody {
+                       lastbody = resp.Body
+                       waitForBodyRead = make(chan bool)
+                       resp.Body.(*bodyEOFSignal).fn = func() {
+                               if alive && !pc.t.putIdleConn(pc) {
                                        alive = false
                                }
+                               if !alive {
+                                       pc.close()
+                               }
+                               waitForBodyRead <- true
+                       }
+               }
+
+               if alive && !hasBody {
+                       // When there's no response body, we immediately
+                       // reuse the TCP connection (putIdleConn), but
+                       // we need to prevent ClientConn.Read from
+                       // closing the Response.Body on the next
+                       // loop, otherwise it might close the body
+                       // before the client code has had a chance to
+                       // read it (even though it'll just be 0, EOF).
+                       lastbody = nil
+
+                       if !pc.t.putIdleConn(pc) {
+                               alive = false
                        }
                }
 
@@ -604,6 +626,10 @@ func (pc *persistConn) readLoop() {
                if waitForBodyRead != nil {
                        <-waitForBodyRead
                }
+
+               if !alive {
+                       pc.close()
+               }
        }
 }
 
@@ -669,7 +695,10 @@ func (pc *persistConn) close() {
 
 func (pc *persistConn) closeLocked() {
        pc.broken = true
-       pc.conn.Close()
+       if !pc.closed {
+               pc.conn.Close()
+               pc.closed = true
+       }
        pc.mutateHeaderFunc = nil
 }
 
index a9e401d..e676bf6 100644 (file)
@@ -13,6 +13,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "net"
        . "net/http"
        "net/http/httptest"
        "net/url"
@@ -20,6 +21,7 @@ import (
        "runtime"
        "strconv"
        "strings"
+       "sync"
        "testing"
        "time"
 )
@@ -35,6 +37,68 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
        w.Write([]byte(r.RemoteAddr))
 })
 
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+       net.Conn
+       set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+       c.set.remove(c)
+       return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+       t      *testing.T
+       closed map[net.Conn]bool
+       list   []net.Conn // in order created
+       mutex  sync.Mutex
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+       tcs.mutex.Lock()
+       defer tcs.mutex.Unlock()
+       tcs.closed[c] = false
+       tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+       tcs.mutex.Lock()
+       defer tcs.mutex.Unlock()
+       tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+       connSet := &testConnSet{
+               t:      t,
+               closed: make(map[net.Conn]bool),
+       }
+       dial := func(n, addr string) (net.Conn, error) {
+               c, err := net.Dial(n, addr)
+               if err != nil {
+                       return nil, err
+               }
+               tc := &testCloseConn{c, connSet}
+               connSet.insert(tc)
+               return tc, nil
+       }
+       return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+       tcs.mutex.Lock()
+       defer tcs.mutex.Unlock()
+
+       for i, c := range tcs.list {
+               if !tcs.closed[c] {
+                       t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+               }
+       }
+}
+
 // Two subsequent requests and verify their response is the same.
 // The response from the server is our own IP:port
 func TestTransportKeepAlives(t *testing.T) {
@@ -72,8 +136,12 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
 
+       connSet, testDial := makeTestDial(t)
+
        for _, connectionClose := range []bool{false, true} {
-               tr := &Transport{}
+               tr := &Transport{
+                       Dial: testDial,
+               }
                c := &Client{Transport: tr}
 
                fetch := func(n int) string {
@@ -92,8 +160,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
                        if err != nil {
                                t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
                        }
-                       body, err := ioutil.ReadAll(res.Body)
                        defer res.Body.Close()
+                       body, err := ioutil.ReadAll(res.Body)
                        if err != nil {
                                t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
                        }
@@ -107,15 +175,23 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
                        t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
                                connectionClose, bodiesDiffer, body1, body2)
                }
+
+               tr.CloseIdleConnections()
        }
+
+       connSet.check(t)
 }
 
 func TestTransportConnectionCloseOnRequest(t *testing.T) {
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
 
+       connSet, testDial := makeTestDial(t)
+
        for _, connectionClose := range []bool{false, true} {
-               tr := &Transport{}
+               tr := &Transport{
+                       Dial: testDial,
+               }
                c := &Client{Transport: tr}
 
                fetch := func(n int) string {
@@ -149,7 +225,11 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
                        t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
                                connectionClose, bodiesDiffer, body1, body2)
                }
+
+               tr.CloseIdleConnections()
        }
+
+       connSet.check(t)
 }
 
 func TestTransportIdleCacheKeys(t *testing.T) {
@@ -724,6 +804,35 @@ func TestTransportIdleConnCrash(t *testing.T) {
        <-didreq
 }
 
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read.  This was a regression
+// which sadly lacked a triggering test.  The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+       const numFoos = 5000
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("Connection", "close")
+               for i := 0; i < numFoos; i++ {
+                       w.Write([]byte("foo "))
+               }
+       }))
+       defer ts.Close()
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+       res, err := c.Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer res.Body.Close()
+       bs, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if len(bs) != numFoos*len("foo ") {
+               t.Errorf("unexpected response length")
+       }
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {
index b23213e..ae21b3c 100644 (file)
@@ -6,7 +6,7 @@
 
 package net
 
-// IPAddr represents the address of a IP end point.
+// IPAddr represents the address of an IP end point.
 type IPAddr struct {
        IP IP
 }
@@ -21,7 +21,7 @@ func (a *IPAddr) String() string {
        return a.IP.String()
 }
 
-// ResolveIPAddr parses addr as a IP address and resolves domain
+// ResolveIPAddr parses addr as an IP address and resolves domain
 // names to numeric addresses on the network net, which must be
 // "ip", "ip4" or "ip6".  A literal IPv6 host address must be
 // enclosed in square brackets, as in "[::]".
index 43719fc..ea3321b 100644 (file)
@@ -59,7 +59,7 @@ func (c *IPConn) RemoteAddr() Addr {
 
 // IP-specific methods.
 
-// ReadFromIP reads a IP packet from c, copying the payload into b.
+// ReadFromIP reads an IP packet from c, copying the payload into b.
 // It returns the number of bytes copied into b and the return address
 // that was on the packet.
 //
@@ -75,7 +75,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
        return 0, nil, syscall.EPLAN9
 }
 
-// WriteToIP writes a IP packet to addr via c, copying the payload from b.
+// WriteToIP writes an IP packet to addr via c, copying the payload from b.
 //
 // WriteToIP can be made to time out and return
 // an error with Timeout() == true after a fixed time limit;
index 9fc7ecd..dda81dd 100644 (file)
@@ -146,7 +146,7 @@ func (c *IPConn) SetWriteBuffer(bytes int) error {
 
 // IP-specific methods.
 
-// ReadFromIP reads a IP packet from c, copying the payload into b.
+// ReadFromIP reads an IP packet from c, copying the payload into b.
 // It returns the number of bytes copied into b and the return address
 // that was on the packet.
 //
@@ -184,7 +184,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
        return n, uaddr.toAddr(), err
 }
 
-// WriteToIP writes a IP packet to addr via c, copying the payload from b.
+// WriteToIP writes an IP packet to addr via c, copying the payload from b.
 //
 // WriteToIP can be made to time out and return
 // an error with Timeout() == true after a fixed time limit;
index b610ccf..93cc4d1 100644 (file)
@@ -47,7 +47,8 @@ type Message struct {
 }
 
 // ReadMessage reads a message from r.
-// The headers are parsed, and the body of the message will be reading from r.
+// The headers are parsed, and the body of the message will be available
+// for reading from r.
 func ReadMessage(r io.Reader) (msg *Message, err error) {
        tp := textproto.NewReader(bufio.NewReader(r))
 
diff --git a/libgo/go/net/net_posix.go b/libgo/go/net/net_posix.go
new file mode 100644 (file)
index 0000000..3bcc54f
--- /dev/null
@@ -0,0 +1,110 @@
+// Copyright 2012 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+// Base posix socket functions.
+
+package net
+
+import (
+       "os"
+       "syscall"
+       "time"
+)
+
+type conn struct {
+       fd *netFD
+}
+
+func (c *conn) ok() bool { return c != nil && c.fd != nil }
+
+// Implementation of the Conn interface - see Conn for documentation.
+
+// Read implements the Conn Read method.
+func (c *conn) Read(b []byte) (int, error) {
+       if !c.ok() {
+               return 0, syscall.EINVAL
+       }
+       return c.fd.Read(b)
+}
+
+// Write implements the Conn Write method.
+func (c *conn) Write(b []byte) (int, error) {
+       if !c.ok() {
+               return 0, syscall.EINVAL
+       }
+       return c.fd.Write(b)
+}
+
+// LocalAddr returns the local network address.
+func (c *conn) LocalAddr() Addr {
+       if !c.ok() {
+               return nil
+       }
+       return c.fd.laddr
+}
+
+// RemoteAddr returns the remote network address.
+func (c *conn) RemoteAddr() Addr {
+       if !c.ok() {
+               return nil
+       }
+       return c.fd.raddr
+}
+
+// SetDeadline implements the Conn SetDeadline method.
+func (c *conn) SetDeadline(t time.Time) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return setDeadline(c.fd, t)
+}
+
+// SetReadDeadline implements the Conn SetReadDeadline method.
+func (c *conn) SetReadDeadline(t time.Time) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return setReadDeadline(c.fd, t)
+}
+
+// SetWriteDeadline implements the Conn SetWriteDeadline method.
+func (c *conn) SetWriteDeadline(t time.Time) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return setWriteDeadline(c.fd, t)
+}
+
+// SetReadBuffer sets the size of the operating system's
+// receive buffer associated with the connection.
+func (c *conn) SetReadBuffer(bytes int) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return setReadBuffer(c.fd, bytes)
+}
+
+// SetWriteBuffer sets the size of the operating system's
+// transmit buffer associated with the connection.
+func (c *conn) SetWriteBuffer(bytes int) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return setWriteBuffer(c.fd, bytes)
+}
+
+// File returns a copy of the underlying os.File, set to blocking mode.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func (c *conn) File() (f *os.File, err error) { return c.fd.dup() }
+
+// Close closes the connection.
+func (c *conn) Close() error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+       return c.fd.Close()
+}
index e6c7441..adc29d5 100644 (file)
@@ -108,7 +108,7 @@ func TestClient(t *testing.T) {
                t.Errorf("Add: expected no error but got string %q", err.Error())
        }
        if reply.C != args.A+args.B {
-               t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+               t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
        }
 
        args = &Args{7, 8}
@@ -118,7 +118,7 @@ func TestClient(t *testing.T) {
                t.Errorf("Mul: expected no error but got string %q", err.Error())
        }
        if reply.C != args.A*args.B {
-               t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+               t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
        }
 
        // Out of order.
@@ -133,7 +133,7 @@ func TestClient(t *testing.T) {
                t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
        }
        if addReply.C != args.A+args.B {
-               t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+               t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
        }
 
        mulCall = <-mulCall.Done
@@ -141,7 +141,7 @@ func TestClient(t *testing.T) {
                t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
        }
        if mulReply.C != args.A*args.B {
-               t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+               t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
        }
 
        // Error test
index 1680e2f..e528220 100644 (file)
 
        where T, T1 and T2 can be marshaled by encoding/gob.
        These requirements apply even if a different codec is used.
-       (In future, these requirements may soften for custom codecs.)
+       (In the future, these requirements may soften for custom codecs.)
 
        The method's first argument represents the arguments provided by the caller; the
        second argument represents the result parameters to be returned to the caller.
        The method's return value, if non-nil, is passed back as a string that the client
-       sees as if created by errors.New.
+       sees as if created by errors.New.  If an error is returned, the reply parameter
+       will not be sent back to the client.
 
        The server may handle requests on a single connection by calling ServeConn.  More
        typically it will create a network listener and call Accept or, for an HTTP
@@ -181,7 +182,7 @@ type Response struct {
 
 // Server represents an RPC Server.
 type Server struct {
-       mu         sync.Mutex // protects the serviceMap
+       mu         sync.RWMutex // protects the serviceMap
        serviceMap map[string]*service
        reqLock    sync.Mutex // protects freeReq
        freeReq    *Request
@@ -538,9 +539,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
                return
        }
        // Look up the request.
-       server.mu.Lock()
+       server.mu.RLock()
        service = server.serviceMap[serviceMethod[0]]
-       server.mu.Unlock()
+       server.mu.RUnlock()
        if service == nil {
                err = errors.New("rpc: can't find service " + req.ServiceMethod)
                return
index 0cd1926..b139c42 100644 (file)
@@ -144,22 +144,6 @@ func setDeadline(fd *netFD, t time.Time) error {
        return setWriteDeadline(fd, t)
 }
 
-func setReuseAddr(fd *netFD, reuse bool) error {
-       if err := fd.incref(false); err != nil {
-               return err
-       }
-       defer fd.decref()
-       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)))
-}
-
-func setDontRoute(fd *netFD, dontroute bool) error {
-       if err := fd.incref(false); err != nil {
-               return err
-       }
-       defer fd.decref()
-       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)))
-}
-
 func setKeepAlive(fd *netFD, keepalive bool) error {
        if err := fd.incref(false); err != nil {
                return err
index 3c9dfb0..85260c8 100644 (file)
@@ -5,21 +5,36 @@
 package os
 
 func isExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return contains(err.Error(), " exists")
 }
 
 func isNotExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return contains(err.Error(), "does not exist")
 }
 
 func isPermission(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return contains(err.Error(), "permission denied")
index 1685c1f..81b626a 100644 (file)
@@ -9,21 +9,36 @@ package os
 import "syscall"
 
 func isExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.EEXIST || err == ErrExist
 }
 
 func isNotExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.ENOENT || err == ErrNotExist
 }
 
 func isPermission(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.EACCES || err == syscall.EPERM || err == ErrPermission
index 42f846f..054bb3f 100644 (file)
@@ -79,3 +79,30 @@ func checkErrorPredicate(predName string, pred func(error) bool, err error) stri
        }
        return ""
 }
+
+var isExistTests = []struct {
+       err   error
+       is    bool
+       isnot bool
+}{
+       {&os.PathError{Err: os.ErrInvalid}, false, false},
+       {&os.PathError{Err: os.ErrPermission}, false, false},
+       {&os.PathError{Err: os.ErrExist}, true, false},
+       {&os.PathError{Err: os.ErrNotExist}, false, true},
+       {&os.LinkError{Err: os.ErrInvalid}, false, false},
+       {&os.LinkError{Err: os.ErrPermission}, false, false},
+       {&os.LinkError{Err: os.ErrExist}, true, false},
+       {&os.LinkError{Err: os.ErrNotExist}, false, true},
+       {nil, false, false},
+}
+
+func TestIsExist(t *testing.T) {
+       for _, tt := range isExistTests {
+               if is := os.IsExist(tt.err); is != tt.is {
+                       t.Errorf("os.IsExist(%T %v) = %v, want %v", tt.err, tt.err, is, tt.is)
+               }
+               if isnot := os.IsNotExist(tt.err); isnot != tt.isnot {
+                       t.Errorf("os.IsNotExist(%T %v) = %v, want %v", tt.err, tt.err, isnot, tt.isnot)
+               }
+       }
+}
index fbb0d4f..83db6c0 100644 (file)
@@ -7,7 +7,12 @@ package os
 import "syscall"
 
 func isExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.ERROR_ALREADY_EXISTS ||
@@ -15,7 +20,12 @@ func isExist(err error) bool {
 }
 
 func isNotExist(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.ERROR_FILE_NOT_FOUND ||
@@ -23,7 +33,12 @@ func isNotExist(err error) bool {
 }
 
 func isPermission(err error) bool {
-       if pe, ok := err.(*PathError); ok {
+       switch pe := err.(type) {
+       case nil:
+               return false
+       case *PathError:
+               err = pe.Err
+       case *LinkError:
                err = pe.Err
        }
        return err == syscall.ERROR_ACCESS_DENIED || err == ErrPermission
index 531b87c..6681acf 100644 (file)
@@ -6,6 +6,7 @@ package os
 
 import (
        "runtime"
+       "sync/atomic"
        "syscall"
 )
 
@@ -13,7 +14,7 @@ import (
 type Process struct {
        Pid    int
        handle uintptr
-       done   bool // process has been successfully waited on
+       isdone uint32 // process has been successfully waited on, non zero if true
 }
 
 func newProcess(pid int, handle uintptr) *Process {
@@ -22,6 +23,14 @@ func newProcess(pid int, handle uintptr) *Process {
        return p
 }
 
+func (p *Process) setDone() {
+       atomic.StoreUint32(&p.isdone, 1)
+}
+
+func (p *Process) done() bool {
+       return atomic.LoadUint32(&p.isdone) > 0
+}
+
 // ProcAttr holds the attributes that will be applied to a new process
 // started by StartProcess.
 type ProcAttr struct {
index 9a8e181..c4907cd 100644 (file)
@@ -16,7 +16,7 @@ import (
        "syscall"
 )
 
-// Error records the name of a binary that failed to be be executed
+// Error records the name of a binary that failed to be executed
 // and the reason it failed.
 type Error struct {
        Name string
@@ -143,6 +143,9 @@ func (c *Cmd) argv() []string {
 func (c *Cmd) stdin() (f *os.File, err error) {
        if c.Stdin == nil {
                f, err = os.Open(os.DevNull)
+               if err != nil {
+                       return
+               }
                c.closeAfterStart = append(c.closeAfterStart, f)
                return
        }
@@ -182,6 +185,9 @@ func (c *Cmd) stderr() (f *os.File, err error) {
 func (c *Cmd) writerDescriptor(w io.Writer) (f *os.File, err error) {
        if w == nil {
                f, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0)
+               if err != nil {
+                       return
+               }
                c.closeAfterStart = append(c.closeAfterStart, f)
                return
        }
index 52f4bce..27ebb60 100644 (file)
@@ -167,6 +167,18 @@ func TestExtraFiles(t *testing.T) {
        }
        defer ln.Close()
 
+       // Make sure duplicated fds don't leak to the child.
+       f, err := ln.(*net.TCPListener).File()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer f.Close()
+       ln2, err := net.FileListener(f)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln2.Close()
+
        // Force TLS root certs to be loaded (which might involve
        // cgo), to make sure none of that potential C code leaks fds.
        ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -203,6 +215,56 @@ func TestExtraFiles(t *testing.T) {
        }
 }
 
+func TestExtraFilesRace(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               t.Logf("no operating system support; skipping")
+               return
+       }
+       listen := func() net.Listener {
+               ln, err := net.Listen("tcp", "127.0.0.1:0")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               return ln
+       }
+       listenerFile := func(ln net.Listener) *os.File {
+               f, err := ln.(*net.TCPListener).File()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               return f
+       }
+       runCommand := func(c *Cmd, out chan<- string) {
+               bout, err := c.CombinedOutput()
+               if err != nil {
+                       out <- "ERROR:" + err.Error()
+               } else {
+                       out <- string(bout)
+               }
+       }
+
+       for i := 0; i < 10; i++ {
+               la := listen()
+               ca := helperCommand("describefiles")
+               ca.ExtraFiles = []*os.File{listenerFile(la)}
+               lb := listen()
+               cb := helperCommand("describefiles")
+               cb.ExtraFiles = []*os.File{listenerFile(lb)}
+               ares := make(chan string)
+               bres := make(chan string)
+               go runCommand(ca, ares)
+               go runCommand(cb, bres)
+               if got, want := <-ares, fmt.Sprintf("fd3: listener %s\n", la.Addr()); got != want {
+                       t.Errorf("iteration %d, process A got:\n%s\nwant:\n%s\n", i, got, want)
+               }
+               if got, want := <-bres, fmt.Sprintf("fd3: listener %s\n", lb.Addr()); got != want {
+                       t.Errorf("iteration %d, process B got:\n%s\nwant:\n%s\n", i, got, want)
+               }
+               la.Close()
+               lb.Close()
+       }
+}
+
 // TestHelperProcess isn't a real test. It's used as a helper process
 // for TestParameterRun.
 func TestHelperProcess(*testing.T) {
@@ -318,6 +380,16 @@ func TestHelperProcess(*testing.T) {
        case "exit":
                n, _ := strconv.Atoi(args[0])
                os.Exit(n)
+       case "describefiles":
+               for fd := uintptr(3); fd < 25; fd++ {
+                       f := os.NewFile(fd, fmt.Sprintf("fd-%d", fd))
+                       ln, err := net.FileListener(f)
+                       if err == nil {
+                               fmt.Printf("fd%d: listener %s\n", fd, ln.Addr())
+                               ln.Close()
+                       }
+               }
+               os.Exit(0)
        default:
                fmt.Fprintf(os.Stderr, "Unknown command %q\n", cmd)
                os.Exit(2)
index 41cc8c2..01f06e2 100644 (file)
@@ -38,7 +38,7 @@ func (note Plan9Note) String() string {
 }
 
 func (p *Process) signal(sig Signal) error {
-       if p.done {
+       if p.done() {
                return errors.New("os: process already finished")
        }
 
@@ -76,7 +76,7 @@ func (p *Process) wait() (ps *ProcessState, err error) {
                }
 
                if waitmsg.Pid == p.Pid {
-                       p.done = true
+                       p.setDone()
                        break
                }
        }
index 70351cf..40fd0fd 100644 (file)
@@ -11,9 +11,10 @@ import (
 )
 
 func startProcess(name string, argv []string, attr *ProcAttr) (p *Process, err error) {
-       // Double-check existence of the directory we want
+       // If there is no SysProcAttr (ie. no Chroot or changed
+       // UID/GID), double-check existence of the directory we want
        // to chdir into.  We can make the error clearer this way.
-       if attr != nil && attr.Dir != "" {
+       if attr != nil && attr.Sys == nil && attr.Dir != "" {
                if _, err := Stat(attr.Dir); err != nil {
                        pe := err.(*PathError)
                        pe.Op = "chdir"
index ecfe535..fa3ba8a 100644 (file)
@@ -24,7 +24,7 @@ func (p *Process) wait() (ps *ProcessState, err error) {
                return nil, NewSyscallError("wait", e)
        }
        if pid1 != 0 {
-               p.done = true
+               p.setDone()
        }
        ps = &ProcessState{
                pid:    pid1,
@@ -35,7 +35,7 @@ func (p *Process) wait() (ps *ProcessState, err error) {
 }
 
 func (p *Process) signal(sig Signal) error {
-       if p.done {
+       if p.done() {
                return errors.New("os: process already finished")
        }
        s, ok := sig.(syscall.Signal)
index 5beca4a..4aa2ade 100644 (file)
@@ -32,7 +32,7 @@ func (p *Process) wait() (ps *ProcessState, err error) {
        if e != nil {
                return nil, NewSyscallError("GetProcessTimes", e)
        }
-       p.done = true
+       p.setDone()
        // NOTE(brainman): It seems that sometimes process is not dead
        // when WaitForSingleObject returns. But we do not know any
        // other way to wait for it. Sleeping for a while seems to do
@@ -43,7 +43,7 @@ func (p *Process) wait() (ps *ProcessState, err error) {
 }
 
 func (p *Process) signal(sig Signal) error {
-       if p.done {
+       if p.done() {
                return errors.New("os: process already finished")
        }
        if sig == Kill {
index 073bd56..1ba3293 100644 (file)
@@ -13,17 +13,6 @@ import (
 
 func sigpipe() // implemented in package runtime
 
-func epipecheck(file *File, e error) {
-       if e == syscall.EPIPE {
-               file.nepipe++
-               if file.nepipe >= 10 {
-                       sigpipe()
-               }
-       } else {
-               file.nepipe = 0
-       }
-}
-
 // Link creates newname as a hard link to the oldname file.
 // If there is an error, it will be of type *LinkError.
 func Link(oldname, newname string) error {
index b8fb2e2..f677dbb 100644 (file)
@@ -8,6 +8,7 @@ package os
 
 import (
        "runtime"
+       "sync/atomic"
        "syscall"
 )
 
@@ -24,7 +25,7 @@ type file struct {
        fd      int
        name    string
        dirinfo *dirInfo // nil unless directory being read
-       nepipe  int      // number of consecutive EPIPE in Write
+       nepipe  int32    // number of consecutive EPIPE in Write
 }
 
 // Fd returns the integer Unix file descriptor referencing the open file.
@@ -52,6 +53,16 @@ type dirInfo struct {
        dir *syscall.DIR // from opendir
 }
 
+func epipecheck(file *File, e error) {
+       if e == syscall.EPIPE {
+               if atomic.AddInt32(&file.nepipe, 1) >= 10 {
+                       sigpipe()
+               }
+       } else {
+               atomic.StoreInt32(&file.nepipe, 0)
+       }
+}
+
 // DevNull is the name of the operating system's ``null device.''
 // On Unix-like systems, it is "/dev/null"; on Windows, "NUL".
 const DevNull = "/dev/null"
index 8d3f677..5046e60 100644 (file)
@@ -67,10 +67,10 @@ var sysdir = func() (sd *sysDir) {
 
 func size(name string, t *testing.T) int64 {
        file, err := Open(name)
-       defer file.Close()
        if err != nil {
                t.Fatal("open failed:", err)
        }
+       defer file.Close()
        var buf [100]byte
        len := 0
        for {
@@ -132,10 +132,10 @@ func TestStat(t *testing.T) {
 func TestFstat(t *testing.T) {
        path := sfdir + "/" + sfname
        file, err1 := Open(path)
-       defer file.Close()
        if err1 != nil {
                t.Fatal("open failed:", err1)
        }
+       defer file.Close()
        dir, err2 := file.Stat()
        if err2 != nil {
                t.Fatal("fstat failed:", err2)
@@ -187,10 +187,10 @@ func TestRead0(t *testing.T) {
 
 func testReaddirnames(dir string, contents []string, t *testing.T) {
        file, err := Open(dir)
-       defer file.Close()
        if err != nil {
                t.Fatalf("open %q failed: %v", dir, err)
        }
+       defer file.Close()
        s, err2 := file.Readdirnames(-1)
        if err2 != nil {
                t.Fatalf("readdirnames %q failed: %v", dir, err2)
@@ -216,10 +216,10 @@ func testReaddirnames(dir string, contents []string, t *testing.T) {
 
 func testReaddir(dir string, contents []string, t *testing.T) {
        file, err := Open(dir)
-       defer file.Close()
        if err != nil {
                t.Fatalf("open %q failed: %v", dir, err)
        }
+       defer file.Close()
        s, err2 := file.Readdir(-1)
        if err2 != nil {
                t.Fatalf("readdir %q failed: %v", dir, err2)
@@ -283,10 +283,10 @@ func TestReaddirnamesOneAtATime(t *testing.T) {
                dir = "/bin"
        }
        file, err := Open(dir)
-       defer file.Close()
        if err != nil {
                t.Fatalf("open %q failed: %v", dir, err)
        }
+       defer file.Close()
        all, err1 := file.Readdirnames(-1)
        if err1 != nil {
                t.Fatalf("readdirnames %q failed: %v", dir, err1)
index 0c95c9c..ecb5787 100644 (file)
@@ -12,7 +12,7 @@ import (
 // Getpagesize returns the underlying system's memory page size.
 func Getpagesize() int { return syscall.Getpagesize() }
 
-// A FileInfo describes a file and is returned by Stat and Lstat
+// A FileInfo describes a file and is returned by Stat and Lstat.
 type FileInfo interface {
        Name() string       // base name of the file
        Size() int64        // length in bytes for regular files; system-dependent for others
index a7e0415..b07534b 100644 (file)
@@ -166,7 +166,8 @@ func IsAbs(path string) bool {
 }
 
 // Dir returns all but the last element of path, typically the path's directory.
-// The path is Cleaned and trailing slashes are removed before processing.
+// After dropping the final element using Split, the path is Cleaned and trailing
+// slashes are removed.
 // If the path is empty, Dir returns ".".
 // If the path consists entirely of slashes followed by non-slash bytes, Dir
 // returns a single slash. In any other case, the returned path does not end in a
index 77f0804..65be550 100644 (file)
@@ -181,6 +181,7 @@ var dirtests = []PathTest{
        {"x/", "x"},
        {"abc", "."},
        {"abc/def", "abc"},
+       {"abc////def", "abc"},
        {"a/b/.x", "a/b"},
        {"a/b/c.", "a/b"},
        {"a/b/c.x", "a/b"},
index e946c0a..56ba8a8 100644 (file)
@@ -1384,7 +1384,30 @@ func TestImportPath(t *testing.T) {
                path string
        }{
                {TypeOf(&base64.Encoding{}).Elem(), "encoding/base64"},
+               {TypeOf(int(0)), ""},
+               {TypeOf(int8(0)), ""},
+               {TypeOf(int16(0)), ""},
+               {TypeOf(int32(0)), ""},
+               {TypeOf(int64(0)), ""},
                {TypeOf(uint(0)), ""},
+               {TypeOf(uint8(0)), ""},
+               {TypeOf(uint16(0)), ""},
+               {TypeOf(uint32(0)), ""},
+               {TypeOf(uint64(0)), ""},
+               {TypeOf(uintptr(0)), ""},
+               {TypeOf(float32(0)), ""},
+               {TypeOf(float64(0)), ""},
+               {TypeOf(complex64(0)), ""},
+               {TypeOf(complex128(0)), ""},
+               {TypeOf(byte(0)), ""},
+               {TypeOf(rune(0)), ""},
+               {TypeOf([]byte(nil)), ""},
+               {TypeOf([]rune(nil)), ""},
+               {TypeOf(string("")), ""},
+               {TypeOf((*interface{})(nil)).Elem(), ""},
+               {TypeOf((*byte)(nil)), ""},
+               {TypeOf((*rune)(nil)), ""},
+               {TypeOf((*int64)(nil)), ""},
                {TypeOf(map[string]int{}), ""},
                {TypeOf((*error)(nil)).Elem(), ""},
        }
index a12fcb2..b25e5c7 100644 (file)
@@ -1705,10 +1705,11 @@ func ValueOf(i interface{}) Value {
        return Value{typ, unsafe.Pointer(eface.word), fl}
 }
 
-// Zero returns a Value representing a zero value for the specified type.
+// Zero returns a Value representing the zero value for the specified type.
 // The result is different from the zero value of the Value struct,
 // which represents no value at all.
 // For example, Zero(TypeOf(42)) returns a Value with Kind Int and value 0.
+// The returned value is neither addressable nor settable.
 func Zero(typ Type) Value {
        if typ == nil {
                panic("reflect: Zero(nil)")
index 87e6b1c..e4896a1 100644 (file)
@@ -441,7 +441,7 @@ func (re *Regexp) ReplaceAllLiteralString(src, repl string) string {
 }
 
 // ReplaceAllStringFunc returns a copy of src in which all matches of the
-// Regexp have been replaced by the return value of of function repl applied
+// Regexp have been replaced by the return value of function repl applied
 // to the matched substring.  The replacement returned by repl is substituted
 // directly, without using Expand.
 func (re *Regexp) ReplaceAllStringFunc(src string, repl func(string) string) string {
@@ -539,7 +539,7 @@ func (re *Regexp) ReplaceAllLiteral(src, repl []byte) []byte {
 }
 
 // ReplaceAllFunc returns a copy of src in which all matches of the
-// Regexp have been replaced by the return value of of function repl applied
+// Regexp have been replaced by the return value of function repl applied
 // to the matched byte slice.  The replacement returned by repl is substituted
 // directly, without using Expand.
 func (re *Regexp) ReplaceAllFunc(src []byte, repl func([]byte) []byte) []byte {
@@ -686,8 +686,9 @@ func (re *Regexp) FindStringIndex(s string) (loc []int) {
 
 // FindReaderIndex returns a two-element slice of integers defining the
 // location of the leftmost match of the regular expression in text read from
-// the RuneReader.  The match itself is at s[loc[0]:loc[1]].  A return
-// value of nil indicates no match.
+// the RuneReader.  The match text was found in the input stream at
+// byte offset loc[0] through loc[1]-1.
+// A return value of nil indicates no match.
 func (re *Regexp) FindReaderIndex(r io.RuneReader) (loc []int) {
        a := re.doExecute(r, nil, "", 0, 2)
        if a == nil {
index 87f17d2..592c4a2 100644 (file)
@@ -357,7 +357,7 @@ func countHeap() int {
        return n
 }
 
-// writeHeapProfile writes the current runtime heap profile to w.
+// writeHeap writes the current runtime heap profile to w.
 func writeHeap(w io.Writer, debug int) error {
        // Find out how many records there are (MemProfile(nil, false)),
        // allocate that many records, and get the data.
index e933058..4740115 100644 (file)
@@ -6,6 +6,7 @@ package pprof_test
 
 import (
        "bytes"
+       "fmt"
        "hash/crc32"
        "os/exec"
        "runtime"
@@ -49,19 +50,27 @@ func TestCPUProfile(t *testing.T) {
 
        // Convert []byte to []uintptr.
        bytes := prof.Bytes()
+       l := len(bytes) / int(unsafe.Sizeof(uintptr(0)))
        val := *(*[]uintptr)(unsafe.Pointer(&bytes))
-       val = val[:len(bytes)/int(unsafe.Sizeof(uintptr(0)))]
+       val = val[:l]
 
-       if len(val) < 10 {
+       if l < 13 {
                t.Fatalf("profile too short: %#x", val)
        }
-       if val[0] != 0 || val[1] != 3 || val[2] != 0 || val[3] != 1e6/100 || val[4] != 0 {
-               t.Fatalf("unexpected header %#x", val[:5])
+
+       fmt.Println(val, l)
+       hd, val, tl := val[:5], val[5:l-3], val[l-3:]
+       fmt.Println(hd, val, tl)
+       if hd[0] != 0 || hd[1] != 3 || hd[2] != 0 || hd[3] != 1e6/100 || hd[4] != 0 {
+               t.Fatalf("unexpected header %#x", hd)
+       }
+
+       if tl[0] != 0 || tl[1] != 1 || tl[2] != 0 {
+               t.Fatalf("malformed end-of-data marker %#x", tl)
        }
 
        // Check that profile is well formed and contains ChecksumIEEE.
        found := false
-       val = val[5:]
        for len(val) > 0 {
                if len(val) < 2 || val[0] < 1 || val[1] < 1 || uintptr(len(val)) < 2+val[1] {
                        t.Fatalf("malformed profile.  leftover: %#x", val)
index 59ef264..bdd5d71 100644 (file)
@@ -44,7 +44,7 @@ func cutoff64(base int) uint64 {
 }
 
 // ParseUint is like ParseInt but for unsigned numbers.
-func ParseUint(s string, b int, bitSize int) (n uint64, err error) {
+func ParseUint(s string, base int, bitSize int) (n uint64, err error) {
        var cutoff, maxVal uint64
 
        if bitSize == 0 {
@@ -57,32 +57,32 @@ func ParseUint(s string, b int, bitSize int) (n uint64, err error) {
                err = ErrSyntax
                goto Error
 
-       case 2 <= b && b <= 36:
+       case 2 <= base && base <= 36:
                // valid base; nothing to do
 
-       case b == 0:
+       case base == 0:
                // Look for octal, hex prefix.
                switch {
                case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
-                       b = 16
+                       base = 16
                        s = s[2:]
                        if len(s) < 1 {
                                err = ErrSyntax
                                goto Error
                        }
                case s[0] == '0':
-                       b = 8
+                       base = 8
                default:
-                       b = 10
+                       base = 10
                }
 
        default:
-               err = errors.New("invalid base " + Itoa(b))
+               err = errors.New("invalid base " + Itoa(base))
                goto Error
        }
 
        n = 0
-       cutoff = cutoff64(b)
+       cutoff = cutoff64(base)
        maxVal = 1<<uint(bitSize) - 1
 
        for i := 0; i < len(s); i++ {
@@ -100,19 +100,19 @@ func ParseUint(s string, b int, bitSize int) (n uint64, err error) {
                        err = ErrSyntax
                        goto Error
                }
-               if int(v) >= b {
+               if int(v) >= base {
                        n = 0
                        err = ErrSyntax
                        goto Error
                }
 
                if n >= cutoff {
-                       // n*b overflows
+                       // n*base overflows
                        n = 1<<64 - 1
                        err = ErrRange
                        goto Error
                }
-               n *= uint64(b)
+               n *= uint64(base)
 
                n1 := n + uint64(v)
                if n1 < n || n1 > maxVal {
index 0165b1f..bc9e738 100644 (file)
@@ -32,10 +32,11 @@ type WaitGroup struct {
 
 // Add adds delta, which may be negative, to the WaitGroup counter.
 // If the counter becomes zero, all goroutines blocked on Wait() are released.
+// If the counter goes negative, Add panics.
 func (wg *WaitGroup) Add(delta int) {
        v := atomic.AddInt32(&wg.counter, int32(delta))
        if v < 0 {
-               panic("sync: negative WaitGroup count")
+               panic("sync: negative WaitGroup counter")
        }
        if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 {
                return
index 34430fc..84c4cfc 100644 (file)
@@ -50,7 +50,7 @@ func TestWaitGroup(t *testing.T) {
 func TestWaitGroupMisuse(t *testing.T) {
        defer func() {
                err := recover()
-               if err != "sync: negative WaitGroup count" {
+               if err != "sync: negative WaitGroup counter" {
                        t.Fatalf("Unexpected panic: %#v", err)
                }
        }()
index 8308f10..3107ae5 100644 (file)
@@ -12,14 +12,18 @@ import (
 )
 
 func Getenv(key string) (value string, found bool) {
+       keyp, err := utf16PtrFromString(key)
+       if err != nil {
+               return "", false
+       }
        b := make([]uint16, 100)
-       n, e := GetEnvironmentVariable(StringToUTF16Ptr(key), &b[0], uint32(len(b)))
+       n, e := GetEnvironmentVariable(keyp, &b[0], uint32(len(b)))
        if n == 0 && e == ERROR_ENVVAR_NOT_FOUND {
                return "", false
        }
        if n > uint32(len(b)) {
                b = make([]uint16, n)
-               n, e = GetEnvironmentVariable(StringToUTF16Ptr(key), &b[0], uint32(len(b)))
+               n, e = GetEnvironmentVariable(keyp, &b[0], uint32(len(b)))
                if n > uint32(len(b)) {
                        n = 0
                }
@@ -32,10 +36,18 @@ func Getenv(key string) (value string, found bool) {
 
 func Setenv(key, value string) error {
        var v *uint16
+       var err error
        if len(value) > 0 {
-               v = StringToUTF16Ptr(value)
+               v, err = utf16PtrFromString(value)
+               if err != nil {
+                       return err
+               }
+       }
+       keyp, err := utf16PtrFromString(key)
+       if err != nil {
+               return err
        }
-       e := SetEnvironmentVariable(StringToUTF16Ptr(key), v)
+       e := SetEnvironmentVariable(keyp, v)
        if e != nil {
                return e
        }
index 664908d..b34ee1b 100644 (file)
@@ -103,8 +103,9 @@ import (
 
 var ForkLock sync.RWMutex
 
-// Convert array of string to array
-// of NUL-terminated byte pointer.
+// Convert array of string to array of NUL-terminated byte pointer.
+// If any string contains a NUL byte this function panics instead
+// of returning an error.
 func StringSlicePtr(ss []string) []*byte {
        bb := make([]*byte, len(ss)+1)
        for i := 0; i < len(ss); i++ {
@@ -114,6 +115,22 @@ func StringSlicePtr(ss []string) []*byte {
        return bb
 }
 
+// slicePtrFromStrings converts a slice of strings to a slice of
+// pointers to NUL-terminated byte slices. If any string contains
+// a NUL byte, it returns (nil, EINVAL).
+func slicePtrFromStrings(ss []string) ([]*byte, error) {
+       var err error
+       bb := make([]*byte, len(ss)+1)
+       for i := 0; i < len(ss); i++ {
+               bb[i], err = bytePtrFromString(ss[i])
+               if err != nil {
+                       return nil, err
+               }
+       }
+       bb[len(ss)] = nil
+       return bb, nil
+}
+
 func CloseOnExec(fd int) { fcntl(fd, F_SETFD, FD_CLOEXEC) }
 
 func SetNonblock(fd int, nonblocking bool) (err error) {
@@ -168,9 +185,18 @@ func forkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err error)
        p[1] = -1
 
        // Convert args to C form.
-       argv0p := StringBytePtr(argv0)
-       argvp := StringSlicePtr(argv)
-       envvp := StringSlicePtr(attr.Env)
+       argv0p, err := bytePtrFromString(argv0)
+       if err != nil {
+               return 0, err
+       }
+       argvp, err := slicePtrFromStrings(argv)
+       if err != nil {
+               return 0, err
+       }
+       envvp, err := slicePtrFromStrings(attr.Env)
+       if err != nil {
+               return 0, err
+       }
 
        if runtime.GOOS == "freebsd" && len(argv[0]) > len(argv0) {
                argvp[0] = argv0p
@@ -178,11 +204,17 @@ func forkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err error)
 
        var chroot *byte
        if sys.Chroot != "" {
-               chroot = StringBytePtr(sys.Chroot)
+               chroot, err = bytePtrFromString(sys.Chroot)
+               if err != nil {
+                       return 0, err
+               }
        }
        var dir *byte
        if attr.Dir != "" {
-               dir = StringBytePtr(attr.Dir)
+               dir, err = bytePtrFromString(attr.Dir)
+               if err != nil {
+                       return 0, err
+               }
        }
 
        // Acquire the fork lock so that no other threads
@@ -254,8 +286,18 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
 
 // Ordinary exec.
 func Exec(argv0 string, argv []string, envv []string) (err error) {
-       err1 := raw_execve(StringBytePtr(argv0),
-               &StringSlicePtr(argv)[0],
-               &StringSlicePtr(envv)[0])
+       argv0p, err := bytePtrFromString(argv0)
+       if err != nil {
+               return err
+       }
+       argvp, err := slicePtrFromStrings(argv)
+       if err != nil {
+               return err
+       }
+       envvp, err := slicePtrFromStrings(envv)
+       if err != nil {
+               return err
+       }
+       err1 := raw_execve(argv0p, &argvp[0], &envvp[0])
        return Errno(err1)
 }
index 4dc4d05..68779c4 100644 (file)
@@ -132,7 +132,10 @@ func SetNonblock(fd Handle, nonblocking bool) (err error) {
 // getFullPath retrieves the full path of the specified file.
 // Just a wrapper for Windows GetFullPathName api.
 func getFullPath(name string) (path string, err error) {
-       p := StringToUTF16Ptr(name)
+       p, err := utf16PtrFromString(name)
+       if err != nil {
+               return "", err
+       }
        buf := make([]uint16, 100)
        n, err := GetFullPathName(p, uint32(len(buf)), &buf[0], nil)
        if err != nil {
@@ -261,7 +264,10 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
                        return 0, 0, err
                }
        }
-       argv0p := StringToUTF16Ptr(argv0)
+       argv0p, err := utf16PtrFromString(argv0)
+       if err != nil {
+               return 0, 0, err
+       }
 
        var cmdline string
        // Windows CreateProcess takes the command line as a single string:
@@ -275,12 +281,18 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
 
        var argvp *uint16
        if len(cmdline) != 0 {
-               argvp = StringToUTF16Ptr(cmdline)
+               argvp, err = utf16PtrFromString(cmdline)
+               if err != nil {
+                       return 0, 0, err
+               }
        }
 
        var dirp *uint16
        if len(attr.Dir) != 0 {
-               dirp = StringToUTF16Ptr(attr.Dir)
+               dirp, err = utf16PtrFromString(attr.Dir)
+               if err != nil {
+                       return 0, 0, err
+               }
        }
 
        // Acquire the fork lock so that no other threads
index bd40fe5..4353af4 100644 (file)
@@ -37,10 +37,13 @@ const (
 // TranslateAccountName converts a directory service
 // object name from one format to another.
 func TranslateAccountName(username string, from, to uint32, initSize int) (string, error) {
-       u := StringToUTF16Ptr(username)
+       u, e := utf16PtrFromString(username)
+       if e != nil {
+               return "", e
+       }
        b := make([]uint16, 50)
        n := uint32(len(b))
-       e := TranslateName(u, from, to, &b[0], &n)
+       e = TranslateName(u, from, to, &b[0], &n)
        if e != nil {
                if e != ERROR_INSUFFICIENT_BUFFER {
                        return "", e
@@ -94,7 +97,11 @@ type SID struct{}
 // sid into a valid, functional sid.
 func StringToSid(s string) (*SID, error) {
        var sid *SID
-       e := ConvertStringSidToSid(StringToUTF16Ptr(s), &sid)
+       p, e := utf16PtrFromString(s)
+       if e != nil {
+               return nil, e
+       }
+       e = ConvertStringSidToSid(p, &sid)
        if e != nil {
                return nil, e
        }
@@ -109,17 +116,23 @@ func LookupSID(system, account string) (sid *SID, domain string, accType uint32,
        if len(account) == 0 {
                return nil, "", 0, EINVAL
        }
-       acc := StringToUTF16Ptr(account)
+       acc, e := utf16PtrFromString(account)
+       if e != nil {
+               return nil, "", 0, e
+       }
        var sys *uint16
        if len(system) > 0 {
-               sys = StringToUTF16Ptr(system)
+               sys, e = utf16PtrFromString(system)
+               if e != nil {
+                       return nil, "", 0, e
+               }
        }
        db := make([]uint16, 50)
        dn := uint32(len(db))
        b := make([]byte, 50)
        n := uint32(len(b))
        sid = (*SID)(unsafe.Pointer(&b[0]))
-       e := LookupAccountName(sys, acc, sid, &n, &db[0], &dn, &accType)
+       e = LookupAccountName(sys, acc, sid, &n, &db[0], &dn, &accType)
        if e != nil {
                if e != ERROR_INSUFFICIENT_BUFFER {
                        return nil, "", 0, e
@@ -170,7 +183,10 @@ func (sid *SID) Copy() (*SID, error) {
 func (sid *SID) LookupAccount(system string) (account, domain string, accType uint32, err error) {
        var sys *uint16
        if len(system) > 0 {
-               sys = StringToUTF16Ptr(system)
+               sys, err = utf16PtrFromString(system)
+               if err != nil {
+                       return "", "", 0, err
+               }
        }
        b := make([]uint16, 50)
        n := uint32(len(b))
index 4efaaec..3090a5e 100644 (file)
@@ -16,18 +16,47 @@ package syscall
 
 import "unsafe"
 
-// StringByteSlice returns a NUL-terminated slice of bytes
-// containing the text of s.
+// StringByteSlice returns a NUL-terminated slice of bytes containing the text of s.
+// If s contains a NUL byte this function panics instead of
+// returning an error.
 func StringByteSlice(s string) []byte {
+       a, err := byteSliceFromString(s)
+       if err != nil {
+               panic("syscall: string with NUL passed to StringByteSlice")
+       }
+       return a
+}
+
+// byteSliceFromString returns a NUL-terminated slice of bytes
+// containing the text of s. If s contains a NUL byte at any
+// location, it returns (nil, EINVAL).
+func byteSliceFromString(s string) ([]byte, error) {
+       for i := 0; i < len(s); i++ {
+               if s[i] == 0 {
+                       return nil, EINVAL
+               }
+       }
        a := make([]byte, len(s)+1)
        copy(a, s)
-       return a
+       return a, nil
 }
 
-// StringBytePtr returns a pointer to a NUL-terminated array of bytes
-// containing the text of s.
+// StringBytePtr returns a pointer to a NUL-terminated array of bytes containing the text of s.
+// If s contains a NUL byte this function panics instead of
+// returning an error.
 func StringBytePtr(s string) *byte { return &StringByteSlice(s)[0] }
 
+// bytePtrFromString returns a pointer to a NUL-terminated array of
+// bytes containing the text of s. If s contains a NUL byte at any
+// location, it returns (nil, EINVAL).
+func bytePtrFromString(s string) (*byte, error) {
+       a, err := byteSliceFromString(s)
+       if err != nil {
+               return nil, err
+       }
+       return &a[0], nil
+}
+
 // Single-word zero for use when we need a valid pointer to 0 bytes.
 // See mksyscall.pl.
 var _zero uintptr
index 9a988a5..08422de 100644 (file)
@@ -8,13 +8,9 @@ package syscall
 
 import "unsafe"
 
-func (r *PtraceRegs) PC() uint64 {
-       return uint64(uint32(r.Eip))
-}
+func (r *PtraceRegs) PC() uint64 { return uint64(uint32(r.Eip)) }
 
-func (r *PtraceRegs) SetPC(pc uint64) {
-       r.Eip = int32(pc)
-}
+func (r *PtraceRegs) SetPC(pc uint64) { r.Eip = int32(pc) }
 
 func PtraceGetRegs(pid int, regsout *PtraceRegs) (err error) {
        return ptrace(PTRACE_GETREGS, pid, 0, uintptr(unsafe.Pointer(regsout)))
index 1cb8a07..2d2f45e 100644 (file)
@@ -80,6 +80,7 @@ package testing
 
 import (
        _ "debug/elf"
+       "bytes"
        "flag"
        "fmt"
        "os"
@@ -87,6 +88,7 @@ import (
        "runtime/pprof"
        "strconv"
        "strings"
+       "sync"
        "time"
 )
 
@@ -116,8 +118,10 @@ var (
 // common holds the elements common between T and B and
 // captures common methods such as Errorf.
 type common struct {
-       output   []byte    // Output generated by test or benchmark.
-       failed   bool      // Test or benchmark has failed.
+       mu     sync.RWMutex // guards output and failed
+       output []byte       // Output generated by test or benchmark.
+       failed bool         // Test or benchmark has failed.
+
        start    time.Time // Time test or benchmark started
        duration time.Duration
        self     interface{}      // To be sent on signal channel when done.
@@ -129,37 +133,42 @@ func Short() bool {
        return *short
 }
 
-// decorate inserts the final newline if needed and indentation tabs for formatting.
-// If addFileLine is true, it also prefixes the string with the file and line of the call site.
-func decorate(s string, addFileLine bool) string {
-       if addFileLine {
-               _, file, line, ok := runtime.Caller(3) // decorate + log + public function.
-               if ok {
-                       // Truncate file name at last file name separator.
-                       if index := strings.LastIndex(file, "/"); index >= 0 {
-                               file = file[index+1:]
-                       } else if index = strings.LastIndex(file, "\\"); index >= 0 {
-                               file = file[index+1:]
-                       }
-               } else {
-                       file = "???"
-                       line = 1
+// decorate prefixes the string with the file and line of the call site
+// and inserts the final newline if needed and indentation tabs for formatting.
+func decorate(s string) string {
+       _, file, line, ok := runtime.Caller(3) // decorate + log + public function.
+       if ok {
+               // Truncate file name at last file name separator.
+               if index := strings.LastIndex(file, "/"); index >= 0 {
+                       file = file[index+1:]
+               } else if index = strings.LastIndex(file, "\\"); index >= 0 {
+                       file = file[index+1:]
                }
-               s = fmt.Sprintf("%s:%d: %s", file, line, s)
-       }
-       s = "\t" + s // Every line is indented at least one tab.
-       n := len(s)
-       if n > 0 && s[n-1] != '\n' {
-               s += "\n"
-               n++
+       } else {
+               file = "???"
+               line = 1
        }
-       for i := 0; i < n-1; i++ { // -1 to avoid final newline
-               if s[i] == '\n' {
+       buf := new(bytes.Buffer)
+       fmt.Fprintf(buf, "%s:%d: ", file, line)
+
+       lines := strings.Split(s, "\n")
+       for i, line := range lines {
+               if i > 0 {
+                       buf.WriteByte('\n')
+               }
+               // Every line is indented at least one tab.
+               buf.WriteByte('\t')
+               if i > 0 {
                        // Second and subsequent lines are indented an extra tab.
-                       return s[0:i+1] + "\t" + decorate(s[i+1:n], false)
+                       buf.WriteByte('\t')
                }
+               buf.WriteString(line)
+       }
+       if l := len(s); l > 0 && s[len(s)-1] != '\n' {
+               // Add final new line if needed.
+               buf.WriteByte('\n')
        }
-       return s
+       return buf.String()
 }
 
 // T is a type passed to Test functions to manage test state and support formatted test logs.
@@ -171,10 +180,18 @@ type T struct {
 }
 
 // Fail marks the function as having failed but continues execution.
-func (c *common) Fail() { c.failed = true }
+func (c *common) Fail() {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       c.failed = true
+}
 
 // Failed returns whether the function has failed.
-func (c *common) Failed() bool { return c.failed }
+func (c *common) Failed() bool {
+       c.mu.RLock()
+       defer c.mu.RUnlock()
+       return c.failed
+}
 
 // FailNow marks the function as having failed and stops its execution.
 // Execution will continue at the next test or benchmark.
@@ -205,7 +222,9 @@ func (c *common) FailNow() {
 
 // log generates the output. It's always at the same stack depth.
 func (c *common) log(s string) {
-       c.output = append(c.output, decorate(s, true)...)
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       c.output = append(c.output, decorate(s)...)
 }
 
 // Log formats its arguments using default formatting, analogous to Println(),
@@ -298,7 +317,7 @@ func Main(matchString func(pat, str string) (bool, error), tests []InternalTest,
 func (t *T) report() {
        tstr := fmt.Sprintf("(%.2f seconds)", t.duration.Seconds())
        format := "--- %s: %s %s\n%s"
-       if t.failed {
+       if t.Failed() {
                fmt.Printf(format, "FAIL", t.name, tstr, t.output)
        } else if *chatty {
                fmt.Printf(format, "PASS", t.name, tstr, t.output)
@@ -357,7 +376,7 @@ func RunTests(matchString func(pat, str string) (bool, error), tests []InternalT
                                continue
                        }
                        t.report()
-                       ok = ok && !out.failed
+                       ok = ok && !out.Failed()
                }
 
                running := 0
@@ -370,7 +389,7 @@ func RunTests(matchString func(pat, str string) (bool, error), tests []InternalT
                        }
                        t := (<-collector).(*T)
                        t.report()
-                       ok = ok && !t.failed
+                       ok = ok && !t.Failed()
                        running--
                }
        }
index ce84600..722ac8d 100644 (file)
@@ -547,7 +547,7 @@ func (b *Writer) Write(buf []byte) (n int, err error) {
 }
 
 // NewWriter allocates and initializes a new tabwriter.Writer.
-// The parameters are the same as for the the Init function.
+// The parameters are the same as for the Init function.
 //
 func NewWriter(output io.Writer, minwidth, tabwidth, padding int, padchar byte, flags uint) *Writer {
        return new(Writer).Init(output, minwidth, tabwidth, padding, padchar, flags)
index aa50ab9..4a1682d 100644 (file)
@@ -198,7 +198,7 @@ If a "range" action initializes a variable, the variable is set to the
 successive elements of the iteration.  Also, a "range" may declare two
 variables, separated by a comma:
 
-       $index, $element := pipeline
+       range $index, $element := pipeline
 
 in which case $index and $element are set to the successive values of the
 array/slice index or map key and element, respectively.  Note that if there is
index f4ae50f..6414953 100644 (file)
@@ -387,7 +387,7 @@ var execTests = []execTest{
        {"slice[WRONG]", "{{index .SI `hello`}}", "", tVal, false},
        {"map[one]", "{{index .MSI `one`}}", "1", tVal, true},
        {"map[two]", "{{index .MSI `two`}}", "2", tVal, true},
-       {"map[NO]", "{{index .MSI `XXX`}}", "", tVal, true},
+       {"map[NO]", "{{index .MSI `XXX`}}", "0", tVal, true},
        {"map[WRONG]", "{{index .MSI 10}}", "", tVal, false},
        {"double index", "{{index .SMSI 1 `eleven`}}", "11", tVal, true},
 
index 8fbf0ef..e6fa0fb 100644 (file)
@@ -128,7 +128,7 @@ func index(item interface{}, indices ...interface{}) (interface{}, error) {
                        if x := v.MapIndex(index); x.IsValid() {
                                v = x
                        } else {
-                               v = reflect.Zero(v.Type().Key())
+                               v = reflect.Zero(v.Type().Elem())
                        }
                default:
                        return nil, fmt.Errorf("can't index item of type %s", index.Type())
@@ -154,7 +154,7 @@ func length(item interface{}) (int, error) {
 
 // Function invocation
 
-// call returns the result of evaluating the the first argument as a function.
+// call returns the result of evaluating the first argument as a function.
 // The function must return 1 result, or 2 results, the second of which is an error.
 func call(fn interface{}, args ...interface{}) (interface{}, error) {
        v := reflect.ValueOf(fn)
index 7705c0b..c4e1a56 100644 (file)
@@ -257,16 +257,17 @@ func lexText(l *lexer) stateFn {
 
 // lexLeftDelim scans the left delimiter, which is known to be present.
 func lexLeftDelim(l *lexer) stateFn {
-       if strings.HasPrefix(l.input[l.pos:], l.leftDelim+leftComment) {
+       l.pos += len(l.leftDelim)
+       if strings.HasPrefix(l.input[l.pos:], leftComment) {
                return lexComment
        }
-       l.pos += len(l.leftDelim)
        l.emit(itemLeftDelim)
        return lexInsideAction
 }
 
 // lexComment scans a comment. The left comment marker is known to be present.
 func lexComment(l *lexer) stateFn {
+       l.pos += len(leftComment)
        i := strings.Index(l.input[l.pos:], rightComment+l.rightDelim)
        if i < 0 {
                return l.errorf("unclosed comment")
index 6ee1b47..f3b23c9 100644 (file)
@@ -198,6 +198,10 @@ var lexTests = []lexTest{
                tRight,
                tEOF,
        }},
+       {"text with bad comment", "hello-{{/*/}}-world", []item{
+               {itemText, "hello-"},
+               {itemError, `unclosed comment`},
+       }},
 }
 
 // collect gathers the emitted items into a slice.
index 2461dac..d48ca0c 100644 (file)
@@ -241,10 +241,10 @@ func (t Time) IsZero() bool {
 // It is called when computing a presentation property like Month or Hour.
 func (t Time) abs() uint64 {
        l := t.loc
-       if l == nil {
-               l = &utcLoc
+       // Avoid function calls when possible.
+       if l == nil || l == &localLoc {
+               l = l.get()
        }
-       // Avoid function call if we hit the local time cache.
        sec := t.sec + internalToUnix
        if l != &utcLoc {
                if l.cacheZone != nil && l.cacheStart <= sec && sec < l.cacheEnd {
index d0a1612..1e389a2 100644 (file)
@@ -87,7 +87,7 @@ runtime_makechan_c(ChanType *t, int64 hint)
        Hchan *c;
        int32 n;
        const Type *elem;
-       
+
        elem = t->__element_type;
 
        if(hint < 0 || (int32)hint != hint || (elem->__size > 0 && (uintptr)hint > MaxMem / elem->__size))
@@ -191,7 +191,7 @@ runtime_chansend(ChanType *t, Hchan *c, byte *ep, bool *pres)
        sg = dequeue(&c->recvq);
        if(sg != nil) {
                runtime_unlock(c);
-               
+
                gp = sg->g;
                gp->param = sg;
                if(sg->elem != nil)
@@ -530,7 +530,7 @@ runtime_selectnbrecv(ChanType *t, byte *v, Hchan *c)
 
        runtime_chanrecv(t, c, v, &selected, nil);
        return selected;
-}      
+}
 
 // func selectnbrecv2(elem *any, ok *bool, c chan any) bool
 //
@@ -562,7 +562,7 @@ runtime_selectnbrecv2(ChanType *t, byte *v, _Bool *received, Hchan *c)
        if(received != nil)
                *received = r;
        return selected;
-}      
+}
 
 // For reflect:
 //     func chansend(c chan, val iword, nb bool) (selected bool)
@@ -578,7 +578,7 @@ reflect_chansend(ChanType *t, Hchan *c, uintptr val, _Bool nb)
        bool selected;
        bool *sp;
        byte *vp;
-       
+
        if(nb) {
                selected = false;
                sp = (bool*)&selected;
@@ -697,7 +697,7 @@ runtime_selectsend(Select *sel, Hchan *c, void *elem, int index)
        // nil cases do not compete
        if(c == nil)
                return;
-       
+
        selectsend(sel, c, index, elem);
 }
 
@@ -706,7 +706,7 @@ selectsend(Select *sel, Hchan *c, int index, void *elem)
 {
        int32 i;
        Scase *cas;
-       
+
        i = sel->ncase;
        if(i >= sel->tcase)
                runtime_throw("selectsend: too many cases");
@@ -977,7 +977,7 @@ loop:
                case CaseRecv:
                        enqueue(&c->recvq, sg);
                        break;
-               
+
                case CaseSend:
                        enqueue(&c->sendq, sg);
                        break;
index 252948d..9bf5d11 100644 (file)
@@ -105,6 +105,7 @@ struct Profile {
        uint32 wtoggle;
        bool wholding;  // holding & need to release a log half
        bool flushing;  // flushing hash table - profile is over
+       bool eod_sent;  // special end-of-data record sent; => flushing
 };
 
 static Lock lk;
@@ -115,6 +116,8 @@ static void add(Profile*, uintptr*, int32);
 static bool evict(Profile*, Entry*);
 static bool flushlog(Profile*);
 
+static uintptr eod[3] = {0, 1, 0};
+
 // LostProfileData is a no-op function used in profiles
 // to mark the number of profiling stack traces that were
 // discarded due to slow data writers.
@@ -168,6 +171,7 @@ runtime_SetCPUProfileRate(int32 hz)
                prof->wholding = false;
                prof->wtoggle = 0;
                prof->flushing = false;
+               prof->eod_sent = false;
                runtime_noteclear(&prof->wait);
 
                runtime_setcpuprofilerate(tick, hz);
@@ -414,6 +418,16 @@ breakflush:
        }
 
        // Made it through the table without finding anything to log.
+       if(!p->eod_sent) {
+               // We may not have space to append this to the partial log buf,
+               // so we always return a new slice for the end-of-data marker.
+               p->eod_sent = true;
+               ret.array = (byte*)eod;
+               ret.len = sizeof eod;
+               ret.cap = ret.len;
+               return ret;
+       }
+
        // Finally done.  Clean up and return nil.
        p->flushing = false;
        if(!runtime_cas(&p->handoff, p->handoff, 0))
index c24304e..5a8e47e 100644 (file)
@@ -20,10 +20,10 @@ gwrite(const void *v, int32 n)
                runtime_write(2, v, n);
                return;
        }
-       
+
        if(g->writenbuf == 0)
                return;
-       
+
        if(n > g->writenbuf)
                n = g->writenbuf;
        runtime_memmove(g->writebuf, v, n);
index 72875fd..e0a7925 100644 (file)
@@ -106,11 +106,15 @@ static byte**     argv;
 extern Slice os_Args asm ("os.Args");
 extern Slice syscall_Envs asm ("syscall.Envs");
 
+void (*runtime_sysargs)(int32, uint8**);
+
 void
 runtime_args(int32 c, byte **v)
 {
        argc = c;
        argv = v;
+       if(runtime_sysargs != nil)
+               runtime_sysargs(c, v);
 }
 
 void
@@ -234,7 +238,7 @@ runtime_showframe(const unsigned char *s)
        
        if(traceback < 0)
                traceback = runtime_gotraceback();
-       return traceback > 1 || (__builtin_strchr((const char*)s, '.') != nil && __builtin_memcmp(s, "runtime.", 7) != 0);
+       return traceback > 1 || (s != nil && __builtin_strchr((const char*)s, '.') != nil && __builtin_memcmp(s, "runtime.", 7) != 0);
 }
 
 bool