keys.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. // untested sections: 6
  2. package gstruct
  3. import (
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "runtime/debug"
  8. "strings"
  9. "github.com/onsi/gomega/format"
  10. errorsutil "github.com/onsi/gomega/gstruct/errors"
  11. "github.com/onsi/gomega/types"
  12. )
  13. func MatchAllKeys(keys Keys) types.GomegaMatcher {
  14. return &KeysMatcher{
  15. Keys: keys,
  16. }
  17. }
  18. func MatchKeys(options Options, keys Keys) types.GomegaMatcher {
  19. return &KeysMatcher{
  20. Keys: keys,
  21. IgnoreExtras: options&IgnoreExtras != 0,
  22. IgnoreMissing: options&IgnoreMissing != 0,
  23. }
  24. }
  25. type KeysMatcher struct {
  26. // Matchers for each key.
  27. Keys Keys
  28. // Whether to ignore extra keys or consider it an error.
  29. IgnoreExtras bool
  30. // Whether to ignore missing keys or consider it an error.
  31. IgnoreMissing bool
  32. // State.
  33. failures []error
  34. }
  35. type Keys map[interface{}]types.GomegaMatcher
  36. func (m *KeysMatcher) Match(actual interface{}) (success bool, err error) {
  37. if reflect.TypeOf(actual).Kind() != reflect.Map {
  38. return false, fmt.Errorf("%v is type %T, expected map", actual, actual)
  39. }
  40. m.failures = m.matchKeys(actual)
  41. if len(m.failures) > 0 {
  42. return false, nil
  43. }
  44. return true, nil
  45. }
  46. func (m *KeysMatcher) matchKeys(actual interface{}) (errs []error) {
  47. actualValue := reflect.ValueOf(actual)
  48. keys := map[interface{}]bool{}
  49. for _, keyValue := range actualValue.MapKeys() {
  50. key := keyValue.Interface()
  51. keys[key] = true
  52. err := func() (err error) {
  53. // This test relies heavily on reflect, which tends to panic.
  54. // Recover here to provide more useful error messages in that case.
  55. defer func() {
  56. if r := recover(); r != nil {
  57. err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
  58. }
  59. }()
  60. matcher, ok := m.Keys[key]
  61. if !ok {
  62. if !m.IgnoreExtras {
  63. return fmt.Errorf("unexpected key %s: %+v", key, actual)
  64. }
  65. return nil
  66. }
  67. valInterface := actualValue.MapIndex(keyValue).Interface()
  68. match, err := matcher.Match(valInterface)
  69. if err != nil {
  70. return err
  71. }
  72. if !match {
  73. if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
  74. return errorsutil.AggregateError(nesting.Failures())
  75. }
  76. return errors.New(matcher.FailureMessage(valInterface))
  77. }
  78. return nil
  79. }()
  80. if err != nil {
  81. errs = append(errs, errorsutil.Nest(fmt.Sprintf(".%#v", key), err))
  82. }
  83. }
  84. for key := range m.Keys {
  85. if !keys[key] && !m.IgnoreMissing {
  86. errs = append(errs, fmt.Errorf("missing expected key %s", key))
  87. }
  88. }
  89. return errs
  90. }
  91. func (m *KeysMatcher) FailureMessage(actual interface{}) (message string) {
  92. failures := make([]string, len(m.failures))
  93. for i := range m.failures {
  94. failures[i] = m.failures[i].Error()
  95. }
  96. return format.Message(reflect.TypeOf(actual).Name(),
  97. fmt.Sprintf("to match keys: {\n%v\n}\n", strings.Join(failures, "\n")))
  98. }
  99. func (m *KeysMatcher) NegatedFailureMessage(actual interface{}) (message string) {
  100. return format.Message(actual, "not to match keys")
  101. }
  102. func (m *KeysMatcher) Failures() []error {
  103. return m.failures
  104. }