OSDN Git Service

428a747e1e0c89b424d0349d5f6ccde744d2c503
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / server.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "bytes"
9         "crypto"
10         "crypto/rand"
11         "crypto/rsa"
12         "crypto/x509"
13         "encoding/pem"
14         "errors"
15         "io"
16         "math/big"
17         "net"
18         "sync"
19 )
20
21 type ServerConfig struct {
22         rsa           *rsa.PrivateKey
23         rsaSerialized []byte
24
25         // Rand provides the source of entropy for key exchange. If Rand is 
26         // nil, the cryptographic random reader in package crypto/rand will 
27         // be used.
28         Rand io.Reader
29
30         // NoClientAuth is true if clients are allowed to connect without
31         // authenticating.
32         NoClientAuth bool
33
34         // PasswordCallback, if non-nil, is called when a user attempts to
35         // authenticate using a password. It may be called concurrently from
36         // several goroutines.
37         PasswordCallback func(user, password string) bool
38
39         // PubKeyCallback, if non-nil, is called when a client attempts public
40         // key authentication. It must return true iff the given public key is
41         // valid for the given user.
42         PubKeyCallback func(user, algo string, pubkey []byte) bool
43
44         // Cryptographic-related configuration.
45         Crypto CryptoConfig
46 }
47
48 func (c *ServerConfig) rand() io.Reader {
49         if c.Rand == nil {
50                 return rand.Reader
51         }
52         return c.Rand
53 }
54
55 // SetRSAPrivateKey sets the private key for a Server. A Server must have a
56 // private key configured in order to accept connections. The private key must
57 // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
58 // typically contains such a key.
59 func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
60         block, _ := pem.Decode(pemBytes)
61         if block == nil {
62                 return errors.New("ssh: no key found")
63         }
64         var err error
65         s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
66         if err != nil {
67                 return err
68         }
69
70         s.rsaSerialized = marshalRSA(s.rsa)
71         return nil
72 }
73
74 // marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
75 func marshalRSA(priv *rsa.PrivateKey) []byte {
76         e := new(big.Int).SetInt64(int64(priv.E))
77         length := stringLength([]byte(hostAlgoRSA))
78         length += intLength(e)
79         length += intLength(priv.N)
80
81         ret := make([]byte, length)
82         r := marshalString(ret, []byte(hostAlgoRSA))
83         r = marshalInt(r, e)
84         r = marshalInt(r, priv.N)
85
86         return ret
87 }
88
89 // parseRSA parses an RSA key according to RFC 4256, section 6.6.
90 func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
91         algo, in, ok := parseString(in)
92         if !ok || string(algo) != hostAlgoRSA {
93                 return nil, false
94         }
95         bigE, in, ok := parseInt(in)
96         if !ok || bigE.BitLen() > 24 {
97                 return nil, false
98         }
99         e := bigE.Int64()
100         if e < 3 || e&1 == 0 {
101                 return nil, false
102         }
103         N, in, ok := parseInt(in)
104         if !ok || len(in) > 0 {
105                 return nil, false
106         }
107         return &rsa.PublicKey{
108                 N: N,
109                 E: int(e),
110         }, true
111 }
112
113 func parseRSASig(in []byte) (sig []byte, ok bool) {
114         algo, in, ok := parseString(in)
115         if !ok || string(algo) != hostAlgoRSA {
116                 return nil, false
117         }
118         sig, in, ok = parseString(in)
119         if len(in) > 0 {
120                 ok = false
121         }
122         return
123 }
124
125 // cachedPubKey contains the results of querying whether a public key is
126 // acceptable for a user. The cache only applies to a single ServerConn.
127 type cachedPubKey struct {
128         user, algo string
129         pubKey     []byte
130         result     bool
131 }
132
133 const maxCachedPubKeys = 16
134
135 // A ServerConn represents an incomming connection.
136 type ServerConn struct {
137         *transport
138         config *ServerConfig
139
140         channels   map[uint32]*channel
141         nextChanId uint32
142
143         // lock protects err and also allows Channels to serialise their writes
144         // to out.
145         lock sync.RWMutex
146         err  error
147
148         // cachedPubKeys contains the cache results of tests for public keys.
149         // Since SSH clients will query whether a public key is acceptable
150         // before attempting to authenticate with it, we end up with duplicate
151         // queries for public key validity.
152         cachedPubKeys []cachedPubKey
153 }
154
155 // Server returns a new SSH server connection
156 // using c as the underlying transport.
157 func Server(c net.Conn, config *ServerConfig) *ServerConn {
158         conn := &ServerConn{
159                 transport: newTransport(c, config.rand()),
160                 channels:  make(map[uint32]*channel),
161                 config:    config,
162         }
163         return conn
164 }
165
166 // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
167 // returned values are given the same names as in RFC 4253, section 8.
168 func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err error) {
169         packet, err := s.readPacket()
170         if err != nil {
171                 return
172         }
173         var kexDHInit kexDHInitMsg
174         if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
175                 return
176         }
177
178         if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
179                 return nil, nil, errors.New("client DH parameter out of bounds")
180         }
181
182         y, err := rand.Int(s.config.rand(), group.p)
183         if err != nil {
184                 return
185         }
186
187         Y := new(big.Int).Exp(group.g, y, group.p)
188         kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
189
190         var serializedHostKey []byte
191         switch hostKeyAlgo {
192         case hostAlgoRSA:
193                 serializedHostKey = s.config.rsaSerialized
194         default:
195                 return nil, nil, errors.New("internal error")
196         }
197
198         h := hashFunc.New()
199         writeString(h, magics.clientVersion)
200         writeString(h, magics.serverVersion)
201         writeString(h, magics.clientKexInit)
202         writeString(h, magics.serverKexInit)
203         writeString(h, serializedHostKey)
204         writeInt(h, kexDHInit.X)
205         writeInt(h, Y)
206         K = make([]byte, intLength(kInt))
207         marshalInt(K, kInt)
208         h.Write(K)
209
210         H = h.Sum()
211
212         h.Reset()
213         h.Write(H)
214         hh := h.Sum()
215
216         var sig []byte
217         switch hostKeyAlgo {
218         case hostAlgoRSA:
219                 sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh)
220                 if err != nil {
221                         return
222                 }
223         default:
224                 return nil, nil, errors.New("internal error")
225         }
226
227         serializedSig := serializeSignature(hostAlgoRSA, sig)
228
229         kexDHReply := kexDHReplyMsg{
230                 HostKey:   serializedHostKey,
231                 Y:         Y,
232                 Signature: serializedSig,
233         }
234         packet = marshal(msgKexDHReply, kexDHReply)
235
236         err = s.writePacket(packet)
237         return
238 }
239
240 // serverVersion is the fixed identification string that Server will use.
241 var serverVersion = []byte("SSH-2.0-Go\r\n")
242
243 // Handshake performs an SSH transport and client authentication on the given ServerConn.
244 func (s *ServerConn) Handshake() error {
245         var magics handshakeMagics
246         if _, err := s.Write(serverVersion); err != nil {
247                 return err
248         }
249         if err := s.Flush(); err != nil {
250                 return err
251         }
252         magics.serverVersion = serverVersion[:len(serverVersion)-2]
253
254         version, err := readVersion(s)
255         if err != nil {
256                 return err
257         }
258         magics.clientVersion = version
259
260         serverKexInit := kexInitMsg{
261                 KexAlgos:                supportedKexAlgos,
262                 ServerHostKeyAlgos:      supportedHostKeyAlgos,
263                 CiphersClientServer:     s.config.Crypto.ciphers(),
264                 CiphersServerClient:     s.config.Crypto.ciphers(),
265                 MACsClientServer:        supportedMACs,
266                 MACsServerClient:        supportedMACs,
267                 CompressionClientServer: supportedCompressions,
268                 CompressionServerClient: supportedCompressions,
269         }
270         kexInitPacket := marshal(msgKexInit, serverKexInit)
271         magics.serverKexInit = kexInitPacket
272
273         if err := s.writePacket(kexInitPacket); err != nil {
274                 return err
275         }
276
277         packet, err := s.readPacket()
278         if err != nil {
279                 return err
280         }
281
282         magics.clientKexInit = packet
283
284         var clientKexInit kexInitMsg
285         if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
286                 return err
287         }
288
289         kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, &clientKexInit, &serverKexInit)
290         if !ok {
291                 return errors.New("ssh: no common algorithms")
292         }
293
294         if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
295                 // The client sent a Kex message for the wrong algorithm,
296                 // which we have to ignore.
297                 if _, err := s.readPacket(); err != nil {
298                         return err
299                 }
300         }
301
302         var H, K []byte
303         var hashFunc crypto.Hash
304         switch kexAlgo {
305         case kexAlgoDH14SHA1:
306                 hashFunc = crypto.SHA1
307                 dhGroup14Once.Do(initDHGroup14)
308                 H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
309         default:
310                 err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
311         }
312         if err != nil {
313                 return err
314         }
315
316         if err = s.writePacket([]byte{msgNewKeys}); err != nil {
317                 return err
318         }
319         if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
320                 return err
321         }
322         if packet, err = s.readPacket(); err != nil {
323                 return err
324         }
325
326         if packet[0] != msgNewKeys {
327                 return UnexpectedMessageError{msgNewKeys, packet[0]}
328         }
329         if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
330                 return err
331         }
332         if packet, err = s.readPacket(); err != nil {
333                 return err
334         }
335
336         var serviceRequest serviceRequestMsg
337         if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
338                 return err
339         }
340         if serviceRequest.Service != serviceUserAuth {
341                 return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
342         }
343         serviceAccept := serviceAcceptMsg{
344                 Service: serviceUserAuth,
345         }
346         if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
347                 return err
348         }
349
350         if err = s.authenticate(H); err != nil {
351                 return err
352         }
353         return nil
354 }
355
356 func isAcceptableAlgo(algo string) bool {
357         return algo == hostAlgoRSA
358 }
359
360 // testPubKey returns true if the given public key is acceptable for the user.
361 func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
362         if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
363                 return false
364         }
365
366         for _, c := range s.cachedPubKeys {
367                 if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
368                         return c.result
369                 }
370         }
371
372         result := s.config.PubKeyCallback(user, algo, pubKey)
373         if len(s.cachedPubKeys) < maxCachedPubKeys {
374                 c := cachedPubKey{
375                         user:   user,
376                         algo:   algo,
377                         pubKey: make([]byte, len(pubKey)),
378                         result: result,
379                 }
380                 copy(c.pubKey, pubKey)
381                 s.cachedPubKeys = append(s.cachedPubKeys, c)
382         }
383
384         return result
385 }
386
387 func (s *ServerConn) authenticate(H []byte) error {
388         var userAuthReq userAuthRequestMsg
389         var err error
390         var packet []byte
391
392 userAuthLoop:
393         for {
394                 if packet, err = s.readPacket(); err != nil {
395                         return err
396                 }
397                 if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
398                         return err
399                 }
400
401                 if userAuthReq.Service != serviceSSH {
402                         return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
403                 }
404
405                 switch userAuthReq.Method {
406                 case "none":
407                         if s.config.NoClientAuth {
408                                 break userAuthLoop
409                         }
410                 case "password":
411                         if s.config.PasswordCallback == nil {
412                                 break
413                         }
414                         payload := userAuthReq.Payload
415                         if len(payload) < 1 || payload[0] != 0 {
416                                 return ParseError{msgUserAuthRequest}
417                         }
418                         payload = payload[1:]
419                         password, payload, ok := parseString(payload)
420                         if !ok || len(payload) > 0 {
421                                 return ParseError{msgUserAuthRequest}
422                         }
423
424                         if s.config.PasswordCallback(userAuthReq.User, string(password)) {
425                                 break userAuthLoop
426                         }
427                 case "publickey":
428                         if s.config.PubKeyCallback == nil {
429                                 break
430                         }
431                         payload := userAuthReq.Payload
432                         if len(payload) < 1 {
433                                 return ParseError{msgUserAuthRequest}
434                         }
435                         isQuery := payload[0] == 0
436                         payload = payload[1:]
437                         algoBytes, payload, ok := parseString(payload)
438                         if !ok {
439                                 return ParseError{msgUserAuthRequest}
440                         }
441                         algo := string(algoBytes)
442
443                         pubKey, payload, ok := parseString(payload)
444                         if !ok {
445                                 return ParseError{msgUserAuthRequest}
446                         }
447                         if isQuery {
448                                 // The client can query if the given public key
449                                 // would be ok.
450                                 if len(payload) > 0 {
451                                         return ParseError{msgUserAuthRequest}
452                                 }
453                                 if s.testPubKey(userAuthReq.User, algo, pubKey) {
454                                         okMsg := userAuthPubKeyOkMsg{
455                                                 Algo:   algo,
456                                                 PubKey: string(pubKey),
457                                         }
458                                         if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
459                                                 return err
460                                         }
461                                         continue userAuthLoop
462                                 }
463                         } else {
464                                 sig, payload, ok := parseString(payload)
465                                 if !ok || len(payload) > 0 {
466                                         return ParseError{msgUserAuthRequest}
467                                 }
468                                 if !isAcceptableAlgo(algo) {
469                                         break
470                                 }
471                                 rsaSig, ok := parseRSASig(sig)
472                                 if !ok {
473                                         return ParseError{msgUserAuthRequest}
474                                 }
475                                 signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
476                                 switch algo {
477                                 case hostAlgoRSA:
478                                         hashFunc := crypto.SHA1
479                                         h := hashFunc.New()
480                                         h.Write(signedData)
481                                         digest := h.Sum()
482                                         rsaKey, ok := parseRSA(pubKey)
483                                         if !ok {
484                                                 return ParseError{msgUserAuthRequest}
485                                         }
486                                         if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
487                                                 return ParseError{msgUserAuthRequest}
488                                         }
489                                 default:
490                                         return errors.New("ssh: isAcceptableAlgo incorrect")
491                                 }
492                                 if s.testPubKey(userAuthReq.User, algo, pubKey) {
493                                         break userAuthLoop
494                                 }
495                         }
496                 }
497
498                 var failureMsg userAuthFailureMsg
499                 if s.config.PasswordCallback != nil {
500                         failureMsg.Methods = append(failureMsg.Methods, "password")
501                 }
502                 if s.config.PubKeyCallback != nil {
503                         failureMsg.Methods = append(failureMsg.Methods, "publickey")
504                 }
505
506                 if len(failureMsg.Methods) == 0 {
507                         return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
508                 }
509
510                 if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
511                         return err
512                 }
513         }
514
515         packet = []byte{msgUserAuthSuccess}
516         if err = s.writePacket(packet); err != nil {
517                 return err
518         }
519
520         return nil
521 }
522
523 const defaultWindowSize = 32768
524
525 // Accept reads and processes messages on a ServerConn. It must be called
526 // in order to demultiplex messages to any resulting Channels.
527 func (s *ServerConn) Accept() (Channel, error) {
528         if s.err != nil {
529                 return nil, s.err
530         }
531
532         for {
533                 packet, err := s.readPacket()
534                 if err != nil {
535
536                         s.lock.Lock()
537                         s.err = err
538                         s.lock.Unlock()
539
540                         for _, c := range s.channels {
541                                 c.dead = true
542                                 c.handleData(nil)
543                         }
544
545                         return nil, err
546                 }
547
548                 switch packet[0] {
549                 case msgChannelData:
550                         if len(packet) < 9 {
551                                 // malformed data packet
552                                 return nil, ParseError{msgChannelData}
553                         }
554                         peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
555                         s.lock.Lock()
556                         c, ok := s.channels[peersId]
557                         if !ok {
558                                 s.lock.Unlock()
559                                 continue
560                         }
561                         if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
562                                 packet = packet[9:]
563                                 c.handleData(packet[:length])
564                         }
565                         s.lock.Unlock()
566                 default:
567                         switch msg := decode(packet).(type) {
568                         case *channelOpenMsg:
569                                 c := new(channel)
570                                 c.chanType = msg.ChanType
571                                 c.theirId = msg.PeersId
572                                 c.theirWindow = msg.PeersWindow
573                                 c.maxPacketSize = msg.MaxPacketSize
574                                 c.extraData = msg.TypeSpecificData
575                                 c.myWindow = defaultWindowSize
576                                 c.serverConn = s
577                                 c.cond = sync.NewCond(&c.lock)
578                                 c.pendingData = make([]byte, c.myWindow)
579
580                                 s.lock.Lock()
581                                 c.myId = s.nextChanId
582                                 s.nextChanId++
583                                 s.channels[c.myId] = c
584                                 s.lock.Unlock()
585                                 return c, nil
586
587                         case *channelRequestMsg:
588                                 s.lock.Lock()
589                                 c, ok := s.channels[msg.PeersId]
590                                 if !ok {
591                                         s.lock.Unlock()
592                                         continue
593                                 }
594                                 c.handlePacket(msg)
595                                 s.lock.Unlock()
596
597                         case *channelEOFMsg:
598                                 s.lock.Lock()
599                                 c, ok := s.channels[msg.PeersId]
600                                 if !ok {
601                                         s.lock.Unlock()
602                                         continue
603                                 }
604                                 c.handlePacket(msg)
605                                 s.lock.Unlock()
606
607                         case *channelCloseMsg:
608                                 s.lock.Lock()
609                                 c, ok := s.channels[msg.PeersId]
610                                 if !ok {
611                                         s.lock.Unlock()
612                                         continue
613                                 }
614                                 c.handlePacket(msg)
615                                 s.lock.Unlock()
616
617                         case *globalRequestMsg:
618                                 if msg.WantReply {
619                                         if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
620                                                 return nil, err
621                                         }
622                                 }
623
624                         case UnexpectedMessageError:
625                                 return nil, msg
626                         case *disconnectMsg:
627                                 return nil, io.EOF
628                         default:
629                                 // Unknown message. Ignore.
630                         }
631                 }
632         }
633
634         panic("unreachable")
635 }
636
637 // A Listener implements a network listener (net.Listener) for SSH connections.
638 type Listener struct {
639         listener net.Listener
640         config   *ServerConfig
641 }
642
643 // Accept waits for and returns the next incoming SSH connection.
644 // The receiver should call Handshake() in another goroutine 
645 // to avoid blocking the accepter.
646 func (l *Listener) Accept() (*ServerConn, error) {
647         c, err := l.listener.Accept()
648         if err != nil {
649                 return nil, err
650         }
651         conn := Server(c, l.config)
652         return conn, nil
653 }
654
655 // Addr returns the listener's network address.
656 func (l *Listener) Addr() net.Addr {
657         return l.listener.Addr()
658 }
659
660 // Close closes the listener.
661 func (l *Listener) Close() error {
662         return l.listener.Close()
663 }
664
665 // Listen creates an SSH listener accepting connections on
666 // the given network address using net.Listen.
667 func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
668         l, err := net.Listen(network, addr)
669         if err != nil {
670                 return nil, err
671         }
672         return &Listener{
673                 l,
674                 config,
675         }, nil
676 }