123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358 |
- /*
- Copyright 2015 The Kubernetes Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package tests
- import (
- "bytes"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
- "time"
- "github.com/stretchr/testify/require"
- "k8s.io/apimachinery/pkg/runtime"
- "k8s.io/apimachinery/pkg/runtime/schema"
- "k8s.io/apimachinery/pkg/types"
- "k8s.io/apimachinery/pkg/util/httpstream"
- remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
- restclient "k8s.io/client-go/rest"
- remoteclient "k8s.io/client-go/tools/remotecommand"
- "k8s.io/client-go/transport/spdy"
- "k8s.io/kubernetes/pkg/api/legacyscheme"
- api "k8s.io/kubernetes/pkg/apis/core"
- "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
- )
- type fakeExecutor struct {
- t *testing.T
- testName string
- errorData string
- stdoutData string
- stderrData string
- expectStdin bool
- stdinReceived bytes.Buffer
- tty bool
- messageCount int
- command []string
- exec bool
- }
- func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize, timeout time.Duration) error {
- return ex.run(name, uid, container, cmd, in, out, err, tty)
- }
- func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize) error {
- return ex.run(name, uid, container, nil, in, out, err, tty)
- }
- func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
- ex.command = cmd
- ex.tty = tty
- if e, a := "pod", name; e != a {
- ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
- }
- if e, a := "uid", uid; e != string(a) {
- ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
- }
- if ex.exec {
- if e, a := "ls /", strings.Join(ex.command, " "); e != a {
- ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
- }
- } else {
- if len(ex.command) > 0 {
- ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
- }
- }
- if len(ex.errorData) > 0 {
- return errors.New(ex.errorData)
- }
- if len(ex.stdoutData) > 0 {
- for i := 0; i < ex.messageCount; i++ {
- fmt.Fprint(out, ex.stdoutData)
- }
- }
- if len(ex.stderrData) > 0 {
- for i := 0; i < ex.messageCount; i++ {
- fmt.Fprint(err, ex.stderrData)
- }
- }
- if ex.expectStdin {
- io.Copy(&ex.stdinReceived, in)
- }
- return nil
- }
- func fakeServer(t *testing.T, requestReceived chan struct{}, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
- return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- executor := &fakeExecutor{
- t: t,
- testName: testName,
- errorData: errorData,
- stdoutData: stdoutData,
- stderrData: stderrData,
- expectStdin: len(stdinData) > 0,
- tty: tty,
- messageCount: messageCount,
- exec: exec,
- }
- opts, err := remotecommand.NewOptions(req)
- require.NoError(t, err)
- if exec {
- cmd := req.URL.Query()[api.ExecCommandParam]
- remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", cmd, opts, 0, 10*time.Second, serverProtocols)
- } else {
- remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", opts, 0, 10*time.Second, serverProtocols)
- }
- if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
- t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
- }
- close(requestReceived)
- })
- }
- func TestStream(t *testing.T) {
- testCases := []struct {
- TestName string
- Stdin string
- Stdout string
- Stderr string
- Error string
- Tty bool
- MessageCount int
- ClientProtocols []string
- ServerProtocols []string
- }{
- {
- TestName: "error",
- Error: "bail",
- Stdout: "a",
- ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- },
- {
- TestName: "in/out/err",
- Stdin: "a",
- Stdout: "b",
- Stderr: "c",
- MessageCount: 100,
- ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- },
- {
- TestName: "oversized stdin",
- Stdin: strings.Repeat("a", 20*1024*1024),
- Stdout: "b",
- Stderr: "",
- MessageCount: 1,
- ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- },
- {
- TestName: "in/out/tty",
- Stdin: "a",
- Stdout: "b",
- Tty: true,
- MessageCount: 100,
- ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
- },
- }
- for _, testCase := range testCases {
- for _, exec := range []bool{true, false} {
- var name string
- if exec {
- name = testCase.TestName + " (exec)"
- } else {
- name = testCase.TestName + " (attach)"
- }
- var (
- streamIn io.Reader
- streamOut, streamErr io.Writer
- )
- localOut := &bytes.Buffer{}
- localErr := &bytes.Buffer{}
- requestReceived := make(chan struct{})
- server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
- url, _ := url.ParseRequestURI(server.URL)
- config := restclient.ClientContentConfig{
- GroupVersion: schema.GroupVersion{Group: "x"},
- Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
- }
- c, err := restclient.NewRESTClient(url, "", config, nil, nil)
- if err != nil {
- t.Fatalf("failed to create a client: %v", err)
- }
- req := c.Post().Resource("testing")
- if exec {
- req.Param("command", "ls")
- req.Param("command", "/")
- }
- if len(testCase.Stdin) > 0 {
- req.Param(api.ExecStdinParam, "1")
- streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
- }
- if len(testCase.Stdout) > 0 {
- req.Param(api.ExecStdoutParam, "1")
- streamOut = localOut
- }
- if testCase.Tty {
- req.Param(api.ExecTTYParam, "1")
- } else if len(testCase.Stderr) > 0 {
- req.Param(api.ExecStderrParam, "1")
- streamErr = localErr
- }
- conf := &restclient.Config{
- Host: server.URL,
- }
- transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
- if err != nil {
- t.Errorf("%s: unexpected error: %v", name, err)
- continue
- }
- e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
- if err != nil {
- t.Errorf("%s: unexpected error: %v", name, err)
- continue
- }
- err = e.Stream(remoteclient.StreamOptions{
- Stdin: streamIn,
- Stdout: streamOut,
- Stderr: streamErr,
- Tty: testCase.Tty,
- })
- hasErr := err != nil
- if len(testCase.Error) > 0 {
- if !hasErr {
- t.Errorf("%s: expected an error", name)
- } else {
- if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
- t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
- }
- }
- server.Close()
- continue
- }
- if hasErr {
- t.Errorf("%s: unexpected error: %v", name, err)
- server.Close()
- continue
- }
- if len(testCase.Stdout) > 0 {
- if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
- t.Errorf("%s: expected stdout data %q, got %q", name, e, a)
- }
- }
- if testCase.Stderr != "" {
- if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
- t.Errorf("%s: expected stderr data %q, got %q", name, e, a)
- }
- }
- select {
- case <-requestReceived:
- case <-time.After(time.Minute):
- t.Errorf("%s: expected fakeServerInstance to receive request", name)
- }
- server.Close()
- }
- }
- }
- type fakeUpgrader struct {
- req *http.Request
- resp *http.Response
- conn httpstream.Connection
- err, connErr error
- checkResponse bool
- called bool
- t *testing.T
- }
- func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) {
- u.called = true
- u.req = req
- return u.resp, u.err
- }
- func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) {
- if u.checkResponse && u.resp != resp {
- u.t.Errorf("response objects passed did not match: %#v", resp)
- }
- return u.conn, u.connErr
- }
- type fakeConnection struct {
- httpstream.Connection
- }
- // Dial is the common functionality between any stream based upgrader, regardless of protocol.
- // This method ensures that someone can use a generic stream executor without being dependent
- // on the core Kube client config behavior.
- func TestDial(t *testing.T) {
- upgrader := &fakeUpgrader{
- t: t,
- checkResponse: true,
- conn: &fakeConnection{},
- resp: &http.Response{
- StatusCode: http.StatusSwitchingProtocols,
- Body: ioutil.NopCloser(&bytes.Buffer{}),
- },
- }
- dialer := spdy.NewDialer(upgrader, &http.Client{Transport: upgrader}, "POST", &url.URL{Host: "something.com", Scheme: "https"})
- conn, protocol, err := dialer.Dial("protocol1")
- if err != nil {
- t.Fatal(err)
- }
- if conn != upgrader.conn {
- t.Errorf("unexpected connection: %#v", conn)
- }
- if !upgrader.called {
- t.Errorf("request not called")
- }
- _ = protocol
- }
|