server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. /*
  2. Copyright 2016 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package streaming
  14. import (
  15. "crypto/tls"
  16. "io"
  17. "net/http"
  18. "net/http/httptest"
  19. "net/url"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "testing"
  24. "github.com/stretchr/testify/assert"
  25. "github.com/stretchr/testify/require"
  26. restclient "k8s.io/client-go/rest"
  27. "k8s.io/client-go/tools/remotecommand"
  28. "k8s.io/client-go/transport/spdy"
  29. runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2"
  30. api "k8s.io/kubernetes/pkg/apis/core"
  31. kubeletportforward "k8s.io/kubernetes/pkg/kubelet/server/portforward"
  32. )
  33. const (
  34. testAddr = "localhost:12345"
  35. testContainerID = "container789"
  36. testPodSandboxID = "pod0987"
  37. )
  38. func TestGetExec(t *testing.T) {
  39. serv, err := NewServer(Config{
  40. Addr: testAddr,
  41. }, nil)
  42. assert.NoError(t, err)
  43. tlsServer, err := NewServer(Config{
  44. Addr: testAddr,
  45. TLSConfig: &tls.Config{},
  46. }, nil)
  47. assert.NoError(t, err)
  48. const pathPrefix = "cri/shim"
  49. prefixServer, err := NewServer(Config{
  50. Addr: testAddr,
  51. BaseURL: &url.URL{
  52. Scheme: "http",
  53. Host: testAddr,
  54. Path: "/" + pathPrefix + "/",
  55. },
  56. }, nil)
  57. assert.NoError(t, err)
  58. assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) {
  59. req, ok := cache.Consume(token)
  60. require.True(t, ok, "token %s not found!", token)
  61. assert.Equal(t, expectedReq, req)
  62. }
  63. request := &runtimeapi.ExecRequest{
  64. ContainerId: testContainerID,
  65. Cmd: []string{"echo", "foo"},
  66. Tty: true,
  67. Stdin: true,
  68. }
  69. { // Non-TLS
  70. resp, err := serv.GetExec(request)
  71. assert.NoError(t, err)
  72. expectedURL := "http://" + testAddr + "/exec/"
  73. assert.Contains(t, resp.Url, expectedURL)
  74. token := strings.TrimPrefix(resp.Url, expectedURL)
  75. assertRequestToken(request, serv.(*server).cache, token)
  76. }
  77. { // TLS
  78. resp, err := tlsServer.GetExec(request)
  79. assert.NoError(t, err)
  80. expectedURL := "https://" + testAddr + "/exec/"
  81. assert.Contains(t, resp.Url, expectedURL)
  82. token := strings.TrimPrefix(resp.Url, expectedURL)
  83. assertRequestToken(request, tlsServer.(*server).cache, token)
  84. }
  85. { // Path prefix
  86. resp, err := prefixServer.GetExec(request)
  87. assert.NoError(t, err)
  88. expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
  89. assert.Contains(t, resp.Url, expectedURL)
  90. token := strings.TrimPrefix(resp.Url, expectedURL)
  91. assertRequestToken(request, prefixServer.(*server).cache, token)
  92. }
  93. }
  94. func TestValidateExecAttachRequest(t *testing.T) {
  95. type config struct {
  96. tty bool
  97. stdin bool
  98. stdout bool
  99. stderr bool
  100. }
  101. for _, tc := range []struct {
  102. desc string
  103. configs []config
  104. expectErr bool
  105. }{
  106. {
  107. desc: "at least one stream must be true",
  108. expectErr: true,
  109. configs: []config{
  110. {false, false, false, false},
  111. {true, false, false, false}},
  112. },
  113. {
  114. desc: "tty and stderr cannot both be true",
  115. expectErr: true,
  116. configs: []config{
  117. {true, false, false, true},
  118. {true, false, true, true},
  119. {true, true, false, true},
  120. {true, true, true, true},
  121. },
  122. },
  123. {
  124. desc: "a valid config should pass",
  125. expectErr: false,
  126. configs: []config{
  127. {false, false, false, true},
  128. {false, false, true, false},
  129. {false, false, true, true},
  130. {false, true, false, false},
  131. {false, true, false, true},
  132. {false, true, true, false},
  133. {false, true, true, true},
  134. {true, false, true, false},
  135. {true, true, false, false},
  136. {true, true, true, false},
  137. },
  138. },
  139. } {
  140. t.Run(tc.desc, func(t *testing.T) {
  141. for _, c := range tc.configs {
  142. // validate the exec request.
  143. execReq := &runtimeapi.ExecRequest{
  144. ContainerId: testContainerID,
  145. Cmd: []string{"date"},
  146. Tty: c.tty,
  147. Stdin: c.stdin,
  148. Stdout: c.stdout,
  149. Stderr: c.stderr,
  150. }
  151. err := validateExecRequest(execReq)
  152. assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
  153. // validate the attach request.
  154. attachReq := &runtimeapi.AttachRequest{
  155. ContainerId: testContainerID,
  156. Tty: c.tty,
  157. Stdin: c.stdin,
  158. Stdout: c.stdout,
  159. Stderr: c.stderr,
  160. }
  161. err = validateAttachRequest(attachReq)
  162. assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
  163. }
  164. })
  165. }
  166. }
  167. func TestGetAttach(t *testing.T) {
  168. serv, err := NewServer(Config{
  169. Addr: testAddr,
  170. }, nil)
  171. require.NoError(t, err)
  172. tlsServer, err := NewServer(Config{
  173. Addr: testAddr,
  174. TLSConfig: &tls.Config{},
  175. }, nil)
  176. require.NoError(t, err)
  177. assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) {
  178. req, ok := cache.Consume(token)
  179. require.True(t, ok, "token %s not found!", token)
  180. assert.Equal(t, expectedReq, req)
  181. }
  182. request := &runtimeapi.AttachRequest{
  183. ContainerId: testContainerID,
  184. Stdin: true,
  185. Tty: true,
  186. }
  187. { // Non-TLS
  188. resp, err := serv.GetAttach(request)
  189. assert.NoError(t, err)
  190. expectedURL := "http://" + testAddr + "/attach/"
  191. assert.Contains(t, resp.Url, expectedURL)
  192. token := strings.TrimPrefix(resp.Url, expectedURL)
  193. assertRequestToken(request, serv.(*server).cache, token)
  194. }
  195. { // TLS
  196. resp, err := tlsServer.GetAttach(request)
  197. assert.NoError(t, err)
  198. expectedURL := "https://" + testAddr + "/attach/"
  199. assert.Contains(t, resp.Url, expectedURL)
  200. token := strings.TrimPrefix(resp.Url, expectedURL)
  201. assertRequestToken(request, tlsServer.(*server).cache, token)
  202. }
  203. }
  204. func TestGetPortForward(t *testing.T) {
  205. podSandboxID := testPodSandboxID
  206. request := &runtimeapi.PortForwardRequest{
  207. PodSandboxId: podSandboxID,
  208. Port: []int32{1, 2, 3, 4},
  209. }
  210. { // Non-TLS
  211. serv, err := NewServer(Config{
  212. Addr: testAddr,
  213. }, nil)
  214. assert.NoError(t, err)
  215. resp, err := serv.GetPortForward(request)
  216. assert.NoError(t, err)
  217. expectedURL := "http://" + testAddr + "/portforward/"
  218. assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
  219. token := strings.TrimPrefix(resp.Url, expectedURL)
  220. req, ok := serv.(*server).cache.Consume(token)
  221. require.True(t, ok, "token %s not found!", token)
  222. assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
  223. }
  224. { // TLS
  225. tlsServer, err := NewServer(Config{
  226. Addr: testAddr,
  227. TLSConfig: &tls.Config{},
  228. }, nil)
  229. assert.NoError(t, err)
  230. resp, err := tlsServer.GetPortForward(request)
  231. assert.NoError(t, err)
  232. expectedURL := "https://" + testAddr + "/portforward/"
  233. assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
  234. token := strings.TrimPrefix(resp.Url, expectedURL)
  235. req, ok := tlsServer.(*server).cache.Consume(token)
  236. require.True(t, ok, "token %s not found!", token)
  237. assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
  238. }
  239. }
  240. func TestServeExec(t *testing.T) {
  241. runRemoteCommandTest(t, "exec")
  242. }
  243. func TestServeAttach(t *testing.T) {
  244. runRemoteCommandTest(t, "attach")
  245. }
  246. func TestServePortForward(t *testing.T) {
  247. s, testServer := startTestServer(t)
  248. defer testServer.Close()
  249. resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
  250. PodSandboxId: testPodSandboxID,
  251. })
  252. require.NoError(t, err)
  253. reqURL, err := url.Parse(resp.Url)
  254. require.NoError(t, err)
  255. transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
  256. require.NoError(t, err)
  257. dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL)
  258. streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name)
  259. require.NoError(t, err)
  260. defer streamConn.Close()
  261. // Create the streams.
  262. headers := http.Header{}
  263. // Error stream is required, but unused in this test.
  264. headers.Set(api.StreamType, api.StreamTypeError)
  265. headers.Set(api.PortHeader, strconv.Itoa(testPort))
  266. _, err = streamConn.CreateStream(headers)
  267. require.NoError(t, err)
  268. // Setup the data stream.
  269. headers.Set(api.StreamType, api.StreamTypeData)
  270. headers.Set(api.PortHeader, strconv.Itoa(testPort))
  271. stream, err := streamConn.CreateStream(headers)
  272. require.NoError(t, err)
  273. doClientStreams(t, "portforward", stream, stream, nil)
  274. }
  275. //
  276. // Run the remote command test.
  277. // commandType is either "exec" or "attach".
  278. func runRemoteCommandTest(t *testing.T, commandType string) {
  279. s, testServer := startTestServer(t)
  280. defer testServer.Close()
  281. var reqURL *url.URL
  282. stdin, stdout, stderr := true, true, true
  283. containerID := testContainerID
  284. switch commandType {
  285. case "exec":
  286. resp, err := s.GetExec(&runtimeapi.ExecRequest{
  287. ContainerId: containerID,
  288. Cmd: []string{"echo"},
  289. Stdin: stdin,
  290. Stdout: stdout,
  291. Stderr: stderr,
  292. })
  293. require.NoError(t, err)
  294. reqURL, err = url.Parse(resp.Url)
  295. require.NoError(t, err)
  296. case "attach":
  297. resp, err := s.GetAttach(&runtimeapi.AttachRequest{
  298. ContainerId: containerID,
  299. Stdin: stdin,
  300. Stdout: stdout,
  301. Stderr: stderr,
  302. })
  303. require.NoError(t, err)
  304. reqURL, err = url.Parse(resp.Url)
  305. require.NoError(t, err)
  306. }
  307. wg := sync.WaitGroup{}
  308. wg.Add(2)
  309. stdinR, stdinW := io.Pipe()
  310. stdoutR, stdoutW := io.Pipe()
  311. stderrR, stderrW := io.Pipe()
  312. go func() {
  313. defer wg.Done()
  314. exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL)
  315. require.NoError(t, err)
  316. opts := remotecommand.StreamOptions{
  317. Stdin: stdinR,
  318. Stdout: stdoutW,
  319. Stderr: stderrW,
  320. Tty: false,
  321. }
  322. require.NoError(t, exec.Stream(opts))
  323. }()
  324. go func() {
  325. defer wg.Done()
  326. doClientStreams(t, commandType, stdinW, stdoutR, stderrR)
  327. }()
  328. wg.Wait()
  329. // Repeat request with the same URL should be a 404.
  330. resp, err := http.Get(reqURL.String())
  331. require.NoError(t, err)
  332. assert.Equal(t, http.StatusNotFound, resp.StatusCode)
  333. }
  334. func startTestServer(t *testing.T) (Server, *httptest.Server) {
  335. var s Server
  336. testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  337. s.ServeHTTP(w, r)
  338. }))
  339. cleanup := true
  340. defer func() {
  341. if cleanup {
  342. testServer.Close()
  343. }
  344. }()
  345. testURL, err := url.Parse(testServer.URL)
  346. require.NoError(t, err)
  347. rt := newFakeRuntime(t)
  348. config := DefaultConfig
  349. config.BaseURL = testURL
  350. s, err = NewServer(config, rt)
  351. require.NoError(t, err)
  352. cleanup = false // Caller must close the test server.
  353. return s, testServer
  354. }
  355. const (
  356. testInput = "abcdefg"
  357. testOutput = "fooBARbaz"
  358. testErr = "ERROR!!!"
  359. testPort = 12345
  360. )
  361. func newFakeRuntime(t *testing.T) *fakeRuntime {
  362. return &fakeRuntime{
  363. t: t,
  364. }
  365. }
  366. type fakeRuntime struct {
  367. t *testing.T
  368. }
  369. func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
  370. assert.Equal(f.t, testContainerID, containerID)
  371. doServerStreams(f.t, "exec", stdin, stdout, stderr)
  372. return nil
  373. }
  374. func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
  375. assert.Equal(f.t, testContainerID, containerID)
  376. doServerStreams(f.t, "attach", stdin, stdout, stderr)
  377. return nil
  378. }
  379. func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
  380. assert.Equal(f.t, testPodSandboxID, podSandboxID)
  381. assert.EqualValues(f.t, testPort, port)
  382. doServerStreams(f.t, "portforward", stream, stream, nil)
  383. return nil
  384. }
  385. // Send & receive expected input/output. Must be the inverse of doClientStreams.
  386. // Function will block until the expected i/o is finished.
  387. func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) {
  388. if stderr != nil {
  389. writeExpected(t, "server stderr", stderr, prefix+testErr)
  390. }
  391. readExpected(t, "server stdin", stdin, prefix+testInput)
  392. writeExpected(t, "server stdout", stdout, prefix+testOutput)
  393. }
  394. // Send & receive expected input/output. Must be the inverse of doServerStreams.
  395. // Function will block until the expected i/o is finished.
  396. func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) {
  397. if stderr != nil {
  398. readExpected(t, "client stderr", stderr, prefix+testErr)
  399. }
  400. writeExpected(t, "client stdin", stdin, prefix+testInput)
  401. readExpected(t, "client stdout", stdout, prefix+testOutput)
  402. }
  403. // Read and verify the expected string from the stream.
  404. func readExpected(t *testing.T, streamName string, r io.Reader, expected string) {
  405. result := make([]byte, len(expected))
  406. _, err := io.ReadAtLeast(r, result, len(expected))
  407. assert.NoError(t, err, "stream %s", streamName)
  408. assert.Equal(t, expected, string(result), "stream %s", streamName)
  409. }
  410. // Write and verify success of the data over the stream.
  411. func writeExpected(t *testing.T, streamName string, w io.Writer, data string) {
  412. n, err := io.WriteString(w, data)
  413. assert.NoError(t, err, "stream %s", streamName)
  414. assert.Equal(t, len(data), n, "stream %s", streamName)
  415. }