chain.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. // Copyright 2016 Michal Witkowski. All Rights Reserved.
  2. // See LICENSE for licensing terms.
  3. // gRPC Server Interceptor chaining middleware.
  4. package grpc_middleware
  5. import (
  6. "golang.org/x/net/context"
  7. "google.golang.org/grpc"
  8. )
  9. // ChainUnaryServer creates a single interceptor out of a chain of many interceptors.
  10. //
  11. // Execution is done in left-to-right order, including passing of context.
  12. // For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
  13. // will see context changes of one and two.
  14. func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
  15. n := len(interceptors)
  16. if n > 1 {
  17. lastI := n - 1
  18. return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
  19. var (
  20. chainHandler grpc.UnaryHandler
  21. curI int
  22. )
  23. chainHandler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
  24. if curI == lastI {
  25. return handler(currentCtx, currentReq)
  26. }
  27. curI++
  28. resp, err := interceptors[curI](currentCtx, currentReq, info, chainHandler)
  29. curI--
  30. return resp, err
  31. }
  32. return interceptors[0](ctx, req, info, chainHandler)
  33. }
  34. }
  35. if n == 1 {
  36. return interceptors[0]
  37. }
  38. // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
  39. return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
  40. return handler(ctx, req)
  41. }
  42. }
  43. // ChainStreamServer creates a single interceptor out of a chain of many interceptors.
  44. //
  45. // Execution is done in left-to-right order, including passing of context.
  46. // For example ChainUnaryServer(one, two, three) will execute one before two before three.
  47. // If you want to pass context between interceptors, use WrapServerStream.
  48. func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
  49. n := len(interceptors)
  50. if n > 1 {
  51. lastI := n - 1
  52. return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
  53. var (
  54. chainHandler grpc.StreamHandler
  55. curI int
  56. )
  57. chainHandler = func(currentSrv interface{}, currentStream grpc.ServerStream) error {
  58. if curI == lastI {
  59. return handler(currentSrv, currentStream)
  60. }
  61. curI++
  62. err := interceptors[curI](currentSrv, currentStream, info, chainHandler)
  63. curI--
  64. return err
  65. }
  66. return interceptors[0](srv, stream, info, chainHandler)
  67. }
  68. }
  69. if n == 1 {
  70. return interceptors[0]
  71. }
  72. // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
  73. return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
  74. return handler(srv, stream)
  75. }
  76. }
  77. // ChainUnaryClient creates a single interceptor out of a chain of many interceptors.
  78. //
  79. // Execution is done in left-to-right order, including passing of context.
  80. // For example ChainUnaryClient(one, two, three) will execute one before two before three.
  81. func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
  82. n := len(interceptors)
  83. if n > 1 {
  84. lastI := n - 1
  85. return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  86. var (
  87. chainHandler grpc.UnaryInvoker
  88. curI int
  89. )
  90. chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
  91. if curI == lastI {
  92. return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)
  93. }
  94. curI++
  95. err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)
  96. curI--
  97. return err
  98. }
  99. return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)
  100. }
  101. }
  102. if n == 1 {
  103. return interceptors[0]
  104. }
  105. // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
  106. return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  107. return invoker(ctx, method, req, reply, cc, opts...)
  108. }
  109. }
  110. // ChainStreamClient creates a single interceptor out of a chain of many interceptors.
  111. //
  112. // Execution is done in left-to-right order, including passing of context.
  113. // For example ChainStreamClient(one, two, three) will execute one before two before three.
  114. func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
  115. n := len(interceptors)
  116. if n > 1 {
  117. lastI := n - 1
  118. return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  119. var (
  120. chainHandler grpc.Streamer
  121. curI int
  122. )
  123. chainHandler = func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
  124. if curI == lastI {
  125. return streamer(currentCtx, currentDesc, currentConn, currentMethod, currentOpts...)
  126. }
  127. curI++
  128. stream, err := interceptors[curI](currentCtx, currentDesc, currentConn, currentMethod, chainHandler, currentOpts...)
  129. curI--
  130. return stream, err
  131. }
  132. return interceptors[0](ctx, desc, cc, method, chainHandler, opts...)
  133. }
  134. }
  135. if n == 1 {
  136. return interceptors[0]
  137. }
  138. // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
  139. return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  140. return streamer(ctx, desc, cc, method, opts...)
  141. }
  142. }
  143. // Chain creates a single interceptor out of a chain of many interceptors.
  144. //
  145. // WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors.
  146. // Basically syntactic sugar.
  147. func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
  148. return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...))
  149. }
  150. // WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors.
  151. // Basically syntactic sugar.
  152. func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
  153. return grpc.StreamInterceptor(ChainStreamServer(interceptors...))
  154. }