handler.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package runtime
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/textproto"
  7. "github.com/golang/protobuf/proto"
  8. "github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
  9. "golang.org/x/net/context"
  10. "google.golang.org/grpc/codes"
  11. "google.golang.org/grpc/grpclog"
  12. "google.golang.org/grpc/status"
  13. )
  14. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  15. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  16. f, ok := w.(http.Flusher)
  17. if !ok {
  18. grpclog.Printf("Flush not supported in %T", w)
  19. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  20. return
  21. }
  22. md, ok := ServerMetadataFromContext(ctx)
  23. if !ok {
  24. grpclog.Printf("Failed to extract ServerMetadata from context")
  25. http.Error(w, "unexpected error", http.StatusInternalServerError)
  26. return
  27. }
  28. handleForwardResponseServerMetadata(w, mux, md)
  29. w.Header().Set("Transfer-Encoding", "chunked")
  30. w.Header().Set("Content-Type", marshaler.ContentType())
  31. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  32. http.Error(w, err.Error(), http.StatusInternalServerError)
  33. return
  34. }
  35. w.WriteHeader(http.StatusOK)
  36. f.Flush()
  37. for {
  38. resp, err := recv()
  39. if err == io.EOF {
  40. return
  41. }
  42. if err != nil {
  43. handleForwardResponseStreamError(marshaler, w, err)
  44. return
  45. }
  46. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  47. handleForwardResponseStreamError(marshaler, w, err)
  48. return
  49. }
  50. buf, err := marshaler.Marshal(streamChunk(resp, nil))
  51. if err != nil {
  52. grpclog.Printf("Failed to marshal response chunk: %v", err)
  53. return
  54. }
  55. if _, err = w.Write(buf); err != nil {
  56. grpclog.Printf("Failed to send response chunk: %v", err)
  57. return
  58. }
  59. f.Flush()
  60. }
  61. }
  62. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  63. for k, vs := range md.HeaderMD {
  64. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  65. for _, v := range vs {
  66. w.Header().Add(h, v)
  67. }
  68. }
  69. }
  70. }
  71. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  72. for k := range md.TrailerMD {
  73. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  74. w.Header().Add("Trailer", tKey)
  75. }
  76. }
  77. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  78. for k, vs := range md.TrailerMD {
  79. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  80. for _, v := range vs {
  81. w.Header().Add(tKey, v)
  82. }
  83. }
  84. }
  85. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  86. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  87. md, ok := ServerMetadataFromContext(ctx)
  88. if !ok {
  89. grpclog.Printf("Failed to extract ServerMetadata from context")
  90. }
  91. handleForwardResponseServerMetadata(w, mux, md)
  92. handleForwardResponseTrailerHeader(w, md)
  93. w.Header().Set("Content-Type", marshaler.ContentType())
  94. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  95. HTTPError(ctx, mux, marshaler, w, req, err)
  96. return
  97. }
  98. buf, err := marshaler.Marshal(resp)
  99. if err != nil {
  100. grpclog.Printf("Marshal error: %v", err)
  101. HTTPError(ctx, mux, marshaler, w, req, err)
  102. return
  103. }
  104. if _, err = w.Write(buf); err != nil {
  105. grpclog.Printf("Failed to write response: %v", err)
  106. }
  107. handleForwardResponseTrailer(w, md)
  108. }
  109. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  110. if len(opts) == 0 {
  111. return nil
  112. }
  113. for _, opt := range opts {
  114. if err := opt(ctx, w, resp); err != nil {
  115. grpclog.Printf("Error handling ForwardResponseOptions: %v", err)
  116. return err
  117. }
  118. }
  119. return nil
  120. }
  121. func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) {
  122. buf, merr := marshaler.Marshal(streamChunk(nil, err))
  123. if merr != nil {
  124. grpclog.Printf("Failed to marshal an error: %v", merr)
  125. return
  126. }
  127. if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil {
  128. grpclog.Printf("Failed to notify error to client: %v", werr)
  129. return
  130. }
  131. }
  132. func streamChunk(result proto.Message, err error) map[string]proto.Message {
  133. if err != nil {
  134. grpcCode := codes.Unknown
  135. if s, ok := status.FromError(err); ok {
  136. grpcCode = s.Code()
  137. }
  138. httpCode := HTTPStatusFromCode(grpcCode)
  139. return map[string]proto.Message{
  140. "error": &internal.StreamError{
  141. GrpcCode: int32(grpcCode),
  142. HttpCode: int32(httpCode),
  143. Message: err.Error(),
  144. HttpStatus: http.StatusText(httpCode),
  145. },
  146. }
  147. }
  148. if result == nil {
  149. return streamChunk(nil, fmt.Errorf("empty response"))
  150. }
  151. return map[string]proto.Message{"result": result}
  152. }