OSDN Git Service

libgo: Update to weekly.2011-11-18.
[pf3gnuchains/gcc-fork.git] / libgo / go / crypto / tls / conn.go
index fac65af..b8fa273 100644 (file)
@@ -11,10 +11,9 @@ import (
        "crypto/cipher"
        "crypto/subtle"
        "crypto/x509"
-       "hash"
+       "errors"
        "io"
        "net"
-       "os"
        "sync"
 )
 
@@ -37,13 +36,15 @@ type Conn struct {
        // verifiedChains contains the certificate chains that we built, as
        // opposed to the ones presented by the server.
        verifiedChains [][]*x509.Certificate
+       // serverName contains the server name indicated by the client, if any.
+       serverName string
 
        clientProtocol         string
        clientProtocolFallback bool
 
        // first permanent error
        errMutex sync.Mutex
-       err      os.Error
+       err      error
 
        // input/output
        in, out  halfConn     // in.Mutex < out.Mutex
@@ -54,7 +55,7 @@ type Conn struct {
        tmp [16]byte
 }
 
-func (c *Conn) setError(err os.Error) os.Error {
+func (c *Conn) setError(err error) error {
        c.errMutex.Lock()
        defer c.errMutex.Unlock()
 
@@ -64,7 +65,7 @@ func (c *Conn) setError(err os.Error) os.Error {
        return err
 }
 
-func (c *Conn) error() os.Error {
+func (c *Conn) error() error {
        c.errMutex.Lock()
        defer c.errMutex.Unlock()
 
@@ -87,46 +88,49 @@ func (c *Conn) RemoteAddr() net.Addr {
 
 // SetTimeout sets the read deadline associated with the connection.
 // There is no write deadline.
-func (c *Conn) SetTimeout(nsec int64) os.Error {
+func (c *Conn) SetTimeout(nsec int64) error {
        return c.conn.SetTimeout(nsec)
 }
 
 // SetReadTimeout sets the time (in nanoseconds) that
-// Read will wait for data before returning os.EAGAIN.
+// Read will wait for data before returning a net.Error
+// with Timeout() == true.
 // Setting nsec == 0 (the default) disables the deadline.
-func (c *Conn) SetReadTimeout(nsec int64) os.Error {
+func (c *Conn) SetReadTimeout(nsec int64) error {
        return c.conn.SetReadTimeout(nsec)
 }
 
 // SetWriteTimeout exists to satisfy the net.Conn interface
 // but is not implemented by TLS.  It always returns an error.
-func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
-       return os.NewError("TLS does not support SetWriteTimeout")
+func (c *Conn) SetWriteTimeout(nsec int64) error {
+       return errors.New("TLS does not support SetWriteTimeout")
 }
 
 // A halfConn represents one direction of the record layer
 // connection, either sending or receiving.
 type halfConn struct {
        sync.Mutex
-       cipher interface{} // cipher algorithm
-       mac    hash.Hash   // MAC algorithm
-       seq    [8]byte     // 64-bit sequence number
-       bfree  *block      // list of free blocks
+       version uint16      // protocol version
+       cipher  interface{} // cipher algorithm
+       mac     macFunction
+       seq     [8]byte // 64-bit sequence number
+       bfree   *block  // list of free blocks
 
        nextCipher interface{} // next encryption state
-       nextMac    hash.Hash   // next MAC algorithm
+       nextMac    macFunction // next MAC algorithm
 }
 
 // prepareCipherSpec sets the encryption and MAC states
 // that a subsequent changeCipherSpec will use.
-func (hc *halfConn) prepareCipherSpec(cipher interface{}, mac hash.Hash) {
+func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
+       hc.version = version
        hc.nextCipher = cipher
        hc.nextMac = mac
 }
 
 // changeCipherSpec changes the encryption and MAC states
 // to the ones previously passed to prepareCipherSpec.
-func (hc *halfConn) changeCipherSpec() os.Error {
+func (hc *halfConn) changeCipherSpec() error {
        if hc.nextCipher == nil {
                return alertInternalError
        }
@@ -197,6 +201,22 @@ func removePadding(payload []byte) ([]byte, byte) {
        return payload[:len(payload)-int(toRemove)], good
 }
 
+// removePaddingSSL30 is a replacement for removePadding in the case that the
+// protocol version is SSLv3. In this version, the contents of the padding
+// are random and cannot be checked.
+func removePaddingSSL30(payload []byte) ([]byte, byte) {
+       if len(payload) < 1 {
+               return payload, 0
+       }
+
+       paddingLen := int(payload[len(payload)-1]) + 1
+       if paddingLen > len(payload) {
+               return payload, 0
+       }
+
+       return payload[:len(payload)-paddingLen], 255
+}
+
 func roundUp(a, b int) int {
        return a + (b-a%b)%b
 }
@@ -226,7 +246,11 @@ func (hc *halfConn) decrypt(b *block) (bool, alert) {
                        }
 
                        c.CryptBlocks(payload, payload)
-                       payload, paddingGood = removePadding(payload)
+                       if hc.version == versionSSL30 {
+                               payload, paddingGood = removePaddingSSL30(payload)
+                       } else {
+                               payload, paddingGood = removePadding(payload)
+                       }
                        b.resize(recordHeaderLen + len(payload))
 
                        // note that we still have a timing side-channel in the
@@ -256,13 +280,10 @@ func (hc *halfConn) decrypt(b *block) (bool, alert) {
                b.data[4] = byte(n)
                b.resize(recordHeaderLen + n)
                remoteMAC := payload[n:]
-
-               hc.mac.Reset()
-               hc.mac.Write(hc.seq[0:])
+               localMAC := hc.mac.MAC(hc.seq[0:], b.data)
                hc.incSeq()
-               hc.mac.Write(b.data)
 
-               if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 || paddingGood != 255 {
+               if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
                        return false, alertBadRecordMAC
                }
        }
@@ -291,11 +312,9 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
 func (hc *halfConn) encrypt(b *block) (bool, alert) {
        // mac
        if hc.mac != nil {
-               hc.mac.Reset()
-               hc.mac.Write(hc.seq[0:])
+               mac := hc.mac.MAC(hc.seq[0:], b.data)
                hc.incSeq()
-               hc.mac.Write(b.data)
-               mac := hc.mac.Sum()
+
                n := len(b.data)
                b.resize(n + len(mac))
                copy(b.data[n:], mac)
@@ -360,7 +379,7 @@ func (b *block) reserve(n int) {
 
 // readFromUntil reads from r into b until b contains at least n bytes
 // or else returns an error.
-func (b *block) readFromUntil(r io.Reader, n int) os.Error {
+func (b *block) readFromUntil(r io.Reader, n int) error {
        // quick case
        if len(b.data) >= n {
                return nil
@@ -381,7 +400,7 @@ func (b *block) readFromUntil(r io.Reader, n int) os.Error {
        return nil
 }
 
-func (b *block) Read(p []byte) (n int, err os.Error) {
+func (b *block) Read(p []byte) (n int, err error) {
        n = copy(p, b.data[b.off:])
        b.off += n
        return
@@ -425,7 +444,7 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
 // readRecord reads the next TLS record from the connection
 // and updates the record layer state.
 // c.in.Mutex <= L; c.input == nil.
-func (c *Conn) readRecord(want recordType) os.Error {
+func (c *Conn) readRecord(want recordType) error {
        // Caller must be in sync with connection:
        // handshake data if handshake not yet completed,
        // else application data.  (We don't support renegotiation.)
@@ -453,7 +472,7 @@ Again:
                // RFC suggests that EOF without an alertCloseNotify is
                // an error, but popular web sites seem to do this,
                // so we can't make it an error.
-               // if err == os.EOF {
+               // if err == io.EOF {
                //      err = io.ErrUnexpectedEOF
                // }
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
@@ -470,8 +489,21 @@ Again:
        if n > maxCiphertext {
                return c.sendAlert(alertRecordOverflow)
        }
+       if !c.haveVers {
+               // First message, be extra suspicious:
+               // this might not be a TLS client.
+               // Bail out before reading a full 'body', if possible.
+               // The current max version is 3.1. 
+               // If the version is >= 16.0, it's probably not real.
+               // Similarly, a clientHello message encodes in
+               // well under a kilobyte.  If the length is >= 12 kB,
+               // it's probably not real.
+               if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
+                       return c.sendAlert(alertUnexpectedMessage)
+               }
+       }
        if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
-               if err == os.EOF {
+               if err == io.EOF {
                        err = io.ErrUnexpectedEOF
                }
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
@@ -503,7 +535,7 @@ Again:
                        break
                }
                if alert(data[1]) == alertCloseNotify {
-                       c.setError(os.EOF)
+                       c.setError(io.EOF)
                        break
                }
                switch data[0] {
@@ -512,7 +544,7 @@ Again:
                        c.in.freeBlock(b)
                        goto Again
                case alertLevelError:
-                       c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
+                       c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
                default:
                        c.sendAlert(alertUnexpectedMessage)
                }
@@ -551,7 +583,7 @@ Again:
 
 // sendAlert sends a TLS alert message.
 // c.out.Mutex <= L.
-func (c *Conn) sendAlertLocked(err alert) os.Error {
+func (c *Conn) sendAlertLocked(err alert) error {
        c.tmp[0] = alertLevelError
        if err == alertNoRenegotiation {
                c.tmp[0] = alertLevelWarning
@@ -560,14 +592,14 @@ func (c *Conn) sendAlertLocked(err alert) os.Error {
        c.writeRecord(recordTypeAlert, c.tmp[0:2])
        // closeNotify is a special case in that it isn't an error:
        if err != alertCloseNotify {
-               return c.setError(&net.OpError{Op: "local error", Error: err})
+               return c.setError(&net.OpError{Op: "local error", Err: err})
        }
        return nil
 }
 
 // sendAlert sends a TLS alert message.
 // L < c.out.Mutex.
-func (c *Conn) sendAlert(err alert) os.Error {
+func (c *Conn) sendAlert(err alert) error {
        c.out.Lock()
        defer c.out.Unlock()
        return c.sendAlertLocked(err)
@@ -576,7 +608,7 @@ func (c *Conn) sendAlert(err alert) os.Error {
 // writeRecord writes a TLS record with the given type and payload
 // to the connection and updates the record layer state.
 // c.out.Mutex <= L.
-func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
+func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
        b := c.out.newBlock()
        for len(data) > 0 {
                m := len(data)
@@ -612,7 +644,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
                        c.tmp[0] = alertLevelError
                        c.tmp[1] = byte(err.(alert))
                        c.writeRecord(recordTypeAlert, c.tmp[0:2])
-                       c.err = &net.OpError{Op: "local error", Error: err}
+                       c.err = &net.OpError{Op: "local error", Err: err}
                        return n, c.err
                }
        }
@@ -622,12 +654,14 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
 // readHandshake reads the next handshake message from
 // the record layer.
 // c.in.Mutex < L; c.out.Mutex < L.
-func (c *Conn) readHandshake() (interface{}, os.Error) {
+func (c *Conn) readHandshake() (interface{}, error) {
        for c.hand.Len() < 4 {
                if c.err != nil {
                        return nil, c.err
                }
-               c.readRecord(recordTypeHandshake)
+               if err := c.readRecord(recordTypeHandshake); err != nil {
+                       return nil, err
+               }
        }
 
        data := c.hand.Bytes()
@@ -640,7 +674,9 @@ func (c *Conn) readHandshake() (interface{}, os.Error) {
                if c.err != nil {
                        return nil, c.err
                }
-               c.readRecord(recordTypeHandshake)
+               if err := c.readRecord(recordTypeHandshake); err != nil {
+                       return nil, err
+               }
        }
        data = c.hand.Next(4 + n)
        var m handshakeMessage
@@ -685,7 +721,7 @@ func (c *Conn) readHandshake() (interface{}, os.Error) {
 }
 
 // Write writes data to the connection.
-func (c *Conn) Write(b []byte) (n int, err os.Error) {
+func (c *Conn) Write(b []byte) (n int, err error) {
        if err = c.Handshake(); err != nil {
                return
        }
@@ -702,9 +738,9 @@ func (c *Conn) Write(b []byte) (n int, err os.Error) {
        return c.writeRecord(recordTypeApplicationData, b)
 }
 
-// Read can be made to time out and return err == os.EAGAIN
+// Read can be made to time out and return a net.Error with Timeout() == true
 // after a fixed time limit; see SetTimeout and SetReadTimeout.
-func (c *Conn) Read(b []byte) (n int, err os.Error) {
+func (c *Conn) Read(b []byte) (n int, err error) {
        if err = c.Handshake(); err != nil {
                return
        }
@@ -730,18 +766,26 @@ func (c *Conn) Read(b []byte) (n int, err os.Error) {
 }
 
 // Close closes the connection.
-func (c *Conn) Close() os.Error {
-       if err := c.Handshake(); err != nil {
+func (c *Conn) Close() error {
+       var alertErr error
+
+       c.handshakeMutex.Lock()
+       defer c.handshakeMutex.Unlock()
+       if c.handshakeComplete {
+               alertErr = c.sendAlert(alertCloseNotify)
+       }
+
+       if err := c.conn.Close(); err != nil {
                return err
        }
-       return c.sendAlert(alertCloseNotify)
+       return alertErr
 }
 
 // Handshake runs the client or server handshake
 // protocol if it has not yet been run.
 // Most uses of this package need not call Handshake
 // explicitly: the first Read or Write will call it automatically.
-func (c *Conn) Handshake() os.Error {
+func (c *Conn) Handshake() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
        if err := c.error(); err != nil {
@@ -769,6 +813,7 @@ func (c *Conn) ConnectionState() ConnectionState {
                state.CipherSuite = c.cipherSuite
                state.PeerCertificates = c.peerCertificates
                state.VerifiedChains = c.verifiedChains
+               state.ServerName = c.serverName
        }
 
        return state
@@ -784,16 +829,16 @@ func (c *Conn) OCSPResponse() []byte {
 }
 
 // VerifyHostname checks that the peer certificate chain is valid for
-// connecting to host.  If so, it returns nil; if not, it returns an os.Error
+// connecting to host.  If so, it returns nil; if not, it returns an error
 // describing the problem.
-func (c *Conn) VerifyHostname(host string) os.Error {
+func (c *Conn) VerifyHostname(host string) error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
        if !c.isClient {
-               return os.NewError("VerifyHostname called on TLS server connection")
+               return errors.New("VerifyHostname called on TLS server connection")
        }
        if !c.handshakeComplete {
-               return os.NewError("TLS handshake has not yet been performed")
+               return errors.New("TLS handshake has not yet been performed")
        }
        return c.peerCertificates[0].VerifyHostname(host)
 }