httpcache.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. // Package httpcache provides a http.RoundTripper implementation that works as a
  2. // mostly RFC-compliant cache for http responses.
  3. //
  4. // It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client
  5. // and not for a shared proxy).
  6. //
  7. package httpcache
  8. import (
  9. "bufio"
  10. "bytes"
  11. "errors"
  12. "fmt"
  13. "io"
  14. "io/ioutil"
  15. "net/http"
  16. "net/http/httputil"
  17. "strings"
  18. "sync"
  19. "time"
  20. )
  21. const (
  22. stale = iota
  23. fresh
  24. transparent
  25. // XFromCache is the header added to responses that are returned from the cache
  26. XFromCache = "X-From-Cache"
  27. )
  28. // A Cache interface is used by the Transport to store and retrieve responses.
  29. type Cache interface {
  30. // Get returns the []byte representation of a cached response and a bool
  31. // set to true if the value isn't empty
  32. Get(key string) (responseBytes []byte, ok bool)
  33. // Set stores the []byte representation of a response against a key
  34. Set(key string, responseBytes []byte)
  35. // Delete removes the value associated with the key
  36. Delete(key string)
  37. }
  38. // cacheKey returns the cache key for req.
  39. func cacheKey(req *http.Request) string {
  40. return req.URL.String()
  41. }
  42. // CachedResponse returns the cached http.Response for req if present, and nil
  43. // otherwise.
  44. func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
  45. cachedVal, ok := c.Get(cacheKey(req))
  46. if !ok {
  47. return
  48. }
  49. b := bytes.NewBuffer(cachedVal)
  50. return http.ReadResponse(bufio.NewReader(b), req)
  51. }
  52. // MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
  53. type MemoryCache struct {
  54. mu sync.RWMutex
  55. items map[string][]byte
  56. }
  57. // Get returns the []byte representation of the response and true if present, false if not
  58. func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
  59. c.mu.RLock()
  60. resp, ok = c.items[key]
  61. c.mu.RUnlock()
  62. return resp, ok
  63. }
  64. // Set saves response resp to the cache with key
  65. func (c *MemoryCache) Set(key string, resp []byte) {
  66. c.mu.Lock()
  67. c.items[key] = resp
  68. c.mu.Unlock()
  69. }
  70. // Delete removes key from the cache
  71. func (c *MemoryCache) Delete(key string) {
  72. c.mu.Lock()
  73. delete(c.items, key)
  74. c.mu.Unlock()
  75. }
  76. // NewMemoryCache returns a new Cache that will store items in an in-memory map
  77. func NewMemoryCache() *MemoryCache {
  78. c := &MemoryCache{items: map[string][]byte{}}
  79. return c
  80. }
  81. // Transport is an implementation of http.RoundTripper that will return values from a cache
  82. // where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since)
  83. // to repeated requests allowing servers to return 304 / Not Modified
  84. type Transport struct {
  85. // The RoundTripper interface actually used to make requests
  86. // If nil, http.DefaultTransport is used
  87. Transport http.RoundTripper
  88. Cache Cache
  89. // If true, responses returned from the cache will be given an extra header, X-From-Cache
  90. MarkCachedResponses bool
  91. }
  92. // NewTransport returns a new Transport with the
  93. // provided Cache implementation and MarkCachedResponses set to true
  94. func NewTransport(c Cache) *Transport {
  95. return &Transport{Cache: c, MarkCachedResponses: true}
  96. }
  97. // Client returns an *http.Client that caches responses.
  98. func (t *Transport) Client() *http.Client {
  99. return &http.Client{Transport: t}
  100. }
  101. // varyMatches will return false unless all of the cached values for the headers listed in Vary
  102. // match the new request
  103. func varyMatches(cachedResp *http.Response, req *http.Request) bool {
  104. for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
  105. header = http.CanonicalHeaderKey(header)
  106. if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
  107. return false
  108. }
  109. }
  110. return true
  111. }
  112. // RoundTrip takes a Request and returns a Response
  113. //
  114. // If there is a fresh Response already in cache, then it will be returned without connecting to
  115. // the server.
  116. //
  117. // If there is a stale Response, then any validators it contains will be set on the new request
  118. // to give the server a chance to respond with NotModified. If this happens, then the cached Response
  119. // will be returned.
  120. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
  121. cacheKey := cacheKey(req)
  122. cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
  123. var cachedResp *http.Response
  124. if cacheable {
  125. cachedResp, err = CachedResponse(t.Cache, req)
  126. } else {
  127. // Need to invalidate an existing value
  128. t.Cache.Delete(cacheKey)
  129. }
  130. transport := t.Transport
  131. if transport == nil {
  132. transport = http.DefaultTransport
  133. }
  134. if cacheable && cachedResp != nil && err == nil {
  135. if t.MarkCachedResponses {
  136. cachedResp.Header.Set(XFromCache, "1")
  137. }
  138. if varyMatches(cachedResp, req) {
  139. // Can only use cached value if the new request doesn't Vary significantly
  140. freshness := getFreshness(cachedResp.Header, req.Header)
  141. if freshness == fresh {
  142. return cachedResp, nil
  143. }
  144. if freshness == stale {
  145. var req2 *http.Request
  146. // Add validators if caller hasn't already done so
  147. etag := cachedResp.Header.Get("etag")
  148. if etag != "" && req.Header.Get("etag") == "" {
  149. req2 = cloneRequest(req)
  150. req2.Header.Set("if-none-match", etag)
  151. }
  152. lastModified := cachedResp.Header.Get("last-modified")
  153. if lastModified != "" && req.Header.Get("last-modified") == "" {
  154. if req2 == nil {
  155. req2 = cloneRequest(req)
  156. }
  157. req2.Header.Set("if-modified-since", lastModified)
  158. }
  159. if req2 != nil {
  160. req = req2
  161. }
  162. }
  163. }
  164. resp, err = transport.RoundTrip(req)
  165. if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
  166. // Replace the 304 response with the one from cache, but update with some new headers
  167. endToEndHeaders := getEndToEndHeaders(resp.Header)
  168. for _, header := range endToEndHeaders {
  169. cachedResp.Header[header] = resp.Header[header]
  170. }
  171. cachedResp.Status = fmt.Sprintf("%d %s", http.StatusOK, http.StatusText(http.StatusOK))
  172. cachedResp.StatusCode = http.StatusOK
  173. resp = cachedResp
  174. } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
  175. req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
  176. // In case of transport failure and stale-if-error activated, returns cached content
  177. // when available
  178. cachedResp.Status = fmt.Sprintf("%d %s", http.StatusOK, http.StatusText(http.StatusOK))
  179. cachedResp.StatusCode = http.StatusOK
  180. return cachedResp, nil
  181. } else {
  182. if err != nil || resp.StatusCode != http.StatusOK {
  183. t.Cache.Delete(cacheKey)
  184. }
  185. if err != nil {
  186. return nil, err
  187. }
  188. }
  189. } else {
  190. reqCacheControl := parseCacheControl(req.Header)
  191. if _, ok := reqCacheControl["only-if-cached"]; ok {
  192. resp = newGatewayTimeoutResponse(req)
  193. } else {
  194. resp, err = transport.RoundTrip(req)
  195. if err != nil {
  196. return nil, err
  197. }
  198. }
  199. }
  200. if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
  201. for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
  202. varyKey = http.CanonicalHeaderKey(varyKey)
  203. fakeHeader := "X-Varied-" + varyKey
  204. reqValue := req.Header.Get(varyKey)
  205. if reqValue != "" {
  206. resp.Header.Set(fakeHeader, reqValue)
  207. }
  208. }
  209. switch req.Method {
  210. case "GET":
  211. // Delay caching until EOF is reached.
  212. resp.Body = &cachingReadCloser{
  213. R: resp.Body,
  214. OnEOF: func(r io.Reader) {
  215. resp := *resp
  216. resp.Body = ioutil.NopCloser(r)
  217. respBytes, err := httputil.DumpResponse(&resp, true)
  218. if err == nil {
  219. t.Cache.Set(cacheKey, respBytes)
  220. }
  221. },
  222. }
  223. default:
  224. respBytes, err := httputil.DumpResponse(resp, true)
  225. if err == nil {
  226. t.Cache.Set(cacheKey, respBytes)
  227. }
  228. }
  229. } else {
  230. t.Cache.Delete(cacheKey)
  231. }
  232. return resp, nil
  233. }
  234. // ErrNoDateHeader indicates that the HTTP headers contained no Date header.
  235. var ErrNoDateHeader = errors.New("no Date header")
  236. // Date parses and returns the value of the Date header.
  237. func Date(respHeaders http.Header) (date time.Time, err error) {
  238. dateHeader := respHeaders.Get("date")
  239. if dateHeader == "" {
  240. err = ErrNoDateHeader
  241. return
  242. }
  243. return time.Parse(time.RFC1123, dateHeader)
  244. }
  245. type realClock struct{}
  246. func (c *realClock) since(d time.Time) time.Duration {
  247. return time.Since(d)
  248. }
  249. type timer interface {
  250. since(d time.Time) time.Duration
  251. }
  252. var clock timer = &realClock{}
  253. // getFreshness will return one of fresh/stale/transparent based on the cache-control
  254. // values of the request and the response
  255. //
  256. // fresh indicates the response can be returned
  257. // stale indicates that the response needs validating before it is returned
  258. // transparent indicates the response should not be used to fulfil the request
  259. //
  260. // Because this is only a private cache, 'public' and 'private' in cache-control aren't
  261. // signficant. Similarly, smax-age isn't used.
  262. func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
  263. respCacheControl := parseCacheControl(respHeaders)
  264. reqCacheControl := parseCacheControl(reqHeaders)
  265. if _, ok := reqCacheControl["no-cache"]; ok {
  266. return transparent
  267. }
  268. if _, ok := respCacheControl["no-cache"]; ok {
  269. return stale
  270. }
  271. if _, ok := reqCacheControl["only-if-cached"]; ok {
  272. return fresh
  273. }
  274. date, err := Date(respHeaders)
  275. if err != nil {
  276. return stale
  277. }
  278. currentAge := clock.since(date)
  279. var lifetime time.Duration
  280. var zeroDuration time.Duration
  281. // If a response includes both an Expires header and a max-age directive,
  282. // the max-age directive overrides the Expires header, even if the Expires header is more restrictive.
  283. if maxAge, ok := respCacheControl["max-age"]; ok {
  284. lifetime, err = time.ParseDuration(maxAge + "s")
  285. if err != nil {
  286. lifetime = zeroDuration
  287. }
  288. } else {
  289. expiresHeader := respHeaders.Get("Expires")
  290. if expiresHeader != "" {
  291. expires, err := time.Parse(time.RFC1123, expiresHeader)
  292. if err != nil {
  293. lifetime = zeroDuration
  294. } else {
  295. lifetime = expires.Sub(date)
  296. }
  297. }
  298. }
  299. if maxAge, ok := reqCacheControl["max-age"]; ok {
  300. // the client is willing to accept a response whose age is no greater than the specified time in seconds
  301. lifetime, err = time.ParseDuration(maxAge + "s")
  302. if err != nil {
  303. lifetime = zeroDuration
  304. }
  305. }
  306. if minfresh, ok := reqCacheControl["min-fresh"]; ok {
  307. // the client wants a response that will still be fresh for at least the specified number of seconds.
  308. minfreshDuration, err := time.ParseDuration(minfresh + "s")
  309. if err == nil {
  310. currentAge = time.Duration(currentAge + minfreshDuration)
  311. }
  312. }
  313. if maxstale, ok := reqCacheControl["max-stale"]; ok {
  314. // Indicates that the client is willing to accept a response that has exceeded its expiration time.
  315. // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded
  316. // its expiration time by no more than the specified number of seconds.
  317. // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age.
  318. //
  319. // Responses served only because of a max-stale value are supposed to have a Warning header added to them,
  320. // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different
  321. // return-value available here.
  322. if maxstale == "" {
  323. return fresh
  324. }
  325. maxstaleDuration, err := time.ParseDuration(maxstale + "s")
  326. if err == nil {
  327. currentAge = time.Duration(currentAge - maxstaleDuration)
  328. }
  329. }
  330. if lifetime > currentAge {
  331. return fresh
  332. }
  333. return stale
  334. }
  335. // Returns true if either the request or the response includes the stale-if-error
  336. // cache control extension: https://tools.ietf.org/html/rfc5861
  337. func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
  338. respCacheControl := parseCacheControl(respHeaders)
  339. reqCacheControl := parseCacheControl(reqHeaders)
  340. var err error
  341. lifetime := time.Duration(-1)
  342. if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok {
  343. if staleMaxAge != "" {
  344. lifetime, err = time.ParseDuration(staleMaxAge + "s")
  345. if err != nil {
  346. return false
  347. }
  348. } else {
  349. return true
  350. }
  351. }
  352. if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok {
  353. if staleMaxAge != "" {
  354. lifetime, err = time.ParseDuration(staleMaxAge + "s")
  355. if err != nil {
  356. return false
  357. }
  358. } else {
  359. return true
  360. }
  361. }
  362. if lifetime >= 0 {
  363. date, err := Date(respHeaders)
  364. if err != nil {
  365. return false
  366. }
  367. currentAge := clock.since(date)
  368. if lifetime > currentAge {
  369. return true
  370. }
  371. }
  372. return false
  373. }
  374. func getEndToEndHeaders(respHeaders http.Header) []string {
  375. // These headers are always hop-by-hop
  376. hopByHopHeaders := map[string]struct{}{
  377. "Connection": struct{}{},
  378. "Keep-Alive": struct{}{},
  379. "Proxy-Authenticate": struct{}{},
  380. "Proxy-Authorization": struct{}{},
  381. "Te": struct{}{},
  382. "Trailers": struct{}{},
  383. "Transfer-Encoding": struct{}{},
  384. "Upgrade": struct{}{},
  385. }
  386. for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
  387. // any header listed in connection, if present, is also considered hop-by-hop
  388. if strings.Trim(extra, " ") != "" {
  389. hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{}
  390. }
  391. }
  392. endToEndHeaders := []string{}
  393. for respHeader, _ := range respHeaders {
  394. if _, ok := hopByHopHeaders[respHeader]; !ok {
  395. endToEndHeaders = append(endToEndHeaders, respHeader)
  396. }
  397. }
  398. return endToEndHeaders
  399. }
  400. func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
  401. if _, ok := respCacheControl["no-store"]; ok {
  402. return false
  403. }
  404. if _, ok := reqCacheControl["no-store"]; ok {
  405. return false
  406. }
  407. return true
  408. }
  409. func newGatewayTimeoutResponse(req *http.Request) *http.Response {
  410. var braw bytes.Buffer
  411. braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
  412. resp, err := http.ReadResponse(bufio.NewReader(&braw), req)
  413. if err != nil {
  414. panic(err)
  415. }
  416. return resp
  417. }
  418. // cloneRequest returns a clone of the provided *http.Request.
  419. // The clone is a shallow copy of the struct and its Header map.
  420. // (This function copyright goauth2 authors: https://code.google.com/p/goauth2)
  421. func cloneRequest(r *http.Request) *http.Request {
  422. // shallow copy of the struct
  423. r2 := new(http.Request)
  424. *r2 = *r
  425. // deep copy of the Header
  426. r2.Header = make(http.Header)
  427. for k, s := range r.Header {
  428. r2.Header[k] = s
  429. }
  430. return r2
  431. }
  432. type cacheControl map[string]string
  433. func parseCacheControl(headers http.Header) cacheControl {
  434. cc := cacheControl{}
  435. ccHeader := headers.Get("Cache-Control")
  436. for _, part := range strings.Split(ccHeader, ",") {
  437. part = strings.Trim(part, " ")
  438. if part == "" {
  439. continue
  440. }
  441. if strings.ContainsRune(part, '=') {
  442. keyval := strings.Split(part, "=")
  443. cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",")
  444. } else {
  445. cc[part] = ""
  446. }
  447. }
  448. return cc
  449. }
  450. // headerAllCommaSepValues returns all comma-separated values (each
  451. // with whitespace trimmed) for header name in headers. According to
  452. // Section 4.2 of the HTTP/1.1 spec
  453. // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2),
  454. // values from multiple occurrences of a header should be concatenated, if
  455. // the header's value is a comma-separated list.
  456. func headerAllCommaSepValues(headers http.Header, name string) []string {
  457. var vals []string
  458. for _, val := range headers[http.CanonicalHeaderKey(name)] {
  459. fields := strings.Split(val, ",")
  460. for i, f := range fields {
  461. fields[i] = strings.TrimSpace(f)
  462. }
  463. vals = append(vals, fields...)
  464. }
  465. return vals
  466. }
  467. // cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF
  468. // handler with a full copy of the content read from R when EOF is
  469. // reached.
  470. type cachingReadCloser struct {
  471. // Underlying ReadCloser.
  472. R io.ReadCloser
  473. // OnEOF is called with a copy of the content of R when EOF is reached.
  474. OnEOF func(io.Reader)
  475. buf bytes.Buffer // buf stores a copy of the content of R.
  476. }
  477. // Read reads the next len(p) bytes from R or until R is drained. The
  478. // return value n is the number of bytes read. If R has no data to
  479. // return, err is io.EOF and OnEOF is called with a full copy of what
  480. // has been read so far.
  481. func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
  482. n, err = r.R.Read(p)
  483. r.buf.Write(p[:n])
  484. if err == io.EOF {
  485. r.OnEOF(bytes.NewReader(r.buf.Bytes()))
  486. }
  487. return n, err
  488. }
  489. func (r *cachingReadCloser) Close() error {
  490. return r.R.Close()
  491. }
  492. // NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
  493. func NewMemoryCacheTransport() *Transport {
  494. c := NewMemoryCache()
  495. t := NewTransport(c)
  496. return t
  497. }