integration_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. // +build integration
  2. /*
  3. Copyright 2018 The Kubernetes Authors.
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. */
  14. package main
  15. import (
  16. "bytes"
  17. cryptorand "crypto/rand"
  18. "crypto/rsa"
  19. "crypto/x509"
  20. "crypto/x509/pkix"
  21. "encoding/pem"
  22. "flag"
  23. "fmt"
  24. "io/ioutil"
  25. "math/big"
  26. "net"
  27. "os"
  28. "path/filepath"
  29. "strings"
  30. "sync"
  31. "testing"
  32. "time"
  33. "github.com/blang/semver"
  34. "k8s.io/klog"
  35. )
  36. var (
  37. testSupportedVersions = MustParseSupportedVersions("3.0.17, 3.1.12")
  38. testVersionPrevious = &EtcdVersion{semver.MustParse("3.0.17")}
  39. testVersionLatest = &EtcdVersion{semver.MustParse("3.1.12")}
  40. )
  41. func init() {
  42. // Enable klog which is used in dependencies
  43. klog.InitFlags(nil)
  44. flag.Set("logtostderr", "true")
  45. flag.Set("v", "9")
  46. }
  47. func TestMigrate(t *testing.T) {
  48. migrations := []struct {
  49. title string
  50. memberCount int
  51. startVersion string
  52. endVersion string
  53. protocol string
  54. }{
  55. // upgrades
  56. {"v3-v3-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https"},
  57. {"oldest-newest-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https"},
  58. // warning: v2->v3 ha upgrades not currently supported.
  59. {"ha-v3-v3-up", 3, "3.0.17/etcd3", "3.1.12/etcd3", "https"},
  60. // downgrades
  61. {"v3-v3-down", 1, "3.1.12/etcd3", "3.0.17/etcd3", "https"},
  62. // warning: ha downgrades not yet supported.
  63. }
  64. for _, m := range migrations {
  65. t.Run(m.title, func(t *testing.T) {
  66. start := MustParseEtcdVersionPair(m.startVersion)
  67. end := MustParseEtcdVersionPair(m.endVersion)
  68. testCfgs := clusterConfig(t, m.title, m.memberCount, m.protocol)
  69. servers := []*EtcdMigrateServer{}
  70. for _, cfg := range testCfgs {
  71. client, err := NewEtcdMigrateClient(cfg)
  72. if err != nil {
  73. t.Fatalf("Failed to create client: %v", err)
  74. }
  75. server := NewEtcdMigrateServer(cfg, client)
  76. servers = append(servers, server)
  77. }
  78. // Start the servers.
  79. parallel(servers, func(server *EtcdMigrateServer) {
  80. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  81. if err != nil {
  82. t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
  83. }
  84. migrator := &Migrator{server.cfg, dataDir, server.client}
  85. err = migrator.MigrateIfNeeded(start)
  86. if err != nil {
  87. t.Fatalf("Migration failed: %v", err)
  88. }
  89. err = server.Start(start.version)
  90. if err != nil {
  91. t.Fatalf("Failed to start server: %v", err)
  92. }
  93. })
  94. // Write a value to each server, read it back.
  95. parallel(servers, func(server *EtcdMigrateServer) {
  96. key := fmt.Sprintf("/registry/%s", server.cfg.name)
  97. value := fmt.Sprintf("value-%s", server.cfg.name)
  98. err := server.client.Put(start.version, key, value)
  99. if err != nil {
  100. t.Fatalf("failed to write text value: %v", err)
  101. }
  102. checkVal, err := server.client.Get(start.version, key)
  103. if err != nil {
  104. t.Errorf("Error getting %s for validation: %v", key, err)
  105. }
  106. if checkVal != value {
  107. t.Errorf("Expected %s from %s but got %s", value, key, checkVal)
  108. }
  109. })
  110. // Migrate the servers in series.
  111. serial(servers, func(server *EtcdMigrateServer) {
  112. err := server.Stop()
  113. if err != nil {
  114. t.Fatalf("Stop server failed: %v", err)
  115. }
  116. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  117. if err != nil {
  118. t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
  119. }
  120. migrator := &Migrator{server.cfg, dataDir, server.client}
  121. err = migrator.MigrateIfNeeded(end)
  122. if err != nil {
  123. t.Fatalf("Migration failed: %v", err)
  124. }
  125. err = server.Start(end.version)
  126. if err != nil {
  127. t.Fatalf("Start server failed: %v", err)
  128. }
  129. })
  130. // Check that all test values can be read back from all the servers.
  131. parallel(servers, func(server *EtcdMigrateServer) {
  132. for _, s := range servers {
  133. key := fmt.Sprintf("/registry/%s", s.cfg.name)
  134. value := fmt.Sprintf("value-%s", s.cfg.name)
  135. checkVal, err := server.client.Get(end.version, key)
  136. if err != nil {
  137. t.Errorf("Error getting %s from etcd 2.x after rollback from 3.x: %v", key, err)
  138. }
  139. if checkVal != value {
  140. t.Errorf("Expected %s from %s but got %s when reading after rollback from %s to %s", value, key, checkVal, start, end)
  141. }
  142. }
  143. })
  144. // Stop the servers.
  145. parallel(servers, func(server *EtcdMigrateServer) {
  146. err := server.Stop()
  147. if err != nil {
  148. t.Fatalf("Failed to stop server: %v", err)
  149. }
  150. })
  151. // Check that version.txt contains the correct end version.
  152. parallel(servers, func(server *EtcdMigrateServer) {
  153. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  154. v, err := dataDir.versionFile.Read()
  155. if err != nil {
  156. t.Fatalf("Failed to read version.txt file: %v", err)
  157. }
  158. if !v.Equals(end) {
  159. t.Errorf("Expected version.txt to contain %s but got %s", end, v)
  160. }
  161. // Integration tests are run in a docker container with umask of 0022.
  162. checkPermissions(t, server.cfg.dataDirectory, 0755|os.ModeDir)
  163. checkPermissions(t, dataDir.versionFile.path, 0644)
  164. })
  165. })
  166. }
  167. }
  168. func parallel(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
  169. var wg sync.WaitGroup
  170. wg.Add(len(servers))
  171. for _, server := range servers {
  172. go func(s *EtcdMigrateServer) {
  173. defer wg.Done()
  174. fn(s)
  175. }(server)
  176. }
  177. wg.Wait()
  178. }
  179. func serial(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
  180. for _, server := range servers {
  181. fn(server)
  182. }
  183. }
  184. func checkPermissions(t *testing.T, path string, expected os.FileMode) {
  185. info, err := os.Stat(path)
  186. if err != nil {
  187. t.Fatalf("Failed to stat file %s: %v", path, err)
  188. }
  189. if info.Mode() != expected {
  190. t.Errorf("Expected permissions for file %s of %s, but got %s", path, expected, info.Mode())
  191. }
  192. }
  193. func clusterConfig(t *testing.T, name string, memberCount int, protocol string) []*EtcdMigrateCfg {
  194. peers := []string{}
  195. for i := 0; i < memberCount; i++ {
  196. memberName := fmt.Sprintf("%s-%d", name, i)
  197. peerPort := uint64(2380 + i*10000)
  198. peer := fmt.Sprintf("%s=%s://127.0.0.1:%d", memberName, protocol, peerPort)
  199. peers = append(peers, peer)
  200. }
  201. initialCluster := strings.Join(peers, ",")
  202. extraArgs := ""
  203. if protocol == "https" {
  204. extraArgs = getOrCreateTLSPeerCertArgs(t)
  205. }
  206. cfgs := []*EtcdMigrateCfg{}
  207. for i := 0; i < memberCount; i++ {
  208. memberName := fmt.Sprintf("%s-%d", name, i)
  209. peerURL := fmt.Sprintf("%s://127.0.0.1:%d", protocol, uint64(2380+i*10000))
  210. cfg := &EtcdMigrateCfg{
  211. binPath: "/usr/local/bin",
  212. name: memberName,
  213. initialCluster: initialCluster,
  214. port: uint64(2379 + i*10000),
  215. peerListenUrls: peerURL,
  216. peerAdvertiseUrls: peerURL,
  217. etcdDataPrefix: "/registry",
  218. ttlKeysDirectory: "/registry/events",
  219. supportedVersions: testSupportedVersions,
  220. dataDirectory: fmt.Sprintf("/tmp/etcd-data-dir-%s", memberName),
  221. etcdServerArgs: extraArgs,
  222. }
  223. cfgs = append(cfgs, cfg)
  224. }
  225. return cfgs
  226. }
  227. func getOrCreateTLSPeerCertArgs(t *testing.T) string {
  228. spec := TestCertSpec{
  229. host: "localhost",
  230. ips: []string{"127.0.0.1"},
  231. }
  232. certDir := "/tmp/certs"
  233. certFile := filepath.Join(certDir, "test.crt")
  234. keyFile := filepath.Join(certDir, "test.key")
  235. err := getOrCreateTestCertFiles(certFile, keyFile, spec)
  236. if err != nil {
  237. t.Fatalf("failed to create server cert: %v", err)
  238. }
  239. return fmt.Sprintf("--peer-client-cert-auth --peer-trusted-ca-file=%s --peer-cert-file=%s --peer-key-file=%s", certFile, certFile, keyFile)
  240. }
  241. type TestCertSpec struct {
  242. host string
  243. names, ips []string // in certificate
  244. }
  245. func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) {
  246. if _, err := os.Stat(certFileName); err == nil {
  247. if _, err := os.Stat(keyFileName); err == nil {
  248. return nil
  249. }
  250. }
  251. certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
  252. if err != nil {
  253. return err
  254. }
  255. os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0777))
  256. err = ioutil.WriteFile(certFileName, certPem, os.FileMode(0777))
  257. if err != nil {
  258. return err
  259. }
  260. os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0777))
  261. err = ioutil.WriteFile(keyFileName, keyPem, os.FileMode(0777))
  262. if err != nil {
  263. return err
  264. }
  265. return nil
  266. }
  267. func parseIPList(ips []string) []net.IP {
  268. var netIPs []net.IP
  269. for _, ip := range ips {
  270. netIPs = append(netIPs, net.ParseIP(ip))
  271. }
  272. return netIPs
  273. }
  274. // generateSelfSignedCertKey creates a self-signed certificate and key for the given host.
  275. // Host may be an IP or a DNS name
  276. // You may also specify additional subject alt names (either ip or dns names) for the certificate
  277. func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
  278. priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
  279. if err != nil {
  280. return nil, nil, err
  281. }
  282. template := x509.Certificate{
  283. SerialNumber: big.NewInt(1),
  284. Subject: pkix.Name{
  285. CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
  286. },
  287. NotBefore: time.Unix(0, 0),
  288. NotAfter: time.Now().Add(time.Hour * 24 * 365 * 100),
  289. KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
  290. ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
  291. BasicConstraintsValid: true,
  292. IsCA: true,
  293. }
  294. if ip := net.ParseIP(host); ip != nil {
  295. template.IPAddresses = append(template.IPAddresses, ip)
  296. } else {
  297. template.DNSNames = append(template.DNSNames, host)
  298. }
  299. template.IPAddresses = append(template.IPAddresses, alternateIPs...)
  300. template.DNSNames = append(template.DNSNames, alternateDNS...)
  301. derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
  302. if err != nil {
  303. return nil, nil, err
  304. }
  305. // Generate cert
  306. certBuffer := bytes.Buffer{}
  307. if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
  308. return nil, nil, err
  309. }
  310. // Generate key
  311. keyBuffer := bytes.Buffer{}
  312. if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
  313. return nil, nil, err
  314. }
  315. return certBuffer.Bytes(), keyBuffer.Bytes(), nil
  316. }