integration_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. // +build integration
  2. /*
  3. Copyright 2018 The Kubernetes Authors.
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. */
  14. package main
  15. import (
  16. "bytes"
  17. cryptorand "crypto/rand"
  18. "crypto/rsa"
  19. "crypto/x509"
  20. "crypto/x509/pkix"
  21. "encoding/pem"
  22. "fmt"
  23. "io/ioutil"
  24. "math/big"
  25. "net"
  26. "os"
  27. "path/filepath"
  28. "strings"
  29. "sync"
  30. "testing"
  31. "time"
  32. "github.com/blang/semver"
  33. )
  34. var (
  35. testSupportedVersions = MustParseSupportedVersions("2.2.1, 2.3.7, 3.0.17, 3.1.12")
  36. testVersionOldest = &EtcdVersion{semver.MustParse("2.2.1")}
  37. testVersionPrevious = &EtcdVersion{semver.MustParse("3.0.17")}
  38. testVersionLatest = &EtcdVersion{semver.MustParse("3.1.12")}
  39. )
  40. func TestMigrate(t *testing.T) {
  41. migrations := []struct {
  42. title string
  43. memberCount int
  44. startVersion string
  45. endVersion string
  46. protocol string
  47. }{
  48. // upgrades
  49. {"v2-v3-up", 1, "2.2.1/etcd2", "3.0.17/etcd3", "https"},
  50. {"v3-v3-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https"},
  51. {"oldest-newest-up", 1, "2.2.1/etcd2", "3.1.12/etcd3", "https"},
  52. // warning: v2->v3 ha upgrades not currently supported.
  53. {"ha-v3-v3-up", 3, "3.0.17/etcd3", "3.1.12/etcd3", "https"},
  54. // downgrades
  55. {"v3-v2-down", 1, "3.0.17/etcd3", "2.2.1/etcd2", "https"},
  56. {"v3-v3-down", 1, "3.1.12/etcd3", "3.0.17/etcd3", "https"},
  57. // warning: ha downgrades not yet supported.
  58. }
  59. for _, m := range migrations {
  60. t.Run(m.title, func(t *testing.T) {
  61. start := MustParseEtcdVersionPair(m.startVersion)
  62. end := MustParseEtcdVersionPair(m.endVersion)
  63. testCfgs := clusterConfig(t, m.title, m.memberCount, m.protocol)
  64. servers := []*EtcdMigrateServer{}
  65. for _, cfg := range testCfgs {
  66. client, err := NewEtcdMigrateClient(cfg)
  67. if err != nil {
  68. t.Fatalf("Failed to create client: %v", err)
  69. }
  70. server := NewEtcdMigrateServer(cfg, client)
  71. servers = append(servers, server)
  72. }
  73. // Start the servers.
  74. parallel(servers, func(server *EtcdMigrateServer) {
  75. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  76. if err != nil {
  77. t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
  78. }
  79. migrator := &Migrator{server.cfg, dataDir, server.client}
  80. err = migrator.MigrateIfNeeded(start)
  81. if err != nil {
  82. t.Fatalf("Migration failed: %v", err)
  83. }
  84. err = server.Start(start.version)
  85. if err != nil {
  86. t.Fatalf("Failed to start server: %v", err)
  87. }
  88. })
  89. // Write a value to each server, read it back.
  90. parallel(servers, func(server *EtcdMigrateServer) {
  91. key := fmt.Sprintf("/registry/%s", server.cfg.name)
  92. value := fmt.Sprintf("value-%s", server.cfg.name)
  93. err := server.client.Put(start.version, key, value)
  94. if err != nil {
  95. t.Fatalf("failed to write text value: %v", err)
  96. }
  97. checkVal, err := server.client.Get(start.version, key)
  98. if err != nil {
  99. t.Errorf("Error getting %s for validation: %v", key, err)
  100. }
  101. if checkVal != value {
  102. t.Errorf("Expected %s from %s but got %s", value, key, checkVal)
  103. }
  104. })
  105. // Migrate the servers in series.
  106. serial(servers, func(server *EtcdMigrateServer) {
  107. err := server.Stop()
  108. if err != nil {
  109. t.Fatalf("Stop server failed: %v", err)
  110. }
  111. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  112. if err != nil {
  113. t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
  114. }
  115. migrator := &Migrator{server.cfg, dataDir, server.client}
  116. err = migrator.MigrateIfNeeded(end)
  117. if err != nil {
  118. t.Fatalf("Migration failed: %v", err)
  119. }
  120. err = server.Start(end.version)
  121. if err != nil {
  122. t.Fatalf("Start server failed: %v", err)
  123. }
  124. })
  125. // Check that all test values can be read back from all the servers.
  126. parallel(servers, func(server *EtcdMigrateServer) {
  127. for _, s := range servers {
  128. key := fmt.Sprintf("/registry/%s", s.cfg.name)
  129. value := fmt.Sprintf("value-%s", s.cfg.name)
  130. checkVal, err := server.client.Get(end.version, key)
  131. if err != nil {
  132. t.Errorf("Error getting %s from etcd 2.x after rollback from 3.x: %v", key, err)
  133. }
  134. if checkVal != value {
  135. t.Errorf("Expected %s from %s but got %s when reading after rollback from %s to %s", value, key, checkVal, start, end)
  136. }
  137. }
  138. })
  139. // Stop the servers.
  140. parallel(servers, func(server *EtcdMigrateServer) {
  141. err := server.Stop()
  142. if err != nil {
  143. t.Fatalf("Failed to stop server: %v", err)
  144. }
  145. })
  146. // Check that version.txt contains the correct end version.
  147. parallel(servers, func(server *EtcdMigrateServer) {
  148. dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
  149. v, err := dataDir.versionFile.Read()
  150. if err != nil {
  151. t.Fatalf("Failed to read version.txt file: %v", err)
  152. }
  153. if !v.Equals(end) {
  154. t.Errorf("Expected version.txt to contain %s but got %s", end, v)
  155. }
  156. // Integration tests are run in a docker container with umask of 0022.
  157. checkPermissions(t, server.cfg.dataDirectory, 0755|os.ModeDir)
  158. checkPermissions(t, dataDir.versionFile.path, 0644)
  159. })
  160. })
  161. }
  162. }
  163. func parallel(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
  164. var wg sync.WaitGroup
  165. wg.Add(len(servers))
  166. for _, server := range servers {
  167. go func(s *EtcdMigrateServer) {
  168. defer wg.Done()
  169. fn(s)
  170. }(server)
  171. }
  172. wg.Wait()
  173. }
  174. func serial(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
  175. for _, server := range servers {
  176. fn(server)
  177. }
  178. }
  179. func checkPermissions(t *testing.T, path string, expected os.FileMode) {
  180. info, err := os.Stat(path)
  181. if err != nil {
  182. t.Fatalf("Failed to stat file %s: %v", path, err)
  183. }
  184. if info.Mode() != expected {
  185. t.Errorf("Expected permissions for file %s of %s, but got %s", path, expected, info.Mode())
  186. }
  187. }
  188. func clusterConfig(t *testing.T, name string, memberCount int, protocol string) []*EtcdMigrateCfg {
  189. peers := []string{}
  190. for i := 0; i < memberCount; i++ {
  191. memberName := fmt.Sprintf("%s-%d", name, i)
  192. peerPort := uint64(2380 + i*10000)
  193. peer := fmt.Sprintf("%s=%s://127.0.0.1:%d", memberName, protocol, peerPort)
  194. peers = append(peers, peer)
  195. }
  196. initialCluster := strings.Join(peers, ",")
  197. extraArgs := ""
  198. if protocol == "https" {
  199. extraArgs = getOrCreateTLSPeerCertArgs(t)
  200. }
  201. cfgs := []*EtcdMigrateCfg{}
  202. for i := 0; i < memberCount; i++ {
  203. memberName := fmt.Sprintf("%s-%d", name, i)
  204. peerURL := fmt.Sprintf("%s://127.0.0.1:%d", protocol, uint64(2380+i*10000))
  205. cfg := &EtcdMigrateCfg{
  206. binPath: "/usr/local/bin",
  207. name: memberName,
  208. initialCluster: initialCluster,
  209. port: uint64(2379 + i*10000),
  210. peerListenUrls: peerURL,
  211. peerAdvertiseUrls: peerURL,
  212. etcdDataPrefix: "/registry",
  213. ttlKeysDirectory: "/registry/events",
  214. supportedVersions: testSupportedVersions,
  215. dataDirectory: fmt.Sprintf("/tmp/etcd-data-dir-%s", memberName),
  216. etcdServerArgs: extraArgs,
  217. }
  218. cfgs = append(cfgs, cfg)
  219. }
  220. return cfgs
  221. }
  222. func getOrCreateTLSPeerCertArgs(t *testing.T) string {
  223. spec := TestCertSpec{
  224. host: "localhost",
  225. ips: []string{"127.0.0.1"},
  226. }
  227. certDir := "/tmp/certs"
  228. certFile := filepath.Join(certDir, "test.crt")
  229. keyFile := filepath.Join(certDir, "test.key")
  230. err := getOrCreateTestCertFiles(certFile, keyFile, spec)
  231. if err != nil {
  232. t.Fatalf("failed to create server cert: %v", err)
  233. }
  234. return fmt.Sprintf("--peer-client-cert-auth --peer-trusted-ca-file=%s --peer-cert-file=%s --peer-key-file=%s", certFile, certFile, keyFile)
  235. }
  236. type TestCertSpec struct {
  237. host string
  238. names, ips []string // in certificate
  239. }
  240. func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) {
  241. if _, err := os.Stat(certFileName); err == nil {
  242. if _, err := os.Stat(keyFileName); err == nil {
  243. return nil
  244. }
  245. }
  246. certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
  247. if err != nil {
  248. return err
  249. }
  250. os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0777))
  251. err = ioutil.WriteFile(certFileName, certPem, os.FileMode(0777))
  252. if err != nil {
  253. return err
  254. }
  255. os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0777))
  256. err = ioutil.WriteFile(keyFileName, keyPem, os.FileMode(0777))
  257. if err != nil {
  258. return err
  259. }
  260. return nil
  261. }
  262. func parseIPList(ips []string) []net.IP {
  263. var netIPs []net.IP
  264. for _, ip := range ips {
  265. netIPs = append(netIPs, net.ParseIP(ip))
  266. }
  267. return netIPs
  268. }
  269. // generateSelfSignedCertKey creates a self-signed certificate and key for the given host.
  270. // Host may be an IP or a DNS name
  271. // You may also specify additional subject alt names (either ip or dns names) for the certificate
  272. func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
  273. priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
  274. if err != nil {
  275. return nil, nil, err
  276. }
  277. template := x509.Certificate{
  278. SerialNumber: big.NewInt(1),
  279. Subject: pkix.Name{
  280. CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
  281. },
  282. NotBefore: time.Unix(0, 0),
  283. NotAfter: time.Now().Add(time.Hour * 24 * 365 * 100),
  284. KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
  285. ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
  286. BasicConstraintsValid: true,
  287. IsCA: true,
  288. }
  289. if ip := net.ParseIP(host); ip != nil {
  290. template.IPAddresses = append(template.IPAddresses, ip)
  291. } else {
  292. template.DNSNames = append(template.DNSNames, host)
  293. }
  294. template.IPAddresses = append(template.IPAddresses, alternateIPs...)
  295. template.DNSNames = append(template.DNSNames, alternateDNS...)
  296. derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
  297. if err != nil {
  298. return nil, nil, err
  299. }
  300. // Generate cert
  301. certBuffer := bytes.Buffer{}
  302. if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
  303. return nil, nil, err
  304. }
  305. // Generate key
  306. keyBuffer := bytes.Buffer{}
  307. if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
  308. return nil, nil, err
  309. }
  310. return certBuffer.Bytes(), keyBuffer.Bytes(), nil
  311. }