OSDN Git Service

libgo: Update to weekly.2012-01-15.
[pf3gnuchains/gcc-fork.git] / libgo / go / exp / ssh / client.go
index 24569ad..8df8145 100644 (file)
@@ -172,40 +172,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
        marshalInt(K, kInt)
        h.Write(K)
 
-       H := h.Sum()
+       H := h.Sum(nil)
 
        return H, K, nil
 }
 
-// openChan opens a new client channel. The most common session type is "session". 
-// The full set of valid session types are listed in RFC 4250 4.9.1.
-func (c *ClientConn) openChan(typ string) (*clientChan, error) {
-       ch := c.newChan(c.transport)
-       if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
-               ChanType:      typ,
-               PeersId:       ch.id,
-               PeersWindow:   1 << 14,
-               MaxPacketSize: 1 << 15, // RFC 4253 6.1
-       })); err != nil {
-               c.chanlist.remove(ch.id)
-               return nil, err
-       }
-       // wait for response
-       switch msg := (<-ch.msg).(type) {
-       case *channelOpenConfirmMsg:
-               ch.peersId = msg.MyId
-               ch.win <- int(msg.MyWindow)
-       case *channelOpenFailureMsg:
-               c.chanlist.remove(ch.id)
-               return nil, errors.New(msg.Message)
-       default:
-               c.chanlist.remove(ch.id)
-               return nil, errors.New("Unexpected packet")
-       }
-       return ch, nil
-}
-
-// mainloop reads incoming messages and routes channel messages
+// mainLoop reads incoming messages and routes channel messages
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
        // TODO(dfc) signal the underlying close to all channels
@@ -215,10 +187,10 @@ func (c *ClientConn) mainLoop() {
                if err != nil {
                        break
                }
-               // TODO(dfc) A note on blocking channel use. 
-               // The msg, win, data and dataExt channels of a clientChan can 
-               // cause this loop to block indefinately if the consumer does 
-               // not service them. 
+               // TODO(dfc) A note on blocking channel use.
+               // The msg, win, data and dataExt channels of a clientChan can
+               // cause this loop to block indefinately if the consumer does
+               // not service them.
                switch packet[0] {
                case msgChannelData:
                        if len(packet) < 9 {
@@ -228,7 +200,7 @@ func (c *ClientConn) mainLoop() {
                        peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
                        if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
                                packet = packet[9:]
-                               c.getChan(peersId).data <- packet[:length]
+                               c.getChan(peersId).stdout.handleData(packet[:length])
                        }
                case msgChannelExtendedData:
                        if len(packet) < 13 {
@@ -239,11 +211,11 @@ func (c *ClientConn) mainLoop() {
                        datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
                        if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
                                packet = packet[13:]
-                               // RFC 4254 5.2 defines data_type_code 1 to be data destined 
+                               // RFC 4254 5.2 defines data_type_code 1 to be data destined
                                // for stderr on interactive sessions. Other data types are
                                // silently discarded.
                                if datatype == 1 {
-                                       c.getChan(peersId).dataExt <- packet[:length]
+                                       c.getChan(peersId).stderr.handleData(packet[:length])
                                }
                        }
                default:
@@ -256,12 +228,22 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelCloseMsg:
                                ch := c.getChan(msg.PeersId)
-                               close(ch.win)
-                               close(ch.data)
-                               close(ch.dataExt)
+                               ch.theyClosed = true
+                               close(ch.stdin.win)
+                               ch.stdout.eof()
+                               ch.stderr.eof()
+                               close(ch.msg)
+                               if !ch.weClosed {
+                                       ch.weClosed = true
+                                       ch.sendClose()
+                               }
                                c.chanlist.remove(msg.PeersId)
                        case *channelEOFMsg:
-                               c.getChan(msg.PeersId).msg <- msg
+                               ch := c.getChan(msg.PeersId)
+                               ch.stdout.eof()
+                               // RFC 4254 is mute on how EOF affects dataExt messages but
+                               // it is logical to signal EOF at the same time.
+                               ch.stderr.eof()
                        case *channelRequestSuccessMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelRequestFailureMsg:
@@ -269,15 +251,17 @@ func (c *ClientConn) mainLoop() {
                        case *channelRequestMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *windowAdjustMsg:
-                               c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
+                               c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
+                       case *disconnectMsg:
+                               break
                        default:
-                               fmt.Printf("mainLoop: unhandled %#v\n", msg)
+                               fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
                        }
                }
        }
 }
 
-// Dial connects to the given network address using net.Dial and 
+// Dial connects to the given network address using net.Dial and
 // then initiates a SSH handshake, returning the resulting client connection.
 func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        conn, err := net.Dial(network, addr)
@@ -287,18 +271,18 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        return Client(conn, config)
 }
 
