OSDN Git Service

libgo: update to weekly.2011-10-25
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / transport.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "bufio"
9         "crypto"
10         "crypto/aes"
11         "crypto/cipher"
12         "crypto/hmac"
13         "crypto/subtle"
14         "hash"
15         "io"
16         "net"
17         "os"
18         "sync"
19 )
20
21 const (
22         paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
23 )
24
25 // filteredConn reduces the set of methods exposed when embeddeding
26 // a net.Conn inside ssh.transport.
27 // TODO(dfc) suggestions for a better name will be warmly received.
28 type filteredConn interface {
29         // Close closes the connection.
30         Close() os.Error
31
32         // LocalAddr returns the local network address.
33         LocalAddr() net.Addr
34
35         // RemoteAddr returns the remote network address.
36         RemoteAddr() net.Addr
37 }
38
39 // Types implementing packetWriter provide the ability to send packets to
40 // an SSH peer.
41 type packetWriter interface {
42         // Encrypt and send a packet of data to the remote peer.
43         writePacket(packet []byte) os.Error
44 }
45
46 // transport represents the SSH connection to the remote peer.
47 type transport struct {
48         reader
49         writer
50
51         filteredConn
52 }
53
54 // reader represents the incoming connection state.
55 type reader struct {
56         io.Reader
57         common
58 }
59
60 // writer represnts the outgoing connection state.
61 type writer struct {
62         *sync.Mutex // protects writer.Writer from concurrent writes
63         *bufio.Writer
64         paddingMultiple int
65         rand            io.Reader
66         common
67 }
68
69 // common represents the cipher state needed to process messages in a single
70 // direction.
71 type common struct {
72         seqNum uint32
73         mac    hash.Hash
74         cipher cipher.Stream
75
76         cipherAlgo      string
77         macAlgo         string
78         compressionAlgo string
79 }
80
81 // Read and decrypt a single packet from the remote peer.
82 func (r *reader) readOnePacket() ([]byte, os.Error) {
83         var lengthBytes = make([]byte, 5)
84         var macSize uint32
85
86         if _, err := io.ReadFull(r, lengthBytes); err != nil {
87                 return nil, err
88         }
89
90         if r.cipher != nil {
91                 r.cipher.XORKeyStream(lengthBytes, lengthBytes)
92         }
93
94         if r.mac != nil {
95                 r.mac.Reset()
96                 seqNumBytes := []byte{
97                         byte(r.seqNum >> 24),
98                         byte(r.seqNum >> 16),
99                         byte(r.seqNum >> 8),
100                         byte(r.seqNum),
101                 }
102                 r.mac.Write(seqNumBytes)
103                 r.mac.Write(lengthBytes)
104                 macSize = uint32(r.mac.Size())
105         }
106
107         length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
108         paddingLength := uint32(lengthBytes[4])
109
110         if length <= paddingLength+1 {
111                 return nil, os.NewError("invalid packet length")
112         }
113         if length > maxPacketSize {
114                 return nil, os.NewError("packet too large")
115         }
116
117         packet := make([]byte, length-1+macSize)
118         if _, err := io.ReadFull(r, packet); err != nil {
119                 return nil, err
120         }
121         mac := packet[length-1:]
122         if r.cipher != nil {
123                 r.cipher.XORKeyStream(packet, packet[:length-1])
124         }
125
126         if r.mac != nil {
127                 r.mac.Write(packet[:length-1])
128                 if subtle.ConstantTimeCompare(r.mac.Sum(), mac) != 1 {
129                         return nil, os.NewError("ssh: MAC failure")
130                 }
131         }
132
133         r.seqNum++
134         return packet[:length-paddingLength-1], nil
135 }
136
137 // Read and decrypt next packet discarding debug and noop messages.
138 func (t *transport) readPacket() ([]byte, os.Error) {
139         for {
140                 packet, err := t.readOnePacket()
141                 if err != nil {
142                         return nil, err
143                 }
144                 if packet[0] != msgIgnore && packet[0] != msgDebug {
145                         return packet, nil
146                 }
147         }
148         panic("unreachable")
149 }
150
151 // Encrypt and send a packet of data to the remote peer.
152 func (w *writer) writePacket(packet []byte) os.Error {
153         w.Mutex.Lock()
154         defer w.Mutex.Unlock()
155
156         paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
157         if paddingLength < 4 {
158                 paddingLength += paddingMultiple
159         }
160
161         length := len(packet) + 1 + paddingLength
162         lengthBytes := []byte{
163                 byte(length >> 24),
164                 byte(length >> 16),
165                 byte(length >> 8),
166                 byte(length),
167                 byte(paddingLength),
168         }
169         padding := make([]byte, paddingLength)
170         _, err := io.ReadFull(w.rand, padding)
171         if err != nil {
172                 return err
173         }
174
175         if w.mac != nil {
176                 w.mac.Reset()
177                 seqNumBytes := []byte{
178                         byte(w.seqNum >> 24),
179                         byte(w.seqNum >> 16),
180                         byte(w.seqNum >> 8),
181                         byte(w.seqNum),
182                 }
183                 w.mac.Write(seqNumBytes)
184                 w.mac.Write(lengthBytes)
185                 w.mac.Write(packet)
186                 w.mac.Write(padding)
187         }
188
189         // TODO(dfc) lengthBytes, packet and padding should be
190         // subslices of a single buffer
191         if w.cipher != nil {
192                 w.cipher.XORKeyStream(lengthBytes, lengthBytes)
193                 w.cipher.XORKeyStream(packet, packet)
194                 w.cipher.XORKeyStream(padding, padding)
195         }
196
197         if _, err := w.Write(lengthBytes); err != nil {
198                 return err
199         }
200         if _, err := w.Write(packet); err != nil {
201                 return err
202         }
203         if _, err := w.Write(padding); err != nil {
204                 return err
205         }
206
207         if w.mac != nil {
208                 if _, err := w.Write(w.mac.Sum()); err != nil {
209                         return err
210                 }
211         }
212
213         if err := w.Flush(); err != nil {
214                 return err
215         }
216         w.seqNum++
217         return err
218 }
219
220 // Send a message to the remote peer
221 func (t *transport) sendMessage(typ uint8, msg interface{}) os.Error {
222         packet := marshal(typ, msg)
223         return t.writePacket(packet)
224 }
225
226 func newTransport(conn net.Conn, rand io.Reader) *transport {
227         return &transport{
228                 reader: reader{
229                         Reader: bufio.NewReader(conn),
230                 },
231                 writer: writer{
232                         Writer: bufio.NewWriter(conn),
233                         rand:   rand,
234                         Mutex:  new(sync.Mutex),
235                 },
236                 filteredConn: conn,
237         }
238 }
239
240 type direction struct {
241         ivTag     []byte
242         keyTag    []byte
243         macKeyTag []byte
244 }
245
246 // TODO(dfc) can this be made a constant ?
247 var (
248         serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
249         clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
250 )
251
252 // setupKeys sets the cipher and MAC keys from K, H and sessionId, as
253 // described in RFC 4253, section 6.4. direction should either be serverKeys
254 // (to setup server->client keys) or clientKeys (for client->server keys).
255 func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
256         h := hashFunc.New()
257
258         blockSize := 16
259         keySize := 16
260         macKeySize := 20
261
262         iv := make([]byte, blockSize)
263         key := make([]byte, keySize)
264         macKey := make([]byte, macKeySize)
265         generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
266         generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
267         generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
268
269         c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
270         aes, err := aes.NewCipher(key)
271         if err != nil {
272                 return err
273         }
274         c.cipher = cipher.NewCTR(aes, iv)
275         return nil
276 }
277
278 // generateKeyMaterial fills out with key material generated from tag, K, H
279 // and sessionId, as specified in RFC 4253, section 7.2.
280 func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
281         var digestsSoFar []byte
282
283         for len(out) > 0 {
284                 h.Reset()
285                 h.Write(K)
286                 h.Write(H)
287
288                 if len(digestsSoFar) == 0 {
289                         h.Write(tag)
290                         h.Write(sessionId)
291                 } else {
292                         h.Write(digestsSoFar)
293                 }
294
295                 digest := h.Sum()
296                 n := copy(out, digest)
297                 out = out[n:]
298                 if len(out) > 0 {
299                         digestsSoFar = append(digestsSoFar, digest...)
300                 }
301         }
302 }
303
304 // truncatingMAC wraps around a hash.Hash and truncates the output digest to
305 // a given size.
306 type truncatingMAC struct {
307         length int
308         hmac   hash.Hash
309 }
310
311 func (t truncatingMAC) Write(data []byte) (int, os.Error) {
312         return t.hmac.Write(data)
313 }
314
315 func (t truncatingMAC) Sum() []byte {
316         digest := t.hmac.Sum()
317         return digest[:t.length]
318 }
319
320 func (t truncatingMAC) Reset() {
321         t.hmac.Reset()
322 }
323
324 func (t truncatingMAC) Size() int {
325         return t.length
326 }
327
328 // maxVersionStringBytes is the maximum number of bytes that we'll accept as a
329 // version string. In the event that the client is talking a different protocol
330 // we need to set a limit otherwise we will keep using more and more memory
331 // while searching for the end of the version handshake.
332 const maxVersionStringBytes = 1024
333
334 // Read version string as specified by RFC 4253, section 4.2.
335 func readVersion(r io.Reader) ([]byte, os.Error) {
336         versionString := make([]byte, 0, 64)
337         var ok, seenCR bool
338         var buf [1]byte
339 forEachByte:
340         for len(versionString) < maxVersionStringBytes {
341                 _, err := io.ReadFull(r, buf[:])
342                 if err != nil {
343                         return nil, err
344                 }
345                 b := buf[0]
346
347                 if !seenCR {
348                         if b == '\r' {
349                                 seenCR = true
350                         }
351                 } else {
352                         if b == '\n' {
353                                 ok = true
354                                 break forEachByte
355                         } else {
356                                 seenCR = false
357                         }
358                 }
359                 versionString = append(versionString, b)
360         }
361
362         if !ok {
363                 return nil, os.NewError("failed to read version string")
364         }
365
366         // We need to remove the CR from versionString
367         return versionString[:len(versionString)-1], nil
368 }