OSDN Git Service

Add Go frontend, libgo library, and Go testsuite.
[pf3gnuchains/gcc-fork.git] / libgo / go / crypto / tls / conn.go
1 // TLS low level connection and record layer
2
3 package tls
4
5 import (
6         "bytes"
7         "crypto/subtle"
8         "crypto/x509"
9         "hash"
10         "io"
11         "net"
12         "os"
13         "sync"
14 )
15
16 // A Conn represents a secured connection.
17 // It implements the net.Conn interface.
18 type Conn struct {
19         // constant
20         conn     net.Conn
21         isClient bool
22
23         // constant after handshake; protected by handshakeMutex
24         handshakeMutex    sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
25         vers              uint16     // TLS version
26         haveVers          bool       // version has been negotiated
27         config            *Config    // configuration passed to constructor
28         handshakeComplete bool
29         cipherSuite       uint16
30         ocspResponse      []byte // stapled OCSP response
31         peerCertificates  []*x509.Certificate
32
33         clientProtocol string
34
35         // first permanent error
36         errMutex sync.Mutex
37         err      os.Error
38
39         // input/output
40         in, out  halfConn     // in.Mutex < out.Mutex
41         rawInput *block       // raw input, right off the wire
42         input    *block       // application data waiting to be read
43         hand     bytes.Buffer // handshake data waiting to be read
44
45         tmp [16]byte
46 }
47
48 func (c *Conn) setError(err os.Error) os.Error {
49         c.errMutex.Lock()
50         defer c.errMutex.Unlock()
51
52         if c.err == nil {
53                 c.err = err
54         }
55         return err
56 }
57
58 func (c *Conn) error() os.Error {
59         c.errMutex.Lock()
60         defer c.errMutex.Unlock()
61
62         return c.err
63 }
64
65 // Access to net.Conn methods.
66 // Cannot just embed net.Conn because that would
67 // export the struct field too.
68
69 // LocalAddr returns the local network address.
70 func (c *Conn) LocalAddr() net.Addr {
71         return c.conn.LocalAddr()
72 }
73
74 // RemoteAddr returns the remote network address.
75 func (c *Conn) RemoteAddr() net.Addr {
76         return c.conn.RemoteAddr()
77 }
78
79 // SetTimeout sets the read deadline associated with the connection.
80 // There is no write deadline.
81 func (c *Conn) SetTimeout(nsec int64) os.Error {
82         return c.conn.SetTimeout(nsec)
83 }
84
85 // SetReadTimeout sets the time (in nanoseconds) that
86 // Read will wait for data before returning os.EAGAIN.
87 // Setting nsec == 0 (the default) disables the deadline.
88 func (c *Conn) SetReadTimeout(nsec int64) os.Error {
89         return c.conn.SetReadTimeout(nsec)
90 }
91
92 // SetWriteTimeout exists to satisfy the net.Conn interface
93 // but is not implemented by TLS.  It always returns an error.
94 func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
95         return os.NewError("TLS does not support SetWriteTimeout")
96 }
97
98 // A halfConn represents one direction of the record layer
99 // connection, either sending or receiving.
100 type halfConn struct {
101         sync.Mutex
102         crypt encryptor // encryption state
103         mac   hash.Hash // MAC algorithm
104         seq   [8]byte   // 64-bit sequence number
105         bfree *block    // list of free blocks
106
107         nextCrypt encryptor // next encryption state
108         nextMac   hash.Hash // next MAC algorithm
109 }
110
111 // prepareCipherSpec sets the encryption and MAC states
112 // that a subsequent changeCipherSpec will use.
113 func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) {
114         hc.nextCrypt = crypt
115         hc.nextMac = mac
116 }
117
118 // changeCipherSpec changes the encryption and MAC states
119 // to the ones previously passed to prepareCipherSpec.
120 func (hc *halfConn) changeCipherSpec() os.Error {
121         if hc.nextCrypt == nil {
122                 return alertInternalError
123         }
124         hc.crypt = hc.nextCrypt
125         hc.mac = hc.nextMac
126         hc.nextCrypt = nil
127         hc.nextMac = nil
128         return nil
129 }
130
131 // incSeq increments the sequence number.
132 func (hc *halfConn) incSeq() {
133         for i := 7; i >= 0; i-- {
134                 hc.seq[i]++
135                 if hc.seq[i] != 0 {
136                         return
137                 }
138         }
139
140         // Not allowed to let sequence number wrap.
141         // Instead, must renegotiate before it does.
142         // Not likely enough to bother.
143         panic("TLS: sequence number wraparound")
144 }
145
146 // resetSeq resets the sequence number to zero.
147 func (hc *halfConn) resetSeq() {
148         for i := range hc.seq {
149                 hc.seq[i] = 0
150         }
151 }
152
153 // decrypt checks and strips the mac and decrypts the data in b.
154 func (hc *halfConn) decrypt(b *block) (bool, alert) {
155         // pull out payload
156         payload := b.data[recordHeaderLen:]
157
158         // decrypt
159         if hc.crypt != nil {
160                 hc.crypt.XORKeyStream(payload)
161         }
162
163         // check, strip mac
164         if hc.mac != nil {
165                 if len(payload) < hc.mac.Size() {
166                         return false, alertBadRecordMAC
167                 }
168
169                 // strip mac off payload, b.data
170                 n := len(payload) - hc.mac.Size()
171                 b.data[3] = byte(n >> 8)
172                 b.data[4] = byte(n)
173                 b.data = b.data[0 : recordHeaderLen+n]
174                 remoteMAC := payload[n:]
175
176                 hc.mac.Reset()
177                 hc.mac.Write(hc.seq[0:])
178                 hc.incSeq()
179                 hc.mac.Write(b.data)
180
181                 if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 {
182                         return false, alertBadRecordMAC
183                 }
184         }
185
186         return true, 0
187 }
188
189 // encrypt encrypts and macs the data in b.
190 func (hc *halfConn) encrypt(b *block) (bool, alert) {
191         // mac
192         if hc.mac != nil {
193                 hc.mac.Reset()
194                 hc.mac.Write(hc.seq[0:])
195                 hc.incSeq()
196                 hc.mac.Write(b.data)
197                 mac := hc.mac.Sum()
198                 n := len(b.data)
199                 b.resize(n + len(mac))
200                 copy(b.data[n:], mac)
201
202                 // update length to include mac
203                 n = len(b.data) - recordHeaderLen
204                 b.data[3] = byte(n >> 8)
205                 b.data[4] = byte(n)
206         }
207
208         // encrypt
209         if hc.crypt != nil {
210                 hc.crypt.XORKeyStream(b.data[recordHeaderLen:])
211         }
212
213         return true, 0
214 }
215
216 // A block is a simple data buffer.
217 type block struct {
218         data []byte
219         off  int // index for Read
220         link *block
221 }
222
223 // resize resizes block to be n bytes, growing if necessary.
224 func (b *block) resize(n int) {
225         if n > cap(b.data) {
226                 b.reserve(n)
227         }
228         b.data = b.data[0:n]
229 }
230
231 // reserve makes sure that block contains a capacity of at least n bytes.
232 func (b *block) reserve(n int) {
233         if cap(b.data) >= n {
234                 return
235         }
236         m := cap(b.data)
237         if m == 0 {
238                 m = 1024
239         }
240         for m < n {
241                 m *= 2
242         }
243         data := make([]byte, len(b.data), m)
244         copy(data, b.data)
245         b.data = data
246 }
247
248 // readFromUntil reads from r into b until b contains at least n bytes
249 // or else returns an error.
250 func (b *block) readFromUntil(r io.Reader, n int) os.Error {
251         // quick case
252         if len(b.data) >= n {
253                 return nil
254         }
255
256         // read until have enough.
257         b.reserve(n)
258         for {
259                 m, err := r.Read(b.data[len(b.data):cap(b.data)])
260                 b.data = b.data[0 : len(b.data)+m]
261                 if len(b.data) >= n {
262                         break
263                 }
264                 if err != nil {
265                         return err
266                 }
267         }
268         return nil
269 }
270
271 func (b *block) Read(p []byte) (n int, err os.Error) {
272         n = copy(p, b.data[b.off:])
273         b.off += n
274         return
275 }
276
277 // newBlock allocates a new block, from hc's free list if possible.
278 func (hc *halfConn) newBlock() *block {
279         b := hc.bfree
280         if b == nil {
281                 return new(block)
282         }
283         hc.bfree = b.link
284         b.link = nil
285         b.resize(0)
286         return b
287 }
288
289 // freeBlock returns a block to hc's free list.
290 // The protocol is such that each side only has a block or two on
291 // its free list at a time, so there's no need to worry about
292 // trimming the list, etc.
293 func (hc *halfConn) freeBlock(b *block) {
294         b.link = hc.bfree
295         hc.bfree = b
296 }
297
298 // splitBlock splits a block after the first n bytes,
299 // returning a block with those n bytes and a
300 // block with the remaindec.  the latter may be nil.
301 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
302         if len(b.data) <= n {
303                 return b, nil
304         }
305         bb := hc.newBlock()
306         bb.resize(len(b.data) - n)
307         copy(bb.data, b.data[n:])
308         b.data = b.data[0:n]
309         return b, bb
310 }
311
312 // readRecord reads the next TLS record from the connection
313 // and updates the record layer state.
314 // c.in.Mutex <= L; c.input == nil.
315 func (c *Conn) readRecord(want recordType) os.Error {
316         // Caller must be in sync with connection:
317         // handshake data if handshake not yet completed,
318         // else application data.  (We don't support renegotiation.)
319         switch want {
320         default:
321                 return c.sendAlert(alertInternalError)
322         case recordTypeHandshake, recordTypeChangeCipherSpec:
323                 if c.handshakeComplete {
324                         return c.sendAlert(alertInternalError)
325                 }
326         case recordTypeApplicationData:
327                 if !c.handshakeComplete {
328                         return c.sendAlert(alertInternalError)
329                 }
330         }
331
332 Again:
333         if c.rawInput == nil {
334                 c.rawInput = c.in.newBlock()
335         }
336         b := c.rawInput
337
338         // Read header, payload.
339         if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
340                 // RFC suggests that EOF without an alertCloseNotify is
341                 // an error, but popular web sites seem to do this,
342                 // so we can't make it an error.
343                 // if err == os.EOF {
344                 //      err = io.ErrUnexpectedEOF
345                 // }
346                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
347                         c.setError(err)
348                 }
349                 return err
350         }
351         typ := recordType(b.data[0])
352         vers := uint16(b.data[1])<<8 | uint16(b.data[2])
353         n := int(b.data[3])<<8 | int(b.data[4])
354         if c.haveVers && vers != c.vers {
355                 return c.sendAlert(alertProtocolVersion)
356         }
357         if n > maxCiphertext {
358                 return c.sendAlert(alertRecordOverflow)
359         }
360         if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
361                 if err == os.EOF {
362                         err = io.ErrUnexpectedEOF
363                 }
364                 if e, ok := err.(net.Error); !ok || !e.Temporary() {
365                         c.setError(err)
366                 }
367                 return err
368         }
369
370         // Process message.
371         b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
372         b.off = recordHeaderLen
373         if ok, err := c.in.decrypt(b); !ok {
374                 return c.sendAlert(err)
375         }
376         data := b.data[b.off:]
377         if len(data) > maxPlaintext {
378                 c.sendAlert(alertRecordOverflow)
379                 c.in.freeBlock(b)
380                 return c.error()
381         }
382
383         switch typ {
384         default:
385                 c.sendAlert(alertUnexpectedMessage)
386
387         case recordTypeAlert:
388                 if len(data) != 2 {
389                         c.sendAlert(alertUnexpectedMessage)
390                         break
391                 }
392                 if alert(data[1]) == alertCloseNotify {
393                         c.setError(os.EOF)
394                         break
395                 }
396                 switch data[0] {
397                 case alertLevelWarning:
398                         // drop on the floor
399                         c.in.freeBlock(b)
400                         goto Again
401                 case alertLevelError:
402                         c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
403                 default:
404                         c.sendAlert(alertUnexpectedMessage)
405                 }
406
407         case recordTypeChangeCipherSpec:
408                 if typ != want || len(data) != 1 || data[0] != 1 {
409                         c.sendAlert(alertUnexpectedMessage)
410                         break
411                 }
412                 err := c.in.changeCipherSpec()
413                 if err != nil {
414                         c.sendAlert(err.(alert))
415                 }
416
417         case recordTypeApplicationData:
418                 if typ != want {
419                         c.sendAlert(alertUnexpectedMessage)
420                         break
421                 }
422                 c.input = b
423                 b = nil
424
425         case recordTypeHandshake:
426                 // TODO(rsc): Should at least pick off connection close.
427                 if typ != want {
428                         return c.sendAlert(alertNoRenegotiation)
429                 }
430                 c.hand.Write(data)
431         }
432
433         if b != nil {
434                 c.in.freeBlock(b)
435         }
436         return c.error()
437 }
438
439 // sendAlert sends a TLS alert message.
440 // c.out.Mutex <= L.
441 func (c *Conn) sendAlertLocked(err alert) os.Error {
442         c.tmp[0] = alertLevelError
443         if err == alertNoRenegotiation {
444                 c.tmp[0] = alertLevelWarning
445         }
446         c.tmp[1] = byte(err)
447         c.writeRecord(recordTypeAlert, c.tmp[0:2])
448         // closeNotify is a special case in that it isn't an error:
449         if err != alertCloseNotify {
450                 return c.setError(&net.OpError{Op: "local error", Error: err})
451         }
452         return nil
453 }
454
455 // sendAlert sends a TLS alert message.
456 // L < c.out.Mutex.
457 func (c *Conn) sendAlert(err alert) os.Error {
458         c.out.Lock()
459         defer c.out.Unlock()
460         return c.sendAlertLocked(err)
461 }
462
463 // writeRecord writes a TLS record with the given type and payload
464 // to the connection and updates the record layer state.
465 // c.out.Mutex <= L.
466 func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
467         b := c.out.newBlock()
468         for len(data) > 0 {
469                 m := len(data)
470                 if m > maxPlaintext {
471                         m = maxPlaintext
472                 }
473                 b.resize(recordHeaderLen + m)
474                 b.data[0] = byte(typ)
475                 vers := c.vers
476                 if vers == 0 {
477                         vers = maxVersion
478                 }
479                 b.data[1] = byte(vers >> 8)
480                 b.data[2] = byte(vers)
481                 b.data[3] = byte(m >> 8)
482                 b.data[4] = byte(m)
483                 copy(b.data[recordHeaderLen:], data)
484                 c.out.encrypt(b)
485                 _, err = c.conn.Write(b.data)
486                 if err != nil {
487                         break
488                 }
489                 n += m
490                 data = data[m:]
491         }
492         c.out.freeBlock(b)
493
494         if typ == recordTypeChangeCipherSpec {
495                 err = c.out.changeCipherSpec()
496                 if err != nil {
497                         // Cannot call sendAlert directly,
498                         // because we already hold c.out.Mutex.
499                         c.tmp[0] = alertLevelError
500                         c.tmp[1] = byte(err.(alert))
501                         c.writeRecord(recordTypeAlert, c.tmp[0:2])
502                         c.err = &net.OpError{Op: "local error", Error: err}
503                         return n, c.err
504                 }
505         }
506         return
507 }
508
509 // readHandshake reads the next handshake message from
510 // the record layer.
511 // c.in.Mutex < L; c.out.Mutex < L.
512 func (c *Conn) readHandshake() (interface{}, os.Error) {
513         for c.hand.Len() < 4 {
514                 if c.err != nil {
515                         return nil, c.err
516                 }
517                 c.readRecord(recordTypeHandshake)
518         }
519
520         data := c.hand.Bytes()
521         n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
522         if n > maxHandshake {
523                 c.sendAlert(alertInternalError)
524                 return nil, c.err
525         }
526         for c.hand.Len() < 4+n {
527                 if c.err != nil {
528                         return nil, c.err
529                 }
530                 c.readRecord(recordTypeHandshake)
531         }
532         data = c.hand.Next(4 + n)
533         var m handshakeMessage
534         switch data[0] {
535         case typeClientHello:
536                 m = new(clientHelloMsg)
537         case typeServerHello:
538                 m = new(serverHelloMsg)
539         case typeCertificate:
540                 m = new(certificateMsg)
541         case typeCertificateRequest:
542                 m = new(certificateRequestMsg)
543         case typeCertificateStatus:
544                 m = new(certificateStatusMsg)
545         case typeServerHelloDone:
546                 m = new(serverHelloDoneMsg)
547         case typeClientKeyExchange:
548                 m = new(clientKeyExchangeMsg)
549         case typeCertificateVerify:
550                 m = new(certificateVerifyMsg)
551         case typeNextProtocol:
552                 m = new(nextProtoMsg)
553         case typeFinished:
554                 m = new(finishedMsg)
555         default:
556                 c.sendAlert(alertUnexpectedMessage)
557                 return nil, alertUnexpectedMessage
558         }
559
560         // The handshake message unmarshallers
561         // expect to be able to keep references to data,
562         // so pass in a fresh copy that won't be overwritten.
563         data = bytes.Add(nil, data)
564
565         if !m.unmarshal(data) {
566                 c.sendAlert(alertUnexpectedMessage)
567                 return nil, alertUnexpectedMessage
568         }
569         return m, nil
570 }
571
572 // Write writes data to the connection.
573 func (c *Conn) Write(b []byte) (n int, err os.Error) {
574         if err = c.Handshake(); err != nil {
575                 return
576         }
577
578         c.out.Lock()
579         defer c.out.Unlock()
580
581         if !c.handshakeComplete {
582                 return 0, alertInternalError
583         }
584         if c.err != nil {
585                 return 0, c.err
586         }
587         return c.writeRecord(recordTypeApplicationData, b)
588 }
589
590 // Read can be made to time out and return err == os.EAGAIN
591 // after a fixed time limit; see SetTimeout and SetReadTimeout.
592 func (c *Conn) Read(b []byte) (n int, err os.Error) {
593         if err = c.Handshake(); err != nil {
594                 return
595         }
596
597         c.in.Lock()
598         defer c.in.Unlock()
599
600         for c.input == nil && c.err == nil {
601                 if err := c.readRecord(recordTypeApplicationData); err != nil {
602                         // Soft error, like EAGAIN
603                         return 0, err
604                 }
605         }
606         if c.err != nil {
607                 return 0, c.err
608         }
609         n, err = c.input.Read(b)
610         if c.input.off >= len(c.input.data) {
611                 c.in.freeBlock(c.input)
612                 c.input = nil
613         }
614         return n, nil
615 }
616
617 // Close closes the connection.
618 func (c *Conn) Close() os.Error {
619         if err := c.Handshake(); err != nil {
620                 return err
621         }
622         return c.sendAlert(alertCloseNotify)
623 }
624
625 // Handshake runs the client or server handshake
626 // protocol if it has not yet been run.
627 // Most uses of this package need not call Handshake
628 // explicitly: the first Read or Write will call it automatically.
629 func (c *Conn) Handshake() os.Error {
630         c.handshakeMutex.Lock()
631         defer c.handshakeMutex.Unlock()
632         if err := c.error(); err != nil {
633                 return err
634         }
635         if c.handshakeComplete {
636                 return nil
637         }
638         if c.isClient {
639                 return c.clientHandshake()
640         }
641         return c.serverHandshake()
642 }
643
644 // ConnectionState returns basic TLS details about the connection.
645 func (c *Conn) ConnectionState() ConnectionState {
646         c.handshakeMutex.Lock()
647         defer c.handshakeMutex.Unlock()
648
649         var state ConnectionState
650         state.HandshakeComplete = c.handshakeComplete
651         if c.handshakeComplete {
652                 state.NegotiatedProtocol = c.clientProtocol
653                 state.CipherSuite = c.cipherSuite
654         }
655
656         return state
657 }
658
659 // OCSPResponse returns the stapled OCSP response from the TLS server, if
660 // any. (Only valid for client connections.)
661 func (c *Conn) OCSPResponse() []byte {
662         c.handshakeMutex.Lock()
663         defer c.handshakeMutex.Unlock()
664
665         return c.ocspResponse
666 }
667
668 // PeerCertificates returns the certificate chain that was presented by the
669 // other side.
670 func (c *Conn) PeerCertificates() []*x509.Certificate {
671         c.handshakeMutex.Lock()
672         defer c.handshakeMutex.Unlock()
673
674         return c.peerCertificates
675 }
676
677 // VerifyHostname checks that the peer certificate chain is valid for
678 // connecting to host.  If so, it returns nil; if not, it returns an os.Error
679 // describing the problem.
680 func (c *Conn) VerifyHostname(host string) os.Error {
681         c.handshakeMutex.Lock()
682         defer c.handshakeMutex.Unlock()
683         if !c.isClient {
684                 return os.ErrorString("VerifyHostname called on TLS server connection")
685         }
686         if !c.handshakeComplete {
687                 return os.ErrorString("TLS handshake has not yet been performed")
688         }
689         return c.peerCertificates[0].VerifyHostname(host)
690 }