msg_generate.go 9.6 KB


  1. //+build ignore
  2. // msg_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate pack/unpack methods based on the struct tags. The generated source is
  5. // written to zmsg.go, and is meant to be checked into git.
  6. package main
  7. import (
  8. "bytes"
  9. "fmt"
  10. "go/format"
  11. "go/importer"
  12. "go/types"
  13. "log"
  14. "os"
  15. "strings"
  16. )
  17. var packageHdr = `
  18. // *** DO NOT MODIFY ***
  19. // AUTOGENERATED BY go generate from msg_generate.go
  20. package dns
  21. `
  22. // getTypeStruct will take a type and the package scope, and return the
  23. // (innermost) struct if the type is considered a RR type (currently defined as
  24. // those structs beginning with a RR_Header, could be redefined as implementing
  25. // the RR interface). The bool return value indicates if embedded structs were
  26. // resolved.
  27. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  28. st, ok := t.Underlying().(*types.Struct)
  29. if !ok {
  30. return nil, false
  31. }
  32. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  33. return st, false
  34. }
  35. if st.Field(0).Anonymous() {
  36. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  37. return st, true
  38. }
  39. return nil, false
  40. }
  41. func main() {
  42. // Import and type-check the package
  43. pkg, err := importer.Default().Import("github.com/miekg/dns")
  44. fatalIfErr(err)
  45. scope := pkg.Scope()
  46. // Collect actual types (*X)
  47. var namedTypes []string
  48. for _, name := range scope.Names() {
  49. o := scope.Lookup(name)
  50. if o == nil || !o.Exported() {
  51. continue
  52. }
  53. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  54. continue
  55. }
  56. if name == "PrivateRR" {
  57. continue
  58. }
  59. // Check if corresponding TypeX exists
  60. if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
  61. log.Fatalf("Constant Type%s does not exist.", o.Name())
  62. }
  63. namedTypes = append(namedTypes, o.Name())
  64. }
  65. b := &bytes.Buffer{}
  66. b.WriteString(packageHdr)
  67. fmt.Fprint(b, "// pack*() functions\n\n")
  68. for _, name := range namedTypes {
  69. o := scope.Lookup(name)
  70. st, _ := getTypeStruct(o.Type(), scope)
  71. fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
  72. fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
  73. if err != nil {
  74. return off, err
  75. }
  76. headerEnd := off
  77. `)
  78. for i := 1; i < st.NumFields(); i++ {
  79. o := func(s string) {
  80. fmt.Fprintf(b, s, st.Field(i).Name())
  81. fmt.Fprint(b, `if err != nil {
  82. return off, err
  83. }
  84. `)
  85. }
  86. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  87. switch st.Tag(i) {
  88. case `dns:"-"`: // ignored
  89. case `dns:"txt"`:
  90. o("off, err = packStringTxt(rr.%s, msg, off)\n")
  91. case `dns:"opt"`:
  92. o("off, err = packDataOpt(rr.%s, msg, off)\n")
  93. case `dns:"nsec"`:
  94. o("off, err = packDataNsec(rr.%s, msg, off)\n")
  95. case `dns:"domain-name"`:
  96. o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
  97. default:
  98. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  99. }
  100. continue
  101. }
  102. switch {
  103. case st.Tag(i) == `dns:"-"`: // ignored
  104. case st.Tag(i) == `dns:"cdomain-name"`:
  105. fallthrough
  106. case st.Tag(i) == `dns:"domain-name"`:
  107. o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
  108. case st.Tag(i) == `dns:"a"`:
  109. o("off, err = packDataA(rr.%s, msg, off)\n")
  110. case st.Tag(i) == `dns:"aaaa"`:
  111. o("off, err = packDataAAAA(rr.%s, msg, off)\n")
  112. case st.Tag(i) == `dns:"uint48"`:
  113. o("off, err = packUint48(rr.%s, msg, off)\n")
  114. case st.Tag(i) == `dns:"txt"`:
  115. o("off, err = packString(rr.%s, msg, off)\n")
  116. case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
  117. fallthrough
  118. case st.Tag(i) == `dns:"base32"`:
  119. o("off, err = packStringBase32(rr.%s, msg, off)\n")
  120. case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
  121. fallthrough
  122. case st.Tag(i) == `dns:"base64"`:
  123. o("off, err = packStringBase64(rr.%s, msg, off)\n")
  124. case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
  125. fallthrough
  126. case st.Tag(i) == `dns:"hex"`:
  127. o("off, err = packStringHex(rr.%s, msg, off)\n")
  128. case st.Tag(i) == `dns:"octet"`:
  129. o("off, err = packStringOctet(rr.%s, msg, off)\n")
  130. case st.Tag(i) == "":
  131. switch st.Field(i).Type().(*types.Basic).Kind() {
  132. case types.Uint8:
  133. o("off, err = packUint8(rr.%s, msg, off)\n")
  134. case types.Uint16:
  135. o("off, err = packUint16(rr.%s, msg, off)\n")
  136. case types.Uint32:
  137. o("off, err = packUint32(rr.%s, msg, off)\n")
  138. case types.Uint64:
  139. o("off, err = packUint64(rr.%s, msg, off)\n")
  140. case types.String:
  141. o("off, err = packString(rr.%s, msg, off)\n")
  142. default:
  143. log.Fatalln(name, st.Field(i).Name())
  144. }
  145. default:
  146. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  147. }
  148. }
  149. // We have packed everything, only now we know the rdlength of this RR
  150. fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off- headerEnd)")
  151. fmt.Fprintln(b, "return off, nil }\n")
  152. }
  153. fmt.Fprint(b, "// unpack*() functions\n\n")
  154. for _, name := range namedTypes {
  155. o := scope.Lookup(name)
  156. st, _ := getTypeStruct(o.Type(), scope)
  157. fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
  158. fmt.Fprintf(b, "rr := new(%s)\n", name)
  159. fmt.Fprint(b, "rr.Hdr = h\n")
  160. fmt.Fprint(b, `if noRdata(h) {
  161. return rr, off, nil
  162. }
  163. var err error
  164. rdStart := off
  165. _ = rdStart
  166. `)
  167. for i := 1; i < st.NumFields(); i++ {
  168. o := func(s string) {
  169. fmt.Fprintf(b, s, st.Field(i).Name())
  170. fmt.Fprint(b, `if err != nil {
  171. return rr, off, err
  172. }
  173. `)
  174. }
  175. // size-* are special, because they reference a struct member we should use for the length.
  176. if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
  177. structMember := structMember(st.Tag(i))
  178. structTag := structTag(st.Tag(i))
  179. switch structTag {
  180. case "hex":
  181. fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  182. case "base32":
  183. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  184. case "base64":
  185. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  186. default:
  187. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  188. }
  189. fmt.Fprint(b, `if err != nil {
  190. return rr, off, err
  191. }
  192. `)
  193. continue
  194. }
  195. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  196. switch st.Tag(i) {
  197. case `dns:"-"`: // ignored
  198. case `dns:"txt"`:
  199. o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
  200. case `dns:"opt"`:
  201. o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
  202. case `dns:"nsec"`:
  203. o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
  204. case `dns:"domain-name"`:
  205. o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  206. default:
  207. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  208. }
  209. continue
  210. }
  211. switch st.Tag(i) {
  212. case `dns:"-"`: // ignored
  213. case `dns:"cdomain-name"`:
  214. fallthrough
  215. case `dns:"domain-name"`:
  216. o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
  217. case `dns:"a"`:
  218. o("rr.%s, off, err = unpackDataA(msg, off)\n")
  219. case `dns:"aaaa"`:
  220. o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
  221. case `dns:"uint48"`:
  222. o("rr.%s, off, err = unpackUint48(msg, off)\n")
  223. case `dns:"txt"`:
  224. o("rr.%s, off, err = unpackString(msg, off)\n")
  225. case `dns:"base32"`:
  226. o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  227. case `dns:"base64"`:
  228. o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  229. case `dns:"hex"`:
  230. o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  231. case `dns:"octet"`:
  232. o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
  233. case "":
  234. switch st.Field(i).Type().(*types.Basic).Kind() {
  235. case types.Uint8:
  236. o("rr.%s, off, err = unpackUint8(msg, off)\n")
  237. case types.Uint16:
  238. o("rr.%s, off, err = unpackUint16(msg, off)\n")
  239. case types.Uint32:
  240. o("rr.%s, off, err = unpackUint32(msg, off)\n")
  241. case types.Uint64:
  242. o("rr.%s, off, err = unpackUint64(msg, off)\n")
  243. case types.String:
  244. o("rr.%s, off, err = unpackString(msg, off)\n")
  245. default:
  246. log.Fatalln(name, st.Field(i).Name())
  247. }
  248. default:
  249. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  250. }
  251. // If we've hit len(msg) we return without error.
  252. if i < st.NumFields()-1 {
  253. fmt.Fprintf(b, `if off == len(msg) {
  254. return rr, off, nil
  255. }
  256. `)
  257. }
  258. }
  259. fmt.Fprintf(b, "return rr, off, err }\n\n")
  260. }
  261. // Generate typeToUnpack map
  262. fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
  263. for _, name := range namedTypes {
  264. if name == "RFC3597" {
  265. continue
  266. }
  267. fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
  268. }
  269. fmt.Fprintln(b, "}\n")
  270. // gofmt
  271. res, err := format.Source(b.Bytes())
  272. if err != nil {
  273. b.WriteTo(os.Stderr)
  274. log.Fatal(err)
  275. }
  276. // write result
  277. f, err := os.Create("zmsg.go")
  278. fatalIfErr(err)
  279. defer f.Close()
  280. f.Write(res)
  281. }
  282. // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
  283. func structMember(s string) string {
  284. fields := strings.Split(s, ":")
  285. if len(fields) == 0 {
  286. return ""
  287. }
  288. f := fields[len(fields)-1]
  289. // f should have a closing "
  290. if len(f) > 1 {
  291. return f[:len(f)-1]
  292. }
  293. return f
  294. }
  295. // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
  296. func structTag(s string) string {
  297. fields := strings.Split(s, ":")
  298. if len(fields) < 2 {
  299. return ""
  300. }
  301. return fields[1][len("\"size-"):]
  302. }
  303. func fatalIfErr(err error) {
  304. if err != nil {
  305. log.Fatal(err)
  306. }
  307. }