OSDN Git Service

libgo: Update to weekly.2011-11-18.
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / client.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "crypto"
9         "crypto/rand"
10         "errors"
11         "fmt"
12         "io"
13         "math/big"
14         "net"
15         "sync"
16 )
17
18 // clientVersion is the fixed identification string that the client will use.
19 var clientVersion = []byte("SSH-2.0-Go\r\n")
20
21 // ClientConn represents the client side of an SSH connection.
22 type ClientConn struct {
23         *transport
24         config *ClientConfig
25         chanlist
26 }
27
28 // Client returns a new SSH client connection using c as the underlying transport.
29 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
30         conn := &ClientConn{
31                 transport: newTransport(c, config.rand()),
32                 config:    config,
33         }
34         if err := conn.handshake(); err != nil {
35                 conn.Close()
36                 return nil, err
37         }
38         go conn.mainLoop()
39         return conn, nil
40 }
41
42 // handshake performs the client side key exchange. See RFC 4253 Section 7.
43 func (c *ClientConn) handshake() error {
44         var magics handshakeMagics
45
46         if _, err := c.Write(clientVersion); err != nil {
47                 return err
48         }
49         if err := c.Flush(); err != nil {
50                 return err
51         }
52         magics.clientVersion = clientVersion[:len(clientVersion)-2]
53
54         // read remote server version
55         version, err := readVersion(c)
56         if err != nil {
57                 return err
58         }
59         magics.serverVersion = version
60         clientKexInit := kexInitMsg{
61                 KexAlgos:                supportedKexAlgos,
62                 ServerHostKeyAlgos:      supportedHostKeyAlgos,
63                 CiphersClientServer:     c.config.Crypto.ciphers(),
64                 CiphersServerClient:     c.config.Crypto.ciphers(),
65                 MACsClientServer:        supportedMACs,
66                 MACsServerClient:        supportedMACs,
67                 CompressionClientServer: supportedCompressions,
68                 CompressionServerClient: supportedCompressions,
69         }
70         kexInitPacket := marshal(msgKexInit, clientKexInit)
71         magics.clientKexInit = kexInitPacket
72
73         if err := c.writePacket(kexInitPacket); err != nil {
74                 return err
75         }
76         packet, err := c.readPacket()
77         if err != nil {
78                 return err
79         }
80
81         magics.serverKexInit = packet
82
83         var serverKexInit kexInitMsg
84         if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
85                 return err
86         }
87
88         kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
89         if !ok {
90                 return errors.New("ssh: no common algorithms")
91         }
92
93         if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
94                 // The server sent a Kex message for the wrong algorithm,
95                 // which we have to ignore.
96                 if _, err := c.readPacket(); err != nil {
97                         return err
98                 }
99         }
100
101         var H, K []byte
102         var hashFunc crypto.Hash
103         switch kexAlgo {
104         case kexAlgoDH14SHA1:
105                 hashFunc = crypto.SHA1
106                 dhGroup14Once.Do(initDHGroup14)
107                 H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
108         default:
109                 err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
110         }
111         if err != nil {
112                 return err
113         }
114
115         if err = c.writePacket([]byte{msgNewKeys}); err != nil {
116                 return err
117         }
118         if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
119                 return err
120         }
121         if packet, err = c.readPacket(); err != nil {
122                 return err
123         }
124         if packet[0] != msgNewKeys {
125                 return UnexpectedMessageError{msgNewKeys, packet[0]}
126         }
127         if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
128                 return err
129         }
130         return c.authenticate(H)
131 }
132
133 // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
134 // returned values are given the same names as in RFC 4253, section 8.
135 func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, error) {
136         x, err := rand.Int(c.config.rand(), group.p)
137         if err != nil {
138                 return nil, nil, err
139         }
140         X := new(big.Int).Exp(group.g, x, group.p)
141         kexDHInit := kexDHInitMsg{
142                 X: X,
143         }
144         if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
145                 return nil, nil, err
146         }
147
148         packet, err := c.readPacket()
149         if err != nil {
150                 return nil, nil, err
151         }
152
153         var kexDHReply = new(kexDHReplyMsg)
154         if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
155                 return nil, nil, err
156         }
157
158         if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
159                 return nil, nil, errors.New("server DH parameter out of bounds")
160         }
161
162         kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
163         h := hashFunc.New()
164         writeString(h, magics.clientVersion)
165         writeString(h, magics.serverVersion)
166         writeString(h, magics.clientKexInit)
167         writeString(h, magics.serverKexInit)
168         writeString(h, kexDHReply.HostKey)
169         writeInt(h, X)
170         writeInt(h, kexDHReply.Y)
171         K := make([]byte, intLength(kInt))
172         marshalInt(K, kInt)
173         h.Write(K)
174
175         H := h.Sum()
176
177         return H, K, nil
178 }
179
180 // openChan opens a new client channel. The most common session type is "session". 
181 // The full set of valid session types are listed in RFC 4250 4.9.1.
182 func (c *ClientConn) openChan(typ string) (*clientChan, error) {
183         ch := c.newChan(c.transport)
184         if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
185                 ChanType:      typ,
186                 PeersId:       ch.id,
187                 PeersWindow:   1 << 14,
188                 MaxPacketSize: 1 << 15, // RFC 4253 6.1
189         })); err != nil {
190                 c.chanlist.remove(ch.id)
191                 return nil, err
192         }
193         // wait for response
194         switch msg := (<-ch.msg).(type) {
195         case *channelOpenConfirmMsg:
196                 ch.peersId = msg.MyId
197                 ch.win <- int(msg.MyWindow)
198         case *channelOpenFailureMsg:
199                 c.chanlist.remove(ch.id)
200                 return nil, errors.New(msg.Message)
201         default:
202                 c.chanlist.remove(ch.id)
203                 return nil, errors.New("Unexpected packet")
204         }
205         return ch, nil
206 }
207
208 // mainloop reads incoming messages and routes channel messages
209 // to their respective ClientChans.
210 func (c *ClientConn) mainLoop() {
211         // TODO(dfc) signal the underlying close to all channels
212         defer c.Close()
213         for {
214                 packet, err := c.readPacket()
215                 if err != nil {
216                         break
217                 }
218                 // TODO(dfc) A note on blocking channel use. 
219                 // The msg, win, data and dataExt channels of a clientChan can 
220                 // cause this loop to block indefinately if the consumer does 
221                 // not service them. 
222                 switch packet[0] {
223                 case msgChannelData:
224                         if len(packet) < 9 {
225                                 // malformed data packet
226                                 break
227                         }
228                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
229                         if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
230                                 packet = packet[9:]
231                                 c.getChan(peersId).data <- packet[:length]
232                         }
233                 case msgChannelExtendedData:
234                         if len(packet) < 13 {
235                                 // malformed data packet
236                                 break
237                         }
238                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
239                         datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
240                         if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
241                                 packet = packet[13:]
242                                 // RFC 4254 5.2 defines data_type_code 1 to be data destined 
243                                 // for stderr on interactive sessions. Other data types are
244                                 // silently discarded.
245                                 if datatype == 1 {
246                                         c.getChan(peersId).dataExt <- packet[:length]
247                                 }
248                         }
249                 default:
250                         switch msg := decode(packet).(type) {
251                         case *channelOpenMsg:
252                                 c.getChan(msg.PeersId).msg <- msg
253                         case *channelOpenConfirmMsg:
254                                 c.getChan(msg.PeersId).msg <- msg
255                         case *channelOpenFailureMsg:
256                                 c.getChan(msg.PeersId).msg <- msg
257                         case *channelCloseMsg:
258                                 ch := c.getChan(msg.PeersId)
259                                 close(ch.win)
260                                 close(ch.data)
261                                 close(ch.dataExt)
262                                 c.chanlist.remove(msg.PeersId)
263                         case *channelEOFMsg:
264                                 c.getChan(msg.PeersId).msg <- msg
265                         case *channelRequestSuccessMsg:
266                                 c.getChan(msg.PeersId).msg <- msg
267                         case *channelRequestFailureMsg:
268                                 c.getChan(msg.PeersId).msg <- msg
269                         case *channelRequestMsg:
270                                 c.getChan(msg.PeersId).msg <- msg
271                         case *windowAdjustMsg:
272                                 c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
273                         default:
274                                 fmt.Printf("mainLoop: unhandled %#v\n", msg)
275                         }
276                 }
277         }
278 }
279
280 // Dial connects to the given network address using net.Dial and 
281 // then initiates a SSH handshake, returning the resulting client connection.
282 func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
283         conn, err := net.Dial(network, addr)
284         if err != nil {
285                 return nil, err
286         }
287         return Client(conn, config)
288 }
289
290 // A ClientConfig structure is used to configure a ClientConn. After one has 
291 // been passed to an SSH function it must not be modified.
292 type ClientConfig struct {
293         // Rand provides the source of entropy for key exchange. If Rand is 
294         // nil, the cryptographic random reader in package crypto/rand will 
295         // be used.
296         Rand io.Reader
297
298         // The username to authenticate.
299         User string
300
301         // A slice of ClientAuth methods. Only the first instance 
302         // of a particular RFC 4252 method will be used during authentication.
303         Auth []ClientAuth
304
305         // Cryptographic-related configuration.
306         Crypto CryptoConfig
307 }
308
309 func (c *ClientConfig) rand() io.Reader {
310         if c.Rand == nil {
311                 return rand.Reader
312         }
313         return c.Rand
314 }
315
316 // A clientChan represents a single RFC 4254 channel that is multiplexed 
317 // over a single SSH connection.
318 type clientChan struct {
319         packetWriter
320         id, peersId uint32
321         data        chan []byte      // receives the payload of channelData messages
322         dataExt     chan []byte      // receives the payload of channelExtendedData messages
323         win         chan int         // receives window adjustments
324         msg         chan interface{} // incoming messages
325 }
326
327 func newClientChan(t *transport, id uint32) *clientChan {
328         return &clientChan{
329                 packetWriter: t,
330                 id:           id,
331                 data:         make(chan []byte, 16),
332                 dataExt:      make(chan []byte, 16),
333                 win:          make(chan int, 16),
334                 msg:          make(chan interface{}, 16),
335         }
336 }
337
338 // Close closes the channel. This does not close the underlying connection.
339 func (c *clientChan) Close() error {
340         return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
341                 PeersId: c.id,
342         }))
343 }
344
345 func (c *clientChan) sendChanReq(req channelRequestMsg) error {
346         if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
347                 return err
348         }
349         msg := <-c.msg
350         if _, ok := msg.(*channelRequestSuccessMsg); ok {
351                 return nil
352         }
353         return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
354 }
355
356 // Thread safe channel list.
357 type chanlist struct {
358         // protects concurrent access to chans
359         sync.Mutex
360         // chans are indexed by the local id of the channel, clientChan.id.
361         // The PeersId value of messages received by ClientConn.mainloop is
362         // used to locate the right local clientChan in this slice.
363         chans []*clientChan
364 }
365
366 // Allocate a new ClientChan with the next avail local id.
367 func (c *chanlist) newChan(t *transport) *clientChan {
368         c.Lock()
369         defer c.Unlock()
370         for i := range c.chans {
371                 if c.chans[i] == nil {
372                         ch := newClientChan(t, uint32(i))
373                         c.chans[i] = ch
374                         return ch
375                 }
376         }
377         i := len(c.chans)
378         ch := newClientChan(t, uint32(i))
379         c.chans = append(c.chans, ch)
380         return ch
381 }
382
383 func (c *chanlist) getChan(id uint32) *clientChan {
384         c.Lock()
385         defer c.Unlock()
386         return c.chans[int(id)]
387 }
388
389 func (c *chanlist) remove(id uint32) {
390         c.Lock()
391         defer c.Unlock()
392         c.chans[int(id)] = nil
393 }
394
395 // A chanWriter represents the stdin of a remote process.
396 type chanWriter struct {
397         win          chan int // receives window adjustments
398         id           uint32   // this channel's id
399         rwin         int      // current rwin size
400         packetWriter          // for sending channelDataMsg
401 }
402
403 // Write writes data to the remote process's standard input.
404 func (w *chanWriter) Write(data []byte) (n int, err error) {
405         for {
406                 if w.rwin == 0 {
407                         win, ok := <-w.win
408                         if !ok {
409                                 return 0, io.EOF
410                         }
411                         w.rwin += win
412                         continue
413                 }
414                 n = len(data)
415                 packet := make([]byte, 0, 9+n)
416                 packet = append(packet, msgChannelData,
417                         byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
418                         byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
419                 err = w.writePacket(append(packet, data...))
420                 w.rwin -= n
421                 return
422         }
423         panic("unreachable")
424 }
425
426 func (w *chanWriter) Close() error {
427         return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
428 }
429
430 // A chanReader represents stdout or stderr of a remote process.
431 type chanReader struct {
432         // TODO(dfc) a fixed size channel may not be the right data structure.
433         // If writes to this channel block, they will block mainLoop, making
434         // it unable to receive new messages from the remote side.
435         data         chan []byte // receives data from remote
436         id           uint32
437         packetWriter // for sending windowAdjustMsg
438         buf          []byte
439 }
440
441 // Read reads data from the remote process's stdout or stderr.
442 func (r *chanReader) Read(data []byte) (int, error) {
443         var ok bool
444         for {
445                 if len(r.buf) > 0 {
446                         n := copy(data, r.buf)
447                         r.buf = r.buf[n:]
448                         msg := windowAdjustMsg{
449                                 PeersId:         r.id,
450                                 AdditionalBytes: uint32(n),
451                         }
452                         return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
453                 }
454                 r.buf, ok = <-r.data
455                 if !ok {
456                         return 0, io.EOF
457                 }
458         }
459         panic("unreachable")
460 }
461
462 func (r *chanReader) Close() error {
463         return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
464 }