deep_equal.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package reflect is a fork of go's standard library reflection package, which
  5. // allows for deep equal with equality functions defined.
  6. package reflect
  7. import (
  8. "fmt"
  9. "reflect"
  10. "strings"
  11. )
  12. // Equalities is a map from type to a function comparing two values of
  13. // that type.
  14. type Equalities map[reflect.Type]reflect.Value
  15. // For convenience, panics on errrors
  16. func EqualitiesOrDie(funcs ...interface{}) Equalities {
  17. e := Equalities{}
  18. if err := e.AddFuncs(funcs...); err != nil {
  19. panic(err)
  20. }
  21. return e
  22. }
  23. // AddFuncs is a shortcut for multiple calls to AddFunc.
  24. func (e Equalities) AddFuncs(funcs ...interface{}) error {
  25. for _, f := range funcs {
  26. if err := e.AddFunc(f); err != nil {
  27. return err
  28. }
  29. }
  30. return nil
  31. }
  32. // AddFunc uses func as an equality function: it must take
  33. // two parameters of the same type, and return a boolean.
  34. func (e Equalities) AddFunc(eqFunc interface{}) error {
  35. fv := reflect.ValueOf(eqFunc)
  36. ft := fv.Type()
  37. if ft.Kind() != reflect.Func {
  38. return fmt.Errorf("expected func, got: %v", ft)
  39. }
  40. if ft.NumIn() != 2 {
  41. return fmt.Errorf("expected two 'in' params, got: %v", ft)
  42. }
  43. if ft.NumOut() != 1 {
  44. return fmt.Errorf("expected one 'out' param, got: %v", ft)
  45. }
  46. if ft.In(0) != ft.In(1) {
  47. return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
  48. }
  49. var forReturnType bool
  50. boolType := reflect.TypeOf(forReturnType)
  51. if ft.Out(0) != boolType {
  52. return fmt.Errorf("expected bool return, got: %v", ft)
  53. }
  54. e[ft.In(0)] = fv
  55. return nil
  56. }
  57. // Below here is forked from go's reflect/deepequal.go
  58. // During deepValueEqual, must keep track of checks that are
  59. // in progress. The comparison algorithm assumes that all
  60. // checks in progress are true when it reencounters them.
  61. // Visited comparisons are stored in a map indexed by visit.
  62. type visit struct {
  63. a1 uintptr
  64. a2 uintptr
  65. typ reflect.Type
  66. }
  67. // unexportedTypePanic is thrown when you use this DeepEqual on something that has an
  68. // unexported type. It indicates a programmer error, so should not occur at runtime,
  69. // which is why it's not public and thus impossible to catch.
  70. type unexportedTypePanic []reflect.Type
  71. func (u unexportedTypePanic) Error() string { return u.String() }
  72. func (u unexportedTypePanic) String() string {
  73. strs := make([]string, len(u))
  74. for i, t := range u {
  75. strs[i] = fmt.Sprintf("%v", t)
  76. }
  77. return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
  78. }
  79. func makeUsefulPanic(v reflect.Value) {
  80. if x := recover(); x != nil {
  81. if u, ok := x.(unexportedTypePanic); ok {
  82. u = append(unexportedTypePanic{v.Type()}, u...)
  83. x = u
  84. }
  85. panic(x)
  86. }
  87. }
  88. // Tests for deep equality using reflected types. The map argument tracks
  89. // comparisons that have already been seen, which allows short circuiting on
  90. // recursive types.
  91. func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
  92. defer makeUsefulPanic(v1)
  93. if !v1.IsValid() || !v2.IsValid() {
  94. return v1.IsValid() == v2.IsValid()
  95. }
  96. if v1.Type() != v2.Type() {
  97. return false
  98. }
  99. if fv, ok := e[v1.Type()]; ok {
  100. return fv.Call([]reflect.Value{v1, v2})[0].Bool()
  101. }
  102. hard := func(k reflect.Kind) bool {
  103. switch k {
  104. case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
  105. return true
  106. }
  107. return false
  108. }
  109. if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
  110. addr1 := v1.UnsafeAddr()
  111. addr2 := v2.UnsafeAddr()
  112. if addr1 > addr2 {
  113. // Canonicalize order to reduce number of entries in visited.
  114. addr1, addr2 = addr2, addr1
  115. }
  116. // Short circuit if references are identical ...
  117. if addr1 == addr2 {
  118. return true
  119. }
  120. // ... or already seen
  121. typ := v1.Type()
  122. v := visit{addr1, addr2, typ}
  123. if visited[v] {
  124. return true
  125. }
  126. // Remember for later.
  127. visited[v] = true
  128. }
  129. switch v1.Kind() {
  130. case reflect.Array:
  131. // We don't need to check length here because length is part of
  132. // an array's type, which has already been filtered for.
  133. for i := 0; i < v1.Len(); i++ {
  134. if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
  135. return false
  136. }
  137. }
  138. return true
  139. case reflect.Slice:
  140. if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
  141. return false
  142. }
  143. if v1.IsNil() || v1.Len() == 0 {
  144. return true
  145. }
  146. if v1.Len() != v2.Len() {
  147. return false
  148. }
  149. if v1.Pointer() == v2.Pointer() {
  150. return true
  151. }
  152. for i := 0; i < v1.Len(); i++ {
  153. if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
  154. return false
  155. }
  156. }
  157. return true
  158. case reflect.Interface:
  159. if v1.IsNil() || v2.IsNil() {
  160. return v1.IsNil() == v2.IsNil()
  161. }
  162. return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
  163. case reflect.Ptr:
  164. return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
  165. case reflect.Struct:
  166. for i, n := 0, v1.NumField(); i < n; i++ {
  167. if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
  168. return false
  169. }
  170. }
  171. return true
  172. case reflect.Map:
  173. if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
  174. return false
  175. }
  176. if v1.IsNil() || v1.Len() == 0 {
  177. return true
  178. }
  179. if v1.Len() != v2.Len() {
  180. return false
  181. }
  182. if v1.Pointer() == v2.Pointer() {
  183. return true
  184. }
  185. for _, k := range v1.MapKeys() {
  186. if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
  187. return false
  188. }
  189. }
  190. return true
  191. case reflect.Func:
  192. if v1.IsNil() && v2.IsNil() {
  193. return true
  194. }
  195. // Can't do better than this:
  196. return false
  197. default:
  198. // Normal equality suffices
  199. if !v1.CanInterface() || !v2.CanInterface() {
  200. panic(unexportedTypePanic{})
  201. }
  202. return v1.Interface() == v2.Interface()
  203. }
  204. }
  205. // DeepEqual is like reflect.DeepEqual, but focused on semantic equality
  206. // instead of memory equality.
  207. //
  208. // It will use e's equality functions if it finds types that match.
  209. //
  210. // An empty slice *is* equal to a nil slice for our purposes; same for maps.
  211. //
  212. // Unexported field members cannot be compared and will cause an informative panic; you must add an Equality
  213. // function for these types.
  214. func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
  215. if a1 == nil || a2 == nil {
  216. return a1 == a2
  217. }
  218. v1 := reflect.ValueOf(a1)
  219. v2 := reflect.ValueOf(a2)
  220. if v1.Type() != v2.Type() {
  221. return false
  222. }
  223. return e.deepValueEqual(v1, v2, make(map[visit]bool), 0)
  224. }
  225. func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
  226. defer makeUsefulPanic(v1)
  227. if !v1.IsValid() || !v2.IsValid() {
  228. return v1.IsValid() == v2.IsValid()
  229. }
  230. if v1.Type() != v2.Type() {
  231. return false
  232. }
  233. if fv, ok := e[v1.Type()]; ok {
  234. return fv.Call([]reflect.Value{v1, v2})[0].Bool()
  235. }
  236. hard := func(k reflect.Kind) bool {
  237. switch k {
  238. case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
  239. return true
  240. }
  241. return false
  242. }
  243. if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
  244. addr1 := v1.UnsafeAddr()
  245. addr2 := v2.UnsafeAddr()
  246. if addr1 > addr2 {
  247. // Canonicalize order to reduce number of entries in visited.
  248. addr1, addr2 = addr2, addr1
  249. }
  250. // Short circuit if references are identical ...
  251. if addr1 == addr2 {
  252. return true
  253. }
  254. // ... or already seen
  255. typ := v1.Type()
  256. v := visit{addr1, addr2, typ}
  257. if visited[v] {
  258. return true
  259. }
  260. // Remember for later.
  261. visited[v] = true
  262. }
  263. switch v1.Kind() {
  264. case reflect.Array:
  265. // We don't need to check length here because length is part of
  266. // an array's type, which has already been filtered for.
  267. for i := 0; i < v1.Len(); i++ {
  268. if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
  269. return false
  270. }
  271. }
  272. return true
  273. case reflect.Slice:
  274. if v1.IsNil() || v1.Len() == 0 {
  275. return true
  276. }
  277. if v1.Len() > v2.Len() {
  278. return false
  279. }
  280. if v1.Pointer() == v2.Pointer() {
  281. return true
  282. }
  283. for i := 0; i < v1.Len(); i++ {
  284. if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
  285. return false
  286. }
  287. }
  288. return true
  289. case reflect.String:
  290. if v1.Len() == 0 {
  291. return true
  292. }
  293. if v1.Len() > v2.Len() {
  294. return false
  295. }
  296. return v1.String() == v2.String()
  297. case reflect.Interface:
  298. if v1.IsNil() {
  299. return true
  300. }
  301. return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
  302. case reflect.Ptr:
  303. if v1.IsNil() {
  304. return true
  305. }
  306. return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
  307. case reflect.Struct:
  308. for i, n := 0, v1.NumField(); i < n; i++ {
  309. if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) {
  310. return false
  311. }
  312. }
  313. return true
  314. case reflect.Map:
  315. if v1.IsNil() || v1.Len() == 0 {
  316. return true
  317. }
  318. if v1.Len() > v2.Len() {
  319. return false
  320. }
  321. if v1.Pointer() == v2.Pointer() {
  322. return true
  323. }
  324. for _, k := range v1.MapKeys() {
  325. if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
  326. return false
  327. }
  328. }
  329. return true
  330. case reflect.Func:
  331. if v1.IsNil() && v2.IsNil() {
  332. return true
  333. }
  334. // Can't do better than this:
  335. return false
  336. default:
  337. // Normal equality suffices
  338. if !v1.CanInterface() || !v2.CanInterface() {
  339. panic(unexportedTypePanic{})
  340. }
  341. return v1.Interface() == v2.Interface()
  342. }
  343. }
  344. // DeepDerivative is similar to DeepEqual except that unset fields in a1 are
  345. // ignored (not compared). This allows us to focus on the fields that matter to
  346. // the semantic comparison.
  347. //
  348. // The unset fields include a nil pointer and an empty string.
  349. func (e Equalities) DeepDerivative(a1, a2 interface{}) bool {
  350. if a1 == nil {
  351. return true
  352. }
  353. v1 := reflect.ValueOf(a1)
  354. v2 := reflect.ValueOf(a2)
  355. if v1.Type() != v2.Type() {
  356. return false
  357. }
  358. return e.deepValueDerive(v1, v2, make(map[visit]bool), 0)
  359. }