123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- /*
- Copyright 2016 The Kubernetes Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package streaming
- import (
- "crypto/tls"
- "io"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strconv"
- "strings"
- "sync"
- "testing"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- restclient "k8s.io/client-go/rest"
- "k8s.io/client-go/tools/remotecommand"
- "k8s.io/client-go/transport/spdy"
- runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2"
- api "k8s.io/kubernetes/pkg/apis/core"
- kubeletportforward "k8s.io/kubernetes/pkg/kubelet/server/portforward"
- )
- const (
- testAddr = "localhost:12345"
- testContainerID = "container789"
- testPodSandboxID = "pod0987"
- )
- func TestGetExec(t *testing.T) {
- serv, err := NewServer(Config{
- Addr: testAddr,
- }, nil)
- assert.NoError(t, err)
- tlsServer, err := NewServer(Config{
- Addr: testAddr,
- TLSConfig: &tls.Config{},
- }, nil)
- assert.NoError(t, err)
- const pathPrefix = "cri/shim"
- prefixServer, err := NewServer(Config{
- Addr: testAddr,
- BaseURL: &url.URL{
- Scheme: "http",
- Host: testAddr,
- Path: "/" + pathPrefix + "/",
- },
- }, nil)
- assert.NoError(t, err)
- assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) {
- req, ok := cache.Consume(token)
- require.True(t, ok, "token %s not found!", token)
- assert.Equal(t, expectedReq, req)
- }
- request := &runtimeapi.ExecRequest{
- ContainerId: testContainerID,
- Cmd: []string{"echo", "foo"},
- Tty: true,
- Stdin: true,
- }
- { // Non-TLS
- resp, err := serv.GetExec(request)
- assert.NoError(t, err)
- expectedURL := "http://" + testAddr + "/exec/"
- assert.Contains(t, resp.Url, expectedURL)
- token := strings.TrimPrefix(resp.Url, expectedURL)
- assertRequestToken(request, serv.(*server).cache, token)
- }
- { // TLS
- resp, err := tlsServer.GetExec(request)
- assert.NoError(t, err)
- expectedURL := "https://" + testAddr + "/exec/"
- assert.Contains(t, resp.Url, expectedURL)
- token := strings.TrimPrefix(resp.Url, expectedURL)
- assertRequestToken(request, tlsServer.(*server).cache, token)
- }
- { // Path prefix
- resp, err := prefixServer.GetExec(request)
- assert.NoError(t, err)
- expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
- assert.Contains(t, resp.Url, expectedURL)
- token := strings.TrimPrefix(resp.Url, expectedURL)
- assertRequestToken(request, prefixServer.(*server).cache, token)
- }
- }
- func TestValidateExecAttachRequest(t *testing.T) {
- type config struct {
- tty bool
- stdin bool
- stdout bool
- stderr bool
- }
- for _, tc := range []struct {
- desc string
- configs []config
- expectErr bool
- }{
- {
- desc: "at least one stream must be true",
- expectErr: true,
- configs: []config{
- {false, false, false, false},
- {true, false, false, false}},
- },
- {
- desc: "tty and stderr cannot both be true",
- expectErr: true,
- configs: []config{
- {true, false, false, true},
- {true, false, true, true},
- {true, true, false, true},
- {true, true, true, true},
- },
- },
- {
- desc: "a valid config should pass",
- expectErr: false,
- configs: []config{
- {false, false, false, true},
- {false, false, true, false},
- {false, false, true, true},
- {false, true, false, false},
- {false, true, false, true},
- {false, true, true, false},
- {false, true, true, true},
- {true, false, true, false},
- {true, true, false, false},
- {true, true, true, false},
- },
- },
- } {
- t.Run(tc.desc, func(t *testing.T) {
- for _, c := range tc.configs {
- // validate the exec request.
- execReq := &runtimeapi.ExecRequest{
- ContainerId: testContainerID,
- Cmd: []string{"date"},
- Tty: c.tty,
- Stdin: c.stdin,
- Stdout: c.stdout,
- Stderr: c.stderr,
- }
- err := validateExecRequest(execReq)
- assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
- // validate the attach request.
- attachReq := &runtimeapi.AttachRequest{
- ContainerId: testContainerID,
- Tty: c.tty,
- Stdin: c.stdin,
- Stdout: c.stdout,
- Stderr: c.stderr,
- }
- err = validateAttachRequest(attachReq)
- assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
- }
- })
- }
- }
- func TestGetAttach(t *testing.T) {
- serv, err := NewServer(Config{
- Addr: testAddr,
- }, nil)
- require.NoError(t, err)
- tlsServer, err := NewServer(Config{
- Addr: testAddr,
- TLSConfig: &tls.Config{},
- }, nil)
- require.NoError(t, err)
- assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) {
- req, ok := cache.Consume(token)
- require.True(t, ok, "token %s not found!", token)
- assert.Equal(t, expectedReq, req)
- }
- request := &runtimeapi.AttachRequest{
- ContainerId: testContainerID,
- Stdin: true,
- Tty: true,
- }
- { // Non-TLS
- resp, err := serv.GetAttach(request)
- assert.NoError(t, err)
- expectedURL := "http://" + testAddr + "/attach/"
- assert.Contains(t, resp.Url, expectedURL)
- token := strings.TrimPrefix(resp.Url, expectedURL)
- assertRequestToken(request, serv.(*server).cache, token)
- }
- { // TLS
- resp, err := tlsServer.GetAttach(request)
- assert.NoError(t, err)
- expectedURL := "https://" + testAddr + "/attach/"
- assert.Contains(t, resp.Url, expectedURL)
- token := strings.TrimPrefix(resp.Url, expectedURL)
- assertRequestToken(request, tlsServer.(*server).cache, token)
- }
- }
- func TestGetPortForward(t *testing.T) {
- podSandboxID := testPodSandboxID
- request := &runtimeapi.PortForwardRequest{
- PodSandboxId: podSandboxID,
- Port: []int32{1, 2, 3, 4},
- }
- { // Non-TLS
- serv, err := NewServer(Config{
- Addr: testAddr,
- }, nil)
- assert.NoError(t, err)
- resp, err := serv.GetPortForward(request)
- assert.NoError(t, err)
- expectedURL := "http://" + testAddr + "/portforward/"
- assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
- token := strings.TrimPrefix(resp.Url, expectedURL)
- req, ok := serv.(*server).cache.Consume(token)
- require.True(t, ok, "token %s not found!", token)
- assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
- }
- { // TLS
- tlsServer, err := NewServer(Config{
- Addr: testAddr,
- TLSConfig: &tls.Config{},
- }, nil)
- assert.NoError(t, err)
- resp, err := tlsServer.GetPortForward(request)
- assert.NoError(t, err)
- expectedURL := "https://" + testAddr + "/portforward/"
- assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
- token := strings.TrimPrefix(resp.Url, expectedURL)
- req, ok := tlsServer.(*server).cache.Consume(token)
- require.True(t, ok, "token %s not found!", token)
- assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
- }
- }
- func TestServeExec(t *testing.T) {
- runRemoteCommandTest(t, "exec")
- }
- func TestServeAttach(t *testing.T) {
- runRemoteCommandTest(t, "attach")
- }
- func TestServePortForward(t *testing.T) {
- s, testServer := startTestServer(t)
- defer testServer.Close()
- resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
- PodSandboxId: testPodSandboxID,
- })
- require.NoError(t, err)
- reqURL, err := url.Parse(resp.Url)
- require.NoError(t, err)
- transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
- require.NoError(t, err)
- dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL)
- streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name)
- require.NoError(t, err)
- defer streamConn.Close()
- // Create the streams.
- headers := http.Header{}
- // Error stream is required, but unused in this test.
- headers.Set(api.StreamType, api.StreamTypeError)
- headers.Set(api.PortHeader, strconv.Itoa(testPort))
- _, err = streamConn.CreateStream(headers)
- require.NoError(t, err)
- // Setup the data stream.
- headers.Set(api.StreamType, api.StreamTypeData)
- headers.Set(api.PortHeader, strconv.Itoa(testPort))
- stream, err := streamConn.CreateStream(headers)
- require.NoError(t, err)
- doClientStreams(t, "portforward", stream, stream, nil)
- }
- //
- // Run the remote command test.
- // commandType is either "exec" or "attach".
- func runRemoteCommandTest(t *testing.T, commandType string) {
- s, testServer := startTestServer(t)
- defer testServer.Close()
- var reqURL *url.URL
- stdin, stdout, stderr := true, true, true
- containerID := testContainerID
- switch commandType {
- case "exec":
- resp, err := s.GetExec(&runtimeapi.ExecRequest{
- ContainerId: containerID,
- Cmd: []string{"echo"},
- Stdin: stdin,
- Stdout: stdout,
- Stderr: stderr,
- })
- require.NoError(t, err)
- reqURL, err = url.Parse(resp.Url)
- require.NoError(t, err)
- case "attach":
- resp, err := s.GetAttach(&runtimeapi.AttachRequest{
- ContainerId: containerID,
- Stdin: stdin,
- Stdout: stdout,
- Stderr: stderr,
- })
- require.NoError(t, err)
- reqURL, err = url.Parse(resp.Url)
- require.NoError(t, err)
- }
- wg := sync.WaitGroup{}
- wg.Add(2)
- stdinR, stdinW := io.Pipe()
- stdoutR, stdoutW := io.Pipe()
- stderrR, stderrW := io.Pipe()
- go func() {
- defer wg.Done()
- exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL)
- require.NoError(t, err)
- opts := remotecommand.StreamOptions{
- Stdin: stdinR,
- Stdout: stdoutW,
- Stderr: stderrW,
- Tty: false,
- }
- require.NoError(t, exec.Stream(opts))
- }()
- go func() {
- defer wg.Done()
- doClientStreams(t, commandType, stdinW, stdoutR, stderrR)
- }()
- wg.Wait()
- // Repeat request with the same URL should be a 404.
- resp, err := http.Get(reqURL.String())
- require.NoError(t, err)
- assert.Equal(t, http.StatusNotFound, resp.StatusCode)
- }
- func startTestServer(t *testing.T) (Server, *httptest.Server) {
- var s Server
- testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- s.ServeHTTP(w, r)
- }))
- cleanup := true
- defer func() {
- if cleanup {
- testServer.Close()
- }
- }()
- testURL, err := url.Parse(testServer.URL)
- require.NoError(t, err)
- rt := newFakeRuntime(t)
- config := DefaultConfig
- config.BaseURL = testURL
- s, err = NewServer(config, rt)
- require.NoError(t, err)
- cleanup = false // Caller must close the test server.
- return s, testServer
- }
- const (
- testInput = "abcdefg"
- testOutput = "fooBARbaz"
- testErr = "ERROR!!!"
- testPort = 12345
- )
- func newFakeRuntime(t *testing.T) *fakeRuntime {
- return &fakeRuntime{
- t: t,
- }
- }
- type fakeRuntime struct {
- t *testing.T
- }
- func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
- assert.Equal(f.t, testContainerID, containerID)
- doServerStreams(f.t, "exec", stdin, stdout, stderr)
- return nil
- }
- func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
- assert.Equal(f.t, testContainerID, containerID)
- doServerStreams(f.t, "attach", stdin, stdout, stderr)
- return nil
- }
- func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
- assert.Equal(f.t, testPodSandboxID, podSandboxID)
- assert.EqualValues(f.t, testPort, port)
- doServerStreams(f.t, "portforward", stream, stream, nil)
- return nil
- }
- // Send & receive expected input/output. Must be the inverse of doClientStreams.
- // Function will block until the expected i/o is finished.
- func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) {
- if stderr != nil {
- writeExpected(t, "server stderr", stderr, prefix+testErr)
- }
- readExpected(t, "server stdin", stdin, prefix+testInput)
- writeExpected(t, "server stdout", stdout, prefix+testOutput)
- }
- // Send & receive expected input/output. Must be the inverse of doServerStreams.
- // Function will block until the expected i/o is finished.
- func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) {
- if stderr != nil {
- readExpected(t, "client stderr", stderr, prefix+testErr)
- }
- writeExpected(t, "client stdin", stdin, prefix+testInput)
- readExpected(t, "client stdout", stdout, prefix+testOutput)
- }
- // Read and verify the expected string from the stream.
- func readExpected(t *testing.T, streamName string, r io.Reader, expected string) {
- result := make([]byte, len(expected))
- _, err := io.ReadAtLeast(r, result, len(expected))
- assert.NoError(t, err, "stream %s", streamName)
- assert.Equal(t, expected, string(result), "stream %s", streamName)
- }
- // Write and verify success of the data over the stream.
- func writeExpected(t *testing.T, streamName string, w io.Writer, data string) {
- n, err := io.WriteString(w, data)
- assert.NoError(t, err, "stream %s", streamName)
- assert.Equal(t, len(data), n, "stream %s", streamName)
- }
|