union.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /*
  2. Copyright 2016 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package generators
  14. import (
  15. "fmt"
  16. "sort"
  17. "k8s.io/gengo/types"
  18. )
  19. const tagUnionMember = "union"
  20. const tagUnionDeprecated = "unionDeprecated"
  21. const tagUnionDiscriminator = "unionDiscriminator"
  22. type union struct {
  23. discriminator string
  24. fieldsToDiscriminated map[string]string
  25. }
  26. // emit prints the union, can be called on a nil union (emits nothing)
  27. func (u *union) emit(g openAPITypeWriter) {
  28. if u == nil {
  29. return
  30. }
  31. g.Do("map[string]interface{}{\n", nil)
  32. if u.discriminator != "" {
  33. g.Do("\"discriminator\": \"$.$\",\n", u.discriminator)
  34. }
  35. g.Do("\"fields-to-discriminateBy\": map[string]interface{}{\n", nil)
  36. keys := []string{}
  37. for field := range u.fieldsToDiscriminated {
  38. keys = append(keys, field)
  39. }
  40. sort.Strings(keys)
  41. for _, field := range keys {
  42. g.Do("\"$.$\": ", field)
  43. g.Do("\"$.$\",\n", u.fieldsToDiscriminated[field])
  44. }
  45. g.Do("},\n", nil)
  46. g.Do("},\n", nil)
  47. }
  48. // Sets the discriminator if it's not set yet, otherwise return an error
  49. func (u *union) setDiscriminator(value string) []error {
  50. errors := []error{}
  51. if u.discriminator != "" {
  52. errors = append(errors, fmt.Errorf("at least two discriminators found: %v and %v", value, u.discriminator))
  53. }
  54. u.discriminator = value
  55. return errors
  56. }
  57. // Add a new member to the union
  58. func (u *union) addMember(jsonName, variableName string) {
  59. if _, ok := u.fieldsToDiscriminated[jsonName]; ok {
  60. panic(fmt.Errorf("same field (%v) found multiple times", jsonName))
  61. }
  62. u.fieldsToDiscriminated[jsonName] = variableName
  63. }
  64. // Makes sure that the union is valid, specifically looking for re-used discriminated
  65. func (u *union) isValid() []error {
  66. errors := []error{}
  67. // Case 1: discriminator but no fields
  68. if u.discriminator != "" && len(u.fieldsToDiscriminated) == 0 {
  69. errors = append(errors, fmt.Errorf("discriminator set with no fields in union"))
  70. }
  71. // Case 2: two fields have the same discriminated value
  72. discriminated := map[string]struct{}{}
  73. for _, d := range u.fieldsToDiscriminated {
  74. if _, ok := discriminated[d]; ok {
  75. errors = append(errors, fmt.Errorf("discriminated value is used twice: %v", d))
  76. }
  77. discriminated[d] = struct{}{}
  78. }
  79. // Case 3: a field is both discriminator AND part of the union
  80. if u.discriminator != "" {
  81. if _, ok := u.fieldsToDiscriminated[u.discriminator]; ok {
  82. errors = append(errors, fmt.Errorf("%v can't be both discriminator and part of the union", u.discriminator))
  83. }
  84. }
  85. return errors
  86. }
  87. // Find unions either directly on the members (or inlined members, not
  88. // going across types) or on the type itself, or on embedded types.
  89. func parseUnions(t *types.Type) ([]union, []error) {
  90. errors := []error{}
  91. unions := []union{}
  92. su, err := parseUnionStruct(t)
  93. if su != nil {
  94. unions = append(unions, *su)
  95. }
  96. errors = append(errors, err...)
  97. eu, err := parseEmbeddedUnion(t)
  98. unions = append(unions, eu...)
  99. errors = append(errors, err...)
  100. mu, err := parseUnionMembers(t)
  101. if mu != nil {
  102. unions = append(unions, *mu)
  103. }
  104. errors = append(errors, err...)
  105. return unions, errors
  106. }
  107. // Find unions in embedded types, unions shouldn't go across types.
  108. func parseEmbeddedUnion(t *types.Type) ([]union, []error) {
  109. errors := []error{}
  110. unions := []union{}
  111. for _, m := range t.Members {
  112. if hasOpenAPITagValue(m.CommentLines, tagValueFalse) {
  113. continue
  114. }
  115. if !shouldInlineMembers(&m) {
  116. continue
  117. }
  118. u, err := parseUnions(m.Type)
  119. unions = append(unions, u...)
  120. errors = append(errors, err...)
  121. }
  122. return unions, errors
  123. }
  124. // Look for union tag on a struct, and then include all the fields
  125. // (except the discriminator if there is one). The struct shouldn't have
  126. // embedded types.
  127. func parseUnionStruct(t *types.Type) (*union, []error) {
  128. errors := []error{}
  129. if types.ExtractCommentTags("+", t.CommentLines)[tagUnionMember] == nil {
  130. return nil, nil
  131. }
  132. u := &union{fieldsToDiscriminated: map[string]string{}}
  133. for _, m := range t.Members {
  134. jsonName := getReferableName(&m)
  135. if jsonName == "" {
  136. continue
  137. }
  138. if shouldInlineMembers(&m) {
  139. errors = append(errors, fmt.Errorf("union structures can't have embedded fields: %v.%v", t.Name, m.Name))
  140. continue
  141. }
  142. if types.ExtractCommentTags("+", m.CommentLines)[tagUnionDeprecated] != nil {
  143. errors = append(errors, fmt.Errorf("union struct can't have unionDeprecated members: %v.%v", t.Name, m.Name))
  144. continue
  145. }
  146. if types.ExtractCommentTags("+", m.CommentLines)[tagUnionDiscriminator] != nil {
  147. errors = append(errors, u.setDiscriminator(jsonName)...)
  148. } else {
  149. if !hasOptionalTag(&m) {
  150. errors = append(errors, fmt.Errorf("union members must be optional: %v.%v", t.Name, m.Name))
  151. }
  152. u.addMember(jsonName, m.Name)
  153. }
  154. }
  155. return u, errors
  156. }
  157. // Find unions specifically on members.
  158. func parseUnionMembers(t *types.Type) (*union, []error) {
  159. errors := []error{}
  160. u := &union{fieldsToDiscriminated: map[string]string{}}
  161. for _, m := range t.Members {
  162. jsonName := getReferableName(&m)
  163. if jsonName == "" {
  164. continue
  165. }
  166. if shouldInlineMembers(&m) {
  167. continue
  168. }
  169. if types.ExtractCommentTags("+", m.CommentLines)[tagUnionDiscriminator] != nil {
  170. errors = append(errors, u.setDiscriminator(jsonName)...)
  171. }
  172. if types.ExtractCommentTags("+", m.CommentLines)[tagUnionMember] != nil {
  173. errors = append(errors, fmt.Errorf("union tag is not accepted on struct members: %v.%v", t.Name, m.Name))
  174. continue
  175. }
  176. if types.ExtractCommentTags("+", m.CommentLines)[tagUnionDeprecated] != nil {
  177. if !hasOptionalTag(&m) {
  178. errors = append(errors, fmt.Errorf("union members must be optional: %v.%v", t.Name, m.Name))
  179. }
  180. u.addMember(jsonName, m.Name)
  181. }
  182. }
  183. if len(u.fieldsToDiscriminated) == 0 {
  184. return nil, nil
  185. }
  186. return u, append(errors, u.isValid()...)
  187. }