keys.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package gstruct
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "runtime/debug"
  7. "strings"
  8. "github.com/onsi/gomega/format"
  9. errorsutil "github.com/onsi/gomega/gstruct/errors"
  10. "github.com/onsi/gomega/types"
  11. )
  12. func MatchAllKeys(keys Keys) types.GomegaMatcher {
  13. return &KeysMatcher{
  14. Keys: keys,
  15. }
  16. }
  17. func MatchKeys(options Options, keys Keys) types.GomegaMatcher {
  18. return &KeysMatcher{
  19. Keys: keys,
  20. IgnoreExtras: options&IgnoreExtras != 0,
  21. IgnoreMissing: options&IgnoreMissing != 0,
  22. }
  23. }
  24. type KeysMatcher struct {
  25. // Matchers for each key.
  26. Keys Keys
  27. // Whether to ignore extra keys or consider it an error.
  28. IgnoreExtras bool
  29. // Whether to ignore missing keys or consider it an error.
  30. IgnoreMissing bool
  31. // State.
  32. failures []error
  33. }
  34. type Keys map[interface{}]types.GomegaMatcher
  35. func (m *KeysMatcher) Match(actual interface{}) (success bool, err error) {
  36. if reflect.TypeOf(actual).Kind() != reflect.Map {
  37. return false, fmt.Errorf("%v is type %T, expected map", actual, actual)
  38. }
  39. m.failures = m.matchKeys(actual)
  40. if len(m.failures) > 0 {
  41. return false, nil
  42. }
  43. return true, nil
  44. }
  45. func (m *KeysMatcher) matchKeys(actual interface{}) (errs []error) {
  46. actualValue := reflect.ValueOf(actual)
  47. keys := map[interface{}]bool{}
  48. for _, keyValue := range actualValue.MapKeys() {
  49. key := keyValue.Interface()
  50. keys[key] = true
  51. err := func() (err error) {
  52. // This test relies heavily on reflect, which tends to panic.
  53. // Recover here to provide more useful error messages in that case.
  54. defer func() {
  55. if r := recover(); r != nil {
  56. err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
  57. }
  58. }()
  59. matcher, ok := m.Keys[key]
  60. if !ok {
  61. if !m.IgnoreExtras {
  62. return fmt.Errorf("unexpected key %s: %+v", key, actual)
  63. }
  64. return nil
  65. }
  66. valValue := actualValue.MapIndex(keyValue)
  67. match, err := matcher.Match(valValue.Interface())
  68. if err != nil {
  69. return err
  70. }
  71. if !match {
  72. if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
  73. return errorsutil.AggregateError(nesting.Failures())
  74. }
  75. return errors.New(matcher.FailureMessage(valValue))
  76. }
  77. return nil
  78. }()
  79. if err != nil {
  80. errs = append(errs, errorsutil.Nest(fmt.Sprintf(".%#v", key), err))
  81. }
  82. }
  83. for key := range m.Keys {
  84. if !keys[key] && !m.IgnoreMissing {
  85. errs = append(errs, fmt.Errorf("missing expected key %s", key))
  86. }
  87. }
  88. return errs
  89. }
  90. func (m *KeysMatcher) FailureMessage(actual interface{}) (message string) {
  91. failures := make([]string, len(m.failures))
  92. for i := range m.failures {
  93. failures[i] = m.failures[i].Error()
  94. }
  95. return format.Message(reflect.TypeOf(actual).Name(),
  96. fmt.Sprintf("to match keys: {\n%v\n}\n", strings.Join(failures, "\n")))
  97. }
  98. func (m *KeysMatcher) NegatedFailureMessage(actual interface{}) (message string) {
  99. return format.Message(actual, "not to match keys")
  100. }
  101. func (m *KeysMatcher) Failures() []error {
  102. return m.failures
  103. }