OSDN Git Service

9eed3157ee70deaba06b4a4a2a3cb8c1f8199ccd
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / client.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "big"
9         "crypto"
10         "crypto/rand"
11         "fmt"
12         "io"
13         "net"
14         "os"
15         "sync"
16 )
17
18 // clientVersion is the fixed identification string that the client will use.
19 var clientVersion = []byte("SSH-2.0-Go\r\n")
20
21 // ClientConn represents the client side of an SSH connection.
22 type ClientConn struct {
23         *transport
24         config *ClientConfig
25         chanlist
26 }
27
28 // Client returns a new SSH client connection using c as the underlying transport.
29 func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) {
30         conn := &ClientConn{
31                 transport: newTransport(c, config.rand()),
32                 config:    config,
33         }
34         if err := conn.handshake(); err != nil {
35                 conn.Close()
36                 return nil, err
37         }
38         if err := conn.authenticate(); err != nil {
39                 conn.Close()
40                 return nil, err
41         }
42         go conn.mainLoop()
43         return conn, nil
44 }
45
46 // handshake performs the client side key exchange. See RFC 4253 Section 7.
47 func (c *ClientConn) handshake() os.Error {
48         var magics handshakeMagics
49
50         if _, err := c.Write(clientVersion); err != nil {
51                 return err
52         }
53         if err := c.Flush(); err != nil {
54                 return err
55         }
56         magics.clientVersion = clientVersion[:len(clientVersion)-2]
57
58         // read remote server version
59         version, err := readVersion(c)
60         if err != nil {
61                 return err
62         }
63         magics.serverVersion = version
64         clientKexInit := kexInitMsg{
65                 KexAlgos:                supportedKexAlgos,
66                 ServerHostKeyAlgos:      supportedHostKeyAlgos,
67                 CiphersClientServer:     supportedCiphers,
68                 CiphersServerClient:     supportedCiphers,
69                 MACsClientServer:        supportedMACs,
70                 MACsServerClient:        supportedMACs,
71                 CompressionClientServer: supportedCompressions,
72                 CompressionServerClient: supportedCompressions,
73         }
74         kexInitPacket := marshal(msgKexInit, clientKexInit)
75         magics.clientKexInit = kexInitPacket
76
77         if err := c.writePacket(kexInitPacket); err != nil {
78                 return err
79         }
80         packet, err := c.readPacket()
81         if err != nil {
82                 return err
83         }
84
85         magics.serverKexInit = packet
86
87         var serverKexInit kexInitMsg
88         if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
89                 return err
90         }
91
92         kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
93         if !ok {
94                 return os.NewError("ssh: no common algorithms")
95         }
96
97         if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
98                 // The server sent a Kex message for the wrong algorithm,
99                 // which we have to ignore.
100                 if _, err := c.readPacket(); err != nil {
101                         return err
102                 }
103         }
104
105         var H, K []byte
106         var hashFunc crypto.Hash
107         switch kexAlgo {
108         case kexAlgoDH14SHA1:
109                 hashFunc = crypto.SHA1
110                 dhGroup14Once.Do(initDHGroup14)
111                 H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
112         default:
113                 err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
114         }
115         if err != nil {
116                 return err
117         }
118
119         if err = c.writePacket([]byte{msgNewKeys}); err != nil {
120                 return err
121         }
122         if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
123                 return err
124         }
125         if packet, err = c.readPacket(); err != nil {
126                 return err
127         }
128         if packet[0] != msgNewKeys {
129                 return UnexpectedMessageError{msgNewKeys, packet[0]}
130         }
131         return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
132 }
133
134 // authenticate authenticates with the remote server. See RFC 4252. 
135 // Only "password" authentication is supported.
136 func (c *ClientConn) authenticate() os.Error {
137         if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
138                 return err
139         }
140         packet, err := c.readPacket()
141         if err != nil {
142                 return err
143         }
144
145         var serviceAccept serviceAcceptMsg
146         if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
147                 return err
148         }
149
150         // TODO(dfc) support proper authentication method negotation
151         method := "none"
152         if c.config.Password != "" {
153                 method = "password"
154         }
155         if err := c.sendUserAuthReq(method); err != nil {
156                 return err
157         }
158
159         if packet, err = c.readPacket(); err != nil {
160                 return err
161         }
162
163         if packet[0] != msgUserAuthSuccess {
164                 return UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
165         }
166         return nil
167 }
168
169 func (c *ClientConn) sendUserAuthReq(method string) os.Error {
170         length := stringLength([]byte(c.config.Password)) + 1
171         payload := make([]byte, length)
172         // always false for password auth, see RFC 4252 Section 8.
173         payload[0] = 0
174         marshalString(payload[1:], []byte(c.config.Password))
175
176         return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
177                 User:    c.config.User,
178                 Service: serviceSSH,
179                 Method:  method,
180                 Payload: payload,
181         }))
182 }
183
184 // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
185 // returned values are given the same names as in RFC 4253, section 8.
186 func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) {
187         x, err := rand.Int(c.config.rand(), group.p)
188         if err != nil {
189                 return nil, nil, err
190         }
191         X := new(big.Int).Exp(group.g, x, group.p)
192         kexDHInit := kexDHInitMsg{
193                 X: X,
194         }
195         if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
196                 return nil, nil, err
197         }
198
199         packet, err := c.readPacket()
200         if err != nil {
201                 return nil, nil, err
202         }
203
204         var kexDHReply = new(kexDHReplyMsg)
205         if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
206                 return nil, nil, err
207         }
208
209         if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
210                 return nil, nil, os.NewError("server DH parameter out of bounds")
211         }
212
213         kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
214         h := hashFunc.New()
215         writeString(h, magics.clientVersion)
216         writeString(h, magics.serverVersion)
217         writeString(h, magics.clientKexInit)
218         writeString(h, magics.serverKexInit)
219         writeString(h, kexDHReply.HostKey)
220         writeInt(h, X)
221         writeInt(h, kexDHReply.Y)
222         K := make([]byte, intLength(kInt))
223         marshalInt(K, kInt)
224         h.Write(K)
225
226         H := h.Sum()
227
228         return H, K, nil
229 }
230
231 // openChan opens a new client channel. The most common session type is "session". 
232 // The full set of valid session types are listed in RFC 4250 4.9.1.
233 func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
234         ch := c.newChan(c.transport)
235         if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
236                 ChanType:      typ,
237                 PeersId:       ch.id,
238                 PeersWindow:   1 << 14,
239                 MaxPacketSize: 1 << 15, // RFC 4253 6.1
240         })); err != nil {
241                 c.chanlist.remove(ch.id)
242                 return nil, err
243         }
244         // wait for response
245         switch msg := (<-ch.msg).(type) {
246         case *channelOpenConfirmMsg:
247                 ch.peersId = msg.MyId
248         case *channelOpenFailureMsg:
249                 c.chanlist.remove(ch.id)
250                 return nil, os.NewError(msg.Message)
251         default:
252                 c.chanlist.remove(ch.id)
253                 return nil, os.NewError("Unexpected packet")
254         }
255         return ch, nil
256 }
257
258 // mainloop reads incoming messages and routes channel messages
259 // to their respective ClientChans.
260 func (c *ClientConn) mainLoop() {
261         // TODO(dfc) signal the underlying close to all channels
262         defer c.Close()
263         for {
264                 packet, err := c.readPacket()
265                 if err != nil {
266                         break
267                 }
268                 // TODO(dfc) A note on blocking channel use. 
269                 // The msg, win, data and dataExt channels of a clientChan can 
270                 // cause this loop to block indefinately if the consumer does 
271                 // not service them. 
272                 switch packet[0] {
273                 case msgChannelData:
274                         if len(packet) < 9 {
275                                 // malformed data packet
276                                 break
277                         }
278                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
279                         if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
280                                 packet = packet[9:]
281                                 c.getChan(peersId).data <- packet[:length]
282                         }
283                 case msgChannelExtendedData:
284                         if len(packet) < 13 {
285                                 // malformed data packet
286                                 break
287                         }
288                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
289                         datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
290                         if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
291                                 packet = packet[13:]
292                                 // RFC 4254 5.2 defines data_type_code 1 to be data destined 
293                                 // for stderr on interactive sessions. Other data types are
294                                 // silently discarded.
295                                 if datatype == 1 {
296                                         c.getChan(peersId).dataExt <- packet[:length]
297                                 }
298                         }
299                 default:
300                         switch msg := decode(packet).(type) {
301                         case *channelOpenMsg:
302                                 c.getChan(msg.PeersId).msg <- msg
303                         case *channelOpenConfirmMsg:
304                                 c.getChan(msg.PeersId).msg <- msg
305                         case *channelOpenFailureMsg:
306                                 c.getChan(msg.PeersId).msg <- msg
307                         case *channelCloseMsg:
308                                 ch := c.getChan(msg.PeersId)
309                                 close(ch.win)
310                                 close(ch.data)
311                                 close(ch.dataExt)
312                                 c.chanlist.remove(msg.PeersId)
313                         case *channelEOFMsg:
314                                 c.getChan(msg.PeersId).msg <- msg
315                         case *channelRequestSuccessMsg:
316                                 c.getChan(msg.PeersId).msg <- msg
317                         case *channelRequestFailureMsg:
318                                 c.getChan(msg.PeersId).msg <- msg
319                         case *channelRequestMsg:
320                                 c.getChan(msg.PeersId).msg <- msg
321                         case *windowAdjustMsg:
322                                 c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
323                         default:
324                                 fmt.Printf("mainLoop: unhandled %#v\n", msg)
325                         }
326                 }
327         }
328 }
329
330 // Dial connects to the given network address using net.Dial and 
331 // then initiates a SSH handshake, returning the resulting client connection.
332 func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) {
333         conn, err := net.Dial(network, addr)
334         if err != nil {
335                 return nil, err
336         }
337         return Client(conn, config)
338 }
339
340 // A ClientConfig structure is used to configure a ClientConn. After one has 
341 // been passed to an SSH function it must not be modified.
342 type ClientConfig struct {
343         // Rand provides the source of entropy for key exchange. If Rand is 
344         // nil, the cryptographic random reader in package crypto/rand will 
345         // be used.
346         Rand io.Reader
347
348         // The username to authenticate.
349         User string
350
351         // Used for "password" method authentication.
352         Password string
353 }
354
355 func (c *ClientConfig) rand() io.Reader {
356         if c.Rand == nil {
357                 return rand.Reader
358         }
359         return c.Rand
360 }
361
362 // A clientChan represents a single RFC 4254 channel that is multiplexed 
363 // over a single SSH connection.
364 type clientChan struct {
365         packetWriter
366         id, peersId uint32
367         data        chan []byte      // receives the payload of channelData messages
368         dataExt     chan []byte      // receives the payload of channelExtendedData messages
369         win         chan int         // receives window adjustments
370         msg         chan interface{} // incoming messages
371 }
372
373 func newClientChan(t *transport, id uint32) *clientChan {
374         return &clientChan{
375                 packetWriter: t,
376                 id:           id,
377                 data:         make(chan []byte, 16),
378                 dataExt:      make(chan []byte, 16),
379                 win:          make(chan int, 16),
380                 msg:          make(chan interface{}, 16),
381         }
382 }
383
384 // Close closes the channel. This does not close the underlying connection.
385 func (c *clientChan) Close() os.Error {
386         return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
387                 PeersId: c.id,
388         }))
389 }
390
391 func (c *clientChan) sendChanReq(req channelRequestMsg) os.Error {
392         if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
393                 return err
394         }
395         msg := <-c.msg
396         if _, ok := msg.(*channelRequestSuccessMsg); ok {
397                 return nil
398         }
399         return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
400 }
401
402 // Thread safe channel list.
403 type chanlist struct {
404         // protects concurrent access to chans
405         sync.Mutex
406         // chans are indexed by the local id of the channel, clientChan.id.
407         // The PeersId value of messages received by ClientConn.mainloop is
408         // used to locate the right local clientChan in this slice.
409         chans []*clientChan
410 }
411
412 // Allocate a new ClientChan with the next avail local id.
413 func (c *chanlist) newChan(t *transport) *clientChan {
414         c.Lock()
415         defer c.Unlock()
416         for i := range c.chans {
417                 if c.chans[i] == nil {
418                         ch := newClientChan(t, uint32(i))
419                         c.chans[i] = ch
420                         return ch
421                 }
422         }
423         i := len(c.chans)
424         ch := newClientChan(t, uint32(i))
425         c.chans = append(c.chans, ch)
426         return ch
427 }
428
429 func (c *chanlist) getChan(id uint32) *clientChan {
430         c.Lock()
431         defer c.Unlock()
432         return c.chans[int(id)]
433 }
434
435 func (c *chanlist) remove(id uint32) {
436         c.Lock()
437         defer c.Unlock()
438         c.chans[int(id)] = nil
439 }
440
441 // A chanWriter represents the stdin of a remote process.
442 type chanWriter struct {
443         win          chan int // receives window adjustments
444         id           uint32   // this channel's id
445         rwin         int      // current rwin size
446         packetWriter          // for sending channelDataMsg
447 }
448
449 // Write writes data to the remote process's standard input.
450 func (w *chanWriter) Write(data []byte) (n int, err os.Error) {
451         for {
452                 if w.rwin == 0 {
453                         win, ok := <-w.win
454                         if !ok {
455                                 return 0, os.EOF
456                         }
457                         w.rwin += win
458                         continue
459                 }
460                 n = len(data)
461                 packet := make([]byte, 0, 9+n)
462                 packet = append(packet, msgChannelData,
463                         byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
464                         byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
465                 err = w.writePacket(append(packet, data...))
466                 w.rwin -= n
467                 return
468         }
469         panic("unreachable")
470 }
471
472 func (w *chanWriter) Close() os.Error {
473         return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
474 }
475
476 // A chanReader represents stdout or stderr of a remote process.
477 type chanReader struct {
478         // TODO(dfc) a fixed size channel may not be the right data structure.
479         // If writes to this channel block, they will block mainLoop, making
480         // it unable to receive new messages from the remote side.
481         data         chan []byte // receives data from remote
482         id           uint32
483         packetWriter // for sending windowAdjustMsg
484         buf          []byte
485 }
486
487 // Read reads data from the remote process's stdout or stderr.
488 func (r *chanReader) Read(data []byte) (int, os.Error) {
489         var ok bool
490         for {
491                 if len(r.buf) > 0 {
492                         n := copy(data, r.buf)
493                         r.buf = r.buf[n:]
494                         msg := windowAdjustMsg{
495                                 PeersId:         r.id,
496                                 AdditionalBytes: uint32(n),
497                         }
498                         return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
499                 }
500                 r.buf, ok = <-r.data
501                 if !ok {
502                         return 0, os.EOF
503                 }
504         }
505         panic("unreachable")
506 }
507
508 func (r *chanReader) Close() os.Error {
509         return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
510 }