preferredimports.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /*
  2. Copyright 2019 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. // verify that all the imports have our preferred alias(es).
  14. package main
  15. import (
  16. "bytes"
  17. "encoding/json"
  18. "flag"
  19. "fmt"
  20. "go/ast"
  21. "go/build"
  22. "go/format"
  23. "go/parser"
  24. "go/token"
  25. "io/ioutil"
  26. "log"
  27. "os"
  28. "path/filepath"
  29. "regexp"
  30. "sort"
  31. "strings"
  32. "golang.org/x/crypto/ssh/terminal"
  33. )
  34. var (
  35. importAliases = flag.String("import-aliases", "hack/.import-aliases", "json file with import aliases")
  36. confirm = flag.Bool("confirm", false, "update file with the preferred aliases for imports")
  37. regex = flag.String("include-path", "(test/e2e/|test/e2e_node)", "only files with paths matching this regex is touched")
  38. isTerminal = terminal.IsTerminal(int(os.Stdout.Fd()))
  39. logPrefix = ""
  40. aliases map[string]string
  41. )
  42. type analyzer struct {
  43. fset *token.FileSet // positions are relative to fset
  44. ctx build.Context
  45. failed bool
  46. donePaths map[string]interface{}
  47. }
  48. func newAnalyzer() *analyzer {
  49. ctx := build.Default
  50. ctx.CgoEnabled = true
  51. a := &analyzer{
  52. fset: token.NewFileSet(),
  53. ctx: ctx,
  54. donePaths: make(map[string]interface{}),
  55. }
  56. return a
  57. }
  58. // collect extracts test metadata from a file.
  59. func (a *analyzer) collect(dir string) {
  60. if _, ok := a.donePaths[dir]; ok {
  61. return
  62. }
  63. a.donePaths[dir] = nil
  64. // Create the AST by parsing src.
  65. fs, err := parser.ParseDir(a.fset, dir, nil, parser.AllErrors|parser.ParseComments)
  66. if err != nil {
  67. fmt.Fprintln(os.Stderr, "ERROR(syntax)", logPrefix, err)
  68. a.failed = true
  69. return
  70. }
  71. for _, p := range fs {
  72. // returns first error, but a.handleError deals with it
  73. files := a.filterFiles(p.Files)
  74. for _, file := range files {
  75. replacements := make(map[string]string)
  76. pathToFile := a.fset.File(file.Pos()).Name()
  77. for _, imp := range file.Imports {
  78. importPath := strings.Replace(imp.Path.Value, "\"", "", -1)
  79. pathSegments := strings.Split(importPath, "/")
  80. importName := pathSegments[len(pathSegments)-1]
  81. if imp.Name != nil {
  82. importName = imp.Name.Name
  83. }
  84. if alias, ok := aliases[importPath]; ok {
  85. if alias != importName {
  86. if !*confirm {
  87. fmt.Fprintf(os.Stderr, "%sERROR wrong alias for import \"%s\" should be %s in file %s\n", logPrefix, importPath, alias, pathToFile)
  88. a.failed = true
  89. }
  90. replacements[importName] = alias
  91. if imp.Name != nil {
  92. imp.Name.Name = alias
  93. } else {
  94. imp.Name = ast.NewIdent(alias)
  95. }
  96. }
  97. }
  98. }
  99. if len(replacements) > 0 {
  100. if *confirm {
  101. fmt.Printf("%sReplacing imports with aliases in file %s\n", logPrefix, pathToFile)
  102. for key, value := range replacements {
  103. renameImportUsages(file, key, value)
  104. }
  105. ast.SortImports(a.fset, file)
  106. var buffer bytes.Buffer
  107. if err = format.Node(&buffer, a.fset, file); err != nil {
  108. panic(fmt.Sprintf("Error formatting ast node after rewriting import.\n%s\n", err.Error()))
  109. }
  110. fileInfo, err := os.Stat(pathToFile)
  111. if err != nil {
  112. panic(fmt.Sprintf("Error stat'ing file: %s\n%s\n", pathToFile, err.Error()))
  113. }
  114. err = ioutil.WriteFile(pathToFile, buffer.Bytes(), fileInfo.Mode())
  115. if err != nil {
  116. panic(fmt.Sprintf("Error writing file: %s\n%s\n", pathToFile, err.Error()))
  117. }
  118. }
  119. }
  120. }
  121. }
  122. }
  123. func renameImportUsages(f *ast.File, old, new string) {
  124. // use this to avoid renaming the package declaration, eg:
  125. // given: package foo; import foo "bar"; foo.Baz, rename foo->qux
  126. // yield: package foo; import qux "bar"; qux.Baz
  127. var pkg *ast.Ident
  128. // Rename top-level old to new, both unresolved names
  129. // (probably defined in another file) and names that resolve
  130. // to a declaration we renamed.
  131. ast.Inspect(f, func(node ast.Node) bool {
  132. if node == nil {
  133. return false
  134. }
  135. switch id := node.(type) {
  136. case *ast.File:
  137. pkg = id.Name
  138. case *ast.Ident:
  139. if pkg != nil && id == pkg {
  140. return false
  141. }
  142. if id.Name == old {
  143. id.Name = new
  144. }
  145. }
  146. return true
  147. })
  148. }
  149. func (a *analyzer) filterFiles(fs map[string]*ast.File) []*ast.File {
  150. var files []*ast.File
  151. for _, f := range fs {
  152. files = append(files, f)
  153. }
  154. return files
  155. }
  156. type collector struct {
  157. dirs []string
  158. regex *regexp.Regexp
  159. }
  160. // handlePath walks the filesystem recursively, collecting directories,
  161. // ignoring some unneeded directories (hidden/vendored) that are handled
  162. // specially later.
  163. func (c *collector) handlePath(path string, info os.FileInfo, err error) error {
  164. if err != nil {
  165. return err
  166. }
  167. if info.IsDir() {
  168. // Ignore hidden directories (.git, .cache, etc)
  169. if len(path) > 1 && path[0] == '.' ||
  170. // Staging code is symlinked from vendor/k8s.io, and uses import
  171. // paths as if it were inside of vendor/. It fails typechecking
  172. // inside of staging/, but works when typechecked as part of vendor/.
  173. path == "staging" ||
  174. // OS-specific vendor code tends to be imported by OS-specific
  175. // packages. We recursively typecheck imported vendored packages for
  176. // each OS, but don't typecheck everything for every OS.
  177. path == "vendor" ||
  178. path == "_output" ||
  179. // This is a weird one. /testdata/ is *mostly* ignored by Go,
  180. // and this translates to kubernetes/vendor not working.
  181. // edit/record.go doesn't compile without gopkg.in/yaml.v2
  182. // in $GOSRC/$GOROOT (both typecheck and the shell script).
  183. path == "pkg/kubectl/cmd/testdata/edit" {
  184. return filepath.SkipDir
  185. }
  186. if c.regex.MatchString(path) {
  187. c.dirs = append(c.dirs, path)
  188. }
  189. }
  190. return nil
  191. }
  192. func main() {
  193. flag.Parse()
  194. args := flag.Args()
  195. if len(args) == 0 {
  196. args = append(args, ".")
  197. }
  198. regex, err := regexp.Compile(*regex)
  199. if err != nil {
  200. log.Fatalf("Error compiling regex: %v", err)
  201. }
  202. c := collector{regex: regex}
  203. for _, arg := range args {
  204. err := filepath.Walk(arg, c.handlePath)
  205. if err != nil {
  206. log.Fatalf("Error walking: %v", err)
  207. }
  208. }
  209. sort.Strings(c.dirs)
  210. if len(*importAliases) > 0 {
  211. bytes, err := ioutil.ReadFile(*importAliases)
  212. if err != nil {
  213. log.Fatalf("Error reading import aliases: %v", err)
  214. }
  215. err = json.Unmarshal(bytes, &aliases)
  216. if err != nil {
  217. log.Fatalf("Error loading aliases: %v", err)
  218. }
  219. }
  220. if isTerminal {
  221. logPrefix = "\r" // clear status bar when printing
  222. }
  223. fmt.Println("checking-imports: ")
  224. a := newAnalyzer()
  225. for _, dir := range c.dirs {
  226. if isTerminal {
  227. fmt.Printf("\r\033[0m %-80s", dir)
  228. }
  229. a.collect(dir)
  230. }
  231. fmt.Println()
  232. if a.failed {
  233. os.Exit(1)
  234. }
  235. }