ssh.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ssh
  14. import (
  15. "bytes"
  16. "context"
  17. "crypto/rand"
  18. "crypto/rsa"
  19. "crypto/tls"
  20. "crypto/x509"
  21. "encoding/pem"
  22. "errors"
  23. "fmt"
  24. "io/ioutil"
  25. mathrand "math/rand"
  26. "net"
  27. "net/http"
  28. "net/url"
  29. "os"
  30. "strings"
  31. "sync"
  32. "time"
  33. "golang.org/x/crypto/ssh"
  34. utilnet "k8s.io/apimachinery/pkg/util/net"
  35. "k8s.io/apimachinery/pkg/util/runtime"
  36. "k8s.io/apimachinery/pkg/util/wait"
  37. "k8s.io/component-base/metrics"
  38. "k8s.io/component-base/metrics/legacyregistry"
  39. "k8s.io/klog"
  40. )
  41. /*
  42. * By default, all the following metrics are defined as falling under
  43. * ALPHA stability level https://github.com/kubernetes/enhancements/blob/master/keps/sig-instrumentation/20190404-kubernetes-control-plane-metrics-stability.md#stability-classes)
  44. *
  45. * Promoting the stability level of the metric is a responsibility of the component owner, since it
  46. * involves explicitly acknowledging support for the metric across multiple releases, in accordance with
  47. * the metric stability policy.
  48. */
  49. var (
  50. tunnelOpenCounter = metrics.NewCounter(
  51. &metrics.CounterOpts{
  52. Name: "ssh_tunnel_open_count",
  53. Help: "Counter of ssh tunnel total open attempts",
  54. StabilityLevel: metrics.ALPHA,
  55. },
  56. )
  57. tunnelOpenFailCounter = metrics.NewCounter(
  58. &metrics.CounterOpts{
  59. Name: "ssh_tunnel_open_fail_count",
  60. Help: "Counter of ssh tunnel failed open attempts",
  61. StabilityLevel: metrics.ALPHA,
  62. },
  63. )
  64. )
  65. func init() {
  66. legacyregistry.MustRegister(tunnelOpenCounter)
  67. legacyregistry.MustRegister(tunnelOpenFailCounter)
  68. }
  69. // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here:
  70. // https://godoc.org/golang.org/x/crypto/ssh#ServerConn
  71. type sshTunnel struct {
  72. Config *ssh.ClientConfig
  73. Host string
  74. SSHPort string
  75. client *ssh.Client
  76. }
  77. func makeSSHTunnel(user string, signer ssh.Signer, host string) (*sshTunnel, error) {
  78. config := ssh.ClientConfig{
  79. User: user,
  80. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  81. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  82. }
  83. return &sshTunnel{
  84. Config: &config,
  85. Host: host,
  86. SSHPort: "22",
  87. }, nil
  88. }
  89. func (s *sshTunnel) Open() error {
  90. var err error
  91. s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
  92. tunnelOpenCounter.Inc()
  93. if err != nil {
  94. tunnelOpenFailCounter.Inc()
  95. }
  96. return err
  97. }
  98. func (s *sshTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
  99. if s.client == nil {
  100. return nil, errors.New("tunnel is not opened.")
  101. }
  102. // This Dial method does not allow to pass a context unfortunately
  103. return s.client.Dial(network, address)
  104. }
  105. func (s *sshTunnel) Close() error {
  106. if s.client == nil {
  107. return errors.New("Cannot close tunnel. Tunnel was not opened.")
  108. }
  109. if err := s.client.Close(); err != nil {
  110. return err
  111. }
  112. return nil
  113. }
  114. // Interface to allow mocking of ssh.Dial, for testing SSH
  115. type sshDialer interface {
  116. Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
  117. }
  118. // Real implementation of sshDialer
  119. type realSSHDialer struct{}
  120. var _ sshDialer = &realSSHDialer{}
  121. func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  122. conn, err := net.DialTimeout(network, addr, config.Timeout)
  123. if err != nil {
  124. return nil, err
  125. }
  126. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  127. c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
  128. if err != nil {
  129. return nil, err
  130. }
  131. conn.SetReadDeadline(time.Time{})
  132. return ssh.NewClient(c, chans, reqs), nil
  133. }
  134. // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
  135. // ssh library can hang indefinitely inside the Dial() call (see issue #23835).
  136. // Wrapping all Dial() calls with a conservative timeout provides safety against
  137. // getting stuck on that.
  138. type timeoutDialer struct {
  139. dialer sshDialer
  140. timeout time.Duration
  141. }
  142. // 150 seconds is longer than the underlying default TCP backoff delay (127
  143. // seconds). This timeout is only intended to catch otherwise uncaught hangs.
  144. const sshDialTimeout = 150 * time.Second
  145. var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
  146. func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  147. config.Timeout = d.timeout
  148. return d.dialer.Dial(network, addr, config)
  149. }
  150. // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
  151. // host as specific user, along with any SSH-level error.
  152. // If user=="", it will default (like SSH) to os.Getenv("USER")
  153. func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
  154. return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
  155. }
  156. // Internal implementation of runSSHCommand, for testing
  157. func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer, retry bool) (string, string, int, error) {
  158. if user == "" {
  159. user = os.Getenv("USER")
  160. }
  161. // Setup the config, dial the server, and open a session.
  162. config := &ssh.ClientConfig{
  163. User: user,
  164. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  165. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  166. }
  167. client, err := dialer.Dial("tcp", host, config)
  168. if err != nil && retry {
  169. err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) {
  170. fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err)
  171. if client, err = dialer.Dial("tcp", host, config); err != nil {
  172. return false, err
  173. }
  174. return true, nil
  175. })
  176. }
  177. if err != nil {
  178. return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
  179. }
  180. defer client.Close()
  181. session, err := client.NewSession()
  182. if err != nil {
  183. return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
  184. }
  185. defer session.Close()
  186. // Run the command.
  187. code := 0
  188. var bout, berr bytes.Buffer
  189. session.Stdout, session.Stderr = &bout, &berr
  190. if err = session.Run(cmd); err != nil {
  191. // Check whether the command failed to run or didn't complete.
  192. if exiterr, ok := err.(*ssh.ExitError); ok {
  193. // If we got an ExitError and the exit code is nonzero, we'll
  194. // consider the SSH itself successful (just that the command run
  195. // errored on the host).
  196. if code = exiterr.ExitStatus(); code != 0 {
  197. err = nil
  198. }
  199. } else {
  200. // Some other kind of error happened (e.g. an IOError); consider the
  201. // SSH unsuccessful.
  202. err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
  203. }
  204. }
  205. return bout.String(), berr.String(), code, err
  206. }
  207. func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
  208. // Create an actual signer.
  209. buffer, err := ioutil.ReadFile(key)
  210. if err != nil {
  211. return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err)
  212. }
  213. return MakePrivateKeySignerFromBytes(buffer)
  214. }
  215. func MakePrivateKeySignerFromBytes(buffer []byte) (ssh.Signer, error) {
  216. signer, err := ssh.ParsePrivateKey(buffer)
  217. if err != nil {
  218. return nil, fmt.Errorf("error parsing SSH key: '%v'", err)
  219. }
  220. return signer, nil
  221. }
  222. func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
  223. buffer, err := ioutil.ReadFile(keyFile)
  224. if err != nil {
  225. return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err)
  226. }
  227. keyBlock, _ := pem.Decode(buffer)
  228. if keyBlock == nil {
  229. return nil, fmt.Errorf("error parsing SSH key %s: 'invalid PEM format'", keyFile)
  230. }
  231. key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes)
  232. if err != nil {
  233. return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err)
  234. }
  235. rsaKey, ok := key.(*rsa.PublicKey)
  236. if !ok {
  237. return nil, fmt.Errorf("SSH key could not be parsed as rsa public key")
  238. }
  239. return rsaKey, nil
  240. }
  241. type tunnel interface {
  242. Open() error
  243. Close() error
  244. Dial(ctx context.Context, network, address string) (net.Conn, error)
  245. }
  246. type sshTunnelEntry struct {
  247. Address string
  248. Tunnel tunnel
  249. }
  250. type sshTunnelCreator interface {
  251. newSSHTunnel(user, keyFile, host string) (tunnel, error)
  252. }
  253. type realTunnelCreator struct{}
  254. func (*realTunnelCreator) newSSHTunnel(user, keyFile, host string) (tunnel, error) {
  255. signer, err := MakePrivateKeySignerFromFile(keyFile)
  256. if err != nil {
  257. return nil, err
  258. }
  259. return makeSSHTunnel(user, signer, host)
  260. }
  261. type SSHTunnelList struct {
  262. entries []sshTunnelEntry
  263. adding map[string]bool
  264. tunnelCreator sshTunnelCreator
  265. tunnelsLock sync.Mutex
  266. user string
  267. keyfile string
  268. healthCheckURL *url.URL
  269. }
  270. func NewSSHTunnelList(user, keyfile string, healthCheckURL *url.URL, stopChan chan struct{}) *SSHTunnelList {
  271. l := &SSHTunnelList{
  272. adding: make(map[string]bool),
  273. tunnelCreator: &realTunnelCreator{},
  274. user: user,
  275. keyfile: keyfile,
  276. healthCheckURL: healthCheckURL,
  277. }
  278. healthCheckPoll := 1 * time.Minute
  279. go wait.Until(func() {
  280. l.tunnelsLock.Lock()
  281. defer l.tunnelsLock.Unlock()
  282. // Healthcheck each tunnel every minute
  283. numTunnels := len(l.entries)
  284. for i, entry := range l.entries {
  285. // Stagger healthchecks evenly across duration of healthCheckPoll.
  286. delay := healthCheckPoll * time.Duration(i) / time.Duration(numTunnels)
  287. l.delayedHealthCheck(entry, delay)
  288. }
  289. }, healthCheckPoll, stopChan)
  290. return l
  291. }
  292. func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration) {
  293. go func() {
  294. defer runtime.HandleCrash()
  295. time.Sleep(delay)
  296. if err := l.healthCheck(e); err != nil {
  297. klog.Errorf("Healthcheck failed for tunnel to %q: %v", e.Address, err)
  298. klog.Infof("Attempting once to re-establish tunnel to %q", e.Address)
  299. l.removeAndReAdd(e)
  300. }
  301. }()
  302. }
  303. func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
  304. // GET the healthcheck path using the provided tunnel's dial function.
  305. transport := utilnet.SetTransportDefaults(&http.Transport{
  306. DialContext: e.Tunnel.Dial,
  307. // TODO(cjcullen): Plumb real TLS options through.
  308. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  309. // We don't reuse the clients, so disable the keep-alive to properly
  310. // close the connection.
  311. DisableKeepAlives: true,
  312. })
  313. client := &http.Client{Transport: transport}
  314. resp, err := client.Get(l.healthCheckURL.String())
  315. if err != nil {
  316. return err
  317. }
  318. resp.Body.Close()
  319. return nil
  320. }
  321. func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
  322. // Find the entry to replace.
  323. l.tunnelsLock.Lock()
  324. for i, entry := range l.entries {
  325. if entry.Tunnel == e.Tunnel {
  326. l.entries = append(l.entries[:i], l.entries[i+1:]...)
  327. l.adding[e.Address] = true
  328. break
  329. }
  330. }
  331. l.tunnelsLock.Unlock()
  332. if err := e.Tunnel.Close(); err != nil {
  333. klog.Infof("Failed to close removed tunnel: %v", err)
  334. }
  335. go l.createAndAddTunnel(e.Address)
  336. }
  337. func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
  338. start := time.Now()
  339. id := mathrand.Int63() // So you can match begins/ends in the log.
  340. klog.Infof("[%x: %v] Dialing...", id, addr)
  341. defer func() {
  342. klog.Infof("[%x: %v] Dialed in %v.", id, addr, time.Since(start))
  343. }()
  344. tunnel, err := l.pickTunnel(strings.Split(addr, ":")[0])
  345. if err != nil {
  346. return nil, err
  347. }
  348. return tunnel.Dial(ctx, net, addr)
  349. }
  350. func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {
  351. l.tunnelsLock.Lock()
  352. defer l.tunnelsLock.Unlock()
  353. if len(l.entries) == 0 {
  354. return nil, fmt.Errorf("No SSH tunnels currently open. Were the targets able to accept an ssh-key for user %q?", l.user)
  355. }
  356. // Prefer same tunnel as kubelet
  357. for _, entry := range l.entries {
  358. if entry.Address == addr {
  359. return entry.Tunnel, nil
  360. }
  361. }
  362. klog.Warningf("SSH tunnel not found for address %q, picking random node", addr)
  363. n := mathrand.Intn(len(l.entries))
  364. return l.entries[n].Tunnel, nil
  365. }
  366. // Update reconciles the list's entries with the specified addresses. Existing
  367. // tunnels that are not in addresses are removed from entries and closed in a
  368. // background goroutine. New tunnels specified in addresses are opened in a
  369. // background goroutine and then added to entries.
  370. func (l *SSHTunnelList) Update(addrs []string) {
  371. haveAddrsMap := make(map[string]bool)
  372. wantAddrsMap := make(map[string]bool)
  373. func() {
  374. l.tunnelsLock.Lock()
  375. defer l.tunnelsLock.Unlock()
  376. // Build a map of what we currently have.
  377. for i := range l.entries {
  378. haveAddrsMap[l.entries[i].Address] = true
  379. }
  380. // Determine any necessary additions.
  381. for i := range addrs {
  382. // Add tunnel if it is not in l.entries or l.adding
  383. if _, ok := haveAddrsMap[addrs[i]]; !ok {
  384. if _, ok := l.adding[addrs[i]]; !ok {
  385. l.adding[addrs[i]] = true
  386. addr := addrs[i]
  387. go func() {
  388. defer runtime.HandleCrash()
  389. // Actually adding tunnel to list will block until lock
  390. // is released after deletions.
  391. l.createAndAddTunnel(addr)
  392. }()
  393. }
  394. }
  395. wantAddrsMap[addrs[i]] = true
  396. }
  397. // Determine any necessary deletions.
  398. var newEntries []sshTunnelEntry
  399. for i := range l.entries {
  400. if _, ok := wantAddrsMap[l.entries[i].Address]; !ok {
  401. tunnelEntry := l.entries[i]
  402. klog.Infof("Removing tunnel to deleted node at %q", tunnelEntry.Address)
  403. go func() {
  404. defer runtime.HandleCrash()
  405. if err := tunnelEntry.Tunnel.Close(); err != nil {
  406. klog.Errorf("Failed to close tunnel to %q: %v", tunnelEntry.Address, err)
  407. }
  408. }()
  409. } else {
  410. newEntries = append(newEntries, l.entries[i])
  411. }
  412. }
  413. l.entries = newEntries
  414. }()
  415. }
  416. func (l *SSHTunnelList) createAndAddTunnel(addr string) {
  417. klog.Infof("Trying to add tunnel to %q", addr)
  418. tunnel, err := l.tunnelCreator.newSSHTunnel(l.user, l.keyfile, addr)
  419. if err != nil {
  420. klog.Errorf("Failed to create tunnel for %q: %v", addr, err)
  421. return
  422. }
  423. if err := tunnel.Open(); err != nil {
  424. klog.Errorf("Failed to open tunnel to %q: %v", addr, err)
  425. l.tunnelsLock.Lock()
  426. delete(l.adding, addr)
  427. l.tunnelsLock.Unlock()
  428. return
  429. }
  430. l.tunnelsLock.Lock()
  431. l.entries = append(l.entries, sshTunnelEntry{addr, tunnel})
  432. delete(l.adding, addr)
  433. l.tunnelsLock.Unlock()
  434. klog.Infof("Successfully added tunnel for %q", addr)
  435. }
  436. func EncodePrivateKey(private *rsa.PrivateKey) []byte {
  437. return pem.EncodeToMemory(&pem.Block{
  438. Bytes: x509.MarshalPKCS1PrivateKey(private),
  439. Type: "RSA PRIVATE KEY",
  440. })
  441. }
  442. func EncodePublicKey(public *rsa.PublicKey) ([]byte, error) {
  443. publicBytes, err := x509.MarshalPKIXPublicKey(public)
  444. if err != nil {
  445. return nil, err
  446. }
  447. return pem.EncodeToMemory(&pem.Block{
  448. Bytes: publicBytes,
  449. Type: "PUBLIC KEY",
  450. }), nil
  451. }
  452. func EncodeSSHKey(public *rsa.PublicKey) ([]byte, error) {
  453. publicKey, err := ssh.NewPublicKey(public)
  454. if err != nil {
  455. return nil, err
  456. }
  457. return ssh.MarshalAuthorizedKey(publicKey), nil
  458. }
  459. func GenerateKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
  460. private, err := rsa.GenerateKey(rand.Reader, bits)
  461. if err != nil {
  462. return nil, nil, err
  463. }
  464. return private, &private.PublicKey, nil
  465. }