remotecommand_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package tests
  14. import (
  15. "bytes"
  16. "errors"
  17. "fmt"
  18. "io"
  19. "io/ioutil"
  20. "net/http"
  21. "net/http/httptest"
  22. "net/url"
  23. "strings"
  24. "testing"
  25. "time"
  26. "github.com/stretchr/testify/require"
  27. "k8s.io/apimachinery/pkg/runtime"
  28. "k8s.io/apimachinery/pkg/runtime/schema"
  29. "k8s.io/apimachinery/pkg/types"
  30. "k8s.io/apimachinery/pkg/util/httpstream"
  31. remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
  32. restclient "k8s.io/client-go/rest"
  33. remoteclient "k8s.io/client-go/tools/remotecommand"
  34. "k8s.io/client-go/transport/spdy"
  35. "k8s.io/kubernetes/pkg/api/legacyscheme"
  36. api "k8s.io/kubernetes/pkg/apis/core"
  37. "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
  38. )
  39. type fakeExecutor struct {
  40. t *testing.T
  41. testName string
  42. errorData string
  43. stdoutData string
  44. stderrData string
  45. expectStdin bool
  46. stdinReceived bytes.Buffer
  47. tty bool
  48. messageCount int
  49. command []string
  50. exec bool
  51. }
  52. func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize, timeout time.Duration) error {
  53. return ex.run(name, uid, container, cmd, in, out, err, tty)
  54. }
  55. func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize) error {
  56. return ex.run(name, uid, container, nil, in, out, err, tty)
  57. }
  58. func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
  59. ex.command = cmd
  60. ex.tty = tty
  61. if e, a := "pod", name; e != a {
  62. ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
  63. }
  64. if e, a := "uid", uid; e != string(a) {
  65. ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
  66. }
  67. if ex.exec {
  68. if e, a := "ls /", strings.Join(ex.command, " "); e != a {
  69. ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
  70. }
  71. } else {
  72. if len(ex.command) > 0 {
  73. ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
  74. }
  75. }
  76. if len(ex.errorData) > 0 {
  77. return errors.New(ex.errorData)
  78. }
  79. if len(ex.stdoutData) > 0 {
  80. for i := 0; i < ex.messageCount; i++ {
  81. fmt.Fprint(out, ex.stdoutData)
  82. }
  83. }
  84. if len(ex.stderrData) > 0 {
  85. for i := 0; i < ex.messageCount; i++ {
  86. fmt.Fprint(err, ex.stderrData)
  87. }
  88. }
  89. if ex.expectStdin {
  90. io.Copy(&ex.stdinReceived, in)
  91. }
  92. return nil
  93. }
  94. func fakeServer(t *testing.T, requestReceived chan struct{}, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
  95. return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  96. executor := &fakeExecutor{
  97. t: t,
  98. testName: testName,
  99. errorData: errorData,
  100. stdoutData: stdoutData,
  101. stderrData: stderrData,
  102. expectStdin: len(stdinData) > 0,
  103. tty: tty,
  104. messageCount: messageCount,
  105. exec: exec,
  106. }
  107. opts, err := remotecommand.NewOptions(req)
  108. require.NoError(t, err)
  109. if exec {
  110. cmd := req.URL.Query()[api.ExecCommandParam]
  111. remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", cmd, opts, 0, 10*time.Second, serverProtocols)
  112. } else {
  113. remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", opts, 0, 10*time.Second, serverProtocols)
  114. }
  115. if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
  116. t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
  117. }
  118. close(requestReceived)
  119. })
  120. }
  121. func TestStream(t *testing.T) {
  122. testCases := []struct {
  123. TestName string
  124. Stdin string
  125. Stdout string
  126. Stderr string
  127. Error string
  128. Tty bool
  129. MessageCount int
  130. ClientProtocols []string
  131. ServerProtocols []string
  132. }{
  133. {
  134. TestName: "error",
  135. Error: "bail",
  136. Stdout: "a",
  137. ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  138. ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  139. },
  140. {
  141. TestName: "in/out/err",
  142. Stdin: "a",
  143. Stdout: "b",
  144. Stderr: "c",
  145. MessageCount: 100,
  146. ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  147. ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  148. },
  149. {
  150. TestName: "oversized stdin",
  151. Stdin: strings.Repeat("a", 20*1024*1024),
  152. Stdout: "b",
  153. Stderr: "",
  154. MessageCount: 1,
  155. ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  156. ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  157. },
  158. {
  159. TestName: "in/out/tty",
  160. Stdin: "a",
  161. Stdout: "b",
  162. Tty: true,
  163. MessageCount: 100,
  164. ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  165. ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
  166. },
  167. }
  168. for _, testCase := range testCases {
  169. for _, exec := range []bool{true, false} {
  170. var name string
  171. if exec {
  172. name = testCase.TestName + " (exec)"
  173. } else {
  174. name = testCase.TestName + " (attach)"
  175. }
  176. var (
  177. streamIn io.Reader
  178. streamOut, streamErr io.Writer
  179. )
  180. localOut := &bytes.Buffer{}
  181. localErr := &bytes.Buffer{}
  182. requestReceived := make(chan struct{})
  183. server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
  184. url, _ := url.ParseRequestURI(server.URL)
  185. config := restclient.ClientContentConfig{
  186. GroupVersion: schema.GroupVersion{Group: "x"},
  187. Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
  188. }
  189. c, err := restclient.NewRESTClient(url, "", config, nil, nil)
  190. if err != nil {
  191. t.Fatalf("failed to create a client: %v", err)
  192. }
  193. req := c.Post().Resource("testing")
  194. if exec {
  195. req.Param("command", "ls")
  196. req.Param("command", "/")
  197. }
  198. if len(testCase.Stdin) > 0 {
  199. req.Param(api.ExecStdinParam, "1")
  200. streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
  201. }
  202. if len(testCase.Stdout) > 0 {
  203. req.Param(api.ExecStdoutParam, "1")
  204. streamOut = localOut
  205. }
  206. if testCase.Tty {
  207. req.Param(api.ExecTTYParam, "1")
  208. } else if len(testCase.Stderr) > 0 {
  209. req.Param(api.ExecStderrParam, "1")
  210. streamErr = localErr
  211. }
  212. conf := &restclient.Config{
  213. Host: server.URL,
  214. }
  215. transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
  216. if err != nil {
  217. t.Errorf("%s: unexpected error: %v", name, err)
  218. continue
  219. }
  220. e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
  221. if err != nil {
  222. t.Errorf("%s: unexpected error: %v", name, err)
  223. continue
  224. }
  225. err = e.Stream(remoteclient.StreamOptions{
  226. Stdin: streamIn,
  227. Stdout: streamOut,
  228. Stderr: streamErr,
  229. Tty: testCase.Tty,
  230. })
  231. hasErr := err != nil
  232. if len(testCase.Error) > 0 {
  233. if !hasErr {
  234. t.Errorf("%s: expected an error", name)
  235. } else {
  236. if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
  237. t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
  238. }
  239. }
  240. server.Close()
  241. continue
  242. }
  243. if hasErr {
  244. t.Errorf("%s: unexpected error: %v", name, err)
  245. server.Close()
  246. continue
  247. }
  248. if len(testCase.Stdout) > 0 {
  249. if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
  250. t.Errorf("%s: expected stdout data %q, got %q", name, e, a)
  251. }
  252. }
  253. if testCase.Stderr != "" {
  254. if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
  255. t.Errorf("%s: expected stderr data %q, got %q", name, e, a)
  256. }
  257. }
  258. select {
  259. case <-requestReceived:
  260. case <-time.After(time.Minute):
  261. t.Errorf("%s: expected fakeServerInstance to receive request", name)
  262. }
  263. server.Close()
  264. }
  265. }
  266. }
  267. type fakeUpgrader struct {
  268. req *http.Request
  269. resp *http.Response
  270. conn httpstream.Connection
  271. err, connErr error
  272. checkResponse bool
  273. called bool
  274. t *testing.T
  275. }
  276. func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) {
  277. u.called = true
  278. u.req = req
  279. return u.resp, u.err
  280. }
  281. func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) {
  282. if u.checkResponse && u.resp != resp {
  283. u.t.Errorf("response objects passed did not match: %#v", resp)
  284. }
  285. return u.conn, u.connErr
  286. }
  287. type fakeConnection struct {
  288. httpstream.Connection
  289. }
  290. // Dial is the common functionality between any stream based upgrader, regardless of protocol.
  291. // This method ensures that someone can use a generic stream executor without being dependent
  292. // on the core Kube client config behavior.
  293. func TestDial(t *testing.T) {
  294. upgrader := &fakeUpgrader{
  295. t: t,
  296. checkResponse: true,
  297. conn: &fakeConnection{},
  298. resp: &http.Response{
  299. StatusCode: http.StatusSwitchingProtocols,
  300. Body: ioutil.NopCloser(&bytes.Buffer{}),
  301. },
  302. }
  303. dialer := spdy.NewDialer(upgrader, &http.Client{Transport: upgrader}, "POST", &url.URL{Host: "something.com", Scheme: "https"})
  304. conn, protocol, err := dialer.Dial("protocol1")
  305. if err != nil {
  306. t.Fatal(err)
  307. }
  308. if conn != upgrader.conn {
  309. t.Errorf("unexpected connection: %#v", conn)
  310. }
  311. if !upgrader.called {
  312. t.Errorf("request not called")
  313. }
  314. _ = protocol
  315. }