OSDN Git Service

libgo: Update to weekly.2012-01-15.
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / transport.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "bufio"
9         "crypto"
10         "crypto/cipher"
11         "crypto/hmac"
12         "crypto/subtle"
13         "errors"
14         "hash"
15         "io"
16         "net"
17         "sync"
18 )
19
20 const (
21         packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
22         minPacketSize      = 16
23         maxPacketSize      = 36000
24         minPaddingSize     = 4 // TODO(huin) should this be configurable?
25 )
26
27 // filteredConn reduces the set of methods exposed when embeddeding
28 // a net.Conn inside ssh.transport.
29 // TODO(dfc) suggestions for a better name will be warmly received.
30 type filteredConn interface {
31         // Close closes the connection.
32         Close() error
33
34         // LocalAddr returns the local network address.
35         LocalAddr() net.Addr
36
37         // RemoteAddr returns the remote network address.
38         RemoteAddr() net.Addr
39 }
40
41 // Types implementing packetWriter provide the ability to send packets to
42 // an SSH peer.
43 type packetWriter interface {
44         // Encrypt and send a packet of data to the remote peer.
45         writePacket(packet []byte) error
46 }
47
48 // transport represents the SSH connection to the remote peer.
49 type transport struct {
50         reader
51         writer
52
53         filteredConn
54 }
55
56 // reader represents the incoming connection state.
57 type reader struct {
58         io.Reader
59         common
60 }
61
62 // writer represnts the outgoing connection state.
63 type writer struct {
64         *sync.Mutex // protects writer.Writer from concurrent writes
65         *bufio.Writer
66         rand io.Reader
67         common
68 }
69
70 // common represents the cipher state needed to process messages in a single
71 // direction.
72 type common struct {
73         seqNum uint32
74         mac    hash.Hash
75         cipher cipher.Stream
76
77         cipherAlgo      string
78         macAlgo         string
79         compressionAlgo string
80 }
81
82 // Read and decrypt a single packet from the remote peer.
83 func (r *reader) readOnePacket() ([]byte, error) {
84         var lengthBytes = make([]byte, 5)
85         var macSize uint32
86         if _, err := io.ReadFull(r, lengthBytes); err != nil {
87                 return nil, err
88         }
89
90         r.cipher.XORKeyStream(lengthBytes, lengthBytes)
91
92         if r.mac != nil {
93                 r.mac.Reset()
94                 seqNumBytes := []byte{
95                         byte(r.seqNum >> 24),
96                         byte(r.seqNum >> 16),
97                         byte(r.seqNum >> 8),
98                         byte(r.seqNum),
99                 }
100                 r.mac.Write(seqNumBytes)
101                 r.mac.Write(lengthBytes)
102                 macSize = uint32(r.mac.Size())
103         }
104
105         length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
106         paddingLength := uint32(lengthBytes[4])
107
108         if length <= paddingLength+1 {
109                 return nil, errors.New("invalid packet length")
110         }
111         if length > maxPacketSize {
112                 return nil, errors.New("packet too large")
113         }
114
115         packet := make([]byte, length-1+macSize)
116         if _, err := io.ReadFull(r, packet); err != nil {
117                 return nil, err
118         }
119         mac := packet[length-1:]
120         r.cipher.XORKeyStream(packet, packet[:length-1])
121
122         if r.mac != nil {
123                 r.mac.Write(packet[:length-1])
124                 if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
125                         return nil, errors.New("ssh: MAC failure")
126                 }
127         }
128
129         r.seqNum++
130         return packet[:length-paddingLength-1], nil
131 }
132
133 // Read and decrypt next packet discarding debug and noop messages.
134 func (t *transport) readPacket() ([]byte, error) {
135         for {
136                 packet, err := t.readOnePacket()
137                 if err != nil {
138                         return nil, err
139                 }
140                 if packet[0] != msgIgnore && packet[0] != msgDebug {
141                         return packet, nil
142                 }
143         }
144         panic("unreachable")
145 }
146
147 // Encrypt and send a packet of data to the remote peer.
148 func (w *writer) writePacket(packet []byte) error {
149         w.Mutex.Lock()
150         defer w.Mutex.Unlock()
151
152         paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
153         if paddingLength < 4 {
154                 paddingLength += packetSizeMultiple
155         }
156
157         length := len(packet) + 1 + paddingLength
158         lengthBytes := []byte{
159                 byte(length >> 24),
160                 byte(length >> 16),
161                 byte(length >> 8),
162                 byte(length),
163                 byte(paddingLength),
164         }
165         padding := make([]byte, paddingLength)
166         _, err := io.ReadFull(w.rand, padding)
167         if err != nil {
168                 return err
169         }
170
171         if w.mac != nil {
172                 w.mac.Reset()
173                 seqNumBytes := []byte{
174                         byte(w.seqNum >> 24),
175                         byte(w.seqNum >> 16),
176                         byte(w.seqNum >> 8),
177                         byte(w.seqNum),
178                 }
179                 w.mac.Write(seqNumBytes)
180                 w.mac.Write(lengthBytes)
181                 w.mac.Write(packet)
182                 w.mac.Write(padding)
183         }
184
185         // TODO(dfc) lengthBytes, packet and padding should be
186         // subslices of a single buffer
187         w.cipher.XORKeyStream(lengthBytes, lengthBytes)
188         w.cipher.XORKeyStream(packet, packet)
189         w.cipher.XORKeyStream(padding, padding)
190
191         if _, err := w.Write(lengthBytes); err != nil {
192                 return err
193         }
194         if _, err := w.Write(packet); err != nil {
195                 return err
196         }
197         if _, err := w.Write(padding); err != nil {
198                 return err
199         }
200
201         if w.mac != nil {
202                 if _, err := w.Write(w.mac.Sum(nil)); err != nil {
203                         return err
204                 }
205         }
206
207         if err := w.Flush(); err != nil {
208                 return err
209         }
210         w.seqNum++
211         return err
212 }
213
214 // Send a message to the remote peer
215 func (t *transport) sendMessage(typ uint8, msg interface{}) error {
216         packet := marshal(typ, msg)
217         return t.writePacket(packet)
218 }
219
220 func newTransport(conn net.Conn, rand io.Reader) *transport {
221         return &transport{
222                 reader: reader{
223                         Reader: bufio.NewReader(conn),
224                         common: common{
225                                 cipher: noneCipher{},
226                         },
227                 },
228                 writer: writer{
229                         Writer: bufio.NewWriter(conn),
230                         rand:   rand,
231                         Mutex:  new(sync.Mutex),
232                         common: common{
233                                 cipher: noneCipher{},
234                         },
235                 },
236                 filteredConn: conn,
237         }
238 }
239
240 type direction struct {
241         ivTag     []byte
242         keyTag    []byte
243         macKeyTag []byte
244 }
245
246 // TODO(dfc) can this be made a constant ?
247 var (
248         serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
249         clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
250 )
251
252 // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
253 // described in RFC 4253, section 6.4. direction should either be serverKeys
254 // (to setup server->client keys) or clientKeys (for client->server keys).
255 func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
256         cipherMode := cipherModes[c.cipherAlgo]
257
258         macKeySize := 20
259
260         iv := make([]byte, cipherMode.ivSize)
261         key := make([]byte, cipherMode.keySize)
262         macKey := make([]byte, macKeySize)
263
264         h := hashFunc.New()
265         generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
266         generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
267         generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
268
269         c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
270
271         cipher, err := cipherMode.createCipher(key, iv)
272         if err != nil {
273                 return err
274         }
275
276         c.cipher = cipher
277
278         return nil
279 }
280
281 // generateKeyMaterial fills out with key material generated from tag, K, H
282 // and sessionId, as specified in RFC 4253, section 7.2.
283 func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
284         var digestsSoFar []byte
285
286         for len(out) > 0 {
287                 h.Reset()
288                 h.Write(K)
289                 h.Write(H)
290
291                 if len(digestsSoFar) == 0 {
292                         h.Write(tag)
293                         h.Write(sessionId)
294                 } else {
295                         h.Write(digestsSoFar)
296                 }
297
298                 digest := h.Sum(nil)
299                 n := copy(out, digest)
300                 out = out[n:]
301                 if len(out) > 0 {
302                         digestsSoFar = append(digestsSoFar, digest...)
303                 }
304         }
305 }
306
307 // truncatingMAC wraps around a hash.Hash and truncates the output digest to
308 // a given size.
309 type truncatingMAC struct {
310         length int
311         hmac   hash.Hash
312 }
313
314 func (t truncatingMAC) Write(data []byte) (int, error) {
315         return t.hmac.Write(data)
316 }
317
318 func (t truncatingMAC) Sum(in []byte) []byte {
319         out := t.hmac.Sum(in)
320         return out[:len(in)+t.length]
321 }
322
323 func (t truncatingMAC) Reset() {
324         t.hmac.Reset()
325 }
326
327 func (t truncatingMAC) Size() int {
328         return t.length
329 }
330
331 // maxVersionStringBytes is the maximum number of bytes that we'll accept as a
332 // version string. In the event that the client is talking a different protocol
333 // we need to set a limit otherwise we will keep using more and more memory
334 // while searching for the end of the version handshake.
335 const maxVersionStringBytes = 1024
336
337 // Read version string as specified by RFC 4253, section 4.2.
338 func readVersion(r io.Reader) ([]byte, error) {
339         versionString := make([]byte, 0, 64)
340         var ok, seenCR bool
341         var buf [1]byte
342 forEachByte:
343         for len(versionString) < maxVersionStringBytes {
344                 _, err := io.ReadFull(r, buf[:])
345                 if err != nil {
346                         return nil, err
347                 }
348                 b := buf[0]
349
350                 if !seenCR {
351                         if b == '\r' {
352                                 seenCR = true
353                         }
354                 } else {
355                         if b == '\n' {
356                                 ok = true
357                                 break forEachByte
358                         } else {
359                                 seenCR = false
360                         }
361                 }
362                 versionString = append(versionString, b)
363         }
364
365         if !ok {
366                 return nil, errors.New("failed to read version string")
367         }
368
369         // We need to remove the CR from versionString
370         return versionString[:len(versionString)-1], nil
371 }