customizations.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package ec2
  2. import (
  3. "time"
  4. "github.com/aws/aws-sdk-go/aws"
  5. "github.com/aws/aws-sdk-go/aws/awsutil"
  6. "github.com/aws/aws-sdk-go/aws/client"
  7. "github.com/aws/aws-sdk-go/aws/endpoints"
  8. "github.com/aws/aws-sdk-go/aws/request"
  9. "github.com/aws/aws-sdk-go/internal/sdkrand"
  10. )
  11. type retryer struct {
  12. client.DefaultRetryer
  13. }
  14. func (d retryer) RetryRules(r *request.Request) time.Duration {
  15. switch r.Operation.Name {
  16. case opModifyNetworkInterfaceAttribute:
  17. fallthrough
  18. case opAssignPrivateIpAddresses:
  19. return customRetryRule(r)
  20. default:
  21. return d.DefaultRetryer.RetryRules(r)
  22. }
  23. }
  24. func customRetryRule(r *request.Request) time.Duration {
  25. retryTimes := []time.Duration{
  26. time.Second,
  27. 3 * time.Second,
  28. 5 * time.Second,
  29. }
  30. count := r.RetryCount
  31. if count >= len(retryTimes) {
  32. count = len(retryTimes) - 1
  33. }
  34. minTime := int(retryTimes[count])
  35. return time.Duration(sdkrand.SeededRand.Intn(minTime) + minTime)
  36. }
  37. func setCustomRetryer(c *client.Client) {
  38. maxRetries := aws.IntValue(c.Config.MaxRetries)
  39. if c.Config.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
  40. maxRetries = 3
  41. }
  42. c.Retryer = retryer{
  43. DefaultRetryer: client.DefaultRetryer{
  44. NumMaxRetries: maxRetries,
  45. },
  46. }
  47. }
  48. func init() {
  49. initClient = func(c *client.Client) {
  50. if c.Config.Retryer == nil {
  51. // Only override the retryer with a custom one if the config
  52. // does not already contain a retryer
  53. setCustomRetryer(c)
  54. }
  55. }
  56. initRequest = func(r *request.Request) {
  57. if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter
  58. r.Handlers.Build.PushFront(fillPresignedURL)
  59. }
  60. }
  61. }
  62. func fillPresignedURL(r *request.Request) {
  63. if !r.ParamsFilled() {
  64. return
  65. }
  66. origParams := r.Params.(*CopySnapshotInput)
  67. // Stop if PresignedURL/DestinationRegion is set
  68. if origParams.PresignedUrl != nil || origParams.DestinationRegion != nil {
  69. return
  70. }
  71. origParams.DestinationRegion = r.Config.Region
  72. newParams := awsutil.CopyOf(r.Params).(*CopySnapshotInput)
  73. // Create a new request based on the existing request. We will use this to
  74. // presign the CopySnapshot request against the source region.
  75. cfg := r.Config.Copy(aws.NewConfig().
  76. WithEndpoint("").
  77. WithRegion(aws.StringValue(origParams.SourceRegion)))
  78. clientInfo := r.ClientInfo
  79. resolved, err := r.Config.EndpointResolver.EndpointFor(
  80. clientInfo.ServiceName, aws.StringValue(cfg.Region),
  81. func(opt *endpoints.Options) {
  82. opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
  83. opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
  84. },
  85. )
  86. if err != nil {
  87. r.Error = err
  88. return
  89. }
  90. clientInfo.Endpoint = resolved.URL
  91. clientInfo.SigningRegion = resolved.SigningRegion
  92. // Presign a CopySnapshot request with modified params
  93. req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data)
  94. url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough.
  95. if err != nil { // bubble error back up to original request
  96. r.Error = err
  97. return
  98. }
  99. // We have our URL, set it on params
  100. origParams.PresignedUrl = &url
  101. }