httpstream.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /*
  2. Copyright 2016 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package portforward
  14. import (
  15. "errors"
  16. "fmt"
  17. "net/http"
  18. "strconv"
  19. "sync"
  20. "time"
  21. "k8s.io/apimachinery/pkg/types"
  22. "k8s.io/apimachinery/pkg/util/httpstream"
  23. "k8s.io/apimachinery/pkg/util/httpstream/spdy"
  24. utilruntime "k8s.io/apimachinery/pkg/util/runtime"
  25. api "k8s.io/kubernetes/pkg/apis/core"
  26. "k8s.io/klog"
  27. )
  28. func handleHTTPStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
  29. _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
  30. // negotiated protocol isn't currently used server side, but could be in the future
  31. if err != nil {
  32. // Handshake writes the error to the client
  33. return err
  34. }
  35. streamChan := make(chan httpstream.Stream, 1)
  36. klog.V(5).Infof("Upgrading port forward response")
  37. upgrader := spdy.NewResponseUpgrader()
  38. conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
  39. if conn == nil {
  40. return errors.New("Unable to upgrade httpstream connection")
  41. }
  42. defer conn.Close()
  43. klog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
  44. conn.SetIdleTimeout(idleTimeout)
  45. h := &httpStreamHandler{
  46. conn: conn,
  47. streamChan: streamChan,
  48. streamPairs: make(map[string]*httpStreamPair),
  49. streamCreationTimeout: streamCreationTimeout,
  50. pod: podName,
  51. uid: uid,
  52. forwarder: portForwarder,
  53. }
  54. h.run()
  55. return nil
  56. }
  57. // httpStreamReceived is the httpstream.NewStreamHandler for port
  58. // forward streams. It checks each stream's port and stream type headers,
  59. // rejecting any streams that with missing or invalid values. Each valid
  60. // stream is sent to the streams channel.
  61. func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
  62. return func(stream httpstream.Stream, replySent <-chan struct{}) error {
  63. // make sure it has a valid port header
  64. portString := stream.Headers().Get(api.PortHeader)
  65. if len(portString) == 0 {
  66. return fmt.Errorf("%q header is required", api.PortHeader)
  67. }
  68. port, err := strconv.ParseUint(portString, 10, 16)
  69. if err != nil {
  70. return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
  71. }
  72. if port < 1 {
  73. return fmt.Errorf("port %q must be > 0", portString)
  74. }
  75. // make sure it has a valid stream type header
  76. streamType := stream.Headers().Get(api.StreamType)
  77. if len(streamType) == 0 {
  78. return fmt.Errorf("%q header is required", api.StreamType)
  79. }
  80. if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
  81. return fmt.Errorf("invalid stream type %q", streamType)
  82. }
  83. streams <- stream
  84. return nil
  85. }
  86. }
  87. // httpStreamHandler is capable of processing multiple port forward
  88. // requests over a single httpstream.Connection.
  89. type httpStreamHandler struct {
  90. conn httpstream.Connection
  91. streamChan chan httpstream.Stream
  92. streamPairsLock sync.RWMutex
  93. streamPairs map[string]*httpStreamPair
  94. streamCreationTimeout time.Duration
  95. pod string
  96. uid types.UID
  97. forwarder PortForwarder
  98. }
  99. // getStreamPair returns a httpStreamPair for requestID. This creates a
  100. // new pair if one does not yet exist for the requestID. The returned bool is
  101. // true if the pair was created.
  102. func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
  103. h.streamPairsLock.Lock()
  104. defer h.streamPairsLock.Unlock()
  105. if p, ok := h.streamPairs[requestID]; ok {
  106. klog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
  107. return p, false
  108. }
  109. klog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
  110. p := newPortForwardPair(requestID)
  111. h.streamPairs[requestID] = p
  112. return p, true
  113. }
  114. // monitorStreamPair waits for the pair to receive both its error and data
  115. // streams, or for the timeout to expire (whichever happens first), and then
  116. // removes the pair.
  117. func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
  118. select {
  119. case <-timeout:
  120. err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
  121. utilruntime.HandleError(err)
  122. p.printError(err.Error())
  123. case <-p.complete:
  124. klog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
  125. }
  126. h.removeStreamPair(p.requestID)
  127. }
  128. // hasStreamPair returns a bool indicating if a stream pair for requestID
  129. // exists.
  130. func (h *httpStreamHandler) hasStreamPair(requestID string) bool {
  131. h.streamPairsLock.RLock()
  132. defer h.streamPairsLock.RUnlock()
  133. _, ok := h.streamPairs[requestID]
  134. return ok
  135. }
  136. // removeStreamPair removes the stream pair identified by requestID from streamPairs.
  137. func (h *httpStreamHandler) removeStreamPair(requestID string) {
  138. h.streamPairsLock.Lock()
  139. defer h.streamPairsLock.Unlock()
  140. delete(h.streamPairs, requestID)
  141. }
  142. // requestID returns the request id for stream.
  143. func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
  144. requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
  145. if len(requestID) == 0 {
  146. klog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
  147. // If we get here, it's because the connection came from an older client
  148. // that isn't generating the request id header
  149. // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
  150. //
  151. // This is a best-effort attempt at supporting older clients.
  152. //
  153. // When there aren't concurrent new forwarded connections, each connection
  154. // will have a pair of streams (data, error), and the stream IDs will be
  155. // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
  156. // the stream ID into a pseudo-request id by taking the stream type and
  157. // using id = stream.Identifier() when the stream type is error,
  158. // and id = stream.Identifier() - 2 when it's data.
  159. //
  160. // NOTE: this only works when there are not concurrent new streams from
  161. // multiple forwarded connections; it's a best-effort attempt at supporting
  162. // old clients that don't generate request ids. If there are concurrent
  163. // new connections, it's possible that 1 connection gets streams whose IDs
  164. // are not consecutive (e.g. 5 and 9 instead of 5 and 7).
  165. streamType := stream.Headers().Get(api.StreamType)
  166. switch streamType {
  167. case api.StreamTypeError:
  168. requestID = strconv.Itoa(int(stream.Identifier()))
  169. case api.StreamTypeData:
  170. requestID = strconv.Itoa(int(stream.Identifier()) - 2)
  171. }
  172. klog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
  173. }
  174. return requestID
  175. }
  176. // run is the main loop for the httpStreamHandler. It processes new
  177. // streams, invoking portForward for each complete stream pair. The loop exits
  178. // when the httpstream.Connection is closed.
  179. func (h *httpStreamHandler) run() {
  180. klog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
  181. Loop:
  182. for {
  183. select {
  184. case <-h.conn.CloseChan():
  185. klog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
  186. break Loop
  187. case stream := <-h.streamChan:
  188. requestID := h.requestID(stream)
  189. streamType := stream.Headers().Get(api.StreamType)
  190. klog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
  191. p, created := h.getStreamPair(requestID)
  192. if created {
  193. go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
  194. }
  195. if complete, err := p.add(stream); err != nil {
  196. msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
  197. utilruntime.HandleError(errors.New(msg))
  198. p.printError(msg)
  199. } else if complete {
  200. go h.portForward(p)
  201. }
  202. }
  203. }
  204. }
  205. // portForward invokes the httpStreamHandler's forwarder.PortForward
  206. // function for the given stream pair.
  207. func (h *httpStreamHandler) portForward(p *httpStreamPair) {
  208. defer p.dataStream.Close()
  209. defer p.errorStream.Close()
  210. portString := p.dataStream.Headers().Get(api.PortHeader)
  211. port, _ := strconv.ParseInt(portString, 10, 32)
  212. klog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
  213. err := h.forwarder.PortForward(h.pod, h.uid, int32(port), p.dataStream)
  214. klog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
  215. if err != nil {
  216. msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
  217. utilruntime.HandleError(msg)
  218. fmt.Fprint(p.errorStream, msg.Error())
  219. }
  220. }
  221. // httpStreamPair represents the error and data streams for a port
  222. // forwarding request.
  223. type httpStreamPair struct {
  224. lock sync.RWMutex
  225. requestID string
  226. dataStream httpstream.Stream
  227. errorStream httpstream.Stream
  228. complete chan struct{}
  229. }
  230. // newPortForwardPair creates a new httpStreamPair.
  231. func newPortForwardPair(requestID string) *httpStreamPair {
  232. return &httpStreamPair{
  233. requestID: requestID,
  234. complete: make(chan struct{}),
  235. }
  236. }
  237. // add adds the stream to the httpStreamPair. If the pair already
  238. // contains a stream for the new stream's type, an error is returned. add
  239. // returns true if both the data and error streams for this pair have been
  240. // received.
  241. func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
  242. p.lock.Lock()
  243. defer p.lock.Unlock()
  244. switch stream.Headers().Get(api.StreamType) {
  245. case api.StreamTypeError:
  246. if p.errorStream != nil {
  247. return false, errors.New("error stream already assigned")
  248. }
  249. p.errorStream = stream
  250. case api.StreamTypeData:
  251. if p.dataStream != nil {
  252. return false, errors.New("data stream already assigned")
  253. }
  254. p.dataStream = stream
  255. }
  256. complete := p.errorStream != nil && p.dataStream != nil
  257. if complete {
  258. close(p.complete)
  259. }
  260. return complete, nil
  261. }
  262. // printError writes s to p.errorStream if p.errorStream has been set.
  263. func (p *httpStreamPair) printError(s string) {
  264. p.lock.RLock()
  265. defer p.lock.RUnlock()
  266. if p.errorStream != nil {
  267. fmt.Fprint(p.errorStream, s)
  268. }
  269. }