container.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. package restful
  2. // Copyright 2013 Ernest Micklei. All rights reserved.
  3. // Use of this source code is governed by a license
  4. // that can be found in the LICENSE file.
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "net/http"
  10. "os"
  11. "runtime"
  12. "strings"
  13. "sync"
  14. "github.com/emicklei/go-restful/log"
  15. )
  16. // Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
  17. // The requests are further dispatched to routes of WebServices using a RouteSelector
  18. type Container struct {
  19. webServicesLock sync.RWMutex
  20. webServices []*WebService
  21. ServeMux *http.ServeMux
  22. isRegisteredOnRoot bool
  23. containerFilters []FilterFunction
  24. doNotRecover bool // default is true
  25. recoverHandleFunc RecoverHandleFunction
  26. serviceErrorHandleFunc ServiceErrorHandleFunction
  27. router RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
  28. contentEncodingEnabled bool // default is false
  29. }
  30. // NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
  31. func NewContainer() *Container {
  32. return &Container{
  33. webServices: []*WebService{},
  34. ServeMux: http.NewServeMux(),
  35. isRegisteredOnRoot: false,
  36. containerFilters: []FilterFunction{},
  37. doNotRecover: true,
  38. recoverHandleFunc: logStackOnRecover,
  39. serviceErrorHandleFunc: writeServiceError,
  40. router: CurlyRouter{},
  41. contentEncodingEnabled: false}
  42. }
  43. // RecoverHandleFunction declares functions that can be used to handle a panic situation.
  44. // The first argument is what recover() returns. The second must be used to communicate an error response.
  45. type RecoverHandleFunction func(interface{}, http.ResponseWriter)
  46. // RecoverHandler changes the default function (logStackOnRecover) to be called
  47. // when a panic is detected. DoNotRecover must be have its default value (=false).
  48. func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
  49. c.recoverHandleFunc = handler
  50. }
  51. // ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
  52. // The first argument is the service error, the second is the request that resulted in the error and
  53. // the third must be used to communicate an error response.
  54. type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
  55. // ServiceErrorHandler changes the default function (writeServiceError) to be called
  56. // when a ServiceError is detected.
  57. func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
  58. c.serviceErrorHandleFunc = handler
  59. }
  60. // DoNotRecover controls whether panics will be caught to return HTTP 500.
  61. // If set to true, Route functions are responsible for handling any error situation.
  62. // Default value is true.
  63. func (c *Container) DoNotRecover(doNot bool) {
  64. c.doNotRecover = doNot
  65. }
  66. // Router changes the default Router (currently CurlyRouter)
  67. func (c *Container) Router(aRouter RouteSelector) {
  68. c.router = aRouter
  69. }
  70. // EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
  71. func (c *Container) EnableContentEncoding(enabled bool) {
  72. c.contentEncodingEnabled = enabled
  73. }
  74. // Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
  75. func (c *Container) Add(service *WebService) *Container {
  76. c.webServicesLock.Lock()
  77. defer c.webServicesLock.Unlock()
  78. // if rootPath was not set then lazy initialize it
  79. if len(service.rootPath) == 0 {
  80. service.Path("/")
  81. }
  82. // cannot have duplicate root paths
  83. for _, each := range c.webServices {
  84. if each.RootPath() == service.RootPath() {
  85. log.Printf("WebService with duplicate root path detected:['%v']", each)
  86. os.Exit(1)
  87. }
  88. }
  89. // If not registered on root then add specific mapping
  90. if !c.isRegisteredOnRoot {
  91. c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
  92. }
  93. c.webServices = append(c.webServices, service)
  94. return c
  95. }
  96. // addHandler may set a new HandleFunc for the serveMux
  97. // this function must run inside the critical region protected by the webServicesLock.
  98. // returns true if the function was registered on root ("/")
  99. func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
  100. pattern := fixedPrefixPath(service.RootPath())
  101. // check if root path registration is needed
  102. if "/" == pattern || "" == pattern {
  103. serveMux.HandleFunc("/", c.dispatch)
  104. return true
  105. }
  106. // detect if registration already exists
  107. alreadyMapped := false
  108. for _, each := range c.webServices {
  109. if each.RootPath() == service.RootPath() {
  110. alreadyMapped = true
  111. break
  112. }
  113. }
  114. if !alreadyMapped {
  115. serveMux.HandleFunc(pattern, c.dispatch)
  116. if !strings.HasSuffix(pattern, "/") {
  117. serveMux.HandleFunc(pattern+"/", c.dispatch)
  118. }
  119. }
  120. return false
  121. }
  122. func (c *Container) Remove(ws *WebService) error {
  123. if c.ServeMux == http.DefaultServeMux {
  124. errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
  125. log.Print(errMsg)
  126. return errors.New(errMsg)
  127. }
  128. c.webServicesLock.Lock()
  129. defer c.webServicesLock.Unlock()
  130. // build a new ServeMux and re-register all WebServices
  131. newServeMux := http.NewServeMux()
  132. newServices := []*WebService{}
  133. newIsRegisteredOnRoot := false
  134. for _, each := range c.webServices {
  135. if each.rootPath != ws.rootPath {
  136. // If not registered on root then add specific mapping
  137. if !newIsRegisteredOnRoot {
  138. newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
  139. }
  140. newServices = append(newServices, each)
  141. }
  142. }
  143. c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
  144. return nil
  145. }
  146. // logStackOnRecover is the default RecoverHandleFunction and is called
  147. // when DoNotRecover is false and the recoverHandleFunc is not set for the container.
  148. // Default implementation logs the stacktrace and writes the stacktrace on the response.
  149. // This may be a security issue as it exposes sourcecode information.
  150. func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
  151. var buffer bytes.Buffer
  152. buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
  153. for i := 2; ; i += 1 {
  154. _, file, line, ok := runtime.Caller(i)
  155. if !ok {
  156. break
  157. }
  158. buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
  159. }
  160. log.Print(buffer.String())
  161. httpWriter.WriteHeader(http.StatusInternalServerError)
  162. httpWriter.Write(buffer.Bytes())
  163. }
  164. // writeServiceError is the default ServiceErrorHandleFunction and is called
  165. // when a ServiceError is returned during route selection. Default implementation
  166. // calls resp.WriteErrorString(err.Code, err.Message)
  167. func writeServiceError(err ServiceError, req *Request, resp *Response) {
  168. resp.WriteErrorString(err.Code, err.Message)
  169. }
  170. // Dispatch the incoming Http Request to a matching WebService.
  171. func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  172. if httpWriter == nil {
  173. panic("httpWriter cannot be nil")
  174. }
  175. if httpRequest == nil {
  176. panic("httpRequest cannot be nil")
  177. }
  178. c.dispatch(httpWriter, httpRequest)
  179. }
  180. // Dispatch the incoming Http Request to a matching WebService.
  181. func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  182. writer := httpWriter
  183. // CompressingResponseWriter should be closed after all operations are done
  184. defer func() {
  185. if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
  186. compressWriter.Close()
  187. }
  188. }()
  189. // Instal panic recovery unless told otherwise
  190. if !c.doNotRecover { // catch all for 500 response
  191. defer func() {
  192. if r := recover(); r != nil {
  193. c.recoverHandleFunc(r, writer)
  194. return
  195. }
  196. }()
  197. }
  198. // Find best match Route ; err is non nil if no match was found
  199. var webService *WebService
  200. var route *Route
  201. var err error
  202. func() {
  203. c.webServicesLock.RLock()
  204. defer c.webServicesLock.RUnlock()
  205. webService, route, err = c.router.SelectRoute(
  206. c.webServices,
  207. httpRequest)
  208. }()
  209. // Detect if compression is needed
  210. // assume without compression, test for override
  211. contentEncodingEnabled := c.contentEncodingEnabled
  212. if route != nil && route.contentEncodingEnabled != nil {
  213. contentEncodingEnabled = *route.contentEncodingEnabled
  214. }
  215. if contentEncodingEnabled {
  216. doCompress, encoding := wantsCompressedResponse(httpRequest)
  217. if doCompress {
  218. var err error
  219. writer, err = NewCompressingResponseWriter(httpWriter, encoding)
  220. if err != nil {
  221. log.Print("unable to install compressor: ", err)
  222. httpWriter.WriteHeader(http.StatusInternalServerError)
  223. return
  224. }
  225. }
  226. }
  227. if err != nil {
  228. // a non-200 response has already been written
  229. // run container filters anyway ; they should not touch the response...
  230. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  231. switch err.(type) {
  232. case ServiceError:
  233. ser := err.(ServiceError)
  234. c.serviceErrorHandleFunc(ser, req, resp)
  235. }
  236. // TODO
  237. }}
  238. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
  239. return
  240. }
  241. pathProcessor, routerProcessesPath := c.router.(PathProcessor)
  242. if !routerProcessesPath {
  243. pathProcessor = defaultPathProcessor{}
  244. }
  245. pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
  246. wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
  247. // pass through filters (if any)
  248. if len(c.containerFilters)+len(webService.filters)+len(route.Filters) > 0 {
  249. // compose filter chain
  250. allFilters := []FilterFunction{}
  251. allFilters = append(allFilters, c.containerFilters...)
  252. allFilters = append(allFilters, webService.filters...)
  253. allFilters = append(allFilters, route.Filters...)
  254. chain := FilterChain{Filters: allFilters, Target: func(req *Request, resp *Response) {
  255. // handle request by route after passing all filters
  256. route.Function(wrappedRequest, wrappedResponse)
  257. }}
  258. chain.ProcessFilter(wrappedRequest, wrappedResponse)
  259. } else {
  260. // no filters, handle request by route
  261. route.Function(wrappedRequest, wrappedResponse)
  262. }
  263. }
  264. // fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
  265. func fixedPrefixPath(pathspec string) string {
  266. varBegin := strings.Index(pathspec, "{")
  267. if -1 == varBegin {
  268. return pathspec
  269. }
  270. return pathspec[:varBegin]
  271. }
  272. // ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
  273. func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
  274. c.ServeMux.ServeHTTP(httpwriter, httpRequest)
  275. }
  276. // Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
  277. func (c *Container) Handle(pattern string, handler http.Handler) {
  278. c.ServeMux.Handle(pattern, handler)
  279. }
  280. // HandleWithFilter registers the handler for the given pattern.
  281. // Container's filter chain is applied for handler.
  282. // If a handler already exists for pattern, HandleWithFilter panics.
  283. func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
  284. f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
  285. if len(c.containerFilters) == 0 {
  286. handler.ServeHTTP(httpResponse, httpRequest)
  287. return
  288. }
  289. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  290. handler.ServeHTTP(httpResponse, httpRequest)
  291. }}
  292. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
  293. }
  294. c.Handle(pattern, http.HandlerFunc(f))
  295. }
  296. // Filter appends a container FilterFunction. These are called before dispatching
  297. // a http.Request to a WebService from the container
  298. func (c *Container) Filter(filter FilterFunction) {
  299. c.containerFilters = append(c.containerFilters, filter)
  300. }
  301. // RegisteredWebServices returns the collections of added WebServices
  302. func (c *Container) RegisteredWebServices() []*WebService {
  303. c.webServicesLock.RLock()
  304. defer c.webServicesLock.RUnlock()
  305. result := make([]*WebService, len(c.webServices))
  306. for ix := range c.webServices {
  307. result[ix] = c.webServices[ix]
  308. }
  309. return result
  310. }
  311. // computeAllowedMethods returns a list of HTTP methods that are valid for a Request
  312. func (c *Container) computeAllowedMethods(req *Request) []string {
  313. // Go through all RegisteredWebServices() and all its Routes to collect the options
  314. methods := []string{}
  315. requestPath := req.Request.URL.Path
  316. for _, ws := range c.RegisteredWebServices() {
  317. matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
  318. if matches != nil {
  319. finalMatch := matches[len(matches)-1]
  320. for _, rt := range ws.Routes() {
  321. matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
  322. if matches != nil {
  323. lastMatch := matches[len(matches)-1]
  324. if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
  325. methods = append(methods, rt.Method)
  326. }
  327. }
  328. }
  329. }
  330. }
  331. // methods = append(methods, "OPTIONS") not sure about this
  332. return methods
  333. }
  334. // newBasicRequestResponse creates a pair of Request,Response from its http versions.
  335. // It is basic because no parameter or (produces) content-type information is given.
  336. func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
  337. resp := NewResponse(httpWriter)
  338. resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
  339. return NewRequest(httpRequest), resp
  340. }