-// A ClientConfig structure is used to configure a ClientConn. After one has 
+// A ClientConfig structure is used to configure a ClientConn. After one has
 // been passed to an SSH function it must not be modified.
 type ClientConfig struct {
-       // Rand provides the source of entropy for key exchange. If Rand is 
-       // nil, the cryptographic random reader in package crypto/rand will 
+       // Rand provides the source of entropy for key exchange. If Rand is
+       // nil, the cryptographic random reader in package crypto/rand will
        // be used.
        Rand io.Reader
 
        // The username to authenticate.
        User string
 
-       // A slice of ClientAuth methods. Only the first instance 
+       // A slice of ClientAuth methods. Only the first instance
        // of a particular RFC 4252 method will be used during authentication.
        Auth []ClientAuth
 
@@ -313,44 +297,80 @@ func (c *ClientConfig) rand() io.Reader {
        return c.Rand
 }
 
-// A clientChan represents a single RFC 4254 channel that is multiplexed 
+// A clientChan represents a single RFC 4254 channel that is multiplexed
 // over a single SSH connection.
 type clientChan struct {
        packetWriter
        id, peersId uint32
-       data        chan []byte      // receives the payload of channelData messages
-       dataExt     chan []byte      // receives the payload of channelExtendedData messages
-       win         chan int         // receives window adjustments
+       stdin       *chanWriter      // receives window adjustments
+       stdout      *chanReader      // receives the payload of channelData messages
+       stderr      *chanReader      // receives the payload of channelExtendedData messages
        msg         chan interface{} // incoming messages
+
+       theyClosed bool // indicates the close msg has been received from the remote side
+       weClosed   bool // incidates the close msg has been sent from our side
 }
 
+// newClientChan returns a partially constructed *clientChan
+// using the local id provided. To be usable clientChan.peersId
+// needs to be assigned once known.
 func newClientChan(t *transport, id uint32) *clientChan {
-       return &clientChan{
+       c := &clientChan{
                packetWriter: t,
                id:           id,
-               data:         make(chan []byte, 16),
-               dataExt:      make(chan []byte, 16),
-               win:          make(chan int, 16),
                msg:          make(chan interface{}, 16),
        }
+       c.stdin = &chanWriter{
+               win:        make(chan int, 16),
+               clientChan: c,
+       }
+       c.stdout = &chanReader{
+               data:       make(chan []byte, 16),
+               clientChan: c,
+       }
+       c.stderr = &chanReader{
+               data:       make(chan []byte, 16),
+               clientChan: c,
+       }
+       return c
 }
 
-// Close closes the channel. This does not close the underlying connection.
-func (c *clientChan) Close() error {
+// waitForChannelOpenResponse, if successful, fills out
+// the peerId and records any initial window advertisement.
+func (c *clientChan) waitForChannelOpenResponse() error {
+       switch msg := (<-c.msg).(type) {
+       case *channelOpenConfirmMsg:
+               // fixup peersId field
+               c.peersId = msg.MyId
+               c.stdin.win <- int(msg.MyWindow)
+               return nil
+       case *channelOpenFailureMsg:
+               return errors.New(safeString(msg.Message))
+       }
+       return errors.New("unexpected packet")
+}
+
+// sendEOF sends EOF to the server. RFC 4254 Section 5.3
+func (c *clientChan) sendEOF() error {
+       return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
+               PeersId: c.peersId,
+       }))
+}
+
+// sendClose signals the intent to close the channel.
+func (c *clientChan) sendClose() error {
        return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
-               PeersId: c.id,
+               PeersId: c.peersId,
        }))
 }
 
