OSDN Git Service

libgo: Update to weekly.2011-12-14.
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / client.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "crypto"
9         "crypto/rand"
10         "errors"
11         "fmt"
12         "io"
13         "math/big"
14         "net"
15         "sync"
16 )
17
18 // clientVersion is the fixed identification string that the client will use.
19 var clientVersion = []byte("SSH-2.0-Go\r\n")
20
21 // ClientConn represents the client side of an SSH connection.
22 type ClientConn struct {
23         *transport
24         config *ClientConfig
25         chanlist
26 }
27
28 // Client returns a new SSH client connection using c as the underlying transport.
29 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
30         conn := &ClientConn{
31                 transport: newTransport(c, config.rand()),
32                 config:    config,
33         }
34         if err := conn.handshake(); err != nil {
35                 conn.Close()
36                 return nil, err
37         }
38         go conn.mainLoop()
39         return conn, nil
40 }
41
42 // handshake performs the client side key exchange. See RFC 4253 Section 7.
43 func (c *ClientConn) handshake() error {
44         var magics handshakeMagics
45
46         if _, err := c.Write(clientVersion); err != nil {
47                 return err
48         }
49         if err := c.Flush(); err != nil {
50                 return err
51         }
52         magics.clientVersion = clientVersion[:len(clientVersion)-2]
53
54         // read remote server version
55         version, err := readVersion(c)
56         if err != nil {
57                 return err
58         }
59         magics.serverVersion = version
60         clientKexInit := kexInitMsg{
61                 KexAlgos:                supportedKexAlgos,
62                 ServerHostKeyAlgos:      supportedHostKeyAlgos,
63                 CiphersClientServer:     c.config.Crypto.ciphers(),
64                 CiphersServerClient:     c.config.Crypto.ciphers(),
65                 MACsClientServer:        supportedMACs,
66                 MACsServerClient:        supportedMACs,
67                 CompressionClientServer: supportedCompressions,
68                 CompressionServerClient: supportedCompressions,
69         }
70         kexInitPacket := marshal(msgKexInit, clientKexInit)
71         magics.clientKexInit = kexInitPacket
72
73         if err := c.writePacket(kexInitPacket); err != nil {
74                 return err
75         }
76         packet, err := c.readPacket()
77         if err != nil {
78                 return err
79         }
80
81         magics.serverKexInit = packet
82
83         var serverKexInit kexInitMsg
84         if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
85                 return err
86         }
87
88         kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
89         if !ok {
90                 return errors.New("ssh: no common algorithms")
91         }
92
93         if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
94                 // The server sent a Kex message for the wrong algorithm,
95                 // which we have to ignore.
96                 if _, err := c.readPacket(); err != nil {
97                         return err
98                 }
99         }
100
101         var H, K []byte
102         var hashFunc crypto.Hash
103         switch kexAlgo {
104         case kexAlgoDH14SHA1:
105                 hashFunc = crypto.SHA1
106                 dhGroup14Once.Do(initDHGroup14)
107                 H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
108         default:
109                 err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
110         }
111         if err != nil {
112                 return err
113         }
114
115         if err = c.writePacket([]byte{msgNewKeys}); err != nil {
116                 return err
117         }
118         if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
119                 return err
120         }
121         if packet, err = c.readPacket(); err != nil {
122                 return err
123         }
124         if packet[0] != msgNewKeys {
125                 return UnexpectedMessageError{msgNewKeys, packet[0]}
126         }
127         if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
128                 return err
129         }
130         return c.authenticate(H)
131 }
132
133 // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
134 // returned values are given the same names as in RFC 4253, section 8.
135 func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, error) {
136         x, err := rand.Int(c.config.rand(), group.p)
137         if err != nil {
138                 return nil, nil, err
139         }
140         X := new(big.Int).Exp(group.g, x, group.p)
141         kexDHInit := kexDHInitMsg{
142                 X: X,
143         }
144         if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
145                 return nil, nil, err
146         }
147
148         packet, err := c.readPacket()
149         if err != nil {
150                 return nil, nil, err
151         }
152
153         var kexDHReply = new(kexDHReplyMsg)
154         if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
155                 return nil, nil, err
156         }
157
158         if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
159                 return nil, nil, errors.New("server DH parameter out of bounds")
160         }
161
162         kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
163         h := hashFunc.New()
164         writeString(h, magics.clientVersion)
165         writeString(h, magics.serverVersion)
166         writeString(h, magics.clientKexInit)
167         writeString(h, magics.serverKexInit)
168         writeString(h, kexDHReply.HostKey)
169         writeInt(h, X)
170         writeInt(h, kexDHReply.Y)
171         K := make([]byte, intLength(kInt))
172         marshalInt(K, kInt)
173         h.Write(K)
174
175         H := h.Sum(nil)
176
177         return H, K, nil
178 }
179
180 // mainLoop reads incoming messages and routes channel messages
181 // to their respective ClientChans.
182 func (c *ClientConn) mainLoop() {
183         // TODO(dfc) signal the underlying close to all channels
184         defer c.Close()
185         for {
186                 packet, err := c.readPacket()
187                 if err != nil {
188                         break
189                 }
190                 // TODO(dfc) A note on blocking channel use.
191                 // The msg, win, data and dataExt channels of a clientChan can
192                 // cause this loop to block indefinately if the consumer does
193                 // not service them.
194                 switch packet[0] {
195                 case msgChannelData:
196                         if len(packet) < 9 {
197                                 // malformed data packet
198                                 break
199                         }
200                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
201                         if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
202                                 packet = packet[9:]
203                                 c.getChan(peersId).stdout.handleData(packet[:length])
204                         }
205                 case msgChannelExtendedData:
206                         if len(packet) < 13 {
207                                 // malformed data packet
208                                 break
209                         }
210                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
211                         datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
212                         if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
213                                 packet = packet[13:]
214                                 // RFC 4254 5.2 defines data_type_code 1 to be data destined
215                                 // for stderr on interactive sessions. Other data types are
216                                 // silently discarded.
217                                 if datatype == 1 {
218                                         c.getChan(peersId).stderr.handleData(packet[:length])
219                                 }
220                         }
221                 default:
222                         switch msg := decode(packet).(type) {
223                         case *channelOpenMsg:
224                                 c.getChan(msg.PeersId).msg <- msg
225                         case *channelOpenConfirmMsg:
226                                 c.getChan(msg.PeersId).msg <- msg
227                         case *channelOpenFailureMsg:
228                                 c.getChan(msg.PeersId).msg <- msg
229                         case *channelCloseMsg:
230                                 ch := c.getChan(msg.PeersId)
231                                 ch.theyClosed = true
232                                 close(ch.stdin.win)
233                                 ch.stdout.eof()
234                                 ch.stderr.eof()
235                                 close(ch.msg)
236                                 if !ch.weClosed {
237                                         ch.weClosed = true
238                                         ch.sendClose()
239                                 }
240                                 c.chanlist.remove(msg.PeersId)
241                         case *channelEOFMsg:
242                                 ch := c.getChan(msg.PeersId)
243                                 ch.stdout.eof()
244                                 // RFC 4254 is mute on how EOF affects dataExt messages but
245                                 // it is logical to signal EOF at the same time.
246                                 ch.stderr.eof()
247                         case *channelRequestSuccessMsg:
248                                 c.getChan(msg.PeersId).msg <- msg
249                         case *channelRequestFailureMsg:
250                                 c.getChan(msg.PeersId).msg <- msg
251                         case *channelRequestMsg:
252                                 c.getChan(msg.PeersId).msg <- msg
253                         case *windowAdjustMsg:
254                                 c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
255                         case *disconnectMsg:
256                                 break
257                         default:
258                                 fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
259                         }
260                 }
261         }
262 }
263
264 // Dial connects to the given network address using net.Dial and
265 // then initiates a SSH handshake, returning the resulting client connection.
266 func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
267         conn, err := net.Dial(network, addr)
268         if err != nil {
269                 return nil, err
270         }
271         return Client(conn, config)
272 }
273
274 // A ClientConfig structure is used to configure a ClientConn. After one has
275 // been passed to an SSH function it must not be modified.
276 type ClientConfig struct {
277         // Rand provides the source of entropy for key exchange. If Rand is
278         // nil, the cryptographic random reader in package crypto/rand will
279         // be used.
280         Rand io.Reader
281
282         // The username to authenticate.
283         User string
284
285         // A slice of ClientAuth methods. Only the first instance
286         // of a particular RFC 4252 method will be used during authentication.
287         Auth []ClientAuth
288
289         // Cryptographic-related configuration.
290         Crypto CryptoConfig
291 }
292
293 func (c *ClientConfig) rand() io.Reader {
294         if c.Rand == nil {
295                 return rand.Reader
296         }
297         return c.Rand
298 }
299
300 // A clientChan represents a single RFC 4254 channel that is multiplexed
301 // over a single SSH connection.
302 type clientChan struct {
303         packetWriter
304         id, peersId uint32
305         stdin       *chanWriter      // receives window adjustments
306         stdout      *chanReader      // receives the payload of channelData messages
307         stderr      *chanReader      // receives the payload of channelExtendedData messages
308         msg         chan interface{} // incoming messages
309
310         theyClosed bool // indicates the close msg has been received from the remote side
311         weClosed   bool // incidates the close msg has been sent from our side
312 }
313
314 // newClientChan returns a partially constructed *clientChan
315 // using the local id provided. To be usable clientChan.peersId
316 // needs to be assigned once known.
317 func newClientChan(t *transport, id uint32) *clientChan {
318         c := &clientChan{
319                 packetWriter: t,
320                 id:           id,
321                 msg:          make(chan interface{}, 16),
322         }
323         c.stdin = &chanWriter{
324                 win:        make(chan int, 16),
325                 clientChan: c,
326         }
327         c.stdout = &chanReader{
328                 data:       make(chan []byte, 16),
329                 clientChan: c,
330         }
331         c.stderr = &chanReader{
332                 data:       make(chan []byte, 16),
333                 clientChan: c,
334         }
335         return c
336 }
337
338 // waitForChannelOpenResponse, if successful, fills out
339 // the peerId and records any initial window advertisement.
340 func (c *clientChan) waitForChannelOpenResponse() error {
341         switch msg := (<-c.msg).(type) {
342         case *channelOpenConfirmMsg:
343                 // fixup peersId field
344                 c.peersId = msg.MyId
345                 c.stdin.win <- int(msg.MyWindow)
346                 return nil
347         case *channelOpenFailureMsg:
348                 return errors.New(safeString(msg.Message))
349         }
350         return errors.New("unexpected packet")
351 }
352
353 // sendEOF sends EOF to the server. RFC 4254 Section 5.3
354 func (c *clientChan) sendEOF() error {
355         return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
356                 PeersId: c.peersId,
357         }))
358 }
359
360 // sendClose signals the intent to close the channel.
361 func (c *clientChan) sendClose() error {
362         return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
363                 PeersId: c.peersId,
364         }))
365 }
366
367 // Close closes the channel. This does not close the underlying connection.
368 func (c *clientChan) Close() error {
369         if !c.weClosed {
370                 c.weClosed = true
371                 return c.sendClose()
372         }
373         return nil
374 }
375
376 // Thread safe channel list.
377 type chanlist struct {
378         // protects concurrent access to chans
379         sync.Mutex
380         // chans are indexed by the local id of the channel, clientChan.id.
381         // The PeersId value of messages received by ClientConn.mainLoop is
382         // used to locate the right local clientChan in this slice.
383         chans []*clientChan
384 }
385
386 // Allocate a new ClientChan with the next avail local id.
387 func (c *chanlist) newChan(t *transport) *clientChan {
388         c.Lock()
389         defer c.Unlock()
390         for i := range c.chans {
391                 if c.chans[i] == nil {
392                         ch := newClientChan(t, uint32(i))
393                         c.chans[i] = ch
394                         return ch
395                 }
396         }
397         i := len(c.chans)
398         ch := newClientChan(t, uint32(i))
399         c.chans = append(c.chans, ch)
400         return ch
401 }
402
403 func (c *chanlist) getChan(id uint32) *clientChan {
404         c.Lock()
405         defer c.Unlock()
406         return c.chans[int(id)]
407 }
408
409 func (c *chanlist) remove(id uint32) {
410         c.Lock()
411         defer c.Unlock()
412         c.chans[int(id)] = nil
413 }
414
415 // A chanWriter represents the stdin of a remote process.
416 type chanWriter struct {
417         win        chan int    // receives window adjustments
418         rwin       int         // current rwin size
419         clientChan *clientChan // the channel backing this writer
420 }
421
422 // Write writes data to the remote process's standard input.
423 func (w *chanWriter) Write(data []byte) (n int, err error) {
424         for {
425                 if w.rwin == 0 {
426                         win, ok := <-w.win
427                         if !ok {
428                                 return 0, io.EOF
429                         }
430                         w.rwin += win
431                         continue
432                 }
433                 peersId := w.clientChan.peersId
434                 n = len(data)
435                 packet := make([]byte, 0, 9+n)
436                 packet = append(packet, msgChannelData,
437                         byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
438                         byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
439                 err = w.clientChan.writePacket(append(packet, data...))
440                 w.rwin -= n
441                 return
442         }
443         panic("unreachable")
444 }
445
446 func (w *chanWriter) Close() error {
447         return w.clientChan.sendEOF()
448 }
449
450 // A chanReader represents stdout or stderr of a remote process.
451 type chanReader struct {
452         // TODO(dfc) a fixed size channel may not be the right data structure.
453         // If writes to this channel block, they will block mainLoop, making
454         // it unable to receive new messages from the remote side.
455         data       chan []byte // receives data from remote
456         dataClosed bool        // protects data from being closed twice
457         clientChan *clientChan // the channel backing this reader
458         buf        []byte
459 }
460
461 // eof signals to the consumer that there is no more data to be received.
462 func (r *chanReader) eof() {
463         if !r.dataClosed {
464                 r.dataClosed = true
465                 close(r.data)
466         }
467 }
468
469 // handleData sends buf to the reader's consumer. If r.data is closed
470 // the data will be silently discarded
471 func (r *chanReader) handleData(buf []byte) {
472         if !r.dataClosed {
473                 r.data <- buf
474         }
475 }
476
477 // Read reads data from the remote process's stdout or stderr.
478 func (r *chanReader) Read(data []byte) (int, error) {
479         var ok bool
480         for {
481                 if len(r.buf) > 0 {
482                         n := copy(data, r.buf)
483                         r.buf = r.buf[n:]
484                         msg := windowAdjustMsg{
485                                 PeersId:         r.clientChan.peersId,
486                                 AdditionalBytes: uint32(n),
487                         }
488                         return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
489                 }
490                 r.buf, ok = <-r.data
491                 if !ok {
492                         return 0, io.EOF
493                 }
494         }
495         panic("unreachable")
496 }