1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
22 paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
25 // filteredConn reduces the set of methods exposed when embeddeding
26 // a net.Conn inside ssh.transport.
27 // TODO(dfc) suggestions for a better name will be warmly received.
28 type filteredConn interface {
29 // Close closes the connection.
32 // LocalAddr returns the local network address.
35 // RemoteAddr returns the remote network address.
39 // Types implementing packetWriter provide the ability to send packets to
41 type packetWriter interface {
42 // Encrypt and send a packet of data to the remote peer.
43 writePacket(packet []byte) os.Error
46 // transport represents the SSH connection to the remote peer.
47 type transport struct {
54 // reader represents the incoming connection state.
60 // writer represnts the outgoing connection state.
62 *sync.Mutex // protects writer.Writer from concurrent writes
69 // common represents the cipher state needed to process messages in a single
78 compressionAlgo string
81 // Read and decrypt a single packet from the remote peer.
82 func (r *reader) readOnePacket() ([]byte, os.Error) {
83 var lengthBytes = make([]byte, 5)
86 if _, err := io.ReadFull(r, lengthBytes); err != nil {
91 r.cipher.XORKeyStream(lengthBytes, lengthBytes)
96 seqNumBytes := []byte{
102 r.mac.Write(seqNumBytes)
103 r.mac.Write(lengthBytes)
104 macSize = uint32(r.mac.Size())
107 length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
108 paddingLength := uint32(lengthBytes[4])
110 if length <= paddingLength+1 {
111 return nil, os.NewError("invalid packet length")
113 if length > maxPacketSize {
114 return nil, os.NewError("packet too large")
117 packet := make([]byte, length-1+macSize)
118 if _, err := io.ReadFull(r, packet); err != nil {
121 mac := packet[length-1:]
123 r.cipher.XORKeyStream(packet, packet[:length-1])
127 r.mac.Write(packet[:length-1])
128 if subtle.ConstantTimeCompare(r.mac.Sum(), mac) != 1 {
129 return nil, os.NewError("ssh: MAC failure")
134 return packet[:length-paddingLength-1], nil
137 // Read and decrypt next packet discarding debug and noop messages.
138 func (t *transport) readPacket() ([]byte, os.Error) {
140 packet, err := t.readOnePacket()
144 if packet[0] != msgIgnore && packet[0] != msgDebug {
151 // Encrypt and send a packet of data to the remote peer.
152 func (w *writer) writePacket(packet []byte) os.Error {
154 defer w.Mutex.Unlock()
156 paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
157 if paddingLength < 4 {
158 paddingLength += paddingMultiple
161 length := len(packet) + 1 + paddingLength
162 lengthBytes := []byte{
169 padding := make([]byte, paddingLength)
170 _, err := io.ReadFull(w.rand, padding)
177 seqNumBytes := []byte{
178 byte(w.seqNum >> 24),
179 byte(w.seqNum >> 16),
183 w.mac.Write(seqNumBytes)
184 w.mac.Write(lengthBytes)
189 // TODO(dfc) lengthBytes, packet and padding should be
190 // subslices of a single buffer
192 w.cipher.XORKeyStream(lengthBytes, lengthBytes)
193 w.cipher.XORKeyStream(packet, packet)
194 w.cipher.XORKeyStream(padding, padding)
197 if _, err := w.Write(lengthBytes); err != nil {
200 if _, err := w.Write(packet); err != nil {
203 if _, err := w.Write(padding); err != nil {
208 if _, err := w.Write(w.mac.Sum()); err != nil {
213 if err := w.Flush(); err != nil {
220 // Send a message to the remote peer
221 func (t *transport) sendMessage(typ uint8, msg interface{}) os.Error {
222 packet := marshal(typ, msg)
223 return t.writePacket(packet)
226 func newTransport(conn net.Conn, rand io.Reader) *transport {
229 Reader: bufio.NewReader(conn),
232 Writer: bufio.NewWriter(conn),
234 Mutex: new(sync.Mutex),
240 type direction struct {
246 // TODO(dfc) can this be made a constant ?
248 serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
249 clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
252 // setupKeys sets the cipher and MAC keys from K, H and sessionId, as
253 // described in RFC 4253, section 6.4. direction should either be serverKeys
254 // (to setup server->client keys) or clientKeys (for client->server keys).
255 func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
262 iv := make([]byte, blockSize)
263 key := make([]byte, keySize)
264 macKey := make([]byte, macKeySize)
265 generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
266 generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
267 generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
269 c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
270 aes, err := aes.NewCipher(key)
274 c.cipher = cipher.NewCTR(aes, iv)
278 // generateKeyMaterial fills out with key material generated from tag, K, H
279 // and sessionId, as specified in RFC 4253, section 7.2.
280 func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
281 var digestsSoFar []byte
288 if len(digestsSoFar) == 0 {
292 h.Write(digestsSoFar)
296 n := copy(out, digest)
299 digestsSoFar = append(digestsSoFar, digest...)
304 // truncatingMAC wraps around a hash.Hash and truncates the output digest to
306 type truncatingMAC struct {
311 func (t truncatingMAC) Write(data []byte) (int, os.Error) {
312 return t.hmac.Write(data)
315 func (t truncatingMAC) Sum() []byte {
316 digest := t.hmac.Sum()
317 return digest[:t.length]
320 func (t truncatingMAC) Reset() {
324 func (t truncatingMAC) Size() int {
328 // maxVersionStringBytes is the maximum number of bytes that we'll accept as a
329 // version string. In the event that the client is talking a different protocol
330 // we need to set a limit otherwise we will keep using more and more memory
331 // while searching for the end of the version handshake.
332 const maxVersionStringBytes = 1024
334 // Read version string as specified by RFC 4253, section 4.2.
335 func readVersion(r io.Reader) ([]byte, os.Error) {
336 versionString := make([]byte, 0, 64)
340 for len(versionString) < maxVersionStringBytes {
341 _, err := io.ReadFull(r, buf[:])
359 versionString = append(versionString, b)
363 return nil, os.NewError("failed to read version string")
366 // We need to remove the CR from versionString
367 return versionString[:len(versionString)-1], nil