context.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package runtime
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "fmt"
  6. "net"
  7. "net/http"
  8. "net/textproto"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/grpclog"
  14. "google.golang.org/grpc/metadata"
  15. "google.golang.org/grpc/status"
  16. )
  17. // MetadataHeaderPrefix is the http prefix that represents custom metadata
  18. // parameters to or from a gRPC call.
  19. const MetadataHeaderPrefix = "Grpc-Metadata-"
  20. // MetadataPrefix is prepended to permanent HTTP header keys (as specified
  21. // by the IANA) when added to the gRPC context.
  22. const MetadataPrefix = "grpcgateway-"
  23. // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
  24. // HTTP headers in a response handled by grpc-gateway
  25. const MetadataTrailerPrefix = "Grpc-Trailer-"
  26. const metadataGrpcTimeout = "Grpc-Timeout"
  27. const metadataHeaderBinarySuffix = "-Bin"
  28. const xForwardedFor = "X-Forwarded-For"
  29. const xForwardedHost = "X-Forwarded-Host"
  30. var (
  31. // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
  32. // header isn't present. If the value is 0 the sent `context` will not have a timeout.
  33. DefaultContextTimeout = 0 * time.Second
  34. )
  35. func decodeBinHeader(v string) ([]byte, error) {
  36. if len(v)%4 == 0 {
  37. // Input was padded, or padding was not necessary.
  38. return base64.StdEncoding.DecodeString(v)
  39. }
  40. return base64.RawStdEncoding.DecodeString(v)
  41. }
  42. /*
  43. AnnotateContext adds context information such as metadata from the request.
  44. At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
  45. except that the forwarded destination is not another HTTP service but rather
  46. a gRPC service.
  47. */
  48. func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
  49. var pairs []string
  50. timeout := DefaultContextTimeout
  51. if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
  52. var err error
  53. timeout, err = timeoutDecode(tm)
  54. if err != nil {
  55. return nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
  56. }
  57. }
  58. for key, vals := range req.Header {
  59. for _, val := range vals {
  60. key = textproto.CanonicalMIMEHeaderKey(key)
  61. // For backwards-compatibility, pass through 'authorization' header with no prefix.
  62. if key == "Authorization" {
  63. pairs = append(pairs, "authorization", val)
  64. }
  65. if h, ok := mux.incomingHeaderMatcher(key); ok {
  66. // Handles "-bin" metadata in grpc, since grpc will do another base64
  67. // encode before sending to server, we need to decode it first.
  68. if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
  69. b, err := decodeBinHeader(val)
  70. if err != nil {
  71. return nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
  72. }
  73. val = string(b)
  74. }
  75. pairs = append(pairs, h, val)
  76. }
  77. }
  78. }
  79. if host := req.Header.Get(xForwardedHost); host != "" {
  80. pairs = append(pairs, strings.ToLower(xForwardedHost), host)
  81. } else if req.Host != "" {
  82. pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
  83. }
  84. if addr := req.RemoteAddr; addr != "" {
  85. if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
  86. if fwd := req.Header.Get(xForwardedFor); fwd == "" {
  87. pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
  88. } else {
  89. pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
  90. }
  91. } else {
  92. grpclog.Infof("invalid remote addr: %s", addr)
  93. }
  94. }
  95. if timeout != 0 {
  96. ctx, _ = context.WithTimeout(ctx, timeout)
  97. }
  98. if len(pairs) == 0 {
  99. return ctx, nil
  100. }
  101. md := metadata.Pairs(pairs...)
  102. for _, mda := range mux.metadataAnnotators {
  103. md = metadata.Join(md, mda(ctx, req))
  104. }
  105. return metadata.NewOutgoingContext(ctx, md), nil
  106. }
  107. // ServerMetadata consists of metadata sent from gRPC server.
  108. type ServerMetadata struct {
  109. HeaderMD metadata.MD
  110. TrailerMD metadata.MD
  111. }
  112. type serverMetadataKey struct{}
  113. // NewServerMetadataContext creates a new context with ServerMetadata
  114. func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
  115. return context.WithValue(ctx, serverMetadataKey{}, md)
  116. }
  117. // ServerMetadataFromContext returns the ServerMetadata in ctx
  118. func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
  119. md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
  120. return
  121. }
  122. func timeoutDecode(s string) (time.Duration, error) {
  123. size := len(s)
  124. if size < 2 {
  125. return 0, fmt.Errorf("timeout string is too short: %q", s)
  126. }
  127. d, ok := timeoutUnitToDuration(s[size-1])
  128. if !ok {
  129. return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
  130. }
  131. t, err := strconv.ParseInt(s[:size-1], 10, 64)
  132. if err != nil {
  133. return 0, err
  134. }
  135. return d * time.Duration(t), nil
  136. }
  137. func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
  138. switch u {
  139. case 'H':
  140. return time.Hour, true
  141. case 'M':
  142. return time.Minute, true
  143. case 'S':
  144. return time.Second, true
  145. case 'm':
  146. return time.Millisecond, true
  147. case 'u':
  148. return time.Microsecond, true
  149. case 'n':
  150. return time.Nanosecond, true
  151. default:
  152. }
  153. return
  154. }
  155. // isPermanentHTTPHeader checks whether hdr belongs to the list of
  156. // permenant request headers maintained by IANA.
  157. // http://www.iana.org/assignments/message-headers/message-headers.xml
  158. func isPermanentHTTPHeader(hdr string) bool {
  159. switch hdr {
  160. case
  161. "Accept",
  162. "Accept-Charset",
  163. "Accept-Language",
  164. "Accept-Ranges",
  165. "Authorization",
  166. "Cache-Control",
  167. "Content-Type",
  168. "Cookie",
  169. "Date",
  170. "Expect",
  171. "From",
  172. "Host",
  173. "If-Match",
  174. "If-Modified-Since",
  175. "If-None-Match",
  176. "If-Schedule-Tag-Match",
  177. "If-Unmodified-Since",
  178. "Max-Forwards",
  179. "Origin",
  180. "Pragma",
  181. "Referer",
  182. "User-Agent",
  183. "Via",
  184. "Warning":
  185. return true
  186. }
  187. return false
  188. }