mux.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package runtime
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/textproto"
  6. "strings"
  7. "github.com/golang/protobuf/proto"
  8. "golang.org/x/net/context"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/metadata"
  11. "google.golang.org/grpc/status"
  12. )
  13. // A HandlerFunc handles a specific pair of path pattern and HTTP method.
  14. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
  15. // ServeMux is a request multiplexer for grpc-gateway.
  16. // It matches http requests to patterns and invokes the corresponding handler.
  17. type ServeMux struct {
  18. // handlers maps HTTP method to a list of handlers.
  19. handlers map[string][]handler
  20. forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
  21. marshalers marshalerRegistry
  22. incomingHeaderMatcher HeaderMatcherFunc
  23. outgoingHeaderMatcher HeaderMatcherFunc
  24. metadataAnnotator func(context.Context, *http.Request) metadata.MD
  25. protoErrorHandler ProtoErrorHandlerFunc
  26. }
  27. // ServeMuxOption is an option that can be given to a ServeMux on construction.
  28. type ServeMuxOption func(*ServeMux)
  29. // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
  30. //
  31. // forwardResponseOption is an option that will be called on the relevant context.Context,
  32. // http.ResponseWriter, and proto.Message before every forwarded response.
  33. //
  34. // The message may be nil in the case where just a header is being sent.
  35. func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
  36. return func(serveMux *ServeMux) {
  37. serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
  38. }
  39. }
  40. // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
  41. type HeaderMatcherFunc func(string) (string, bool)
  42. // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
  43. // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
  44. // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
  45. func DefaultHeaderMatcher(key string) (string, bool) {
  46. key = textproto.CanonicalMIMEHeaderKey(key)
  47. if isPermanentHTTPHeader(key) {
  48. return MetadataPrefix + key, true
  49. } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
  50. return key[len(MetadataHeaderPrefix):], true
  51. }
  52. return "", false
  53. }
  54. // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
  55. //
  56. // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
  57. // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
  58. func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  59. return func(mux *ServeMux) {
  60. mux.incomingHeaderMatcher = fn
  61. }
  62. }
  63. // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
  64. //
  65. // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
  66. // passed to http response returned from gateway. To transform the header before passing to response,
  67. // matcher should return modified header.
  68. func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  69. return func(mux *ServeMux) {
  70. mux.outgoingHeaderMatcher = fn
  71. }
  72. }
  73. // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
  74. //
  75. // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
  76. // is reading token from cookie and adding it in gRPC context.
  77. func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
  78. return func(serveMux *ServeMux) {
  79. serveMux.metadataAnnotator = annotator
  80. }
  81. }
  82. // WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
  83. //
  84. // This can be used to handle an error as general proto message defined by gRPC.
  85. // The response including body and status is not backward compatible with the default error handler.
  86. // When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
  87. func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
  88. return func(serveMux *ServeMux) {
  89. serveMux.protoErrorHandler = fn
  90. }
  91. }
  92. // NewServeMux returns a new ServeMux whose internal mapping is empty.
  93. func NewServeMux(opts ...ServeMuxOption) *ServeMux {
  94. serveMux := &ServeMux{
  95. handlers: make(map[string][]handler),
  96. forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
  97. marshalers: makeMarshalerMIMERegistry(),
  98. }
  99. for _, opt := range opts {
  100. opt(serveMux)
  101. }
  102. if serveMux.protoErrorHandler != nil {
  103. HTTPError = serveMux.protoErrorHandler
  104. // OtherErrorHandler is no longer used when protoErrorHandler is set.
  105. // Overwritten by a special error handler to return Unknown.
  106. OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
  107. ctx := context.Background()
  108. _, outboundMarshaler := MarshalerForRequest(serveMux, r)
  109. sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
  110. serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
  111. }
  112. }
  113. if serveMux.incomingHeaderMatcher == nil {
  114. serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
  115. }
  116. if serveMux.outgoingHeaderMatcher == nil {
  117. serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
  118. return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
  119. }
  120. }
  121. return serveMux
  122. }
  123. // Handle associates "h" to the pair of HTTP method and path pattern.
  124. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
  125. s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
  126. }
  127. // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
  128. func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  129. ctx := r.Context()
  130. path := r.URL.Path
  131. if !strings.HasPrefix(path, "/") {
  132. if s.protoErrorHandler != nil {
  133. _, outboundMarshaler := MarshalerForRequest(s, r)
  134. sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
  135. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  136. } else {
  137. OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  138. }
  139. return
  140. }
  141. components := strings.Split(path[1:], "/")
  142. l := len(components)
  143. var verb string
  144. if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
  145. if s.protoErrorHandler != nil {
  146. _, outboundMarshaler := MarshalerForRequest(s, r)
  147. sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
  148. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  149. } else {
  150. OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
  151. }
  152. return
  153. } else if idx > 0 {
  154. c := components[l-1]
  155. components[l-1], verb = c[:idx], c[idx+1:]
  156. }
  157. if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
  158. r.Method = strings.ToUpper(override)
  159. if err := r.ParseForm(); err != nil {
  160. if s.protoErrorHandler != nil {
  161. _, outboundMarshaler := MarshalerForRequest(s, r)
  162. sterr := status.Error(codes.InvalidArgument, err.Error())
  163. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  164. } else {
  165. OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
  166. }
  167. return
  168. }
  169. }
  170. for _, h := range s.handlers[r.Method] {
  171. pathParams, err := h.pat.Match(components, verb)
  172. if err != nil {
  173. continue
  174. }
  175. h.h(w, r, pathParams)
  176. return
  177. }
  178. // lookup other methods to handle fallback from GET to POST and
  179. // to determine if it is MethodNotAllowed or NotFound.
  180. for m, handlers := range s.handlers {
  181. if m == r.Method {
  182. continue
  183. }
  184. for _, h := range handlers {
  185. pathParams, err := h.pat.Match(components, verb)
  186. if err != nil {
  187. continue
  188. }
  189. // X-HTTP-Method-Override is optional. Always allow fallback to POST.
  190. if isPathLengthFallback(r) {
  191. if err := r.ParseForm(); err != nil {
  192. if s.protoErrorHandler != nil {
  193. _, outboundMarshaler := MarshalerForRequest(s, r)
  194. sterr := status.Error(codes.InvalidArgument, err.Error())
  195. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  196. } else {
  197. OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
  198. }
  199. return
  200. }
  201. h.h(w, r, pathParams)
  202. return
  203. }
  204. if s.protoErrorHandler != nil {
  205. _, outboundMarshaler := MarshalerForRequest(s, r)
  206. sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusMethodNotAllowed))
  207. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  208. } else {
  209. OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
  210. }
  211. return
  212. }
  213. }
  214. if s.protoErrorHandler != nil {
  215. _, outboundMarshaler := MarshalerForRequest(s, r)
  216. sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
  217. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  218. } else {
  219. OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
  220. }
  221. }
  222. // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
  223. func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
  224. return s.forwardResponseOptions
  225. }
  226. func isPathLengthFallback(r *http.Request) bool {
  227. return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
  228. }
  229. type handler struct {
  230. pat Pattern
  231. h HandlerFunc
  232. }