fields.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. // untested sections: 6
  2. package gstruct
  3. import (
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "runtime/debug"
  8. "strings"
  9. "github.com/onsi/gomega/format"
  10. errorsutil "github.com/onsi/gomega/gstruct/errors"
  11. "github.com/onsi/gomega/types"
  12. )
  13. //MatchAllFields succeeds if every field of a struct matches the field matcher associated with
  14. //it, and every element matcher is matched.
  15. // actual := struct{
  16. // A int
  17. // B []bool
  18. // C string
  19. // }{
  20. // A: 5,
  21. // B: []bool{true, false},
  22. // C: "foo",
  23. // }
  24. //
  25. // Expect(actual).To(MatchAllFields(Fields{
  26. // "A": Equal(5),
  27. // "B": ConsistOf(true, false),
  28. // "C": Equal("foo"),
  29. // }))
  30. func MatchAllFields(fields Fields) types.GomegaMatcher {
  31. return &FieldsMatcher{
  32. Fields: fields,
  33. }
  34. }
  35. //MatchFields succeeds if each element of a struct matches the field matcher associated with
  36. //it. It can ignore extra fields and/or missing fields.
  37. // actual := struct{
  38. // A int
  39. // B []bool
  40. // C string
  41. // }{
  42. // A: 5,
  43. // B: []bool{true, false},
  44. // C: "foo",
  45. // }
  46. //
  47. // Expect(actual).To(MatchFields(IgnoreExtras, Fields{
  48. // "A": Equal(5),
  49. // "B": ConsistOf(true, false),
  50. // }))
  51. // Expect(actual).To(MatchFields(IgnoreMissing, Fields{
  52. // "A": Equal(5),
  53. // "B": ConsistOf(true, false),
  54. // "C": Equal("foo"),
  55. // "D": Equal("extra"),
  56. // }))
  57. func MatchFields(options Options, fields Fields) types.GomegaMatcher {
  58. return &FieldsMatcher{
  59. Fields: fields,
  60. IgnoreExtras: options&IgnoreExtras != 0,
  61. IgnoreMissing: options&IgnoreMissing != 0,
  62. }
  63. }
  64. type FieldsMatcher struct {
  65. // Matchers for each field.
  66. Fields Fields
  67. // Whether to ignore extra elements or consider it an error.
  68. IgnoreExtras bool
  69. // Whether to ignore missing elements or consider it an error.
  70. IgnoreMissing bool
  71. // State.
  72. failures []error
  73. }
  74. // Field name to matcher.
  75. type Fields map[string]types.GomegaMatcher
  76. func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
  77. if reflect.TypeOf(actual).Kind() != reflect.Struct {
  78. return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
  79. }
  80. m.failures = m.matchFields(actual)
  81. if len(m.failures) > 0 {
  82. return false, nil
  83. }
  84. return true, nil
  85. }
  86. func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
  87. val := reflect.ValueOf(actual)
  88. typ := val.Type()
  89. fields := map[string]bool{}
  90. for i := 0; i < val.NumField(); i++ {
  91. fieldName := typ.Field(i).Name
  92. fields[fieldName] = true
  93. err := func() (err error) {
  94. // This test relies heavily on reflect, which tends to panic.
  95. // Recover here to provide more useful error messages in that case.
  96. defer func() {
  97. if r := recover(); r != nil {
  98. err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
  99. }
  100. }()
  101. matcher, expected := m.Fields[fieldName]
  102. if !expected {
  103. if !m.IgnoreExtras {
  104. return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
  105. }
  106. return nil
  107. }
  108. field := val.Field(i).Interface()
  109. match, err := matcher.Match(field)
  110. if err != nil {
  111. return err
  112. } else if !match {
  113. if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
  114. return errorsutil.AggregateError(nesting.Failures())
  115. }
  116. return errors.New(matcher.FailureMessage(field))
  117. }
  118. return nil
  119. }()
  120. if err != nil {
  121. errs = append(errs, errorsutil.Nest("."+fieldName, err))
  122. }
  123. }
  124. for field := range m.Fields {
  125. if !fields[field] && !m.IgnoreMissing {
  126. errs = append(errs, fmt.Errorf("missing expected field %s", field))
  127. }
  128. }
  129. return errs
  130. }
  131. func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
  132. failures := make([]string, len(m.failures))
  133. for i := range m.failures {
  134. failures[i] = m.failures[i].Error()
  135. }
  136. return format.Message(reflect.TypeOf(actual).Name(),
  137. fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
  138. }
  139. func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
  140. return format.Message(actual, "not to match fields")
  141. }
  142. func (m *FieldsMatcher) Failures() []error {
  143. return m.failures
  144. }