ssh.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ssh
  14. import (
  15. "bytes"
  16. "context"
  17. "crypto/rand"
  18. "crypto/rsa"
  19. "crypto/tls"
  20. "crypto/x509"
  21. "encoding/pem"
  22. "errors"
  23. "fmt"
  24. "io/ioutil"
  25. mathrand "math/rand"
  26. "net"
  27. "net/http"
  28. "net/url"
  29. "os"
  30. "strings"
  31. "sync"
  32. "time"
  33. "github.com/prometheus/client_golang/prometheus"
  34. "golang.org/x/crypto/ssh"
  35. utilnet "k8s.io/apimachinery/pkg/util/net"
  36. "k8s.io/apimachinery/pkg/util/runtime"
  37. "k8s.io/apimachinery/pkg/util/wait"
  38. "k8s.io/klog"
  39. )
  40. var (
  41. tunnelOpenCounter = prometheus.NewCounter(
  42. prometheus.CounterOpts{
  43. Name: "ssh_tunnel_open_count",
  44. Help: "Counter of ssh tunnel total open attempts",
  45. },
  46. )
  47. tunnelOpenFailCounter = prometheus.NewCounter(
  48. prometheus.CounterOpts{
  49. Name: "ssh_tunnel_open_fail_count",
  50. Help: "Counter of ssh tunnel failed open attempts",
  51. },
  52. )
  53. )
  54. func init() {
  55. prometheus.MustRegister(tunnelOpenCounter)
  56. prometheus.MustRegister(tunnelOpenFailCounter)
  57. }
  58. // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here:
  59. // https://godoc.org/golang.org/x/crypto/ssh#ServerConn
  60. type sshTunnel struct {
  61. Config *ssh.ClientConfig
  62. Host string
  63. SSHPort string
  64. client *ssh.Client
  65. }
  66. func makeSSHTunnel(user string, signer ssh.Signer, host string) (*sshTunnel, error) {
  67. config := ssh.ClientConfig{
  68. User: user,
  69. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  70. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  71. }
  72. return &sshTunnel{
  73. Config: &config,
  74. Host: host,
  75. SSHPort: "22",
  76. }, nil
  77. }
  78. func (s *sshTunnel) Open() error {
  79. var err error
  80. s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
  81. tunnelOpenCounter.Inc()
  82. if err != nil {
  83. tunnelOpenFailCounter.Inc()
  84. }
  85. return err
  86. }
  87. func (s *sshTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
  88. if s.client == nil {
  89. return nil, errors.New("tunnel is not opened.")
  90. }
  91. // This Dial method does not allow to pass a context unfortunately
  92. return s.client.Dial(network, address)
  93. }
  94. func (s *sshTunnel) Close() error {
  95. if s.client == nil {
  96. return errors.New("Cannot close tunnel. Tunnel was not opened.")
  97. }
  98. if err := s.client.Close(); err != nil {
  99. return err
  100. }
  101. return nil
  102. }
  103. // Interface to allow mocking of ssh.Dial, for testing SSH
  104. type sshDialer interface {
  105. Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
  106. }
  107. // Real implementation of sshDialer
  108. type realSSHDialer struct{}
  109. var _ sshDialer = &realSSHDialer{}
  110. func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  111. conn, err := net.DialTimeout(network, addr, config.Timeout)
  112. if err != nil {
  113. return nil, err
  114. }
  115. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  116. c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
  117. if err != nil {
  118. return nil, err
  119. }
  120. conn.SetReadDeadline(time.Time{})
  121. return ssh.NewClient(c, chans, reqs), nil
  122. }
  123. // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
  124. // ssh library can hang indefinitely inside the Dial() call (see issue #23835).
  125. // Wrapping all Dial() calls with a conservative timeout provides safety against
  126. // getting stuck on that.
  127. type timeoutDialer struct {
  128. dialer sshDialer
  129. timeout time.Duration
  130. }
  131. // 150 seconds is longer than the underlying default TCP backoff delay (127
  132. // seconds). This timeout is only intended to catch otherwise uncaught hangs.
  133. const sshDialTimeout = 150 * time.Second
  134. var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
  135. func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  136. config.Timeout = d.timeout
  137. return d.dialer.Dial(network, addr, config)
  138. }
  139. // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
  140. // host as specific user, along with any SSH-level error.
  141. // If user=="", it will default (like SSH) to os.Getenv("USER")
  142. func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
  143. return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
  144. }
  145. // Internal implementation of runSSHCommand, for testing
  146. func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer, retry bool) (string, string, int, error) {
  147. if user == "" {
  148. user = os.Getenv("USER")
  149. }
  150. // Setup the config, dial the server, and open a session.
  151. config := &ssh.ClientConfig{
  152. User: user,
  153. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  154. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  155. }
  156. client, err := dialer.Dial("tcp", host, config)
  157. if err != nil && retry {
  158. err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) {
  159. fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err)
  160. if client, err = dialer.Dial("tcp", host, config); err != nil {
  161. return false, err
  162. }
  163. return true, nil
  164. })
  165. }
  166. if err != nil {
  167. return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
  168. }
  169. session, err := client.NewSession()
  170. if err != nil {
  171. return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
  172. }
  173. defer session.Close()
  174. // Run the command.
  175. code := 0
  176. var bout, berr bytes.Buffer
  177. session.Stdout, session.Stderr = &bout, &berr
  178. if err = session.Run(cmd); err != nil {
  179. // Check whether the command failed to run or didn't complete.
  180. if exiterr, ok := err.(*ssh.ExitError); ok {
  181. // If we got an ExitError and the exit code is nonzero, we'll
  182. // consider the SSH itself successful (just that the command run
  183. // errored on the host).
  184. if code = exiterr.ExitStatus(); code != 0 {
  185. err = nil
  186. }
  187. } else {
  188. // Some other kind of error happened (e.g. an IOError); consider the
  189. // SSH unsuccessful.
  190. err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
  191. }
  192. }
  193. return bout.String(), berr.String(), code, err
  194. }
  195. func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
  196. // Create an actual signer.
  197. buffer, err := ioutil.ReadFile(key)
  198. if err != nil {
  199. return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err)
  200. }
  201. return MakePrivateKeySignerFromBytes(buffer)
  202. }
  203. func MakePrivateKeySignerFromBytes(buffer []byte) (ssh.Signer, error) {
  204. signer, err := ssh.ParsePrivateKey(buffer)
  205. if err != nil {
  206. return nil, fmt.Errorf("error parsing SSH key: '%v'", err)
  207. }
  208. return signer, nil
  209. }
  210. func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
  211. buffer, err := ioutil.ReadFile(keyFile)
  212. if err != nil {
  213. return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err)
  214. }
  215. keyBlock, _ := pem.Decode(buffer)
  216. if keyBlock == nil {
  217. return nil, fmt.Errorf("error parsing SSH key %s: 'invalid PEM format'", keyFile)
  218. }
  219. key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes)
  220. if err != nil {
  221. return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err)
  222. }
  223. rsaKey, ok := key.(*rsa.PublicKey)
  224. if !ok {
  225. return nil, fmt.Errorf("SSH key could not be parsed as rsa public key")
  226. }
  227. return rsaKey, nil
  228. }
  229. type tunnel interface {
  230. Open() error
  231. Close() error
  232. Dial(ctx context.Context, network, address string) (net.Conn, error)
  233. }
  234. type sshTunnelEntry struct {
  235. Address string
  236. Tunnel tunnel
  237. }
  238. type sshTunnelCreator interface {
  239. newSSHTunnel(user, keyFile, host string) (tunnel, error)
  240. }
  241. type realTunnelCreator struct{}
  242. func (*realTunnelCreator) newSSHTunnel(user, keyFile, host string) (tunnel, error) {
  243. signer, err := MakePrivateKeySignerFromFile(keyFile)
  244. if err != nil {
  245. return nil, err
  246. }
  247. return makeSSHTunnel(user, signer, host)
  248. }
  249. type SSHTunnelList struct {
  250. entries []sshTunnelEntry
  251. adding map[string]bool
  252. tunnelCreator sshTunnelCreator
  253. tunnelsLock sync.Mutex
  254. user string
  255. keyfile string
  256. healthCheckURL *url.URL
  257. }
  258. func NewSSHTunnelList(user, keyfile string, healthCheckURL *url.URL, stopChan chan struct{}) *SSHTunnelList {
  259. l := &SSHTunnelList{
  260. adding: make(map[string]bool),
  261. tunnelCreator: &realTunnelCreator{},
  262. user: user,
  263. keyfile: keyfile,
  264. healthCheckURL: healthCheckURL,
  265. }
  266. healthCheckPoll := 1 * time.Minute
  267. go wait.Until(func() {
  268. l.tunnelsLock.Lock()
  269. defer l.tunnelsLock.Unlock()
  270. // Healthcheck each tunnel every minute
  271. numTunnels := len(l.entries)
  272. for i, entry := range l.entries {
  273. // Stagger healthchecks evenly across duration of healthCheckPoll.
  274. delay := healthCheckPoll * time.Duration(i) / time.Duration(numTunnels)
  275. l.delayedHealthCheck(entry, delay)
  276. }
  277. }, healthCheckPoll, stopChan)
  278. return l
  279. }
  280. func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration) {
  281. go func() {
  282. defer runtime.HandleCrash()
  283. time.Sleep(delay)
  284. if err := l.healthCheck(e); err != nil {
  285. klog.Errorf("Healthcheck failed for tunnel to %q: %v", e.Address, err)
  286. klog.Infof("Attempting once to re-establish tunnel to %q", e.Address)
  287. l.removeAndReAdd(e)
  288. }
  289. }()
  290. }
  291. func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
  292. // GET the healthcheck path using the provided tunnel's dial function.
  293. transport := utilnet.SetTransportDefaults(&http.Transport{
  294. DialContext: e.Tunnel.Dial,
  295. // TODO(cjcullen): Plumb real TLS options through.
  296. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  297. // We don't reuse the clients, so disable the keep-alive to properly
  298. // close the connection.
  299. DisableKeepAlives: true,
  300. })
  301. client := &http.Client{Transport: transport}
  302. resp, err := client.Get(l.healthCheckURL.String())
  303. if err != nil {
  304. return err
  305. }
  306. resp.Body.Close()
  307. return nil
  308. }
  309. func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
  310. // Find the entry to replace.
  311. l.tunnelsLock.Lock()
  312. for i, entry := range l.entries {
  313. if entry.Tunnel == e.Tunnel {
  314. l.entries = append(l.entries[:i], l.entries[i+1:]...)
  315. l.adding[e.Address] = true
  316. break
  317. }
  318. }
  319. l.tunnelsLock.Unlock()
  320. if err := e.Tunnel.Close(); err != nil {
  321. klog.Infof("Failed to close removed tunnel: %v", err)
  322. }
  323. go l.createAndAddTunnel(e.Address)
  324. }
  325. func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
  326. start := time.Now()
  327. id := mathrand.Int63() // So you can match begins/ends in the log.
  328. klog.Infof("[%x: %v] Dialing...", id, addr)
  329. defer func() {
  330. klog.Infof("[%x: %v] Dialed in %v.", id, addr, time.Since(start))
  331. }()
  332. tunnel, err := l.pickTunnel(strings.Split(addr, ":")[0])
  333. if err != nil {
  334. return nil, err
  335. }
  336. return tunnel.Dial(ctx, net, addr)
  337. }
  338. func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {
  339. l.tunnelsLock.Lock()
  340. defer l.tunnelsLock.Unlock()
  341. if len(l.entries) == 0 {
  342. return nil, fmt.Errorf("No SSH tunnels currently open. Were the targets able to accept an ssh-key for user %q?", l.user)
  343. }
  344. // Prefer same tunnel as kubelet
  345. // TODO: Change l.entries to a map of address->tunnel
  346. for _, entry := range l.entries {
  347. if entry.Address == addr {
  348. return entry.Tunnel, nil
  349. }
  350. }
  351. klog.Warningf("SSH tunnel not found for address %q, picking random node", addr)
  352. n := mathrand.Intn(len(l.entries))
  353. return l.entries[n].Tunnel, nil
  354. }
  355. // Update reconciles the list's entries with the specified addresses. Existing
  356. // tunnels that are not in addresses are removed from entries and closed in a
  357. // background goroutine. New tunnels specified in addresses are opened in a
  358. // background goroutine and then added to entries.
  359. func (l *SSHTunnelList) Update(addrs []string) {
  360. haveAddrsMap := make(map[string]bool)
  361. wantAddrsMap := make(map[string]bool)
  362. func() {
  363. l.tunnelsLock.Lock()
  364. defer l.tunnelsLock.Unlock()
  365. // Build a map of what we currently have.
  366. for i := range l.entries {
  367. haveAddrsMap[l.entries[i].Address] = true
  368. }
  369. // Determine any necessary additions.
  370. for i := range addrs {
  371. // Add tunnel if it is not in l.entries or l.adding
  372. if _, ok := haveAddrsMap[addrs[i]]; !ok {
  373. if _, ok := l.adding[addrs[i]]; !ok {
  374. l.adding[addrs[i]] = true
  375. addr := addrs[i]
  376. go func() {
  377. defer runtime.HandleCrash()
  378. // Actually adding tunnel to list will block until lock
  379. // is released after deletions.
  380. l.createAndAddTunnel(addr)
  381. }()
  382. }
  383. }
  384. wantAddrsMap[addrs[i]] = true
  385. }
  386. // Determine any necessary deletions.
  387. var newEntries []sshTunnelEntry
  388. for i := range l.entries {
  389. if _, ok := wantAddrsMap[l.entries[i].Address]; !ok {
  390. tunnelEntry := l.entries[i]
  391. klog.Infof("Removing tunnel to deleted node at %q", tunnelEntry.Address)
  392. go func() {
  393. defer runtime.HandleCrash()
  394. if err := tunnelEntry.Tunnel.Close(); err != nil {
  395. klog.Errorf("Failed to close tunnel to %q: %v", tunnelEntry.Address, err)
  396. }
  397. }()
  398. } else {
  399. newEntries = append(newEntries, l.entries[i])
  400. }
  401. }
  402. l.entries = newEntries
  403. }()
  404. }
  405. func (l *SSHTunnelList) createAndAddTunnel(addr string) {
  406. klog.Infof("Trying to add tunnel to %q", addr)
  407. tunnel, err := l.tunnelCreator.newSSHTunnel(l.user, l.keyfile, addr)
  408. if err != nil {
  409. klog.Errorf("Failed to create tunnel for %q: %v", addr, err)
  410. return
  411. }
  412. if err := tunnel.Open(); err != nil {
  413. klog.Errorf("Failed to open tunnel to %q: %v", addr, err)
  414. l.tunnelsLock.Lock()
  415. delete(l.adding, addr)
  416. l.tunnelsLock.Unlock()
  417. return
  418. }
  419. l.tunnelsLock.Lock()
  420. l.entries = append(l.entries, sshTunnelEntry{addr, tunnel})
  421. delete(l.adding, addr)
  422. l.tunnelsLock.Unlock()
  423. klog.Infof("Successfully added tunnel for %q", addr)
  424. }
  425. func EncodePrivateKey(private *rsa.PrivateKey) []byte {
  426. return pem.EncodeToMemory(&pem.Block{
  427. Bytes: x509.MarshalPKCS1PrivateKey(private),
  428. Type: "RSA PRIVATE KEY",
  429. })
  430. }
  431. func EncodePublicKey(public *rsa.PublicKey) ([]byte, error) {
  432. publicBytes, err := x509.MarshalPKIXPublicKey(public)
  433. if err != nil {
  434. return nil, err
  435. }
  436. return pem.EncodeToMemory(&pem.Block{
  437. Bytes: publicBytes,
  438. Type: "PUBLIC KEY",
  439. }), nil
  440. }
  441. func EncodeSSHKey(public *rsa.PublicKey) ([]byte, error) {
  442. publicKey, err := ssh.NewPublicKey(public)
  443. if err != nil {
  444. return nil, err
  445. }
  446. return ssh.MarshalAuthorizedKey(publicKey), nil
  447. }
  448. func GenerateKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
  449. private, err := rsa.GenerateKey(rand.Reader, bits)
  450. if err != nil {
  451. return nil, nil, err
  452. }
  453. return private, &private.PublicKey, nil
  454. }