util.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package storage
  2. // Copyright 2017 Microsoft Corporation
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. import (
  16. "bytes"
  17. "crypto/hmac"
  18. "crypto/rand"
  19. "crypto/sha256"
  20. "encoding/base64"
  21. "encoding/xml"
  22. "fmt"
  23. "io"
  24. "io/ioutil"
  25. "net/http"
  26. "net/url"
  27. "reflect"
  28. "strconv"
  29. "strings"
  30. "time"
  31. uuid "github.com/satori/go.uuid"
  32. )
  33. var (
  34. fixedTime = time.Date(2050, time.December, 20, 21, 55, 0, 0, time.FixedZone("GMT", -6))
  35. accountSASOptions = AccountSASTokenOptions{
  36. Services: Services{
  37. Blob: true,
  38. },
  39. ResourceTypes: ResourceTypes{
  40. Service: true,
  41. Container: true,
  42. Object: true,
  43. },
  44. Permissions: Permissions{
  45. Read: true,
  46. Write: true,
  47. Delete: true,
  48. List: true,
  49. Add: true,
  50. Create: true,
  51. Update: true,
  52. Process: true,
  53. },
  54. Expiry: fixedTime,
  55. UseHTTPS: true,
  56. }
  57. )
  58. func (c Client) computeHmac256(message string) string {
  59. h := hmac.New(sha256.New, c.accountKey)
  60. h.Write([]byte(message))
  61. return base64.StdEncoding.EncodeToString(h.Sum(nil))
  62. }
  63. func currentTimeRfc1123Formatted() string {
  64. return timeRfc1123Formatted(time.Now().UTC())
  65. }
  66. func timeRfc1123Formatted(t time.Time) string {
  67. return t.Format(http.TimeFormat)
  68. }
  69. func timeRFC3339Formatted(t time.Time) string {
  70. return t.Format("2006-01-02T15:04:05.0000000Z")
  71. }
  72. func mergeParams(v1, v2 url.Values) url.Values {
  73. out := url.Values{}
  74. for k, v := range v1 {
  75. out[k] = v
  76. }
  77. for k, v := range v2 {
  78. vals, ok := out[k]
  79. if ok {
  80. vals = append(vals, v...)
  81. out[k] = vals
  82. } else {
  83. out[k] = v
  84. }
  85. }
  86. return out
  87. }
  88. func prepareBlockListRequest(blocks []Block) string {
  89. s := `<?xml version="1.0" encoding="utf-8"?><BlockList>`
  90. for _, v := range blocks {
  91. s += fmt.Sprintf("<%s>%s</%s>", v.Status, v.ID, v.Status)
  92. }
  93. s += `</BlockList>`
  94. return s
  95. }
  96. func xmlUnmarshal(body io.Reader, v interface{}) error {
  97. data, err := ioutil.ReadAll(body)
  98. if err != nil {
  99. return err
  100. }
  101. return xml.Unmarshal(data, v)
  102. }
  103. func xmlMarshal(v interface{}) (io.Reader, int, error) {
  104. b, err := xml.Marshal(v)
  105. if err != nil {
  106. return nil, 0, err
  107. }
  108. return bytes.NewReader(b), len(b), nil
  109. }
  110. func headersFromStruct(v interface{}) map[string]string {
  111. headers := make(map[string]string)
  112. value := reflect.ValueOf(v)
  113. for i := 0; i < value.NumField(); i++ {
  114. key := value.Type().Field(i).Tag.Get("header")
  115. if key != "" {
  116. reflectedValue := reflect.Indirect(value.Field(i))
  117. var val string
  118. if reflectedValue.IsValid() {
  119. switch reflectedValue.Type() {
  120. case reflect.TypeOf(fixedTime):
  121. val = timeRfc1123Formatted(reflectedValue.Interface().(time.Time))
  122. case reflect.TypeOf(uint64(0)), reflect.TypeOf(uint(0)):
  123. val = strconv.FormatUint(reflectedValue.Uint(), 10)
  124. case reflect.TypeOf(int(0)):
  125. val = strconv.FormatInt(reflectedValue.Int(), 10)
  126. default:
  127. val = reflectedValue.String()
  128. }
  129. }
  130. if val != "" {
  131. headers[key] = val
  132. }
  133. }
  134. }
  135. return headers
  136. }
  137. // merges extraHeaders into headers and returns headers
  138. func mergeHeaders(headers, extraHeaders map[string]string) map[string]string {
  139. for k, v := range extraHeaders {
  140. headers[k] = v
  141. }
  142. return headers
  143. }
  144. func addToHeaders(h map[string]string, key, value string) map[string]string {
  145. if value != "" {
  146. h[key] = value
  147. }
  148. return h
  149. }
  150. func addTimeToHeaders(h map[string]string, key string, value *time.Time) map[string]string {
  151. if value != nil {
  152. h = addToHeaders(h, key, timeRfc1123Formatted(*value))
  153. }
  154. return h
  155. }
  156. func addTimeout(params url.Values, timeout uint) url.Values {
  157. if timeout > 0 {
  158. params.Add("timeout", fmt.Sprintf("%v", timeout))
  159. }
  160. return params
  161. }
  162. func addSnapshot(params url.Values, snapshot *time.Time) url.Values {
  163. if snapshot != nil {
  164. params.Add("snapshot", timeRFC3339Formatted(*snapshot))
  165. }
  166. return params
  167. }
  168. func getTimeFromHeaders(h http.Header, key string) (*time.Time, error) {
  169. var out time.Time
  170. var err error
  171. outStr := h.Get(key)
  172. if outStr != "" {
  173. out, err = time.Parse(time.RFC1123, outStr)
  174. if err != nil {
  175. return nil, err
  176. }
  177. }
  178. return &out, nil
  179. }
  180. // TimeRFC1123 is an alias for time.Time needed for custom Unmarshalling
  181. type TimeRFC1123 time.Time
  182. // UnmarshalXML is a custom unmarshaller that overrides the default time unmarshal which uses a different time layout.
  183. func (t *TimeRFC1123) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
  184. var value string
  185. d.DecodeElement(&value, &start)
  186. parse, err := time.Parse(time.RFC1123, value)
  187. if err != nil {
  188. return err
  189. }
  190. *t = TimeRFC1123(parse)
  191. return nil
  192. }
  193. // MarshalXML marshals using time.RFC1123.
  194. func (t *TimeRFC1123) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
  195. return e.EncodeElement(time.Time(*t).Format(time.RFC1123), start)
  196. }
  197. // returns a map of custom metadata values from the specified HTTP header
  198. func getMetadataFromHeaders(header http.Header) map[string]string {
  199. metadata := make(map[string]string)
  200. for k, v := range header {
  201. // Can't trust CanonicalHeaderKey() to munge case
  202. // reliably. "_" is allowed in identifiers:
  203. // https://msdn.microsoft.com/en-us/library/azure/dd179414.aspx
  204. // https://msdn.microsoft.com/library/aa664670(VS.71).aspx
  205. // http://tools.ietf.org/html/rfc7230#section-3.2
  206. // ...but "_" is considered invalid by
  207. // CanonicalMIMEHeaderKey in
  208. // https://golang.org/src/net/textproto/reader.go?s=14615:14659#L542
  209. // so k can be "X-Ms-Meta-Lol" or "x-ms-meta-lol_rofl".
  210. k = strings.ToLower(k)
  211. if len(v) == 0 || !strings.HasPrefix(k, strings.ToLower(userDefinedMetadataHeaderPrefix)) {
  212. continue
  213. }
  214. // metadata["lol"] = content of the last X-Ms-Meta-Lol header
  215. k = k[len(userDefinedMetadataHeaderPrefix):]
  216. metadata[k] = v[len(v)-1]
  217. }
  218. if len(metadata) == 0 {
  219. return nil
  220. }
  221. return metadata
  222. }
  223. // newUUID returns a new uuid using RFC 4122 algorithm.
  224. func newUUID() (uuid.UUID, error) {
  225. u := [16]byte{}
  226. // Set all bits to randomly (or pseudo-randomly) chosen values.
  227. _, err := rand.Read(u[:])
  228. if err != nil {
  229. return uuid.UUID{}, err
  230. }
  231. u[8] = (u[8]&(0xff>>2) | (0x02 << 6)) // u.setVariant(ReservedRFC4122)
  232. u[6] = (u[6] & 0xF) | (uuid.V4 << 4) // u.setVersion(V4)
  233. return uuid.FromBytes(u[:])
  234. }