unmarshal.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package rest
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net/http"
  9. "reflect"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/aws/aws-sdk-go/aws"
  14. "github.com/aws/aws-sdk-go/aws/awserr"
  15. "github.com/aws/aws-sdk-go/aws/request"
  16. awsStrings "github.com/aws/aws-sdk-go/internal/strings"
  17. "github.com/aws/aws-sdk-go/private/protocol"
  18. )
  19. // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
  20. var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
  21. // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
  22. var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
  23. // Unmarshal unmarshals the REST component of a response in a REST service.
  24. func Unmarshal(r *request.Request) {
  25. if r.DataFilled() {
  26. v := reflect.Indirect(reflect.ValueOf(r.Data))
  27. if err := unmarshalBody(r, v); err != nil {
  28. r.Error = err
  29. }
  30. }
  31. }
  32. // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
  33. func UnmarshalMeta(r *request.Request) {
  34. r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
  35. if r.RequestID == "" {
  36. // Alternative version of request id in the header
  37. r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
  38. }
  39. if r.DataFilled() {
  40. if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
  41. r.Error = err
  42. }
  43. }
  44. }
  45. // UnmarshalResponse attempts to unmarshal the REST response headers to
  46. // the data type passed in. The type must be a pointer. An error is returned
  47. // with any error unmarshaling the response into the target datatype.
  48. func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
  49. v := reflect.Indirect(reflect.ValueOf(data))
  50. return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
  51. }
  52. func unmarshalBody(r *request.Request, v reflect.Value) error {
  53. if field, ok := v.Type().FieldByName("_"); ok {
  54. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  55. pfield, _ := v.Type().FieldByName(payloadName)
  56. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  57. payload := v.FieldByName(payloadName)
  58. if payload.IsValid() {
  59. switch payload.Interface().(type) {
  60. case []byte:
  61. defer r.HTTPResponse.Body.Close()
  62. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  63. if err != nil {
  64. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  65. }
  66. payload.Set(reflect.ValueOf(b))
  67. case *string:
  68. defer r.HTTPResponse.Body.Close()
  69. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  70. if err != nil {
  71. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  72. }
  73. str := string(b)
  74. payload.Set(reflect.ValueOf(&str))
  75. default:
  76. switch payload.Type().String() {
  77. case "io.ReadCloser":
  78. payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
  79. case "io.ReadSeeker":
  80. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  81. if err != nil {
  82. return awserr.New(request.ErrCodeSerialization,
  83. "failed to read response body", err)
  84. }
  85. payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
  86. default:
  87. io.Copy(ioutil.Discard, r.HTTPResponse.Body)
  88. r.HTTPResponse.Body.Close()
  89. return awserr.New(request.ErrCodeSerialization,
  90. "failed to decode REST response",
  91. fmt.Errorf("unknown payload type %s", payload.Type()))
  92. }
  93. }
  94. }
  95. }
  96. }
  97. }
  98. return nil
  99. }
  100. func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
  101. for i := 0; i < v.NumField(); i++ {
  102. m, field := v.Field(i), v.Type().Field(i)
  103. if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
  104. continue
  105. }
  106. if m.IsValid() {
  107. name := field.Tag.Get("locationName")
  108. if name == "" {
  109. name = field.Name
  110. }
  111. switch field.Tag.Get("location") {
  112. case "statusCode":
  113. unmarshalStatusCode(m, resp.StatusCode)
  114. case "header":
  115. err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
  116. if err != nil {
  117. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  118. }
  119. case "headers":
  120. prefix := field.Tag.Get("locationName")
  121. err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
  122. if err != nil {
  123. awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  124. }
  125. }
  126. }
  127. }
  128. return nil
  129. }
  130. func unmarshalStatusCode(v reflect.Value, statusCode int) {
  131. if !v.IsValid() {
  132. return
  133. }
  134. switch v.Interface().(type) {
  135. case *int64:
  136. s := int64(statusCode)
  137. v.Set(reflect.ValueOf(&s))
  138. }
  139. }
  140. func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
  141. if len(headers) == 0 {
  142. return nil
  143. }
  144. switch r.Interface().(type) {
  145. case map[string]*string: // we only support string map value types
  146. out := map[string]*string{}
  147. for k, v := range headers {
  148. if awsStrings.HasPrefixFold(k, prefix) {
  149. if normalize == true {
  150. k = strings.ToLower(k)
  151. } else {
  152. k = http.CanonicalHeaderKey(k)
  153. }
  154. out[k[len(prefix):]] = &v[0]
  155. }
  156. }
  157. if len(out) != 0 {
  158. r.Set(reflect.ValueOf(out))
  159. }
  160. }
  161. return nil
  162. }
  163. func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
  164. switch tag.Get("type") {
  165. case "jsonvalue":
  166. if len(header) == 0 {
  167. return nil
  168. }
  169. case "blob":
  170. if len(header) == 0 {
  171. return nil
  172. }
  173. default:
  174. if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
  175. return nil
  176. }
  177. }
  178. switch v.Interface().(type) {
  179. case *string:
  180. v.Set(reflect.ValueOf(&header))
  181. case []byte:
  182. b, err := base64.StdEncoding.DecodeString(header)
  183. if err != nil {
  184. return err
  185. }
  186. v.Set(reflect.ValueOf(b))
  187. case *bool:
  188. b, err := strconv.ParseBool(header)
  189. if err != nil {
  190. return err
  191. }
  192. v.Set(reflect.ValueOf(&b))
  193. case *int64:
  194. i, err := strconv.ParseInt(header, 10, 64)
  195. if err != nil {
  196. return err
  197. }
  198. v.Set(reflect.ValueOf(&i))
  199. case *float64:
  200. f, err := strconv.ParseFloat(header, 64)
  201. if err != nil {
  202. return err
  203. }
  204. v.Set(reflect.ValueOf(&f))
  205. case *time.Time:
  206. format := tag.Get("timestampFormat")
  207. if len(format) == 0 {
  208. format = protocol.RFC822TimeFormatName
  209. }
  210. t, err := protocol.ParseTime(format, header)
  211. if err != nil {
  212. return err
  213. }
  214. v.Set(reflect.ValueOf(&t))
  215. case aws.JSONValue:
  216. escaping := protocol.NoEscape
  217. if tag.Get("location") == "header" {
  218. escaping = protocol.Base64Escape
  219. }
  220. m, err := protocol.DecodeJSONValue(header, escaping)
  221. if err != nil {
  222. return err
  223. }
  224. v.Set(reflect.ValueOf(m))
  225. default:
  226. err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
  227. return err
  228. }
  229. return nil
  230. }