-2f4482b89a6b
+b4a91b693374
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.
go_mime_files = \
go/mime/grammar.go \
go/mime/mediatype.go \
- go/mime/type.go
+ go/mime/type.go \
+ go/mime/type_unix.go
if LIBGO_IS_RTEMS
go_net_fd_os_file = go/net/fd_select.go
$(go_os_dir_file) \
go/os/dir.go \
go/os/env.go \
- go/os/env_unix.go \
go/os/error.go \
go/os/error_posix.go \
go/os/exec.go \
go/exp/sql/sql.go
go_exp_ssh_files = \
go/exp/ssh/channel.go \
+ go/exp/ssh/cipher.go \
go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \
go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
+ go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go
go_exp_terminal_files = \
- go/exp/terminal/shell.go \
- go/exp/terminal/terminal.go
+ go/exp/terminal/terminal.go \
+ go/exp/terminal/util.go
go_exp_types_files = \
go/exp/types/check.go \
go/exp/types/const.go \
endif
go_base_syscall_files = \
+ go/syscall/env_unix.go \
go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \
go/syscall/socket.go \
go_mime_files = \
go/mime/grammar.go \
go/mime/mediatype.go \
- go/mime/type.go
+ go/mime/type.go \
+ go/mime/type_unix.go
# By default use select with pipes. Most systems should have
# something better.
$(go_os_dir_file) \
go/os/dir.go \
go/os/env.go \
- go/os/env_unix.go \
go/os/error.go \
go/os/error_posix.go \
go/os/exec.go \
go_exp_ssh_files = \
go/exp/ssh/channel.go \
+ go/exp/ssh/cipher.go \
go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \
go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
+ go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go
go_exp_terminal_files = \
- go/exp/terminal/shell.go \
- go/exp/terminal/terminal.go
+ go/exp/terminal/terminal.go \
+ go/exp/terminal/util.go
go_exp_types_files = \
go/exp/types/check.go \
# Support for netlink sockets and messages.
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
go_base_syscall_files = \
+ go/syscall/env_unix.go \
go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \
go/syscall/socket.go \
"fmt"
"io"
"io/ioutil"
- "os"
"strings"
"testing"
"testing/iotest"
{0, 1, nil, io.ErrShortWrite},
{1, 2, nil, io.ErrShortWrite},
{1, 1, nil, nil},
- {0, 1, os.EPIPE, os.EPIPE},
- {1, 2, os.EPIPE, os.EPIPE},
- {1, 1, os.EPIPE, os.EPIPE},
+ {0, 1, io.ErrClosedPipe, io.ErrClosedPipe},
+ {1, 2, io.ErrClosedPipe, io.ErrClosedPipe},
+ {1, 1, io.ErrClosedPipe, io.ErrClosedPipe},
}
func TestWriteErrors(t *testing.T) {
// invocation.
type Type int
+// Type1 is here for the purposes of documentation only. It is a stand-in
+// for any Go type, but represents the same type for any given function
+// invocation.
+type Type1 int
+
// IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc.
type IntegerType int
// len(src) and len(dst).
func copy(dst, src []Type) int
+// The delete built-in function deletes the element with the specified key
+// (m[key]) from the map. If there is no such element, delete is a no-op.
+// If m is nil, delete panics.
+func delete(m map[Type]Type1, key Type)
+
// The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil).
// The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType
-// The imaginary built-in function returns the imaginary part of the complex
+// The imag built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to
// the type of c.
func imag(c ComplexType) FloatType
}
type TrimTest struct {
- f func([]byte, string) []byte
+ f string
in, cutset, out string
}
var trimTests = []TrimTest{
- {Trim, "abba", "a", "bb"},
- {Trim, "abba", "ab", ""},
- {TrimLeft, "abba", "ab", ""},
- {TrimRight, "abba", "ab", ""},
- {TrimLeft, "abba", "a", "bba"},
- {TrimRight, "abba", "a", "abb"},
- {Trim, "<tag>", "<>", "tag"},
- {Trim, "* listitem", " *", "listitem"},
- {Trim, `"quote"`, `"`, "quote"},
- {Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
+ {"Trim", "abba", "a", "bb"},
+ {"Trim", "abba", "ab", ""},
+ {"TrimLeft", "abba", "ab", ""},
+ {"TrimRight", "abba", "ab", ""},
+ {"TrimLeft", "abba", "a", "bba"},
+ {"TrimRight", "abba", "a", "abb"},
+ {"Trim", "<tag>", "<>", "tag"},
+ {"Trim", "* listitem", " *", "listitem"},
+ {"Trim", `"quote"`, `"`, "quote"},
+ {"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests
- {Trim, "abba", "", "abba"},
- {Trim, "", "123", ""},
- {Trim, "", "", ""},
- {TrimLeft, "abba", "", "abba"},
- {TrimLeft, "", "123", ""},
- {TrimLeft, "", "", ""},
- {TrimRight, "abba", "", "abba"},
- {TrimRight, "", "123", ""},
- {TrimRight, "", "", ""},
- {TrimRight, "☺\xc0", "☺", "☺\xc0"},
+ {"Trim", "abba", "", "abba"},
+ {"Trim", "", "123", ""},
+ {"Trim", "", "", ""},
+ {"TrimLeft", "abba", "", "abba"},
+ {"TrimLeft", "", "123", ""},
+ {"TrimLeft", "", "", ""},
+ {"TrimRight", "abba", "", "abba"},
+ {"TrimRight", "", "123", ""},
+ {"TrimRight", "", "", ""},
+ {"TrimRight", "☺\xc0", "☺", "☺\xc0"},
}
func TestTrim(t *testing.T) {
for _, tc := range trimTests {
- actual := string(tc.f([]byte(tc.in), tc.cutset))
- var name string
- switch tc.f {
- case Trim:
- name = "Trim"
- case TrimLeft:
- name = "TrimLeft"
- case TrimRight:
- name = "TrimRight"
+ name := tc.f
+ var f func([]byte, string) []byte
+ switch name {
+ case "Trim":
+ f = Trim
+ case "TrimLeft":
+ f = TrimLeft
+ case "TrimRight":
+ f = TrimRight
default:
- t.Error("Undefined trim function")
+ t.Error("Undefined trim function %s", name)
}
+ actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
}
"errors"
"fmt"
"io"
- "os"
)
// Order specifies the bit ordering in an LZW data stream.
d.o = 0
}
+var errClosed = errors.New("compress/lzw: reader/writer is closed")
+
func (d *decoder) Close() error {
- d.err = os.EINVAL // in case any Reads come along
+ d.err = errClosed // in case any Reads come along
return nil
}
"errors"
"fmt"
"io"
- "os"
)
// A writer is a buffered, flushable writer.
type encoder struct {
// w is the writer that compressed bytes are written to.
w writer
- // write, bits, nBits and width are the state for converting a code stream
- // into a byte stream.
+ // order, write, bits, nBits and width are the state for
+ // converting a code stream into a byte stream.
+ order Order
write func(*encoder, uint32) error
bits uint32
nBits uint
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
// err is the first error encountered during writing. Closing the encoder
- // will make any future Write calls return os.EINVAL.
+ // will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing.
// flush e's underlying writer.
func (e *encoder) Close() error {
if e.err != nil {
- if e.err == os.EINVAL {
+ if e.err == errClosed {
return nil
}
return e.err
}
- // Make any future calls to Write return os.EINVAL.
- e.err = os.EINVAL
+ // Make any future calls to Write return errClosed.
+ e.err = errClosed
// Write the savedCode if valid.
if e.savedCode != invalidCode {
if err := e.write(e, e.savedCode); err != nil {
}
// Write the final bits.
if e.nBits > 0 {
- if e.write == (*encoder).writeMSB {
+ if e.order == MSB {
e.bits >>= 24
}
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
lw := uint(litWidth)
return &encoder{
w: bw,
+ order: order,
write: write,
width: 1 + lw,
litWidth: lw,
return
}
_, err1 := lzww.Write(b[:n])
- if err1 == os.EPIPE {
- // Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
- return
- }
if err1 != nil {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return
}
defer zlibw.Close()
_, err = zlibw.Write(b0)
- if err == os.EPIPE {
- // Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
- return
- }
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return
}
// BlockSize returns the AES block size, 16 bytes.
-// It is necessary to satisfy the Cipher interface in the
+// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
}
// BlockSize returns the Blowfish block size, 8 bytes.
-// It is necessary to satisfy the Cipher interface in the
+// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
if r.prov == 0 {
const provType = syscall.PROV_RSA_FULL
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
- errno := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
- if errno != 0 {
+ err := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
+ if err != nil {
r.mu.Unlock()
- return 0, os.NewSyscallError("CryptAcquireContext", errno)
+ return 0, os.NewSyscallError("CryptAcquireContext", err)
}
}
r.mu.Unlock()
- errno := syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
- if errno != 0 {
- return 0, os.NewSyscallError("CryptGenRandom", errno)
+ err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
+ if err != nil {
+ return 0, os.NewSyscallError("CryptGenRandom", err)
}
return len(b), nil
}
package rand
import (
+ "errors"
"io"
"math/big"
- "os"
)
// Prime returns a number, p, of the given size, such that p is prime
// with high probability.
func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
if bits < 1 {
- err = os.EINVAL
+ err = errors.New("crypto/rand: prime size must be positive")
}
b := uint(bits % 8)
}
// SetReadTimeout sets the time (in nanoseconds) that
-// Read will wait for data before returning os.EAGAIN.
+// Read will wait for data before returning a net.Error
+// with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
func (c *Conn) SetReadTimeout(nsec int64) error {
return c.conn.SetReadTimeout(nsec)
return c.writeRecord(recordTypeApplicationData, b)
}
-// Read can be made to time out and return err == os.EAGAIN
+// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetTimeout and SetReadTimeout.
func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
package tls
+import "bytes"
+
type clientHelloMsg struct {
raw []byte
vers uint16
supportedPoints []uint8
}
+func (m *clientHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ eqUint16s(m.cipherSuites, m1.cipherSuites) &&
+ bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ m.serverName == m1.serverName &&
+ m.ocspStapling == m1.ocspStapling &&
+ eqUint16s(m.supportedCurves, m1.supportedCurves) &&
+ bytes.Equal(m.supportedPoints, m1.supportedPoints)
+}
+
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
ocspStapling bool
}
+func (m *serverHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ m.cipherSuite == m1.cipherSuite &&
+ m.compressionMethod == m1.compressionMethod &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ eqStrings(m.nextProtos, m1.nextProtos) &&
+ m.ocspStapling == m1.ocspStapling
+}
+
func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
certificates [][]byte
}
+func (m *certificateMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ eqByteSlices(m.certificates, m1.certificates)
+}
+
func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
key []byte
}
+func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.key, m1.key)
+}
+
func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
response []byte
}
+func (m *certificateStatusMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateStatusMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.statusType == m1.statusType &&
+ bytes.Equal(m.response, m1.response)
+}
+
func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil {
return m.raw
type serverHelloDoneMsg struct{}
+func (m *serverHelloDoneMsg) equal(i interface{}) bool {
+ _, ok := i.(*serverHelloDoneMsg)
+ return ok
+}
+
func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4)
x[0] = typeServerHelloDone
ciphertext []byte
}
+func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.ciphertext, m1.ciphertext)
+}
+
func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
verifyData []byte
}
+func (m *finishedMsg) equal(i interface{}) bool {
+ m1, ok := i.(*finishedMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.verifyData, m1.verifyData)
+}
+
func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
proto string
}
+func (m *nextProtoMsg) equal(i interface{}) bool {
+ m1, ok := i.(*nextProtoMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.proto == m1.proto
+}
+
func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil {
return m.raw
certificateAuthorities [][]byte
}
+func (m *certificateRequestMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateRequestMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
+ eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
+}
+
func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
signature []byte
}
+func (m *certificateVerifyMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateVerifyMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.signature, m1.signature)
+}
+
func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
return true
}
+
+func eqUint16s(x, y []uint16) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqStrings(x, y []string) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqByteSlices(x, y [][]byte) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if !bytes.Equal(v, y[i]) {
+ return false
+ }
+ }
+ return true
+}
type testMessage interface {
marshal() []byte
unmarshal([]byte) bool
+ equal(interface{}) bool
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
+
for i, iface := range tests {
ty := reflect.ValueOf(iface).Type()
}
m2.marshal() // to fill any marshal cache in the message
- if !reflect.DeepEqual(m1, m2) {
+ if !m1.equal(m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break
}
)
func loadStore(roots *x509.CertPool, name string) {
- store, errno := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
- if errno != 0 {
+ store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
+ if err != nil {
return
}
}
// BlockSize returns the XTEA block size, 8 bytes.
-// It is necessary to satisfy the Cipher interface in the
+// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Large data benchmark.
+// The JSON data is a summary of agl's changes in the
+// go, webkit, and chromium open source projects.
+// We benchmark converting between the JSON form
+// and in-memory data structures.
+
+package json
+
+import (
+ "bytes"
+ "compress/gzip"
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+type codeResponse struct {
+ Tree *codeNode `json:"tree"`
+ Username string `json:"username"`
+}
+
+type codeNode struct {
+ Name string `json:"name"`
+ Kids []*codeNode `json:"kids"`
+ CLWeight float64 `json:"cl_weight"`
+ Touches int `json:"touches"`
+ MinT int64 `json:"min_t"`
+ MaxT int64 `json:"max_t"`
+ MeanT int64 `json:"mean_t"`
+}
+
+var codeJSON []byte
+var codeStruct codeResponse
+
+func codeInit() {
+ f, err := os.Open("testdata/code.json.gz")
+ if err != nil {
+ panic(err)
+ }
+ defer f.Close()
+ gz, err := gzip.NewReader(f)
+ if err != nil {
+ panic(err)
+ }
+ data, err := ioutil.ReadAll(gz)
+ if err != nil {
+ panic(err)
+ }
+
+ codeJSON = data
+
+ if err := Unmarshal(codeJSON, &codeStruct); err != nil {
+ panic("unmarshal code.json: " + err.Error())
+ }
+
+ if data, err = Marshal(&codeStruct); err != nil {
+ panic("marshal code.json: " + err.Error())
+ }
+
+ if !bytes.Equal(data, codeJSON) {
+ println("different lengths", len(data), len(codeJSON))
+ for i := 0; i < len(data) && i < len(codeJSON); i++ {
+ if data[i] != codeJSON[i] {
+ println("re-marshal: changed at byte", i)
+ println("orig: ", string(codeJSON[i-10:i+10]))
+ println("new: ", string(data[i-10:i+10]))
+ break
+ }
+ }
+ panic("re-marshal code.json: different result")
+ }
+}
+
+func BenchmarkCodeEncoder(b *testing.B) {
+ if codeJSON == nil {
+ b.StopTimer()
+ codeInit()
+ b.StartTimer()
+ }
+ enc := NewEncoder(ioutil.Discard)
+ for i := 0; i < b.N; i++ {
+ if err := enc.Encode(&codeStruct); err != nil {
+ panic(err)
+ }
+ }
+ b.SetBytes(int64(len(codeJSON)))
+}
+
+func BenchmarkCodeMarshal(b *testing.B) {
+ if codeJSON == nil {
+ b.StopTimer()
+ codeInit()
+ b.StartTimer()
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := Marshal(&codeStruct); err != nil {
+ panic(err)
+ }
+ }
+ b.SetBytes(int64(len(codeJSON)))
+}
+
+func BenchmarkCodeDecoder(b *testing.B) {
+ if codeJSON == nil {
+ b.StopTimer()
+ codeInit()
+ b.StartTimer()
+ }
+ var buf bytes.Buffer
+ dec := NewDecoder(&buf)
+ var r codeResponse
+ for i := 0; i < b.N; i++ {
+ buf.Write(codeJSON)
+ // hide EOF
+ buf.WriteByte('\n')
+ buf.WriteByte('\n')
+ buf.WriteByte('\n')
+ if err := dec.Decode(&r); err != nil {
+ panic(err)
+ }
+ }
+ b.SetBytes(int64(len(codeJSON)))
+}
+
+func BenchmarkCodeUnmarshal(b *testing.B) {
+ if codeJSON == nil {
+ b.StopTimer()
+ codeInit()
+ b.StartTimer()
+ }
+ for i := 0; i < b.N; i++ {
+ var r codeResponse
+ if err := Unmarshal(codeJSON, &r); err != nil {
+ panic(err)
+ }
+ }
+ b.SetBytes(int64(len(codeJSON)))
+}
+
+func BenchmarkCodeUnmarshalReuse(b *testing.B) {
+ if codeJSON == nil {
+ b.StopTimer()
+ codeInit()
+ b.StartTimer()
+ }
+ var r codeResponse
+ for i := 0; i < b.N; i++ {
+ if err := Unmarshal(codeJSON, &r); err != nil {
+ panic(err)
+ }
+ }
+ b.SetBytes(int64(len(codeJSON)))
+}
// d.scan thinks we're still at the beginning of the item.
// Feed in an empty string - the shortest, simplest value -
// so that it knows we got to the end of the value.
- if d.scan.step == stateRedo {
+ if d.scan.redo {
panic("redo")
}
d.scan.step(&d.scan, '"')
d.error(errPhase)
}
}
+
if i < av.Len() {
if !sv.IsValid() {
// Array. Zero the rest.
sv.SetLen(i)
}
}
+ if i == 0 && av.Kind() == reflect.Slice && sv.IsNil() {
+ sv.Set(reflect.MakeSlice(sv.Type(), 0, 0))
+ }
}
// object consumes an object from d.data[d.off-1:], decoding into the value v.
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, int) int
+ // Reached end of top-level value.
+ endTop bool
+
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
err error
// 1-byte redo (see undo method)
+ redo bool
redoCode int
redoState func(*scanner, int) int
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
+ s.redo = false
+ s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
if s.err != nil {
return scanError
}
- if s.step == stateEndTop {
+ if s.endTop {
return scanEnd
}
s.step(s, ' ')
- if s.step == stateEndTop {
+ if s.endTop {
return scanEnd
}
if s.err == nil {
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
+ s.redo = false
if n == 0 {
s.step = stateEndTop
+ s.endTop = true
} else {
s.step = stateEndValue
}
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
+ s.endTop = true
return stateEndTop(s, c)
}
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
// undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) {
- if s.step == stateRedo {
- panic("invalid use of scanner")
+ if s.redo {
+ panic("json: invalid use of scanner")
}
s.redoCode = scanCode
s.redoState = s.step
s.step = stateRedo
+ s.redo = true
}
// stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c int) int {
+ s.redo = false
s.step = s.redoState
return s.redoCode
}
}
}
+var benchScan scanner
+
func BenchmarkSkipValue(b *testing.B) {
initBig()
- var scan scanner
for i := 0; i < b.N; i++ {
- nextValue(jsonBig, &scan)
+ nextValue(jsonBig, &benchScan)
}
b.SetBytes(int64(len(jsonBig)))
}
import (
"bytes"
"io"
- "os"
"reflect"
"strings"
"testing"
CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"", "hello"}},
CharData([]byte("\n ")),
- StartElement{Name{"", "goodbye"}, nil},
+ StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}},
CharData([]byte("\n ")),
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
- StartElement{Name{"", "inner"}, nil},
+ StartElement{Name{"", "inner"}, []Attr{}},
EndElement{Name{"", "inner"}},
CharData([]byte("\n ")),
EndElement{Name{"", "outer"}},
CharData([]byte("\n ")),
- StartElement{Name{"tag", "name"}, nil},
+ StartElement{Name{"tag", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"ns2", "hello"}},
CharData([]byte("\n ")),
- StartElement{Name{"ns2", "goodbye"}, nil},
+ StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}},
CharData([]byte("\n ")),
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
- StartElement{Name{"ns2", "inner"}, nil},
+ StartElement{Name{"ns2", "inner"}, []Attr{}},
EndElement{Name{"ns2", "inner"}},
CharData([]byte("\n ")),
EndElement{Name{"ns2", "outer"}},
CharData([]byte("\n ")),
- StartElement{Name{"ns3", "name"}, nil},
+ StartElement{Name{"ns3", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
CharData([]byte("\n")),
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
CharData([]byte("\n")),
- StartElement{Name{"", "tag"}, nil},
+ StartElement{Name{"", "tag"}, []Attr{}},
CharData([]byte("value")),
EndElement{Name{"", "tag"}},
}
func (d *downCaser) Read(p []byte) (int, error) {
d.t.Fatalf("unexpected Read call on downCaser reader")
- return 0, os.EINVAL
+ panic("unreachable")
}
func TestRawTokenAltEncoding(t *testing.T) {
watchEntry.flags |= flags
flags |= syscall.IN_MASK_ADD
}
- wd, errno := syscall.InotifyAddWatch(w.fd, path, flags)
- if wd == -1 {
- return &os.PathError{"inotify_add_watch", path, os.Errno(errno)}
+ wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
+ if err != nil {
+ return &os.PathError{"inotify_add_watch", path, err}
}
if !found {
// readEvents reads from the inotify file descriptor, converts the
// received events into Event objects and sends them via the Event channel
func (w *Watcher) readEvents() {
- var (
- buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
- n int // Number of bytes read with read()
- errno int // Syscall errno
- )
+ var buf [syscall.SizeofInotifyEvent * 4096]byte
for {
- n, errno = syscall.Read(w.fd, buf[0:])
+ n, err := syscall.Read(w.fd, buf[0:])
// See if there is a message on the "done" channel
var done bool
select {
// If EOF or a "done" message is received
if n == 0 || done {
- errno := syscall.Close(w.fd)
- if errno == -1 {
- w.Error <- os.NewSyscallError("close", errno)
+ err := syscall.Close(w.fd)
+ if err != nil {
+ w.Error <- os.NewSyscallError("close", err)
}
close(w.Event)
close(w.Error)
return
}
if n < 0 {
- w.Error <- os.NewSyscallError("read", errno)
+ w.Error <- os.NewSyscallError("read", err)
continue
}
if n < syscall.SizeofInotifyEvent {
"strconv"
)
+// subsetTypeArgs takes a slice of arguments from callers of the sql
+// package and converts them into a slice of the driver package's
+// "subset types".
+func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
+ out := make([]interface{}, len(args))
+ for n, arg := range args {
+ var err error
+ out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+ if err != nil {
+ return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
+ }
+ }
+ return out, nil
+}
+
// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.
Open(name string) (Conn, error)
}
-// Execer is an optional interface that may be implemented by a Driver
-// or a Conn.
-//
-// If a Driver does not implement Execer, the sql package's DB.Exec
-// method first obtains a free connection from its free pool or from
-// the driver's Open method. Execer should only be implemented by
-// drivers that can provide a more efficient implementation.
+// ErrSkip may be returned by some optional interfaces' methods to
+// indicate at runtime that the fast path is unavailable and the sql
+// package should continue as if the optional interface was not
+// implemented. ErrSkip is only supported where explicitly
+// documented.
+var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
+
+// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the db package's DB.Exec will
// first prepare a query, execute the statement, and then close the
// statement.
//
// All arguments are of a subset type as defined in the package docs.
+//
+// Exec may return ErrSkip.
type Execer interface {
Exec(query string, args []interface{}) (Result, error)
}
Close() error
// NumInput returns the number of placeholder parameters.
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
NumInput() int
// Exec executes a query that doesn't return rows, such
// The dest slice may be populated with only with values
// of subset types defined above, but excluding string.
// All string values must be converted to []byte.
+ //
+ // Next should return io.EOF when there are no more rows.
Next(dest []interface{}) error
}
return nil
}
+func checkSubsetTypes(args []interface{}) error {
+ for n, arg := range args {
+ switch arg.(type) {
+ case int64, float64, bool, nil, []byte, string:
+ default:
+ return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+ }
+ }
+ return nil
+}
+
+func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
+ // This is an optional interface, but it's implemented here
+ // just to check that all the args of of the proper types.
+ // ErrSkip is returned so the caller acts as if we didn't
+ // implement this at all.
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+ return nil, driver.ErrSkip
+}
+
func errf(msg string, args ...interface{}) error {
return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
}
}
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+
db := s.c.db
switch s.cmd {
case "WIPE":
}
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
driver driver.Driver
dsn string
- mu sync.Mutex
+ mu sync.Mutex // protects freeConn and closed
freeConn []driver.Conn
+ closed bool
}
// Open opens a database specified by its database driver name and a
return &DB{driver: driver, dsn: dataSourceName}, nil
}
+// Close closes the database, releasing any open resources.
+func (db *DB) Close() error {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ var err error
+ for _, c := range db.freeConn {
+ err1 := c.Close()
+ if err1 != nil {
+ err = err1
+ }
+ }
+ db.freeConn = nil
+ db.closed = true
+ return err
+}
+
func (db *DB) maxIdleConns() int {
const defaultMaxIdleConns = 2
// TODO(bradfitz): ask driver, if supported, for its default preference
// conn returns a newly-opened or cached driver.Conn
func (db *DB) conn() (driver.Conn, error) {
db.mu.Lock()
+ if db.closed {
+ return nil, errors.New("sql: database is closed")
+ }
if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1]
}
func (db *DB) putConn(c driver.Conn) {
- if n := len(db.freeConn); n < db.maxIdleConns() {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
db.freeConn = append(db.freeConn, c)
return
}
- db.closeConn(c)
+ db.closeConn(c) // TODO(bradfitz): release lock before calling this?
}
func (db *DB) closeConn(c driver.Conn) {
// Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
- // Optional fast path, if the driver implements driver.Execer.
- if execer, ok := db.driver.(driver.Execer); ok {
- resi, err := execer.Exec(query, args)
- if err != nil {
- return nil, err
- }
- return result{resi}, nil
+ sargs, err := subsetTypeArgs(args)
+ if err != nil {
+ return nil, err
}
- // If the driver does not implement driver.Execer, we need
- // a connection.
ci, err := db.conn()
if err != nil {
return nil, err
defer db.putConn(ci)
if execer, ok := ci.(driver.Execer); ok {
- resi, err := execer.Exec(query, args)
- if err != nil {
- return nil, err
+ resi, err := execer.Exec(query, sargs)
+ if err != driver.ErrSkip {
+ if err != nil {
+ return nil, err
+ }
+ return result{resi}, nil
}
- return result{resi}, nil
}
sti, err := ci.Prepare(query)
return nil, err
}
defer sti.Close()
- resi, err := sti.Exec(args)
+
+ resi, err := sti.Exec(sargs)
if err != nil {
return nil, err
}
return nil, err
}
defer sti.Close()
- resi, err := sti.Exec(args)
+
+ sargs, err := subsetTypeArgs(args)
+ if err != nil {
+ return nil, err
+ }
+
+ resi, err := sti.Exec(sargs)
if err != nil {
return nil, err
}
}
defer releaseConn()
- if want := si.NumInput(); len(args) != want {
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
}
if err != nil {
return nil, err
}
- if len(args) != si.NumInput() {
+
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
}
- rowsi, err := si.Query(args)
+ sargs, err := subsetTypeArgs(args)
+ if err != nil {
+ return nil, err
+ }
+ rowsi, err := si.Query(sargs)
if err != nil {
s.db.putConn(ci)
return nil, err
}
}
+func closeDB(t *testing.T, db *DB) {
+ err := db.Close()
+ if err != nil {
+ t.Fatalf("error closing DB: %v", err)
+ }
+}
+
func TestQuery(t *testing.T) {
db := newTestDB(t, "people")
+ defer closeDB(t, db)
var name string
var age int
func TestStatementQueryRow(t *testing.T) {
db := newTestDB(t, "people")
+ defer closeDB(t, db)
stmt, err := db.Prepare("SELECT|people|age|name=?")
if err != nil {
t.Fatalf("Prepare: %v", err)
// just a test of fakedb itself
func TestBogusPreboundParameters(t *testing.T) {
db := newTestDB(t, "foo")
+ defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
if err == nil {
func TestDb(t *testing.T) {
db := newTestDB(t, "foo")
+ defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rc4"
+)
+
+// streamDump is used to dump the initial keystream for stream ciphers. It is a
+// a write-only buffer, and not intended for reading so do not require a mutex.
+var streamDump [512]byte
+
+// noneCipher implements cipher.Stream and provides no encryption. It is used
+// by the transport before the first key-exchange.
+type noneCipher struct{}
+
+func (c noneCipher) XORKeyStream(dst, src []byte) {
+ copy(dst, src)
+}
+
+func newAESCTR(key, iv []byte) (cipher.Stream, error) {
+ c, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ return cipher.NewCTR(c, iv), nil
+}
+
+func newRC4(key, iv []byte) (cipher.Stream, error) {
+ return rc4.NewCipher(key)
+}
+
+type cipherMode struct {
+ keySize int
+ ivSize int
+ skip int
+ createFn func(key, iv []byte) (cipher.Stream, error)
+}
+
+func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
+ if len(key) < c.keySize {
+ panic("ssh: key length too small for cipher")
+ }
+ if len(iv) < c.ivSize {
+ panic("ssh: iv too small for cipher")
+ }
+
+ stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
+ if err != nil {
+ return nil, err
+ }
+
+ for remainingToDump := c.skip; remainingToDump > 0; {
+ dumpThisTime := remainingToDump
+ if dumpThisTime > len(streamDump) {
+ dumpThisTime = len(streamDump)
+ }
+ stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
+ remainingToDump -= dumpThisTime
+ }
+
+ return stream, nil
+}
+
+// Specifies a default set of ciphers and a preference order. This is based on
+// OpenSSH's default client preference order, minus algorithms that are not
+// implemented.
+var DefaultCipherOrder = []string{
+ "aes128-ctr", "aes192-ctr", "aes256-ctr",
+ "arcfour256", "arcfour128",
+}
+
+var cipherModes = map[string]*cipherMode{
+ // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
+ // are defined in the order specified in the RFC.
+ "aes128-ctr": &cipherMode{16, aes.BlockSize, 0, newAESCTR},
+ "aes192-ctr": &cipherMode{24, aes.BlockSize, 0, newAESCTR},
+ "aes256-ctr": &cipherMode{32, aes.BlockSize, 0, newAESCTR},
+
+ // Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
+ // They are defined in the order specified in the RFC.
+ "arcfour128": &cipherMode{16, 0, 1536, newRC4},
+ "arcfour256": &cipherMode{32, 0, 1536, newRC4},
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bytes"
+ "testing"
+)
+
+// TestCipherReversal tests that each cipher factory produces ciphers that can
+// encrypt and decrypt some data successfully.
+func TestCipherReversal(t *testing.T) {
+ testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
+ testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
+ testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
+
+ cryptBuffer := make([]byte, 32)
+
+ for name, cipherMode := range cipherModes {
+ encrypter, err := cipherMode.createCipher(testKey, testIv)
+ if err != nil {
+ t.Errorf("failed to create encrypter for %q: %s", name, err)
+ continue
+ }
+ decrypter, err := cipherMode.createCipher(testKey, testIv)
+ if err != nil {
+ t.Errorf("failed to create decrypter for %q: %s", name, err)
+ continue
+ }
+
+ copy(cryptBuffer, testData)
+
+ encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
+ if name == "none" {
+ if !bytes.Equal(cryptBuffer, testData) {
+ t.Errorf("encryption made change with 'none' cipher")
+ continue
+ }
+ } else {
+ if bytes.Equal(cryptBuffer, testData) {
+ t.Errorf("encryption made no change with %q", name)
+ continue
+ }
+ }
+
+ decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
+ if !bytes.Equal(cryptBuffer, testData) {
+ t.Errorf("decrypted bytes not equal to input with %q", name)
+ continue
+ }
+ }
+}
+
+func TestDefaultCiphersExist(t *testing.T) {
+ for _, cipherAlgo := range DefaultCipherOrder {
+ if _, ok := cipherModes[cipherAlgo]; !ok {
+ t.Errorf("default cipher %q is unknown", cipherAlgo)
+ }
+ }
+}
conn.Close()
return nil, err
}
- if err := conn.authenticate(); err != nil {
- conn.Close()
- return nil, err
- }
go conn.mainLoop()
return conn, nil
}
clientKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
- CiphersClientServer: supportedCiphers,
- CiphersServerClient: supportedCiphers,
+ CiphersClientServer: c.config.Crypto.ciphers(),
+ CiphersServerClient: c.config.Crypto.ciphers(),
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
- return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
+ if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+ return err
+ }
+ return c.authenticate(H)
}
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
+ ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New(msg.Message)
// A slice of ClientAuth methods. Only the first instance
// of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth
+
+ // Cryptographic-related configuration.
+ Crypto CryptoConfig
}
func (c *ClientConfig) rand() io.Reader {
import (
"errors"
+ "io"
)
// authenticate authenticates with the remote server. See RFC 4252.
-func (c *ClientConn) authenticate() error {
+func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
// then any untried methods suggested by the server.
tried, remain := make(map[string]bool), make(map[string]bool)
for auth := ClientAuth(new(noneAuth)); auth != nil; {
- ok, methods, err := auth.auth(c.config.User, c.transport)
+ ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
if err != nil {
return err
}
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned.
- auth(user string, t *transport) (bool, []string, error)
+ auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name.
method() string
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
-func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) {
+func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user,
Service: serviceSSH,
ClientPassword
}
-func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) {
+func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
User string
Service string
func ClientAuthPassword(impl ClientPassword) ClientAuth {
return &passwordAuth{impl}
}
+
+// ClientKeyring implements access to a client key ring.
+type ClientKeyring interface {
+ // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
+ // no key exists at i.
+ Key(i int) (key interface{}, err error)
+
+ // Sign returns a signature of the given data using the i'th key
+ // and the supplied random source.
+ Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
+}
+
+// "publickey" authentication, RFC 4252 Section 7.
+type publickeyAuth struct {
+ ClientKeyring
+}
+
+func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+ type publickeyAuthMsg struct {
+ User string
+ Service string
+ Method string
+ // HasSig indicates to the reciver packet that the auth request is signed and
+ // should be used for authentication of the request.
+ HasSig bool
+ Algoname string
+ Pubkey string
+ // Sig is defined as []byte so marshal will exclude it during the query phase
+ Sig []byte `ssh:"rest"`
+ }
+
+ // Authentication is performed in two stages. The first stage sends an
+ // enquiry to test if each key is acceptable to the remote. The second
+ // stage attempts to authenticate with the valid keys obtained in the
+ // first stage.
+
+ var index int
+ // a map of public keys to their index in the keyring
+ validKeys := make(map[int]interface{})
+ for {
+ key, err := p.Key(index)
+ if err != nil {
+ return false, nil, err
+ }
+ if key == nil {
+ // no more keys in the keyring
+ break
+ }
+ pubkey := serializePublickey(key)
+ algoname := algoName(key)
+ msg := publickeyAuthMsg{
+ User: user,
+ Service: serviceSSH,
+ Method: p.method(),
+ HasSig: false,
+ Algoname: algoname,
+ Pubkey: string(pubkey),
+ }
+ if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+ return false, nil, err
+ }
+ packet, err := t.readPacket()
+ if err != nil {
+ return false, nil, err
+ }
+ switch packet[0] {
+ case msgUserAuthPubKeyOk:
+ msg := decode(packet).(*userAuthPubKeyOkMsg)
+ if msg.Algo != algoname || msg.PubKey != string(pubkey) {
+ continue
+ }
+ validKeys[index] = key
+ case msgUserAuthFailure:
+ default:
+ return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+ }
+ index++
+ }
+
+ // methods that may continue if this auth is not successful.
+ var methods []string
+ for i, key := range validKeys {
+ pubkey := serializePublickey(key)
+ algoname := algoName(key)
+ sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
+ User: user,
+ Service: serviceSSH,
+ Method: p.method(),
+ }, []byte(algoname), pubkey))
+ if err != nil {
+ return false, nil, err
+ }
+ // manually wrap the serialized signature in a string
+ s := serializeSignature(algoname, sign)
+ sig := make([]byte, stringLength(s))
+ marshalString(sig, s)
+ msg := publickeyAuthMsg{
+ User: user,
+ Service: serviceSSH,
+ Method: p.method(),
+ HasSig: true,
+ Algoname: algoname,
+ Pubkey: string(pubkey),
+ Sig: sig,
+ }
+ p := marshal(msgUserAuthRequest, msg)
+ if err := t.writePacket(p); err != nil {
+ return false, nil, err
+ }
+ packet, err := t.readPacket()
+ if err != nil {
+ return false, nil, err
+ }
+ switch packet[0] {
+ case msgUserAuthSuccess:
+ return true, nil, nil
+ case msgUserAuthFailure:
+ msg := decode(packet).(*userAuthFailureMsg)
+ methods = msg.Methods
+ continue
+ case msgDisconnect:
+ return false, nil, io.EOF
+ default:
+ return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+ }
+ }
+ return false, methods, nil
+}
+
+func (p *publickeyAuth) method() string {
+ return "publickey"
+}
+
+// ClientAuthPublickey returns a ClientAuth using public key authentication.
+func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
+ return &publickeyAuth{impl}
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "io"
+ "io/ioutil"
+ "testing"
+)
+
+const _pem = `-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
+70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
+9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
+tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
+s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
+qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
++IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
+riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
+D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
+atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
+b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
+ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
+MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
+KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
+e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
+D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
+3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
+orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
+64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
+XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
+QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
+/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
+I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
+gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
+NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
+-----END RSA PRIVATE KEY-----`
+
+// reused internally by tests
+var serverConfig = new(ServerConfig)
+
+func init() {
+ if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
+ panic("unable to set private key: " + err.Error())
+ }
+}
+
+// keychain implements the ClientPublickey interface
+type keychain struct {
+ keys []*rsa.PrivateKey
+}
+
+func (k *keychain) Key(i int) (interface{}, error) {
+ if i < 0 || i >= len(k.keys) {
+ return nil, nil
+ }
+ return k.keys[i].PublicKey, nil
+}
+
+func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
+ hashFunc := crypto.SHA1
+ h := hashFunc.New()
+ h.Write(data)
+ digest := h.Sum()
+ return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
+}
+
+func (k *keychain) loadPEM(file string) error {
+ buf, err := ioutil.ReadFile(file)
+ if err != nil {
+ return err
+ }
+ block, _ := pem.Decode(buf)
+ if block == nil {
+ return errors.New("ssh: no key found")
+ }
+ r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+ k.keys = append(k.keys, r)
+ return nil
+}
+
+var pkey *rsa.PrivateKey
+
+func init() {
+ var err error
+ pkey, err = rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ panic("unable to generate public key")
+ }
+}
+
+func TestClientAuthPublickey(t *testing.T) {
+ k := new(keychain)
+ k.keys = append(k.keys, pkey)
+
+ serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
+ expected := []byte(serializePublickey(k.keys[0].PublicKey))
+ algoname := algoName(k.keys[0].PublicKey)
+ return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+ }
+ serverConfig.PasswordCallback = nil
+
+ l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+ if err != nil {
+ t.Fatalf("unable to listen: %s", err)
+ }
+ defer l.Close()
+
+ done := make(chan bool, 1)
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if err := c.Handshake(); err != nil {
+ t.Error(err)
+ }
+ done <- true
+ }()
+
+ config := &ClientConfig{
+ User: "testuser",
+ Auth: []ClientAuth{
+ ClientAuthPublickey(k),
+ },
+ }
+
+ c, err := Dial("tcp", l.Addr().String(), config)
+ if err != nil {
+ t.Fatalf("unable to dial remote side: %s", err)
+ }
+ defer c.Close()
+ <-done
+}
+
+// password implements the ClientPassword interface
+type password string
+
+func (p password) Password(user string) (string, error) {
+ return string(p), nil
+}
+
+func TestClientAuthPassword(t *testing.T) {
+ pw := password("tiger")
+
+ serverConfig.PasswordCallback = func(user, pass string) bool {
+ return user == "testuser" && pass == string(pw)
+ }
+ serverConfig.PubKeyCallback = nil
+
+ l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+ if err != nil {
+ t.Fatalf("unable to listen: %s", err)
+ }
+ defer l.Close()
+
+ done := make(chan bool)
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := c.Handshake(); err != nil {
+ t.Error(err)
+ }
+ defer c.Close()
+ done <- true
+ }()
+
+ config := &ClientConfig{
+ User: "testuser",
+ Auth: []ClientAuth{
+ ClientAuthPassword(pw),
+ },
+ }
+
+ c, err := Dial("tcp", l.Addr().String(), config)
+ if err != nil {
+ t.Fatalf("unable to dial remote side: %s", err)
+ }
+ defer c.Close()
+ <-done
+}
+
+func TestClientAuthPasswordAndPublickey(t *testing.T) {
+ pw := password("tiger")
+
+ serverConfig.PasswordCallback = func(user, pass string) bool {
+ return user == "testuser" && pass == string(pw)
+ }
+
+ k := new(keychain)
+ k.keys = append(k.keys, pkey)
+
+ serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
+ expected := []byte(serializePublickey(k.keys[0].PublicKey))
+ algoname := algoName(k.keys[0].PublicKey)
+ return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+ }
+
+ l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+ if err != nil {
+ t.Fatalf("unable to listen: %s", err)
+ }
+ defer l.Close()
+
+ done := make(chan bool)
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := c.Handshake(); err != nil {
+ t.Error(err)
+ }
+ defer c.Close()
+ done <- true
+ }()
+
+ wrongPw := password("wrong")
+ config := &ClientConfig{
+ User: "testuser",
+ Auth: []ClientAuth{
+ ClientAuthPassword(wrongPw),
+ ClientAuthPublickey(k),
+ },
+ }
+
+ c, err := Dial("tcp", l.Addr().String(), config)
+ if err != nil {
+ t.Fatalf("unable to dial remote side: %s", err)
+ }
+ defer c.Close()
+ <-done
+}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// ClientConn functional tests.
+// These tests require a running ssh server listening on port 22
+// on the local host. Functional tests will be skipped unless
+// -ssh.user and -ssh.pass must be passed to gotest.
+
+import (
+ "flag"
+ "testing"
+)
+
+var (
+ sshuser = flag.String("ssh.user", "", "ssh username")
+ sshpass = flag.String("ssh.pass", "", "ssh password")
+ sshprivkey = flag.String("ssh.privkey", "", "ssh privkey file")
+)
+
+func TestFuncPasswordAuth(t *testing.T) {
+ if *sshuser == "" {
+ t.Log("ssh.user not defined, skipping test")
+ return
+ }
+ config := &ClientConfig{
+ User: *sshuser,
+ Auth: []ClientAuth{
+ ClientAuthPassword(password(*sshpass)),
+ },
+ }
+ conn, err := Dial("tcp", "localhost:22", config)
+ if err != nil {
+ t.Fatalf("Unable to connect: %s", err)
+ }
+ defer conn.Close()
+}
+
+func TestFuncPublickeyAuth(t *testing.T) {
+ if *sshuser == "" {
+ t.Log("ssh.user not defined, skipping test")
+ return
+ }
+ kc := new(keychain)
+ if err := kc.loadPEM(*sshprivkey); err != nil {
+ t.Fatalf("unable to load private key: %s", err)
+ }
+ config := &ClientConfig{
+ User: *sshuser,
+ Auth: []ClientAuth{
+ ClientAuthPublickey(kc),
+ },
+ }
+ conn, err := Dial("tcp", "localhost:22", config)
+ if err != nil {
+ t.Fatalf("unable to connect: %s", err)
+ }
+ defer conn.Close()
+}
package ssh
import (
+ "crypto/dsa"
+ "crypto/rsa"
"math/big"
"strconv"
"sync"
const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa"
- cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96"
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
var supportedHostKeyAlgos = []string{hostAlgoRSA}
-var supportedCiphers = []string{cipherAES128CTR}
var supportedMACs = []string{macSHA196}
var supportedCompressions = []string{compressionNone}
ok = true
return
}
+
+// Cryptographic configuration common to both ServerConfig and ClientConfig.
+type CryptoConfig struct {
+ // The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
+ // used.
+ Ciphers []string
+}
+
+func (c *CryptoConfig) ciphers() []string {
+ if c.Ciphers == nil {
+ return DefaultCipherOrder
+ }
+ return c.Ciphers
+}
+
+// serialize a signed slice according to RFC 4254 6.6.
+func serializeSignature(algoname string, sig []byte) []byte {
+ length := stringLength([]byte(algoname))
+ length += stringLength(sig)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(algoname))
+ r = marshalString(r, sig)
+
+ return ret
+}
+
+// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
+func serializePublickey(key interface{}) []byte {
+ algoname := algoName(key)
+ switch key := key.(type) {
+ case rsa.PublicKey:
+ e := new(big.Int).SetInt64(int64(key.E))
+ length := stringLength([]byte(algoname))
+ length += intLength(e)
+ length += intLength(key.N)
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(algoname))
+ r = marshalInt(r, e)
+ marshalInt(r, key.N)
+ return ret
+ case dsa.PublicKey:
+ length := stringLength([]byte(algoname))
+ length += intLength(key.P)
+ length += intLength(key.Q)
+ length += intLength(key.G)
+ length += intLength(key.Y)
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(algoname))
+ r = marshalInt(r, key.P)
+ r = marshalInt(r, key.Q)
+ r = marshalInt(r, key.G)
+ marshalInt(r, key.Y)
+ return ret
+ }
+ panic("unexpected key type")
+}
+
+func algoName(key interface{}) string {
+ switch key.(type) {
+ case rsa.PublicKey:
+ return "ssh-rsa"
+ case dsa.PublicKey:
+ return "ssh-dss"
+ }
+ panic("unexpected key type")
+}
+
+// buildDataSignedForAuth returns the data that is signed in order to prove
+// posession of a private key. See RFC 4252, section 7.
+func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+ user := []byte(req.User)
+ service := []byte(req.Service)
+ method := []byte(req.Method)
+
+ length := stringLength(sessionId)
+ length += 1
+ length += stringLength(user)
+ length += stringLength(service)
+ length += stringLength(method)
+ length += 1
+ length += stringLength(algo)
+ length += stringLength(pubKey)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, sessionId)
+ r[0] = msgUserAuthRequest
+ r = r[1:]
+ r = marshalString(r, user)
+ r = marshalString(r, service)
+ r = marshalString(r, method)
+ r[0] = 1
+ r = r[1:]
+ r = marshalString(r, algo)
+ r = marshalString(r, pubKey)
+ return ret
+}
return
}
-var comma = []byte{','}
+var (
+ comma = []byte{','}
+ emptyNameList = []string{}
+)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
return
}
if len(contents) == 0 {
+ out = emptyNameList
return
}
parts := bytes.Split(contents, comma)
return
}
-const maxPacketSize = 36000
-
func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */
for i, name := range namelist {
// key authentication. It must return true iff the given public key is
// valid for the given user.
PubKeyCallback func(user, algo string, pubkey []byte) bool
+
+ // Cryptographic-related configuration.
+ Crypto CryptoConfig
}
func (c *ServerConfig) rand() io.Reader {
return nil, nil, errors.New("internal error")
}
- serializedSig := serializeRSASignature(sig)
+ serializedSig := serializeSignature(hostAlgoRSA, sig)
kexDHReply := kexDHReplyMsg{
HostKey: serializedHostKey,
return
}
-func serializeRSASignature(sig []byte) []byte {
- length := stringLength([]byte(hostAlgoRSA))
- length += stringLength(sig)
-
- ret := make([]byte, length)
- r := marshalString(ret, []byte(hostAlgoRSA))
- r = marshalString(r, sig)
-
- return ret
-}
-
// serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n")
-// buildDataSignedForAuth returns the data that is signed in order to prove
-// posession of a private key. See RFC 4252, section 7.
-func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
- user := []byte(req.User)
- service := []byte(req.Service)
- method := []byte(req.Method)
-
- length := stringLength(sessionId)
- length += 1
- length += stringLength(user)
- length += stringLength(service)
- length += stringLength(method)
- length += 1
- length += stringLength(algo)
- length += stringLength(pubKey)
-
- ret := make([]byte, length)
- r := marshalString(ret, sessionId)
- r[0] = msgUserAuthRequest
- r = r[1:]
- r = marshalString(r, user)
- r = marshalString(r, service)
- r = marshalString(r, method)
- r[0] = 1
- r = r[1:]
- r = marshalString(r, algo)
- r = marshalString(r, pubKey)
- return ret
-}
-
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var magics handshakeMagics
serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
- CiphersClientServer: supportedCiphers,
- CiphersServerClient: supportedCiphers,
+ CiphersClientServer: s.config.Crypto.ciphers(),
+ CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
- s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
+ if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
+ return err
+ }
if packet, err = s.readPacket(); err != nil {
return err
}
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "errors"
+ "io"
+ "net"
+)
+// Dial initiates a connection to the addr from the remote host.
+// addr is resolved using net.ResolveTCPAddr before connection.
+// This could allow an observer to observe the DNS name of the
+// remote host. Consider using ssh.DialTCP to avoid this.
+func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
+ raddr, err := net.ResolveTCPAddr(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ return c.DialTCP(n, nil, raddr)
+}
+
+// DialTCP connects to the remote address raddr on the network net,
+// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
+// as the local address for the connection.
+func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
+ if laddr == nil {
+ laddr = &net.TCPAddr{
+ IP: net.IPv4zero,
+ Port: 0,
+ }
+ }
+ ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
+ if err != nil {
+ return nil, err
+ }
+ return &tcpchanconn{
+ tcpchan: ch,
+ laddr: laddr,
+ raddr: raddr,
+ }, nil
+}
+
+// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
+// strings and are expected to be resolveable at the remote end.
+func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
+ // RFC 4254 7.2
+ type channelOpenDirectMsg struct {
+ ChanType string
+ PeersId uint32
+ PeersWindow uint32
+ MaxPacketSize uint32
+ raddr string
+ rport uint32
+ laddr string
+ lport uint32
+ }
+ ch := c.newChan(c.transport)
+ if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
+ ChanType: "direct-tcpip",
+ PeersId: ch.id,
+ PeersWindow: 1 << 14,
+ MaxPacketSize: 1 << 15, // RFC 4253 6.1
+ raddr: raddr,
+ rport: uint32(rport),
+ laddr: laddr,
+ lport: uint32(lport),
+ })); err != nil {
+ c.chanlist.remove(ch.id)
+ return nil, err
+ }
+ // wait for response
+ switch msg := (<-ch.msg).(type) {
+ case *channelOpenConfirmMsg:
+ ch.peersId = msg.MyId
+ ch.win <- int(msg.MyWindow)
+ case *channelOpenFailureMsg:
+ c.chanlist.remove(ch.id)
+ return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
+ default:
+ c.chanlist.remove(ch.id)
+ return nil, errors.New("ssh: unexpected packet")
+ }
+ return &tcpchan{
+ clientChan: ch,
+ Reader: &chanReader{
+ packetWriter: ch,
+ id: ch.id,
+ data: ch.data,
+ },
+ Writer: &chanWriter{
+ packetWriter: ch,
+ id: ch.id,
+ win: ch.win,
+ },
+ }, nil
+}
+
+type tcpchan struct {
+ *clientChan // the backing channel
+ io.Reader
+ io.Writer
+}
+
+// tcpchanconn fulfills the net.Conn interface without
+// the tcpchan having to hold laddr or raddr directly.
+type tcpchanconn struct {
+ *tcpchan
+ laddr, raddr net.Addr
+}
+
+// LocalAddr returns the local network address.
+func (t *tcpchanconn) LocalAddr() net.Addr {
+ return t.laddr
+}
+
+// RemoteAddr returns the remote network address.
+func (t *tcpchanconn) RemoteAddr() net.Addr {
+ return t.raddr
+}
+
+// SetTimeout sets the read and write deadlines associated
+// with the connection.
+func (t *tcpchanconn) SetTimeout(nsec int64) error {
+ if err := t.SetReadTimeout(nsec); err != nil {
+ return err
+ }
+ return t.SetWriteTimeout(nsec)
+}
+
+// SetReadTimeout sets the time (in nanoseconds) that
+// Read will wait for data before returning an error with Timeout() == true.
+// Setting nsec == 0 (the default) disables the deadline.
+func (t *tcpchanconn) SetReadTimeout(nsec int64) error {
+ return errors.New("ssh: tcpchan: timeout not supported")
+}
+
+// SetWriteTimeout sets the time (in nanoseconds) that
+// Write will wait to send its data before returning an error with Timeout() == true.
+// Setting nsec == 0 (the default) disables the deadline.
+// Even if write times out, it may return n > 0, indicating that
+// some of the data was successfully written.
+func (t *tcpchanconn) SetWriteTimeout(nsec int64) error {
+ return errors.New("ssh: tcpchan: timeout not supported")
+}
import (
"bufio"
"crypto"
- "crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/subtle"
)
const (
- paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
+ packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
+ minPacketSize = 16
+ maxPacketSize = 36000
+ minPaddingSize = 4 // TODO(huin) should this be configurable?
)
// filteredConn reduces the set of methods exposed when embeddeding
type writer struct {
*sync.Mutex // protects writer.Writer from concurrent writes
*bufio.Writer
- paddingMultiple int
- rand io.Reader
+ rand io.Reader
common
}
func (r *reader) readOnePacket() ([]byte, error) {
var lengthBytes = make([]byte, 5)
var macSize uint32
-
if _, err := io.ReadFull(r, lengthBytes); err != nil {
return nil, err
}
- if r.cipher != nil {
- r.cipher.XORKeyStream(lengthBytes, lengthBytes)
- }
+ r.cipher.XORKeyStream(lengthBytes, lengthBytes)
if r.mac != nil {
r.mac.Reset()
w.Mutex.Lock()
defer w.Mutex.Unlock()
- paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
+ paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
if paddingLength < 4 {
- paddingLength += paddingMultiple
+ paddingLength += packetSizeMultiple
}
length := len(packet) + 1 + paddingLength
// TODO(dfc) lengthBytes, packet and padding should be
// subslices of a single buffer
- if w.cipher != nil {
- w.cipher.XORKeyStream(lengthBytes, lengthBytes)
- w.cipher.XORKeyStream(packet, packet)
- w.cipher.XORKeyStream(padding, padding)
- }
+ w.cipher.XORKeyStream(lengthBytes, lengthBytes)
+ w.cipher.XORKeyStream(packet, packet)
+ w.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(lengthBytes); err != nil {
return err
return &transport{
reader: reader{
Reader: bufio.NewReader(conn),
+ common: common{
+ cipher: noneCipher{},
+ },
},
writer: writer{
Writer: bufio.NewWriter(conn),
rand: rand,
Mutex: new(sync.Mutex),
+ common: common{
+ cipher: noneCipher{},
+ },
},
filteredConn: conn,
}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
-// setupKeys sets the cipher and MAC keys from K, H and sessionId, as
+// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
- h := hashFunc.New()
+ cipherMode := cipherModes[c.cipherAlgo]
- blockSize := 16
- keySize := 16
macKeySize := 20
- iv := make([]byte, blockSize)
- key := make([]byte, keySize)
+ iv := make([]byte, cipherMode.ivSize)
+ key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macKeySize)
+
+ h := hashFunc.New()
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
- aes, err := aes.NewCipher(key)
+
+ cipher, err := cipherMode.createCipher(key, iv)
if err != nil {
return err
}
- c.cipher = cipher.NewCTR(aes, iv)
+
+ c.cipher = cipher
+
return nil
}
+++ /dev/null
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package terminal
-
-import "io"
-
-// Shell contains the state for running a VT100 terminal that is capable of
-// reading lines of input.
-type Shell struct {
- c io.ReadWriter
- prompt string
-
- // line is the current line being entered.
- line []byte
- // pos is the logical position of the cursor in line
- pos int
-
- // cursorX contains the current X value of the cursor where the left
- // edge is 0. cursorY contains the row number where the first row of
- // the current line is 0.
- cursorX, cursorY int
- // maxLine is the greatest value of cursorY so far.
- maxLine int
-
- termWidth, termHeight int
-
- // outBuf contains the terminal data to be sent.
- outBuf []byte
- // remainder contains the remainder of any partial key sequences after
- // a read. It aliases into inBuf.
- remainder []byte
- inBuf [256]byte
-}
-
-// NewShell runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
-// a local terminal, that terminal must first have been put into raw mode.
-// prompt is a string that is written at the start of each input line (i.e.
-// "> ").
-func NewShell(c io.ReadWriter, prompt string) *Shell {
- return &Shell{
- c: c,
- prompt: prompt,
- termWidth: 80,
- termHeight: 24,
- }
-}
-
-const (
- keyCtrlD = 4
- keyEnter = '\r'
- keyEscape = 27
- keyBackspace = 127
- keyUnknown = 256 + iota
- keyUp
- keyDown
- keyLeft
- keyRight
- keyAltLeft
- keyAltRight
-)
-
-// bytesToKey tries to parse a key sequence from b. If successful, it returns
-// the key and the remainder of the input. Otherwise it returns -1.
-func bytesToKey(b []byte) (int, []byte) {
- if len(b) == 0 {
- return -1, nil
- }
-
- if b[0] != keyEscape {
- return int(b[0]), b[1:]
- }
-
- if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
- switch b[2] {
- case 'A':
- return keyUp, b[3:]
- case 'B':
- return keyDown, b[3:]
- case 'C':
- return keyRight, b[3:]
- case 'D':
- return keyLeft, b[3:]
- }
- }
-
- if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
- switch b[5] {
- case 'C':
- return keyAltRight, b[6:]
- case 'D':
- return keyAltLeft, b[6:]
- }
- }
-
- // If we get here then we have a key that we don't recognise, or a
- // partial sequence. It's not clear how one should find the end of a
- // sequence without knowing them all, but it seems that [a-zA-Z] only
- // appears at the end of a sequence.
- for i, c := range b[0:] {
- if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
- return keyUnknown, b[i+1:]
- }
- }
-
- return -1, b
-}
-
-// queue appends data to the end of ss.outBuf
-func (ss *Shell) queue(data []byte) {
- if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
- newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
- copy(newOutBuf, ss.outBuf)
- ss.outBuf = newOutBuf
- }
-
- oldLen := len(ss.outBuf)
- ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
- copy(ss.outBuf[oldLen:], data)
-}
-
-var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
-
-func isPrintable(key int) bool {
- return key >= 32 && key < 127
-}
-
-// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
-// given, logical position in the text.
-func (ss *Shell) moveCursorToPos(pos int) {
- x := len(ss.prompt) + pos
- y := x / ss.termWidth
- x = x % ss.termWidth
-
- up := 0
- if y < ss.cursorY {
- up = ss.cursorY - y
- }
-
- down := 0
- if y > ss.cursorY {
- down = y - ss.cursorY
- }
-
- left := 0
- if x < ss.cursorX {
- left = ss.cursorX - x
- }
-
- right := 0
- if x > ss.cursorX {
- right = x - ss.cursorX
- }
-
- movement := make([]byte, 3*(up+down+left+right))
- m := movement
- for i := 0; i < up; i++ {
- m[0] = keyEscape
- m[1] = '['
- m[2] = 'A'
- m = m[3:]
- }
- for i := 0; i < down; i++ {
- m[0] = keyEscape
- m[1] = '['
- m[2] = 'B'
- m = m[3:]
- }
- for i := 0; i < left; i++ {
- m[0] = keyEscape
- m[1] = '['
- m[2] = 'D'
- m = m[3:]
- }
- for i := 0; i < right; i++ {
- m[0] = keyEscape
- m[1] = '['
- m[2] = 'C'
- m = m[3:]
- }
-
- ss.cursorX = x
- ss.cursorY = y
- ss.queue(movement)
-}
-
-const maxLineLength = 4096
-
-// handleKey processes the given key and, optionally, returns a line of text
-// that the user has entered.
-func (ss *Shell) handleKey(key int) (line string, ok bool) {
- switch key {
- case keyBackspace:
- if ss.pos == 0 {
- return
- }
- ss.pos--
-
- copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
- ss.line = ss.line[:len(ss.line)-1]
- ss.writeLine(ss.line[ss.pos:])
- ss.moveCursorToPos(ss.pos)
- ss.queue(eraseUnderCursor)
- case keyAltLeft:
- // move left by a word.
- if ss.pos == 0 {
- return
- }
- ss.pos--
- for ss.pos > 0 {
- if ss.line[ss.pos] != ' ' {
- break
- }
- ss.pos--
- }
- for ss.pos > 0 {
- if ss.line[ss.pos] == ' ' {
- ss.pos++
- break
- }
- ss.pos--
- }
- ss.moveCursorToPos(ss.pos)
- case keyAltRight:
- // move right by a word.
- for ss.pos < len(ss.line) {
- if ss.line[ss.pos] == ' ' {
- break
- }
- ss.pos++
- }
- for ss.pos < len(ss.line) {
- if ss.line[ss.pos] != ' ' {
- break
- }
- ss.pos++
- }
- ss.moveCursorToPos(ss.pos)
- case keyLeft:
- if ss.pos == 0 {
- return
- }
- ss.pos--
- ss.moveCursorToPos(ss.pos)
- case keyRight:
- if ss.pos == len(ss.line) {
- return
- }
- ss.pos++
- ss.moveCursorToPos(ss.pos)
- case keyEnter:
- ss.moveCursorToPos(len(ss.line))
- ss.queue([]byte("\r\n"))
- line = string(ss.line)
- ok = true
- ss.line = ss.line[:0]
- ss.pos = 0
- ss.cursorX = 0
- ss.cursorY = 0
- ss.maxLine = 0
- default:
- if !isPrintable(key) {
- return
- }
- if len(ss.line) == maxLineLength {
- return
- }
- if len(ss.line) == cap(ss.line) {
- newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
- copy(newLine, ss.line)
- ss.line = newLine
- }
- ss.line = ss.line[:len(ss.line)+1]
- copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
- ss.line[ss.pos] = byte(key)
- ss.writeLine(ss.line[ss.pos:])
- ss.pos++
- ss.moveCursorToPos(ss.pos)
- }
- return
-}
-
-func (ss *Shell) writeLine(line []byte) {
- for len(line) != 0 {
- if ss.cursorX == ss.termWidth {
- ss.queue([]byte("\r\n"))
- ss.cursorX = 0
- ss.cursorY++
- if ss.cursorY > ss.maxLine {
- ss.maxLine = ss.cursorY
- }
- }
-
- remainingOnLine := ss.termWidth - ss.cursorX
- todo := len(line)
- if todo > remainingOnLine {
- todo = remainingOnLine
- }
- ss.queue(line[:todo])
- ss.cursorX += todo
- line = line[todo:]
- }
-}
-
-func (ss *Shell) Write(buf []byte) (n int, err error) {
- return ss.c.Write(buf)
-}
-
-// ReadLine returns a line of input from the terminal.
-func (ss *Shell) ReadLine() (line string, err error) {
- ss.writeLine([]byte(ss.prompt))
- ss.c.Write(ss.outBuf)
- ss.outBuf = ss.outBuf[:0]
-
- for {
- // ss.remainder is a slice at the beginning of ss.inBuf
- // containing a partial key sequence
- readBuf := ss.inBuf[len(ss.remainder):]
- var n int
- n, err = ss.c.Read(readBuf)
- if err != nil {
- return
- }
-
- if err == nil {
- ss.remainder = ss.inBuf[:n+len(ss.remainder)]
- rest := ss.remainder
- lineOk := false
- for !lineOk {
- var key int
- key, rest = bytesToKey(rest)
- if key < 0 {
- break
- }
- if key == keyCtrlD {
- return "", io.EOF
- }
- line, lineOk = ss.handleKey(key)
- }
- if len(rest) > 0 {
- n := copy(ss.inBuf[:], rest)
- ss.remainder = ss.inBuf[:n]
- } else {
- ss.remainder = nil
- }
- ss.c.Write(ss.outBuf)
- ss.outBuf = ss.outBuf[:0]
- if lineOk {
- return
- }
- continue
- }
- }
- panic("unreachable")
-}
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package terminal provides support functions for dealing with terminals, as
-// commonly found on UNIX systems.
-//
-// Putting a terminal into raw mode is the most common requirement:
-//
-// oldState, err := terminal.MakeRaw(0)
-// if err != nil {
-// panic(err.String())
-// }
-// defer terminal.Restore(0, oldState)
package terminal
-import (
- "io"
- "os"
- "syscall"
-)
+import "io"
+
+// Terminal contains the state for running a VT100 terminal that is capable of
+// reading lines of input.
+type Terminal struct {
+ c io.ReadWriter
+ prompt string
+
+ // line is the current line being entered.
+ line []byte
+ // pos is the logical position of the cursor in line
+ pos int
+
+ // cursorX contains the current X value of the cursor where the left
+ // edge is 0. cursorY contains the row number where the first row of
+ // the current line is 0.
+ cursorX, cursorY int
+ // maxLine is the greatest value of cursorY so far.
+ maxLine int
+
+ termWidth, termHeight int
-// State contains the state of a terminal.
-type State struct {
- termios syscall.Termios
+ // outBuf contains the terminal data to be sent.
+ outBuf []byte
+ // remainder contains the remainder of any partial key sequences after
+ // a read. It aliases into inBuf.
+ remainder []byte
+ inBuf [256]byte
}
-// IsTerminal returns true if the given file descriptor is a terminal.
-func IsTerminal(fd int) bool {
- var termios syscall.Termios
- e := syscall.Tcgetattr(fd, &termios)
- return e == 0
+// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
+// a local terminal, that terminal must first have been put into raw mode.
+// prompt is a string that is written at the start of each input line (i.e.
+// "> ").
+func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
+ return &Terminal{
+ c: c,
+ prompt: prompt,
+ termWidth: 80,
+ termHeight: 24,
+ }
}
-// MakeRaw put the terminal connected to the given file descriptor into raw
-// mode and returns the previous state of the terminal so that it can be
-// restored.
-func MakeRaw(fd int) (*State, error) {
- var oldState State
- if e := syscall.Tcgetattr(fd, &oldState.termios); e != 0 {
- return nil, os.Errno(e)
+const (
+ keyCtrlD = 4
+ keyEnter = '\r'
+ keyEscape = 27
+ keyBackspace = 127
+ keyUnknown = 256 + iota
+ keyUp
+ keyDown
+ keyLeft
+ keyRight
+ keyAltLeft
+ keyAltRight
+)
+
+// bytesToKey tries to parse a key sequence from b. If successful, it returns
+// the key and the remainder of the input. Otherwise it returns -1.
+func bytesToKey(b []byte) (int, []byte) {
+ if len(b) == 0 {
+ return -1, nil
+ }
+
+ if b[0] != keyEscape {
+ return int(b[0]), b[1:]
+ }
+
+ if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
+ switch b[2] {
+ case 'A':
+ return keyUp, b[3:]
+ case 'B':
+ return keyDown, b[3:]
+ case 'C':
+ return keyRight, b[3:]
+ case 'D':
+ return keyLeft, b[3:]
+ }
+ }
+
+ if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
+ switch b[5] {
+ case 'C':
+ return keyAltRight, b[6:]
+ case 'D':
+ return keyAltLeft, b[6:]
+ }
+ }
+
+ // If we get here then we have a key that we don't recognise, or a
+ // partial sequence. It's not clear how one should find the end of a
+ // sequence without knowing them all, but it seems that [a-zA-Z] only
+ // appears at the end of a sequence.
+ for i, c := range b[0:] {
+ if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
+ return keyUnknown, b[i+1:]
+ }
}
- newState := oldState.termios
- newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
- newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
- if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
- return nil, os.Errno(e)
+ return -1, b
+}
+
+// queue appends data to the end of t.outBuf
+func (t *Terminal) queue(data []byte) {
+ if len(t.outBuf)+len(data) > cap(t.outBuf) {
+ newOutBuf := make([]byte, len(t.outBuf), 2*(len(t.outBuf)+len(data)))
+ copy(newOutBuf, t.outBuf)
+ t.outBuf = newOutBuf
}
- return &oldState, nil
+ oldLen := len(t.outBuf)
+ t.outBuf = t.outBuf[:len(t.outBuf)+len(data)]
+ copy(t.outBuf[oldLen:], data)
}
-// Restore restores the terminal connected to the given file descriptor to a
-// previous state.
-func Restore(fd int, state *State) error {
- e := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
- return os.Errno(e)
+var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
+
+func isPrintable(key int) bool {
+ return key >= 32 && key < 127
}
-// ReadPassword reads a line of input from a terminal without local echo. This
-// is commonly used for inputting passwords and other sensitive data. The slice
-// returned does not include the \n.
-func ReadPassword(fd int) ([]byte, error) {
- var oldState syscall.Termios
- if e := syscall.Tcgetattr(fd, &oldState); e != 0 {
- return nil, os.Errno(e)
+// moveCursorToPos appends data to t.outBuf which will move the cursor to the
+// given, logical position in the text.
+func (t *Terminal) moveCursorToPos(pos int) {
+ x := len(t.prompt) + pos
+ y := x / t.termWidth
+ x = x % t.termWidth
+
+ up := 0
+ if y < t.cursorY {
+ up = t.cursorY - y
}
- newState := oldState
- newState.Lflag &^= syscall.ECHO
- if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
- return nil, os.Errno(e)
+ down := 0
+ if y > t.cursorY {
+ down = y - t.cursorY
}
- defer func() {
- syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
- }()
+ left := 0
+ if x < t.cursorX {
+ left = t.cursorX - x
+ }
- var buf [16]byte
- var ret []byte
- for {
- n, errno := syscall.Read(fd, buf[:])
- if errno != 0 {
- return nil, os.Errno(errno)
+ right := 0
+ if x > t.cursorX {
+ right = x - t.cursorX
+ }
+
+ movement := make([]byte, 3*(up+down+left+right))
+ m := movement
+ for i := 0; i < up; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'A'
+ m = m[3:]
+ }
+ for i := 0; i < down; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'B'
+ m = m[3:]
+ }
+ for i := 0; i < left; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'D'
+ m = m[3:]
+ }
+ for i := 0; i < right; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'C'
+ m = m[3:]
+ }
+
+ t.cursorX = x
+ t.cursorY = y
+ t.queue(movement)
+}
+
+const maxLineLength = 4096
+
+// handleKey processes the given key and, optionally, returns a line of text
+// that the user has entered.
+func (t *Terminal) handleKey(key int) (line string, ok bool) {
+ switch key {
+ case keyBackspace:
+ if t.pos == 0 {
+ return
+ }
+ t.pos--
+
+ copy(t.line[t.pos:], t.line[1+t.pos:])
+ t.line = t.line[:len(t.line)-1]
+ t.writeLine(t.line[t.pos:])
+ t.moveCursorToPos(t.pos)
+ t.queue(eraseUnderCursor)
+ case keyAltLeft:
+ // move left by a word.
+ if t.pos == 0 {
+ return
+ }
+ t.pos--
+ for t.pos > 0 {
+ if t.line[t.pos] != ' ' {
+ break
+ }
+ t.pos--
+ }
+ for t.pos > 0 {
+ if t.line[t.pos] == ' ' {
+ t.pos++
+ break
+ }
+ t.pos--
+ }
+ t.moveCursorToPos(t.pos)
+ case keyAltRight:
+ // move right by a word.
+ for t.pos < len(t.line) {
+ if t.line[t.pos] == ' ' {
+ break
+ }
+ t.pos++
}
- if n == 0 {
- if len(ret) == 0 {
- return nil, io.EOF
+ for t.pos < len(t.line) {
+ if t.line[t.pos] != ' ' {
+ break
+ }
+ t.pos++
+ }
+ t.moveCursorToPos(t.pos)
+ case keyLeft:
+ if t.pos == 0 {
+ return
+ }
+ t.pos--
+ t.moveCursorToPos(t.pos)
+ case keyRight:
+ if t.pos == len(t.line) {
+ return
+ }
+ t.pos++
+ t.moveCursorToPos(t.pos)
+ case keyEnter:
+ t.moveCursorToPos(len(t.line))
+ t.queue([]byte("\r\n"))
+ line = string(t.line)
+ ok = true
+ t.line = t.line[:0]
+ t.pos = 0
+ t.cursorX = 0
+ t.cursorY = 0
+ t.maxLine = 0
+ default:
+ if !isPrintable(key) {
+ return
+ }
+ if len(t.line) == maxLineLength {
+ return
+ }
+ if len(t.line) == cap(t.line) {
+ newLine := make([]byte, len(t.line), 2*(1+len(t.line)))
+ copy(newLine, t.line)
+ t.line = newLine
+ }
+ t.line = t.line[:len(t.line)+1]
+ copy(t.line[t.pos+1:], t.line[t.pos:])
+ t.line[t.pos] = byte(key)
+ t.writeLine(t.line[t.pos:])
+ t.pos++
+ t.moveCursorToPos(t.pos)
+ }
+ return
+}
+
+func (t *Terminal) writeLine(line []byte) {
+ for len(line) != 0 {
+ if t.cursorX == t.termWidth {
+ t.queue([]byte("\r\n"))
+ t.cursorX = 0
+ t.cursorY++
+ if t.cursorY > t.maxLine {
+ t.maxLine = t.cursorY
}
- break
}
- if buf[n-1] == '\n' {
- n--
+
+ remainingOnLine := t.termWidth - t.cursorX
+ todo := len(line)
+ if todo > remainingOnLine {
+ todo = remainingOnLine
}
- ret = append(ret, buf[:n]...)
- if n < len(buf) {
- break
+ t.queue(line[:todo])
+ t.cursorX += todo
+ line = line[todo:]
+ }
+}
+
+func (t *Terminal) Write(buf []byte) (n int, err error) {
+ return t.c.Write(buf)
+}
+
+// ReadLine returns a line of input from the terminal.
+func (t *Terminal) ReadLine() (line string, err error) {
+ if t.cursorX == 0 {
+ t.writeLine([]byte(t.prompt))
+ t.c.Write(t.outBuf)
+ t.outBuf = t.outBuf[:0]
+ }
+
+ for {
+ // t.remainder is a slice at the beginning of t.inBuf
+ // containing a partial key sequence
+ readBuf := t.inBuf[len(t.remainder):]
+ var n int
+ n, err = t.c.Read(readBuf)
+ if err != nil {
+ return
+ }
+
+ if err == nil {
+ t.remainder = t.inBuf[:n+len(t.remainder)]
+ rest := t.remainder
+ lineOk := false
+ for !lineOk {
+ var key int
+ key, rest = bytesToKey(rest)
+ if key < 0 {
+ break
+ }
+ if key == keyCtrlD {
+ return "", io.EOF
+ }
+ line, lineOk = t.handleKey(key)
+ }
+ if len(rest) > 0 {
+ n := copy(t.inBuf[:], rest)
+ t.remainder = t.inBuf[:n]
+ } else {
+ t.remainder = nil
+ }
+ t.c.Write(t.outBuf)
+ t.outBuf = t.outBuf[:0]
+ if lineOk {
+ return
+ }
+ continue
}
}
+ panic("unreachable")
+}
- return ret, nil
+func (t *Terminal) SetSize(width, height int) {
+ t.termWidth, t.termHeight = width, height
}
func TestClose(t *testing.T) {
c := &MockTerminal{}
- ss := NewShell(c, "> ")
+ ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
toSend: []byte(test.in),
bytesPerRead: j,
}
- ss := NewShell(c, "> ")
+ ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package terminal provides support functions for dealing with terminals, as
+// commonly found on UNIX systems.
+//
+// Putting a terminal into raw mode is the most common requirement:
+//
+// oldState, err := terminal.MakeRaw(0)
+// if err != nil {
+// panic(err.String())
+// }
+// defer terminal.Restore(0, oldState)
+package terminal
+
+import (
+ "io"
+ "syscall"
+)
+
+// State contains the state of a terminal.
+type State struct {
+ termios syscall.Termios
+}
+
+// IsTerminal returns true if the given file descriptor is a terminal.
+func IsTerminal(fd int) bool {
+ var termios syscall.Termios
+ err := syscall.Tcgetattr(fd, &termios)
+ return err == nil
+}
+
+// MakeRaw put the terminal connected to the given file descriptor into raw
+// mode and returns the previous state of the terminal so that it can be
+// restored.
+func MakeRaw(fd int) (*State, error) {
+ var oldState State
+ if err := syscall.Tcgetattr(fd, &oldState.termios); err != nil {
+ return nil, err
+ }
+
+ newState := oldState.termios
+ newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
+ newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
+ if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
+ return nil, err
+ }
+
+ return &oldState, nil
+}
+
+// Restore restores the terminal connected to the given file descriptor to a
+// previous state.
+func Restore(fd int, state *State) error {
+ err := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
+ return err
+}
+
+// ReadPassword reads a line of input from a terminal without local echo. This
+// is commonly used for inputting passwords and other sensitive data. The slice
+// returned does not include the \n.
+func ReadPassword(fd int) ([]byte, error) {
+ var oldState syscall.Termios
+ if err := syscall.Tcgetattr(fd, &oldState); err != nil {
+ return nil, err
+ }
+
+ newState := oldState
+ newState.Lflag &^= syscall.ECHO
+ if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
+ return nil, err
+ }
+
+ defer func() {
+ syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
+ }()
+
+ var buf [16]byte
+ var ret []byte
+ for {
+ n, err := syscall.Read(fd, buf[:])
+ if err != nil {
+ return nil, err
+ }
+ if n == 0 {
+ if len(ret) == 0 {
+ return nil, io.EOF
+ }
+ break
+ }
+ if buf[n-1] == '\n' {
+ n--
+ }
+ ret = append(ret, buf[:n]...)
+ if n < len(buf) {
+ break
+ }
+ }
+
+ return ret, nil
+}
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
+ {"%#v", []int(nil), `[]int(nil)`},
+ {"%#v", []int{}, `[]int{}`},
+ {"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
+ {"%#v", map[int]byte{}, `map[int] uint8{}`},
// slices with other formats
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},
case reflect.Map:
if goSyntax {
p.buf.WriteString(f.Type().String())
+ if f.IsNil() {
+ p.buf.WriteString("(nil)")
+ break
+ }
p.buf.WriteByte('{')
} else {
p.buf.Write(mapBytes)
}
if goSyntax {
p.buf.WriteString(value.Type().String())
+ if f.IsNil() {
+ p.buf.WriteString("(nil)")
+ break
+ }
p.buf.WriteByte('{')
} else {
p.buf.WriteByte('[')
var z IntString
var multiTests = []ScanfMultiTest{
- {"", "", nil, nil, ""},
+ {"", "", []interface{}{}, []interface{}{}, ""},
{"%d", "23", args(&i), args(23), ""},
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
}
val := v.Interface()
if !reflect.DeepEqual(val, test.out) {
- t.Errorf("%s scanning %q: expected %v got %v, type %T", name, test.text, test.out, val, val)
+ t.Errorf("%s scanning %q: expected %#v got %#v, type %T", name, test.text, test.out, val, val)
}
}
}
}
val := v.Interface()
if !reflect.DeepEqual(val, test.out) {
- t.Errorf("scanning (%q, %q): expected %v got %v, type %T", test.format, test.text, test.out, val, val)
+ t.Errorf("scanning (%q, %q): expected %#v got %#v, type %T", test.format, test.text, test.out, val, val)
}
}
}
}
result := resultVal.Interface()
if !reflect.DeepEqual(result, test.out) {
- t.Errorf("scanning (%q, %q): expected %v got %v", test.format, test.text, test.out, result)
+ t.Errorf("scanning (%q, %q): expected %#v got %#v", test.format, test.text, test.out, result)
}
}
}
// exprNode() ensures that only expression/type nodes can be
// assigned to an ExprNode.
//
-func (x *BadExpr) exprNode() {}
-func (x *Ident) exprNode() {}
-func (x *Ellipsis) exprNode() {}
-func (x *BasicLit) exprNode() {}
-func (x *FuncLit) exprNode() {}
-func (x *CompositeLit) exprNode() {}
-func (x *ParenExpr) exprNode() {}
-func (x *SelectorExpr) exprNode() {}
-func (x *IndexExpr) exprNode() {}
-func (x *SliceExpr) exprNode() {}
-func (x *TypeAssertExpr) exprNode() {}
-func (x *CallExpr) exprNode() {}
-func (x *StarExpr) exprNode() {}
-func (x *UnaryExpr) exprNode() {}
-func (x *BinaryExpr) exprNode() {}
-func (x *KeyValueExpr) exprNode() {}
-
-func (x *ArrayType) exprNode() {}
-func (x *StructType) exprNode() {}
-func (x *FuncType) exprNode() {}
-func (x *InterfaceType) exprNode() {}
-func (x *MapType) exprNode() {}
-func (x *ChanType) exprNode() {}
+func (*BadExpr) exprNode() {}
+func (*Ident) exprNode() {}
+func (*Ellipsis) exprNode() {}
+func (*BasicLit) exprNode() {}
+func (*FuncLit) exprNode() {}
+func (*CompositeLit) exprNode() {}
+func (*ParenExpr) exprNode() {}
+func (*SelectorExpr) exprNode() {}
+func (*IndexExpr) exprNode() {}
+func (*SliceExpr) exprNode() {}
+func (*TypeAssertExpr) exprNode() {}
+func (*CallExpr) exprNode() {}
+func (*StarExpr) exprNode() {}
+func (*UnaryExpr) exprNode() {}
+func (*BinaryExpr) exprNode() {}
+func (*KeyValueExpr) exprNode() {}
+
+func (*ArrayType) exprNode() {}
+func (*StructType) exprNode() {}
+func (*FuncType) exprNode() {}
+func (*InterfaceType) exprNode() {}
+func (*MapType) exprNode() {}
+func (*ChanType) exprNode() {}
// ----------------------------------------------------------------------------
// Convenience functions for Idents
// stmtNode() ensures that only statement nodes can be
// assigned to a StmtNode.
//
-func (s *BadStmt) stmtNode() {}
-func (s *DeclStmt) stmtNode() {}
-func (s *EmptyStmt) stmtNode() {}
-func (s *LabeledStmt) stmtNode() {}
-func (s *ExprStmt) stmtNode() {}
-func (s *SendStmt) stmtNode() {}
-func (s *IncDecStmt) stmtNode() {}
-func (s *AssignStmt) stmtNode() {}
-func (s *GoStmt) stmtNode() {}
-func (s *DeferStmt) stmtNode() {}
-func (s *ReturnStmt) stmtNode() {}
-func (s *BranchStmt) stmtNode() {}
-func (s *BlockStmt) stmtNode() {}
-func (s *IfStmt) stmtNode() {}
-func (s *CaseClause) stmtNode() {}
-func (s *SwitchStmt) stmtNode() {}
-func (s *TypeSwitchStmt) stmtNode() {}
-func (s *CommClause) stmtNode() {}
-func (s *SelectStmt) stmtNode() {}
-func (s *ForStmt) stmtNode() {}
-func (s *RangeStmt) stmtNode() {}
+func (*BadStmt) stmtNode() {}
+func (*DeclStmt) stmtNode() {}
+func (*EmptyStmt) stmtNode() {}
+func (*LabeledStmt) stmtNode() {}
+func (*ExprStmt) stmtNode() {}
+func (*SendStmt) stmtNode() {}
+func (*IncDecStmt) stmtNode() {}
+func (*AssignStmt) stmtNode() {}
+func (*GoStmt) stmtNode() {}
+func (*DeferStmt) stmtNode() {}
+func (*ReturnStmt) stmtNode() {}
+func (*BranchStmt) stmtNode() {}
+func (*BlockStmt) stmtNode() {}
+func (*IfStmt) stmtNode() {}
+func (*CaseClause) stmtNode() {}
+func (*SwitchStmt) stmtNode() {}
+func (*TypeSwitchStmt) stmtNode() {}
+func (*CommClause) stmtNode() {}
+func (*SelectStmt) stmtNode() {}
+func (*ForStmt) stmtNode() {}
+func (*RangeStmt) stmtNode() {}
// ----------------------------------------------------------------------------
// Declarations
// specNode() ensures that only spec nodes can be
// assigned to a Spec.
//
-func (s *ImportSpec) specNode() {}
-func (s *ValueSpec) specNode() {}
-func (s *TypeSpec) specNode() {}
+func (*ImportSpec) specNode() {}
+func (*ValueSpec) specNode() {}
+func (*TypeSpec) specNode() {}
// A declaration is represented by one of the following declaration nodes.
//
// declNode() ensures that only declaration nodes can be
// assigned to a DeclNode.
//
-func (d *BadDecl) declNode() {}
-func (d *GenDecl) declNode() {}
-func (d *FuncDecl) declNode() {}
+func (*BadDecl) declNode() {}
+func (*GenDecl) declNode() {}
+func (*FuncDecl) declNode() {}
// ----------------------------------------------------------------------------
// Files and packages
// it returns false otherwise.
//
func FileExports(src *File) bool {
- return FilterFile(src, exportFilter)
+ return filterFile(src, exportFilter, true)
}
// PackageExports trims the AST for a Go package in place such that
// it returns false otherwise.
//
func PackageExports(pkg *Package) bool {
- return FilterPackage(pkg, exportFilter)
+ return filterPackage(pkg, exportFilter, true)
}
// ----------------------------------------------------------------------------
return nil
}
-func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
+func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil {
return false
}
keepField = len(f.Names) > 0
}
if keepField {
- if filter == exportFilter {
- filterType(f.Type, filter)
+ if export {
+ filterType(f.Type, filter, export)
}
list[j] = f
j++
return
}
-func filterParamList(fields *FieldList, filter Filter) bool {
+func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil {
return false
}
var b bool
for _, f := range fields.List {
- if filterType(f.Type, filter) {
+ if filterType(f.Type, filter, export) {
b = true
}
}
return b
}
-func filterType(typ Expr, f Filter) bool {
+func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) {
case *Ident:
return f(t.Name)
case *ParenExpr:
- return filterType(t.X, f)
+ return filterType(t.X, f, export)
case *ArrayType:
- return filterType(t.Elt, f)
+ return filterType(t.Elt, f, export)
case *StructType:
- if filterFieldList(t.Fields, f) {
+ if filterFieldList(t.Fields, f, export) {
t.Incomplete = true
}
return len(t.Fields.List) > 0
case *FuncType:
- b1 := filterParamList(t.Params, f)
- b2 := filterParamList(t.Results, f)
+ b1 := filterParamList(t.Params, f, export)
+ b2 := filterParamList(t.Results, f, export)
return b1 || b2
case *InterfaceType:
- if filterFieldList(t.Methods, f) {
+ if filterFieldList(t.Methods, f, export) {
t.Incomplete = true
}
return len(t.Methods.List) > 0
case *MapType:
- b1 := filterType(t.Key, f)
- b2 := filterType(t.Value, f)
+ b1 := filterType(t.Key, f, export)
+ b2 := filterType(t.Value, f, export)
return b1 || b2
case *ChanType:
- return filterType(t.Value, f)
+ return filterType(t.Value, f, export)
}
return false
}
-func filterSpec(spec Spec, f Filter) bool {
+func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) {
case *ValueSpec:
s.Names = filterIdentList(s.Names, f)
if len(s.Names) > 0 {
- if f == exportFilter {
- filterType(s.Type, f)
+ if export {
+ filterType(s.Type, f, export)
}
return true
}
case *TypeSpec:
if f(s.Name.Name) {
- if f == exportFilter {
- filterType(s.Type, f)
+ if export {
+ filterType(s.Type, f, export)
}
return true
}
- if f != exportFilter {
+ if !export {
// For general filtering (not just exports),
// filter type even if name is not filtered
// out.
// If the type contains filtered elements,
// keep the declaration.
- return filterType(s.Type, f)
+ return filterType(s.Type, f, export)
}
}
return false
}
-func filterSpecList(list []Spec, f Filter) []Spec {
+func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0
for _, s := range list {
- if filterSpec(s, f) {
+ if filterSpec(s, f, export) {
list[j] = s
j++
}
// filtering; it returns false otherwise.
//
func FilterDecl(decl Decl, f Filter) bool {
+ return filterDecl(decl, f, false)
+}
+
+func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) {
case *GenDecl:
- d.Specs = filterSpecList(d.Specs, f)
+ d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0
case *FuncDecl:
return f(d.Name.Name)
// left after filtering; it returns false otherwise.
//
func FilterFile(src *File, f Filter) bool {
+ return filterFile(src, f, false)
+}
+
+func filterFile(src *File, f Filter, export bool) bool {
j := 0
for _, d := range src.Decls {
- if FilterDecl(d, f) {
+ if filterDecl(d, f, export) {
src.Decls[j] = d
j++
}
// left after filtering; it returns false otherwise.
//
func FilterPackage(pkg *Package, f Filter) bool {
+ return filterPackage(pkg, f, false)
+}
+
+func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false
for _, src := range pkg.Files {
- if FilterFile(src, f) {
+ if filterFile(src, f, export) {
hasDecls = true
}
}
{
"go/build/cmdtest",
&DirInfo{
- GoFiles: []string{"main.go"},
- Package: "main",
- Imports: []string{"go/build/pkgtest"},
+ GoFiles: []string{"main.go"},
+ Package: "main",
+ Imports: []string{"go/build/pkgtest"},
+ TestImports: []string{},
},
},
{
"go/build/cgotest",
&DirInfo{
- CgoFiles: []string{"cgotest.go"},
- CFiles: []string{"cgotest.c"},
- Imports: []string{"C", "unsafe"},
- Package: "cgotest",
+ CgoFiles: []string{"cgotest.go"},
+ CFiles: []string{"cgotest.c"},
+ Imports: []string{"C", "unsafe"},
+ TestImports: []string{},
+ Package: "cgotest",
},
},
}
"io"
"os"
"path/filepath"
+ "strconv"
+ "strings"
"text/tabwriter"
)
p.last = p.pos
}
+const linePrefix = "//line "
+
// writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely.
// a group of comments (or nil), and isKeyword indicates if the
// next item is a keyword.
//
-func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, isKeyword bool) {
+func (p *printer) writeCommentPrefix(pos, next token.Position, prev, comment *ast.Comment, isKeyword bool) {
if p.written == 0 {
// the comment is the first item to be printed - don't write any whitespace
return
}
p.writeWhitespace(j)
}
+
+ // turn off indent if we're about to print a line directive.
+ indent := p.indent
+ if strings.HasPrefix(comment.Text, linePrefix) {
+ p.indent = 0
+ }
+
// use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate
// individual lines of /*-style comments - but make
n = 1
}
p.writeNewlines(n, true)
+ p.indent = indent
}
}
func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text
+ if strings.HasPrefix(text, linePrefix) {
+ pos := strings.TrimSpace(text[len(linePrefix):])
+ i := strings.LastIndex(pos, ":")
+ if i >= 0 {
+ // The line directive we are about to print changed
+ // the Filename and Line number used by go/token
+ // as it was reading the input originally.
+ // In order to match the original input, we have to
+ // update our own idea of the file and line number
+ // accordingly, after printing the directive.
+ file := pos[:i]
+ line, _ := strconv.Atoi(string(pos[i+1:]))
+ defer func() {
+ p.pos.Filename = string(file)
+ p.pos.Line = line
+ p.pos.Column = 1
+ }()
+ }
+ }
+
// shortcut common case of //-style comments
if text[1] == '/' {
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
var last *ast.Comment
for ; p.commentBefore(next); p.cindex++ {
for _, c := range p.comments[p.cindex].List {
- p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, tok.IsKeyword())
+ p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, c, tok.IsKeyword())
p.writeComment(c)
last = c
}
for {
if z.Next() == html.ErrorToken {
// Returning io.EOF indicates success.
- return z.Error()
+ return z.Err()
}
emitToken(z.Token())
}
tt := z.Next()
switch tt {
case ErrorToken:
- return z.Error()
+ return z.Err()
case TextToken:
if depth > 0 {
// emitBytes should copy the []byte it receives,
head, form *Node
// Other parsing state flags (section 11.2.3.5).
scripting, framesetOK bool
+ // im is the current insertion mode.
+ im insertionMode
// originalIM is the insertion mode to go back to after completing a text
// or inTableText insertion mode.
originalIM insertionMode
// An insertion mode (section 11.2.3.1) is the state transition function from
// a particular state in the HTML5 parser's state machine. It updates the
-// parser's fields depending on parser.token (where ErrorToken means EOF). In
-// addition to returning the next insertionMode state, it also returns whether
-// the token was consumed.
-type insertionMode func(*parser) (insertionMode, bool)
-
-// useTheRulesFor runs the delegate insertionMode over p, returning the actual
-// insertionMode unless the delegate caused a state transition.
-// Section 11.2.3.1, "using the rules for".
-func useTheRulesFor(p *parser, actual, delegate insertionMode) (insertionMode, bool) {
- im, consumed := delegate(p)
- if p.originalIM == delegate {
- p.originalIM = actual
- }
- if im != delegate {
- return im, consumed
- }
- return actual, consumed
-}
+// parser's fields depending on parser.tok (where ErrorToken means EOF).
+// It returns whether the token was consumed.
+type insertionMode func(*parser) bool
// setOriginalIM sets the insertion mode to return to after completing a text or
// inTableText insertion mode.
// Section 11.2.3.1, "using the rules for".
-func (p *parser) setOriginalIM(im insertionMode) {
+func (p *parser) setOriginalIM() {
if p.originalIM != nil {
panic("html: bad parser state: originalIM was set twice")
}
- p.originalIM = im
+ p.originalIM = p.im
}
// Section 11.2.3.1, "reset the insertion mode".
-func (p *parser) resetInsertionMode() insertionMode {
+func (p *parser) resetInsertionMode() {
for i := len(p.oe) - 1; i >= 0; i-- {
n := p.oe[i]
if i == 0 {
}
switch n.Data {
case "select":
- return inSelectIM
+ p.im = inSelectIM
case "td", "th":
- return inCellIM
+ p.im = inCellIM
case "tr":
- return inRowIM
+ p.im = inRowIM
case "tbody", "thead", "tfoot":
- return inTableBodyIM
+ p.im = inTableBodyIM
case "caption":
- // TODO: return inCaptionIM
+ p.im = inCaptionIM
case "colgroup":
- // TODO: return inColumnGroupIM
+ p.im = inColumnGroupIM
case "table":
- return inTableIM
+ p.im = inTableIM
case "head":
- return inBodyIM
+ p.im = inBodyIM
case "body":
- return inBodyIM
+ p.im = inBodyIM
case "frameset":
- // TODO: return inFramesetIM
+ p.im = inFramesetIM
case "html":
- return beforeHeadIM
+ p.im = beforeHeadIM
+ default:
+ continue
}
+ return
}
- return inBodyIM
+ p.im = inBodyIM
}
// Section 11.2.5.4.1.
-func initialIM(p *parser) (insertionMode, bool) {
+func initialIM(p *parser) bool {
switch p.tok.Type {
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return initialIM, true
+ return true
case DoctypeToken:
p.doc.Add(&Node{
Type: DoctypeNode,
Data: p.tok.Data,
})
- return beforeHTMLIM, true
+ p.im = beforeHTMLIM
+ return true
}
// TODO: set "quirks mode"? It's defined in the DOM spec instead of HTML5 proper,
// and so switching on "quirks mode" might belong in a different package.
- return beforeHTMLIM, false
+ p.im = beforeHTMLIM
+ return false
}
// Section 11.2.5.4.2.
-func beforeHTMLIM(p *parser) (insertionMode, bool) {
- var (
- add bool
- attr []Attribute
- implied bool
- )
+func beforeHTMLIM(p *parser) bool {
switch p.tok.Type {
- case ErrorToken:
- implied = true
- case TextToken:
- // TODO: distinguish whitespace text from others.
- implied = true
case StartTagToken:
if p.tok.Data == "html" {
- add = true
- attr = p.tok.Attr
- } else {
- implied = true
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.im = beforeHeadIM
+ return true
}
case EndTagToken:
switch p.tok.Data {
case "head", "body", "html", "br":
- implied = true
+ // Drop down to creating an implied <html> tag.
default:
// Ignore the token.
+ return true
}
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return beforeHTMLIM, true
- }
- if add || implied {
- p.addElement("html", attr)
+ return true
}
- return beforeHeadIM, !implied
+ // Create an implied <html> tag.
+ p.addElement("html", nil)
+ p.im = beforeHeadIM
+ return false
}
// Section 11.2.5.4.3.
-func beforeHeadIM(p *parser) (insertionMode, bool) {
+func beforeHeadIM(p *parser) bool {
var (
add bool
attr []Attribute
add = true
attr = p.tok.Attr
case "html":
- return useTheRulesFor(p, beforeHeadIM, inBodyIM)
+ return inBodyIM(p)
default:
implied = true
}
Type: CommentNode,
Data: p.tok.Data,
})
- return beforeHeadIM, true
+ return true
}
if add || implied {
p.addElement("head", attr)
p.head = p.top()
}
- return inHeadIM, !implied
+ p.im = inHeadIM
+ return !implied
}
const whitespace = " \t\r\n\f"
// Section 11.2.5.4.4.
-func inHeadIM(p *parser) (insertionMode, bool) {
+func inHeadIM(p *parser) bool {
var (
pop bool
implied bool
// Add the initial whitespace to the current node.
p.addText(p.tok.Data[:len(p.tok.Data)-len(s)])
if s == "" {
- return inHeadIM, true
+ return true
}
p.tok.Data = s
}
p.acknowledgeSelfClosingTag()
case "script", "title", "noscript", "noframes", "style":
p.addElement(p.tok.Data, p.tok.Attr)
- p.setOriginalIM(inHeadIM)
- return textIM, true
+ p.setOriginalIM()
+ p.im = textIM
+ return true
default:
implied = true
}
case EndTagToken:
- if p.tok.Data == "head" {
+ switch p.tok.Data {
+ case "head":
pop = true
+ case "body", "html", "br":
+ implied = true
+ default:
+ // Ignore the token.
+ return true
}
- // TODO.
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return inHeadIM, true
+ return true
}
if pop || implied {
n := p.oe.pop()
if n.Data != "head" {
panic("html: bad parser state: <head> element not found, in the in-head insertion mode")
}
- return afterHeadIM, !implied
+ p.im = afterHeadIM
+ return !implied
}
- return inHeadIM, true
+ return true
}
// Section 11.2.5.4.6.
-func afterHeadIM(p *parser) (insertionMode, bool) {
+func afterHeadIM(p *parser) bool {
var (
add bool
attr []Attribute
attr = p.tok.Attr
framesetOK = false
case "frameset":
- // TODO.
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.im = inFramesetIM
+ return true
case "base", "basefont", "bgsound", "link", "meta", "noframes", "script", "style", "title":
p.oe = append(p.oe, p.head)
defer p.oe.pop()
- return useTheRulesFor(p, afterHeadIM, inHeadIM)
+ return inHeadIM(p)
case "head":
// TODO.
default:
framesetOK = true
}
case EndTagToken:
- // TODO.
+ switch p.tok.Data {
+ case "body", "html", "br":
+ implied = true
+ framesetOK = true
+ default:
+ // Ignore the token.
+ return true
+ }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return afterHeadIM, true
+ return true
}
if add || implied {
p.addElement("body", attr)
p.framesetOK = framesetOK
}
- return inBodyIM, !implied
+ p.im = inBodyIM
+ return !implied
}
// copyAttributes copies attributes of src not found on dst to dst.
}
// Section 11.2.5.4.7.
-func inBodyIM(p *parser) (insertionMode, bool) {
+func inBodyIM(p *parser) bool {
switch p.tok.Type {
case TextToken:
p.reconstructActiveFormattingElements()
p.popUntil(buttonScopeStopTags, "p") // TODO: skip this step in quirks mode.
p.addElement(p.tok.Data, p.tok.Attr)
p.framesetOK = false
- return inTableIM, true
+ p.im = inTableIM
+ return true
case "hr":
p.popUntil(buttonScopeStopTags, "p")
p.addElement(p.tok.Data, p.tok.Attr)
p.addElement(p.tok.Data, p.tok.Attr)
p.framesetOK = false
// TODO: detect <select> inside a table.
- return inSelectIM, true
+ p.im = inSelectIM
+ return true
+ case "form":
+ if p.form == nil {
+ p.popUntil(buttonScopeStopTags, "p")
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.form = p.top()
+ }
case "li":
p.framesetOK = false
for i := len(p.oe) - 1; i >= 0; i-- {
break
}
p.popUntil(buttonScopeStopTags, "p")
- p.addElement("li", p.tok.Attr)
+ p.addElement(p.tok.Data, p.tok.Attr)
+ case "dd", "dt":
+ p.framesetOK = false
+ for i := len(p.oe) - 1; i >= 0; i-- {
+ node := p.oe[i]
+ switch node.Data {
+ case "dd", "dt":
+ p.oe = p.oe[:i]
+ case "address", "div", "p":
+ continue
+ default:
+ if !isSpecialElement[node.Data] {
+ continue
+ }
+ }
+ break
+ }
+ p.popUntil(buttonScopeStopTags, "p")
+ p.addElement(p.tok.Data, p.tok.Attr)
+ case "plaintext":
+ p.popUntil(buttonScopeStopTags, "p")
+ p.addElement(p.tok.Data, p.tok.Attr)
case "optgroup", "option":
if p.top().Data == "option" {
p.oe.pop()
}
}
case "base", "basefont", "bgsound", "command", "link", "meta", "noframes", "script", "style", "title":
- return useTheRulesFor(p, inBodyIM, inHeadIM)
+ return inHeadIM(p)
case "image":
p.tok.Data = "img"
- return inBodyIM, false
+ return false
+ case "isindex":
+ if p.form != nil {
+ // Ignore the token.
+ return true
+ }
+ action := ""
+ prompt := "This is a searchable index. Enter search keywords: "
+ attr := []Attribute{{Key: "name", Val: "isindex"}}
+ for _, a := range p.tok.Attr {
+ switch a.Key {
+ case "action":
+ action = a.Val
+ case "name":
+ // Ignore the attribute.
+ case "prompt":
+ prompt = a.Val
+ default:
+ attr = append(attr, a)
+ }
+ }
+ p.acknowledgeSelfClosingTag()
+ p.popUntil(buttonScopeStopTags, "p")
+ p.addElement("form", nil)
+ p.form = p.top()
+ if action != "" {
+ p.form.Attr = []Attribute{{Key: "action", Val: action}}
+ }
+ p.addElement("hr", nil)
+ p.oe.pop()
+ p.addElement("label", nil)
+ p.addText(prompt)
+ p.addElement("input", attr)
+ p.oe.pop()
+ p.oe.pop()
+ p.addElement("hr", nil)
+ p.oe.pop()
+ p.oe.pop()
+ p.form = nil
+ case "caption", "col", "colgroup", "frame", "head", "tbody", "td", "tfoot", "th", "thead", "tr":
+ // Ignore the token.
default:
// TODO.
p.addElement(p.tok.Data, p.tok.Attr)
switch p.tok.Data {
case "body":
// TODO: autoclose the stack of open elements.
- return afterBodyIM, true
+ p.im = afterBodyIM
+ return true
case "p":
if !p.elementInScope(buttonScopeStopTags, "p") {
p.addElement("p", nil)
if p.popUntil(defaultScopeStopTags, p.tok.Data) {
p.clearActiveFormattingElements()
}
+ case "br":
+ p.tok.Type = StartTagToken
+ return false
default:
p.inBodyEndTagOther(p.tok.Data)
}
})
}
- return inBodyIM, true
+ return true
}
func (p *parser) inBodyEndTagFormatting(tag string) {
}
// Section 11.2.5.4.8.
-func textIM(p *parser) (insertionMode, bool) {
+func textIM(p *parser) bool {
switch p.tok.Type {
case ErrorToken:
p.oe.pop()
case TextToken:
p.addText(p.tok.Data)
- return textIM, true
+ return true
case EndTagToken:
p.oe.pop()
}
- o := p.originalIM
+ p.im = p.originalIM
p.originalIM = nil
- return o, p.tok.Type == EndTagToken
+ return p.tok.Type == EndTagToken
}
// Section 11.2.5.4.9.
-func inTableIM(p *parser) (insertionMode, bool) {
+func inTableIM(p *parser) bool {
switch p.tok.Type {
case ErrorToken:
// Stop parsing.
- return nil, true
+ return true
case TextToken:
// TODO.
case StartTagToken:
switch p.tok.Data {
+ case "caption":
+ p.clearStackToContext(tableScopeStopTags)
+ p.afe = append(p.afe, &scopeMarker)
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.im = inCaptionIM
+ return true
case "tbody", "tfoot", "thead":
p.clearStackToContext(tableScopeStopTags)
p.addElement(p.tok.Data, p.tok.Attr)
- return inTableBodyIM, true
+ p.im = inTableBodyIM
+ return true
case "td", "th", "tr":
p.clearStackToContext(tableScopeStopTags)
p.addElement("tbody", nil)
- return inTableBodyIM, false
+ p.im = inTableBodyIM
+ return false
case "table":
if p.popUntil(tableScopeStopTags, "table") {
- return p.resetInsertionMode(), false
+ p.resetInsertionMode()
+ return false
}
// Ignore the token.
- return inTableIM, true
+ return true
+ case "colgroup":
+ p.clearStackToContext(tableScopeStopTags)
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.im = inColumnGroupIM
+ return true
+ case "col":
+ p.clearStackToContext(tableScopeStopTags)
+ p.addElement("colgroup", p.tok.Attr)
+ p.im = inColumnGroupIM
+ return false
default:
// TODO.
}
switch p.tok.Data {
case "table":
if p.popUntil(tableScopeStopTags, "table") {
- return p.resetInsertionMode(), true
+ p.resetInsertionMode()
+ return true
}
// Ignore the token.
- return inTableIM, true
+ return true
case "body", "caption", "col", "colgroup", "html", "tbody", "td", "tfoot", "th", "thead", "tr":
// Ignore the token.
- return inTableIM, true
+ return true
}
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return inTableIM, true
+ return true
}
switch p.top().Data {
defer func() { p.fosterParenting = false }()
}
- return useTheRulesFor(p, inTableIM, inBodyIM)
+ return inBodyIM(p)
}
// clearStackToContext pops elements off the stack of open elements
}
}
+// Section 11.2.5.4.11.
+func inCaptionIM(p *parser) bool {
+ switch p.tok.Type {
+ case StartTagToken:
+ switch p.tok.Data {
+ case "caption", "col", "colgroup", "tbody", "td", "tfoot", "thead", "tr":
+ if p.popUntil(tableScopeStopTags, "caption") {
+ p.clearActiveFormattingElements()
+ p.im = inTableIM
+ return false
+ } else {
+ // Ignore the token.
+ return true
+ }
+ }
+ case EndTagToken:
+ switch p.tok.Data {
+ case "caption":
+ if p.popUntil(tableScopeStopTags, "caption") {
+ p.clearActiveFormattingElements()
+ p.im = inTableIM
+ }
+ return true
+ case "table":
+ if p.popUntil(tableScopeStopTags, "caption") {
+ p.clearActiveFormattingElements()
+ p.im = inTableIM
+ return false
+ } else {
+ // Ignore the token.
+ return true
+ }
+ case "body", "col", "colgroup", "html", "tbody", "td", "tfoot", "th", "thead", "tr":
+ // Ignore the token.
+ return true
+ }
+ }
+ return inBodyIM(p)
+}
+
+// Section 11.2.5.4.12.
+func inColumnGroupIM(p *parser) bool {
+ switch p.tok.Type {
+ case CommentToken:
+ p.addChild(&Node{
+ Type: CommentNode,
+ Data: p.tok.Data,
+ })
+ return true
+ case DoctypeToken:
+ // Ignore the token.
+ return true
+ case StartTagToken:
+ switch p.tok.Data {
+ case "html":
+ return inBodyIM(p)
+ case "col":
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.oe.pop()
+ p.acknowledgeSelfClosingTag()
+ return true
+ }
+ case EndTagToken:
+ switch p.tok.Data {
+ case "colgroup":
+ if p.oe.top().Data != "html" {
+ p.oe.pop()
+ }
+ p.im = inTableIM
+ return true
+ case "col":
+ // Ignore the token.
+ return true
+ }
+ }
+ if p.oe.top().Data != "html" {
+ p.oe.pop()
+ }
+ p.im = inTableIM
+ return false
+}
+
// Section 11.2.5.4.13.
-func inTableBodyIM(p *parser) (insertionMode, bool) {
+func inTableBodyIM(p *parser) bool {
var (
add bool
data string
switch p.tok.Data {
case "table":
if p.popUntil(tableScopeStopTags, "tbody", "thead", "tfoot") {
- return inTableIM, false
+ p.im = inTableIM
+ return false
}
// Ignore the token.
- return inTableBodyIM, true
+ return true
case "body", "caption", "col", "colgroup", "html", "td", "th", "tr":
// Ignore the token.
- return inTableBodyIM, true
+ return true
}
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return inTableBodyIM, true
+ return true
}
if add {
// TODO: clear the stack back to a table body context.
p.addElement(data, attr)
- return inRowIM, consumed
+ p.im = inRowIM
+ return consumed
}
- return useTheRulesFor(p, inTableBodyIM, inTableIM)
+ return inTableIM(p)
}
// Section 11.2.5.4.14.
-func inRowIM(p *parser) (insertionMode, bool) {
+func inRowIM(p *parser) bool {
switch p.tok.Type {
case ErrorToken:
// TODO.
p.clearStackToContext(tableRowContextStopTags)
p.addElement(p.tok.Data, p.tok.Attr)
p.afe = append(p.afe, &scopeMarker)
- return inCellIM, true
+ p.im = inCellIM
+ return true
case "caption", "col", "colgroup", "tbody", "tfoot", "thead", "tr":
if p.popUntil(tableScopeStopTags, "tr") {
- return inTableBodyIM, false
+ p.im = inTableBodyIM
+ return false
}
// Ignore the token.
- return inRowIM, true
+ return true
default:
// TODO.
}
switch p.tok.Data {
case "tr":
if p.popUntil(tableScopeStopTags, "tr") {
- return inTableBodyIM, true
+ p.im = inTableBodyIM
+ return true
}
// Ignore the token.
- return inRowIM, true
+ return true
case "table":
if p.popUntil(tableScopeStopTags, "tr") {
- return inTableBodyIM, false
+ p.im = inTableBodyIM
+ return false
}
// Ignore the token.
- return inRowIM, true
+ return true
case "tbody", "tfoot", "thead":
// TODO.
case "body", "caption", "col", "colgroup", "html", "td", "th":
// Ignore the token.
- return inRowIM, true
+ return true
default:
// TODO.
}
Type: CommentNode,
Data: p.tok.Data,
})
- return inRowIM, true
+ return true
}
- return useTheRulesFor(p, inRowIM, inTableIM)
+ return inTableIM(p)
}
// Section 11.2.5.4.15.
-func inCellIM(p *parser) (insertionMode, bool) {
+func inCellIM(p *parser) bool {
var (
closeTheCellAndReprocess bool
)
case "td", "th":
if !p.popUntil(tableScopeStopTags, p.tok.Data) {
// Ignore the token.
- return inCellIM, true
+ return true
}
p.clearActiveFormattingElements()
- return inRowIM, true
+ p.im = inRowIM
+ return true
case "body", "caption", "col", "colgroup", "html":
// TODO.
case "table", "tbody", "tfoot", "thead", "tr":
Type: CommentNode,
Data: p.tok.Data,
})
- return inCellIM, true
+ return true
}
if closeTheCellAndReprocess {
if p.popUntil(tableScopeStopTags, "td") || p.popUntil(tableScopeStopTags, "th") {
p.clearActiveFormattingElements()
- return inRowIM, false
+ p.im = inRowIM
+ return false
}
}
- return useTheRulesFor(p, inCellIM, inBodyIM)
+ return inBodyIM(p)
}
// Section 11.2.5.4.16.
-func inSelectIM(p *parser) (insertionMode, bool) {
+func inSelectIM(p *parser) bool {
endSelect := false
switch p.tok.Type {
case ErrorToken:
}
p.addElement(p.tok.Data, p.tok.Attr)
case "optgroup":
- // TODO.
+ if p.top().Data == "option" {
+ p.oe.pop()
+ }
+ if p.top().Data == "optgroup" {
+ p.oe.pop()
+ }
+ p.addElement(p.tok.Data, p.tok.Attr)
case "select":
endSelect = true
case "input", "keygen", "textarea":
case EndTagToken:
switch p.tok.Data {
case "option":
- // TODO.
+ if p.top().Data == "option" {
+ p.oe.pop()
+ }
case "optgroup":
- // TODO.
+ i := len(p.oe) - 1
+ if p.oe[i].Data == "option" {
+ i--
+ }
+ if p.oe[i].Data == "optgroup" {
+ p.oe = p.oe[:i]
+ }
case "select":
endSelect = true
default:
switch p.oe[i].Data {
case "select":
p.oe = p.oe[:i]
- return p.resetInsertionMode(), true
+ p.resetInsertionMode()
+ return true
case "option", "optgroup":
continue
default:
// Ignore the token.
- return inSelectIM, true
+ return true
}
}
}
- return inSelectIM, true
+ return true
}
// Section 11.2.5.4.18.
-func afterBodyIM(p *parser) (insertionMode, bool) {
+func afterBodyIM(p *parser) bool {
switch p.tok.Type {
case ErrorToken:
- // TODO.
- case TextToken:
- // TODO.
+ // Stop parsing.
+ return true
case StartTagToken:
- // TODO.
+ if p.tok.Data == "html" {
+ return inBodyIM(p)
+ }
case EndTagToken:
- switch p.tok.Data {
- case "html":
- // TODO: autoclose the stack of open elements.
- return afterAfterBodyIM, true
- default:
- // TODO.
+ if p.tok.Data == "html" {
+ p.im = afterAfterBodyIM
+ return true
}
case CommentToken:
// The comment is attached to the <html> element.
Type: CommentNode,
Data: p.tok.Data,
})
- return afterBodyIM, true
+ return true
+ }
+ p.im = inBodyIM
+ return false
+}
+
+// Section 11.2.5.4.19.
+func inFramesetIM(p *parser) bool {
+ switch p.tok.Type {
+ case CommentToken:
+ p.addChild(&Node{
+ Type: CommentNode,
+ Data: p.tok.Data,
+ })
+ case StartTagToken:
+ switch p.tok.Data {
+ case "html":
+ return inBodyIM(p)
+ case "frameset":
+ p.addElement(p.tok.Data, p.tok.Attr)
+ case "frame":
+ p.addElement(p.tok.Data, p.tok.Attr)
+ p.oe.pop()
+ p.acknowledgeSelfClosingTag()
+ case "noframes":
+ return inHeadIM(p)
+ }
+ case EndTagToken:
+ switch p.tok.Data {
+ case "frameset":
+ if p.oe.top().Data != "html" {
+ p.oe.pop()
+ if p.oe.top().Data != "frameset" {
+ p.im = afterFramesetIM
+ return true
+ }
+ }
+ }
+ default:
+ // Ignore the token.
+ }
+ return true
+}
+
+// Section 11.2.5.4.20.
+func afterFramesetIM(p *parser) bool {
+ switch p.tok.Type {
+ case CommentToken:
+ p.addChild(&Node{
+ Type: CommentNode,
+ Data: p.tok.Data,
+ })
+ case StartTagToken:
+ switch p.tok.Data {
+ case "html":
+ return inBodyIM(p)
+ case "noframes":
+ return inHeadIM(p)
+ }
+ case EndTagToken:
+ switch p.tok.Data {
+ case "html":
+ p.im = afterAfterFramesetIM
+ return true
+ }
+ default:
+ // Ignore the token.
}
- // TODO: should this be "return inBodyIM, true"?
- return afterBodyIM, true
+ return true
}
// Section 11.2.5.4.21.
-func afterAfterBodyIM(p *parser) (insertionMode, bool) {
+func afterAfterBodyIM(p *parser) bool {
switch p.tok.Type {
case ErrorToken:
// Stop parsing.
- return nil, true
+ return true
case TextToken:
// TODO.
case StartTagToken:
if p.tok.Data == "html" {
- return useTheRulesFor(p, afterAfterBodyIM, inBodyIM)
+ return inBodyIM(p)
}
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
- return afterAfterBodyIM, true
+ return true
+ }
+ p.im = inBodyIM
+ return false
+}
+
+// Section 11.2.5.4.22.
+func afterAfterFramesetIM(p *parser) bool {
+ switch p.tok.Type {
+ case CommentToken:
+ p.addChild(&Node{
+ Type: CommentNode,
+ Data: p.tok.Data,
+ })
+ case StartTagToken:
+ switch p.tok.Data {
+ case "html":
+ return inBodyIM(p)
+ case "noframes":
+ return inHeadIM(p)
+ }
+ default:
+ // Ignore the token.
}
- return inBodyIM, false
+ return true
}
// Parse returns the parse tree for the HTML from the given Reader.
},
scripting: true,
framesetOK: true,
+ im: initialIM,
}
// Iterate until EOF. Any other error will cause an early return.
- im, consumed := initialIM, true
+ consumed := true
for {
if consumed {
if err := p.read(); err != nil {
return nil, err
}
}
- im, consumed = im(p)
+ consumed = p.im(p)
}
// Loop until the final token (the ErrorToken signifying EOF) is consumed.
for {
- if im, consumed = im(p); consumed {
+ if consumed = p.im(p); consumed {
break
}
}
n int
}{
// TODO(nigeltao): Process all the test cases from all the .dat files.
- {"tests1.dat", 92},
- {"tests2.dat", 0},
+ {"tests1.dat", -1},
+ {"tests2.dat", 43},
{"tests3.dat", 0},
}
for _, tf := range testFiles {
// More cases of <a> being reparented:
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
`<a><table><a></table><p><a><div><a>`: true,
+ `<a><table><td><a><table></table><a></tr><a></table><a>`: true,
+ // A <plaintext> element is reparented, putting it before a table.
+ // A <plaintext> element can't have anything after it in HTML.
+ `<table><plaintext><td>`: true,
}
return buf.Flush()
}
+// plaintextAbort is returned from render1 when a <plaintext> element
+// has been rendered. No more end tags should be rendered after that.
+var plaintextAbort = errors.New("html: internal error (plaintext abort)")
+
func render(w writer, n *Node) error {
+ err := render1(w, n)
+ if err == plaintextAbort {
+ err = nil
+ }
+ return err
+}
+
+func render1(w writer, n *Node) error {
// Render non-element nodes; these are the easy cases.
switch n.Type {
case ErrorNode:
return escape(w, n.Data)
case DocumentNode:
for _, c := range n.Child {
- if err := render(w, c); err != nil {
+ if err := render1(w, c); err != nil {
return err
}
}
// Render any child nodes.
switch n.Data {
- case "noembed", "noframes", "noscript", "script", "style":
+ case "noembed", "noframes", "noscript", "plaintext", "script", "style":
for _, c := range n.Child {
if c.Type != TextNode {
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
return err
}
}
+ if n.Data == "plaintext" {
+ // Don't render anything else. <plaintext> must be the
+ // last element in the file, with no closing tag.
+ return plaintextAbort
+ }
case "textarea", "title":
for _, c := range n.Child {
if c.Type != TextNode {
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
}
- if err := render(w, c); err != nil {
+ if err := render1(w, c); err != nil {
return err
}
}
default:
for _, c := range n.Child {
- if err := render(w, c); err != nil {
+ if err := render1(w, c); err != nil {
return err
}
}
import (
"fmt"
+ "reflect"
)
// Strings of content from a trusted source.
contentTypeUnsafe
)
+// indirect returns the value, after dereferencing as many times
+// as necessary to reach the base type (or nil).
+func indirect(a interface{}) interface{} {
+ if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
+ // Avoid creating a reflect.Value if it's not a pointer.
+ return a
+ }
+ v := reflect.ValueOf(a)
+ for v.Kind() == reflect.Ptr && !v.IsNil() {
+ v = v.Elem()
+ }
+ return v.Interface()
+}
+
// stringify converts its arguments to a string and the type of the content.
+// All pointers are dereferenced, as in the text/template package.
func stringify(args ...interface{}) (string, contentType) {
if len(args) == 1 {
- switch s := args[0].(type) {
+ switch s := indirect(args[0]).(type) {
case string:
return s, contentTypePlain
case CSS:
return string(s), contentTypeURL
}
}
+ for i, arg := range args {
+ args[i] = indirect(arg)
+ }
return fmt.Sprint(args...), contentTypePlain
}
}
func TestEscape(t *testing.T) {
- var data = struct {
+ data := struct {
F, T bool
C, G, H string
A, E []string
Z: nil,
W: HTML(`¡<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
}
+ pdata := &data
tests := []struct {
name string
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue
}
+ b.Reset()
+ if err := tmpl.Execute(b, pdata); err != nil {
+ t.Errorf("%s: template execution failed for pointer: %s", test.name, err)
+ continue
+ }
+ if w, g := test.output, b.String(); w != g {
+ t.Errorf("%s: escaped output for pointer: want\n\t%q\ngot\n\t%q", test.name, w, g)
+ continue
+ }
}
}
}
}
+func TestIndirectPrint(t *testing.T) {
+ a := 3
+ ap := &a
+ b := "hello"
+ bp := &b
+ bpp := &bp
+ tmpl := Must(New("t").Parse(`{{.}}`))
+ var buf bytes.Buffer
+ err := tmpl.Execute(&buf, ap)
+ if err != nil {
+ t.Errorf("Unexpected error: %s", err)
+ } else if buf.String() != "3" {
+ t.Errorf(`Expected "3"; got %q`, buf.String())
+ }
+ buf.Reset()
+ err = tmpl.Execute(&buf, bpp)
+ if err != nil {
+ t.Errorf("Unexpected error: %s", err)
+ } else if buf.String() != "hello" {
+ t.Errorf(`Expected "hello"; got %q`, buf.String())
+ }
+}
+
func BenchmarkEscapedExecute(b *testing.B) {
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
var buf bytes.Buffer
"bytes"
"encoding/json"
"fmt"
+ "reflect"
"strings"
"unicode/utf8"
)
"void": true,
}
+var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
+
+// indirectToJSONMarshaler returns the value, after dereferencing as many times
+// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
+func indirectToJSONMarshaler(a interface{}) interface{} {
+ v := reflect.ValueOf(a)
+ for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Ptr && !v.IsNil() {
+ v = v.Elem()
+ }
+ return v.Interface()
+}
+
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
-// nether side-effects nor free variables outside (NaN, Infinity).
+// neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...interface{}) string {
var a interface{}
if len(args) == 1 {
- a = args[0]
+ a = indirectToJSONMarshaler(args[0])
switch t := a.(type) {
case JS:
return string(t)
a = t.String()
}
} else {
+ for i, arg := range args {
+ args[i] = indirectToJSONMarshaler(arg)
+ }
a = fmt.Sprint(args...)
}
// TODO: detect cycles before calling Marshal which loops infinitely on
break
}
}
- // Any "<noembed>", "<noframes>", "<noscript>", "<script>", "<style>",
+ // Any "<noembed>", "<noframes>", "<noscript>", "<plaintext", "<script>", "<style>",
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
- // The tag name lengths of these special cases ranges in [5, 8].
- if x := z.data.end - z.data.start; 5 <= x && x <= 8 {
+ // The tag name lengths of these special cases ranges in [5, 9].
+ if x := z.data.end - z.data.start; 5 <= x && x <= 9 {
switch z.buf[z.data.start] {
- case 'n', 's', 't', 'N', 'S', 'T':
+ case 'n', 'p', 's', 't', 'N', 'P', 'S', 'T':
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
- case "noembed", "noframes", "noscript", "script", "style", "textarea", "title":
+ case "noembed", "noframes", "noscript", "plaintext", "script", "style", "textarea", "title":
z.rawTag = s
}
}
z.data.start = z.raw.end
z.data.end = z.raw.end
if z.rawTag != "" {
- z.readRawOrRCDATA()
- z.tt = TextToken
- return z.tt
+ if z.rawTag == "plaintext" {
+ // Read everything up to EOF.
+ for z.err == nil {
+ z.readByte()
+ }
+ z.textIsRaw = true
+ } else {
+ z.readRawOrRCDATA()
+ }
+ if z.data.end > z.data.start {
+ z.tt = TextToken
+ return z.tt
+ }
}
z.textIsRaw = false
package tiff
-import (
- "io"
- "os"
-)
+import "io"
// buffer buffers an io.Reader to satisfy io.ReaderAt.
type buffer struct {
o := int(off)
end := o + len(p)
if int64(end) != off+int64(len(p)) {
- return 0, os.EINVAL
+ return 0, io.ErrUnexpectedEOF
}
m := len(b.buf)
"os"
"path/filepath"
"strconv"
+ "time"
)
// Random number state, accessed without lock; racy but harmless.
var rand uint32
func reseed() uint32 {
- sec, nsec, _ := os.Time()
-