OSDN Git Service

libgo: Update to weekly.2012-01-15.
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / client.go
index d89b908..8df8145 100644 (file)
@@ -187,10 +187,10 @@ func (c *ClientConn) mainLoop() {
                if err != nil {
                        break
                }
-               // TODO(dfc) A note on blocking channel use. 
-               // The msg, win, data and dataExt channels of a clientChan can 
-               // cause this loop to block indefinately if the consumer does 
-               // not service them. 
+               // TODO(dfc) A note on blocking channel use.
+               // The msg, win, data and dataExt channels of a clientChan can
+               // cause this loop to block indefinately if the consumer does
+               // not service them.
                switch packet[0] {
                case msgChannelData:
                        if len(packet) < 9 {
@@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() {
                        peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
                        if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
                                packet = packet[9:]
-                               c.getChan(peersId).stdout.data <- packet[:length]
+                               c.getChan(peersId).stdout.handleData(packet[:length])
                        }
                case msgChannelExtendedData:
                        if len(packet) < 13 {
@@ -211,11 +211,11 @@ func (c *ClientConn) mainLoop() {
                        datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
                        if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
                                packet = packet[13:]
-                               // RFC 4254 5.2 defines data_type_code 1 to be data destined 
+                               // RFC 4254 5.2 defines data_type_code 1 to be data destined
                                // for stderr on interactive sessions. Other data types are
                                // silently discarded.
                                if datatype == 1 {
-                                       c.getChan(peersId).stderr.data <- packet[:length]
+                                       c.getChan(peersId).stderr.handleData(packet[:length])
                                }
                        }
                default:
@@ -228,12 +228,22 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelCloseMsg:
                                ch := c.getChan(msg.PeersId)
+                               ch.theyClosed = true
                                close(ch.stdin.win)
-                               close(ch.stdout.data)
-                               close(ch.stderr.data)
+                               ch.stdout.eof()
+                               ch.stderr.eof()
+                               close(ch.msg)
+                               if !ch.weClosed {
+                                       ch.weClosed = true
+                                       ch.sendClose()
+                               }
                                c.chanlist.remove(msg.PeersId)
                        case *channelEOFMsg:
-                               c.getChan(msg.PeersId).msg <- msg
+                               ch := c.getChan(msg.PeersId)
+                               ch.stdout.eof()
+                               // RFC 4254 is mute on how EOF affects dataExt messages but
+                               // it is logical to signal EOF at the same time.
+                               ch.stderr.eof()
                        case *channelRequestSuccessMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelRequestFailureMsg:
@@ -242,6 +252,8 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *windowAdjustMsg:
                                c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
+                       case *disconnectMsg:
+                               break
                        default:
                                fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
                        }
@@ -249,7 +261,7 @@ func (c *ClientConn) mainLoop() {
        }
 }
 
-// Dial connects to the given network address using net.Dial and 
+// Dial connects to the given network address using net.Dial and
 // then initiates a SSH handshake, returning the resulting client connection.
 func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        conn, err := net.Dial(network, addr)
@@ -259,18 +271,18 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        return Client(conn, config)
 }
 
-// A ClientConfig structure is used to configure a ClientConn. After one has 
+// A ClientConfig structure is used to configure a ClientConn. After one has
 // been passed to an SSH function it must not be modified.
 type ClientConfig struct {
-       // Rand provides the source of entropy for key exchange. If Rand is 
-       // nil, the cryptographic random reader in package crypto/rand will 
+       // Rand provides the source of entropy for key exchange. If Rand is
+       // nil, the cryptographic random reader in package crypto/rand will
        // be used.
        Rand io.Reader
 
        // The username to authenticate.
        User string
 
-       // A slice of ClientAuth methods. Only the first instance 
+       // A slice of ClientAuth methods. Only the first instance
        // of a particular RFC 4252 method will be used during authentication.
        Auth []ClientAuth
 
@@ -285,7 +297,7 @@ func (c *ClientConfig) rand() io.Reader {
        return c.Rand
 }
 
-// A clientChan represents a single RFC 4254 channel that is multiplexed 
+// A clientChan represents a single RFC 4254 channel that is multiplexed
 // over a single SSH connection.
 type clientChan struct {
        packetWriter
@@ -294,10 +306,13 @@ type clientChan struct {
        stdout      *chanReader      // receives the payload of channelData messages
        stderr      *chanReader      // receives the payload of channelExtendedData messages
        msg         chan interface{} // incoming messages
+
+       theyClosed bool // indicates the close msg has been received from the remote side
+       weClosed   bool // incidates the close msg has been sent from our side
 }
 
 // newClientChan returns a partially constructed *clientChan
-// using the local id provided. To be usable clientChan.peersId 
+// using the local id provided. To be usable clientChan.peersId
 // needs to be assigned once known.
 func newClientChan(t *transport, id uint32) *clientChan {
        c := &clientChan{
@@ -320,8 +335,8 @@ func newClientChan(t *transport, id uint32) *clientChan {
        return c
 }
 
-// waitForChannelOpenResponse, if successful, fills out 
-// the peerId and records any initial window advertisement. 
+// waitForChannelOpenResponse, if successful, fills out
+// the peerId and records any initial window advertisement.
 func (c *clientChan) waitForChannelOpenResponse() error {
        switch msg := (<-c.msg).(type) {
        case *channelOpenConfirmMsg:
@@ -335,13 +350,29 @@ func (c *clientChan) waitForChannelOpenResponse() error {
        return errors.New("unexpected packet")
 }
 
-// Close closes the channel. This does not close the underlying connection.
-func (c *clientChan) Close() error {
+// sendEOF sends EOF to the server. RFC 4254 Section 5.3
+func (c *clientChan) sendEOF() error {
+       return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
+               PeersId: c.peersId,
+       }))
+}
+
+// sendClose signals the intent to close the channel.
+func (c *clientChan) sendClose() error {
        return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
                PeersId: c.peersId,
        }))
 }
 
+// Close closes the channel. This does not close the underlying connection.
+func (c *clientChan) Close() error {
+       if !c.weClosed {
+               c.weClosed = true
+               return c.sendClose()
+       }
+       return nil
+}
+
 // Thread safe channel list.
 type chanlist struct {
        // protects concurrent access to chans
@@ -389,31 +420,41 @@ type chanWriter struct {
 }
 
 // Write writes data to the remote process's standard input.
-func (w *chanWriter) Write(data []byte) (n int, err error) {
-       for {
-               if w.rwin == 0 {
+func (w *chanWriter) Write(data []byte) (written int, err error) {
+       for len(data) > 0 {
+               for w.rwin < 1 {
                        win, ok := <-w.win
                        if !ok {
                                return 0, io.EOF
                        }
                        w.rwin += win
-                       continue
                }
+               n := min(len(data), w.rwin)
                peersId := w.clientChan.peersId
-               n = len(data)
-               packet := make([]byte, 0, 9+n)
-               packet = append(packet, msgChannelData,
-                       byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
-                       byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
-               err = w.clientChan.writePacket(append(packet, data...))
+               packet := []byte{
+                       msgChannelData,
+                       byte(peersId >> 24), byte(peersId >> 16), byte(peersId >> 8), byte(peersId),
+                       byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
+               }
+               if err = w.clientChan.writePacket(append(packet, data[:n]...)); err != nil {
+                       break
+               }
+               data = data[n:]
                w.rwin -= n
-               return
+               written += n
        }
-       panic("unreachable")
+       return
+}
+
+func min(a, b int) int {
+       if a < b {
+               return a
+       }
+       return b
 }
 
 func (w *chanWriter) Close() error {
-       return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId}))
+       return w.clientChan.sendEOF()
 }
 
 // A chanReader represents stdout or stderr of a remote process.
@@ -422,10 +463,27 @@ type chanReader struct {
        // If writes to this channel block, they will block mainLoop, making
        // it unable to receive new messages from the remote side.
        data       chan []byte // receives data from remote
+       dataClosed bool        // protects data from being closed twice
        clientChan *clientChan // the channel backing this reader
        buf        []byte
 }
 
+// eof signals to the consumer that there is no more data to be received.
+func (r *chanReader) eof() {
+       if !r.dataClosed {
+               r.dataClosed = true
+               close(r.data)
+       }
+}
+
+// handleData sends buf to the reader's consumer. If r.data is closed
+// the data will be silently discarded
+func (r *chanReader) handleData(buf []byte) {
+       if !r.dataClosed {
+               r.data <- buf
+       }
+}
+
 // Read reads data from the remote process's stdout or stderr.
 func (r *chanReader) Read(data []byte) (int, error) {
        var ok bool