1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
21 packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
24 minPaddingSize = 4 // TODO(huin) should this be configurable?
27 // filteredConn reduces the set of methods exposed when embeddeding
28 // a net.Conn inside ssh.transport.
29 // TODO(dfc) suggestions for a better name will be warmly received.
30 type filteredConn interface {
31 // Close closes the connection.
34 // LocalAddr returns the local network address.
37 // RemoteAddr returns the remote network address.
41 // Types implementing packetWriter provide the ability to send packets to
43 type packetWriter interface {
44 // Encrypt and send a packet of data to the remote peer.
45 writePacket(packet []byte) error
48 // transport represents the SSH connection to the remote peer.
49 type transport struct {
56 // reader represents the incoming connection state.
62 // writer represnts the outgoing connection state.
64 *sync.Mutex // protects writer.Writer from concurrent writes
70 // common represents the cipher state needed to process messages in a single
79 compressionAlgo string
82 // Read and decrypt a single packet from the remote peer.
83 func (r *reader) readOnePacket() ([]byte, error) {
84 var lengthBytes = make([]byte, 5)
86 if _, err := io.ReadFull(r, lengthBytes); err != nil {
90 r.cipher.XORKeyStream(lengthBytes, lengthBytes)
94 seqNumBytes := []byte{
100 r.mac.Write(seqNumBytes)
101 r.mac.Write(lengthBytes)
102 macSize = uint32(r.mac.Size())
105 length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
106 paddingLength := uint32(lengthBytes[4])
108 if length <= paddingLength+1 {
109 return nil, errors.New("invalid packet length")
111 if length > maxPacketSize {
112 return nil, errors.New("packet too large")
115 packet := make([]byte, length-1+macSize)
116 if _, err := io.ReadFull(r, packet); err != nil {
119 mac := packet[length-1:]
120 r.cipher.XORKeyStream(packet, packet[:length-1])
123 r.mac.Write(packet[:length-1])
124 if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
125 return nil, errors.New("ssh: MAC failure")
130 return packet[:length-paddingLength-1], nil
133 // Read and decrypt next packet discarding debug and noop messages.
134 func (t *transport) readPacket() ([]byte, error) {
136 packet, err := t.readOnePacket()
140 if packet[0] != msgIgnore && packet[0] != msgDebug {
147 // Encrypt and send a packet of data to the remote peer.
148 func (w *writer) writePacket(packet []byte) error {
150 defer w.Mutex.Unlock()
152 paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
153 if paddingLength < 4 {
154 paddingLength += packetSizeMultiple
157 length := len(packet) + 1 + paddingLength
158 lengthBytes := []byte{
165 padding := make([]byte, paddingLength)
166 _, err := io.ReadFull(w.rand, padding)
173 seqNumBytes := []byte{
174 byte(w.seqNum >> 24),
175 byte(w.seqNum >> 16),
179 w.mac.Write(seqNumBytes)
180 w.mac.Write(lengthBytes)
185 // TODO(dfc) lengthBytes, packet and padding should be
186 // subslices of a single buffer
187 w.cipher.XORKeyStream(lengthBytes, lengthBytes)
188 w.cipher.XORKeyStream(packet, packet)
189 w.cipher.XORKeyStream(padding, padding)
191 if _, err := w.Write(lengthBytes); err != nil {
194 if _, err := w.Write(packet); err != nil {
197 if _, err := w.Write(padding); err != nil {
202 if _, err := w.Write(w.mac.Sum(nil)); err != nil {
207 if err := w.Flush(); err != nil {
214 // Send a message to the remote peer
215 func (t *transport) sendMessage(typ uint8, msg interface{}) error {
216 packet := marshal(typ, msg)
217 return t.writePacket(packet)
220 func newTransport(conn net.Conn, rand io.Reader) *transport {
223 Reader: bufio.NewReader(conn),
225 cipher: noneCipher{},
229 Writer: bufio.NewWriter(conn),
231 Mutex: new(sync.Mutex),
233 cipher: noneCipher{},
240 type direction struct {
246 // TODO(dfc) can this be made a constant ?
248 serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
249 clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
252 // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
253 // described in RFC 4253, section 6.4. direction should either be serverKeys
254 // (to setup server->client keys) or clientKeys (for client->server keys).
255 func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
256 cipherMode := cipherModes[c.cipherAlgo]
260 iv := make([]byte, cipherMode.ivSize)
261 key := make([]byte, cipherMode.keySize)
262 macKey := make([]byte, macKeySize)
265 generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
266 generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
267 generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
269 c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
271 cipher, err := cipherMode.createCipher(key, iv)
281 // generateKeyMaterial fills out with key material generated from tag, K, H
282 // and sessionId, as specified in RFC 4253, section 7.2.
283 func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
284 var digestsSoFar []byte
291 if len(digestsSoFar) == 0 {
295 h.Write(digestsSoFar)
299 n := copy(out, digest)
302 digestsSoFar = append(digestsSoFar, digest...)
307 // truncatingMAC wraps around a hash.Hash and truncates the output digest to
309 type truncatingMAC struct {
314 func (t truncatingMAC) Write(data []byte) (int, error) {
315 return t.hmac.Write(data)
318 func (t truncatingMAC) Sum(in []byte) []byte {
319 out := t.hmac.Sum(in)
320 return out[:len(in)+t.length]
323 func (t truncatingMAC) Reset() {
327 func (t truncatingMAC) Size() int {
331 // maxVersionStringBytes is the maximum number of bytes that we'll accept as a
332 // version string. In the event that the client is talking a different protocol
333 // we need to set a limit otherwise we will keep using more and more memory
334 // while searching for the end of the version handshake.
335 const maxVersionStringBytes = 1024
337 // Read version string as specified by RFC 4253, section 4.2.
338 func readVersion(r io.Reader) ([]byte, error) {
339 versionString := make([]byte, 0, 64)
343 for len(versionString) < maxVersionStringBytes {
344 _, err := io.ReadFull(r, buf[:])
362 versionString = append(versionString, b)
366 return nil, errors.New("failed to read version string")
369 // We need to remove the CR from versionString
370 return versionString[:len(versionString)-1], nil