-func (c *clientChan) sendChanReq(req channelRequestMsg) error {
-       if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
-               return err
-       }
-       msg := <-c.msg
-       if _, ok := msg.(*channelRequestSuccessMsg); ok {
-               return nil
+// Close closes the channel. This does not close the underlying connection.
+func (c *clientChan) Close() error {
+       if !c.weClosed {
+               c.weClosed = true
+               return c.sendClose()
        }
-       return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
+       return nil
 }
 
 // Thread safe channel list.
@@ -358,7 +378,7 @@ type chanlist struct {
        // protects concurrent access to chans
        sync.Mutex
        // chans are indexed by the local id of the channel, clientChan.id.
-       // The PeersId value of messages received by ClientConn.mainloop is
+       // The PeersId value of messages received by ClientConn.mainLoop is
        // used to locate the right local clientChan in this slice.
        chans []*clientChan
 }
@@ -394,37 +414,47 @@ func (c *chanlist) remove(id uint32) {
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
-       win          chan int // receives window adjustments
-       id           uint32   // this channel's id
-       rwin         int      // current rwin size
-       packetWriter          // for sending channelDataMsg
+       win        chan int    // receives window adjustments
+       rwin       int         // current rwin size
+       clientChan *clientChan // the channel backing this writer
 }
 
 // Write writes data to the remote process's standard input.
-func (w *chanWriter) Write(data []byte) (n int, err error) {
-       for {
-               if w.rwin == 0 {
+func (w *chanWriter) Write(data []byte) (written int, err error) {
+       for len(data) > 0 {
+               for w.rwin < 1 {
                        win, ok := <-w.win
                        if !ok {
                                return 0, io.EOF
                        }
                        w.rwin += win
-                       continue
                }
-               n = len(data)
-               packet := make([]byte, 0, 9+n)
-               packet = append(packet, msgChannelData,
-                       byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
-                       byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
-               err = w.writePacket(append(packet, data...))
+               n := min(len(data), w.rwin)
+               peersId := w.clientChan.peersId
+               packet := []byte{
+                       msgChannelData,
+                       byte(peersId >> 24), byte(peersId >> 16), byte(peersId >> 8), byte(peersId),
+                       byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
+               }
+               if err = w.clientChan.writePacket(append(packet, data[:n]...)); err != nil {
+                       break
+               }
+               data = data[n:]
                w.rwin -= n
-               return
+               written += n
        }
-       panic("unreachable")
+       return
+}
+
+func min(a, b int) int {
+       if a < b {
+               return a
+       }
+       return b
 }
 
 func (w *chanWriter) Close() error {
-       return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
+       return w.clientChan.sendEOF()
 }
 
 // A chanReader represents stdout or stderr of a remote process.
@@ -432,10 +462,26 @@ type chanReader struct {
        // TODO(dfc) a fixed size channel may not be the right data structure.
        // If writes to this channel block, they will block mainLoop, making
        // it unable to receive new messages from the remote side.
-       data         chan []byte // receives data from remote
-       id           uint32
-       packetWriter // for sending windowAdjustMsg
-       buf          []byte
+       data       chan []byte // receives data from remote
+       dataClosed bool        // protects data from being closed twice
+       clientChan *clientChan // the channel backing this reader
+       buf        []byte
+}
+
+// eof signals to the consumer that there is no more data to be received.
+func (r *chanReader) eof() {
+       if !r.dataClosed {
+               r.dataClosed = true
+               close(r.data)
+       }
+}
+
+// handleData sends buf to the reader's consumer. If r.data is closed
+// the data will be silently discarded
+func (r *chanReader) handleData(buf []byte) {
+       if !r.dataClosed {
+               r.data <- buf
+       }
 }
 
 // Read reads data from the remote process's stdout or stderr.
@@ -446,10 +492,10 @@ func (r *chanReader) Read(data []byte) (int, error) {
                        n := copy(data, r.buf)
                        r.buf = r.buf[n:]
                        msg := windowAdjustMsg{
-                               PeersId:         r.id,
+                               PeersId:         r.clientChan.peersId,
                                AdditionalBytes: uint32(n),
                        }
-                       return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
+                       return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
                }
                r.buf, ok = <-r.data
                if !ok {
@@ -458,7 +504,3 @@ func (r *chanReader) Read(data []byte) (int, error) {
        }
        panic("unreachable")
 }
-
-func (r *chanReader) Close() error {
-       return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
-}