async_assertion.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. // untested sections: 2
  2. package asyncassertion
  3. import (
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "time"
  8. "github.com/onsi/gomega/internal/oraclematcher"
  9. "github.com/onsi/gomega/types"
  10. )
  11. type AsyncAssertionType uint
  12. const (
  13. AsyncAssertionTypeEventually AsyncAssertionType = iota
  14. AsyncAssertionTypeConsistently
  15. )
  16. type AsyncAssertion struct {
  17. asyncType AsyncAssertionType
  18. actualInput interface{}
  19. timeoutInterval time.Duration
  20. pollingInterval time.Duration
  21. failWrapper *types.GomegaFailWrapper
  22. offset int
  23. }
  24. func New(asyncType AsyncAssertionType, actualInput interface{}, failWrapper *types.GomegaFailWrapper, timeoutInterval time.Duration, pollingInterval time.Duration, offset int) *AsyncAssertion {
  25. actualType := reflect.TypeOf(actualInput)
  26. if actualType.Kind() == reflect.Func {
  27. if actualType.NumIn() != 0 || actualType.NumOut() == 0 {
  28. panic("Expected a function with no arguments and one or more return values.")
  29. }
  30. }
  31. return &AsyncAssertion{
  32. asyncType: asyncType,
  33. actualInput: actualInput,
  34. failWrapper: failWrapper,
  35. timeoutInterval: timeoutInterval,
  36. pollingInterval: pollingInterval,
  37. offset: offset,
  38. }
  39. }
  40. func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
  41. assertion.failWrapper.TWithHelper.Helper()
  42. return assertion.match(matcher, true, optionalDescription...)
  43. }
  44. func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
  45. assertion.failWrapper.TWithHelper.Helper()
  46. return assertion.match(matcher, false, optionalDescription...)
  47. }
  48. func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string {
  49. switch len(optionalDescription) {
  50. case 0:
  51. return ""
  52. default:
  53. return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
  54. }
  55. }
  56. func (assertion *AsyncAssertion) actualInputIsAFunction() bool {
  57. actualType := reflect.TypeOf(assertion.actualInput)
  58. return actualType.Kind() == reflect.Func && actualType.NumIn() == 0 && actualType.NumOut() > 0
  59. }
  60. func (assertion *AsyncAssertion) pollActual() (interface{}, error) {
  61. if assertion.actualInputIsAFunction() {
  62. values := reflect.ValueOf(assertion.actualInput).Call([]reflect.Value{})
  63. extras := []interface{}{}
  64. for _, value := range values[1:] {
  65. extras = append(extras, value.Interface())
  66. }
  67. success, message := vetExtras(extras)
  68. if !success {
  69. return nil, errors.New(message)
  70. }
  71. return values[0].Interface(), nil
  72. }
  73. return assertion.actualInput, nil
  74. }
  75. func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool {
  76. if assertion.actualInputIsAFunction() {
  77. return true
  78. }
  79. return oraclematcher.MatchMayChangeInTheFuture(matcher, value)
  80. }
  81. func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool {
  82. timer := time.Now()
  83. timeout := time.After(assertion.timeoutInterval)
  84. description := assertion.buildDescription(optionalDescription...)
  85. var matches bool
  86. var err error
  87. mayChange := true
  88. value, err := assertion.pollActual()
  89. if err == nil {
  90. mayChange = assertion.matcherMayChange(matcher, value)
  91. matches, err = matcher.Match(value)
  92. }
  93. assertion.failWrapper.TWithHelper.Helper()
  94. fail := func(preamble string) {
  95. errMsg := ""
  96. message := ""
  97. if err != nil {
  98. errMsg = "Error: " + err.Error()
  99. } else {
  100. if desiredMatch {
  101. message = matcher.FailureMessage(value)
  102. } else {
  103. message = matcher.NegatedFailureMessage(value)
  104. }
  105. }
  106. assertion.failWrapper.TWithHelper.Helper()
  107. assertion.failWrapper.Fail(fmt.Sprintf("%s after %.3fs.\n%s%s%s", preamble, time.Since(timer).Seconds(), description, message, errMsg), 3+assertion.offset)
  108. }
  109. if assertion.asyncType == AsyncAssertionTypeEventually {
  110. for {
  111. if err == nil && matches == desiredMatch {
  112. return true
  113. }
  114. if !mayChange {
  115. fail("No future change is possible. Bailing out early")
  116. return false
  117. }
  118. select {
  119. case <-time.After(assertion.pollingInterval):
  120. value, err = assertion.pollActual()
  121. if err == nil {
  122. mayChange = assertion.matcherMayChange(matcher, value)
  123. matches, err = matcher.Match(value)
  124. }
  125. case <-timeout:
  126. fail("Timed out")
  127. return false
  128. }
  129. }
  130. } else if assertion.asyncType == AsyncAssertionTypeConsistently {
  131. for {
  132. if !(err == nil && matches == desiredMatch) {
  133. fail("Failed")
  134. return false
  135. }
  136. if !mayChange {
  137. return true
  138. }
  139. select {
  140. case <-time.After(assertion.pollingInterval):
  141. value, err = assertion.pollActual()
  142. if err == nil {
  143. mayChange = assertion.matcherMayChange(matcher, value)
  144. matches, err = matcher.Match(value)
  145. }
  146. case <-timeout:
  147. return true
  148. }
  149. }
  150. }
  151. return false
  152. }
  153. func vetExtras(extras []interface{}) (bool, string) {
  154. for i, extra := range extras {
  155. if extra != nil {
  156. zeroValue := reflect.Zero(reflect.TypeOf(extra)).Interface()
  157. if !reflect.DeepEqual(zeroValue, extra) {
  158. message := fmt.Sprintf("Unexpected non-nil/non-zero extra argument at index %d:\n\t<%T>: %#v", i+1, extra, extra)
  159. return false, message
  160. }
  161. }
  162. }
  163. return true, ""
  164. }