ssh_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ssh
  14. import (
  15. "context"
  16. "fmt"
  17. "io"
  18. "net"
  19. "os"
  20. "reflect"
  21. "strings"
  22. "testing"
  23. "time"
  24. "golang.org/x/crypto/ssh"
  25. "k8s.io/apimachinery/pkg/util/wait"
  26. "k8s.io/klog"
  27. )
  28. type testSSHServer struct {
  29. Host string
  30. Port string
  31. Type string
  32. Data []byte
  33. PrivateKey []byte
  34. PublicKey []byte
  35. }
  36. func runTestSSHServer(user, password string) (*testSSHServer, error) {
  37. result := &testSSHServer{}
  38. // Largely derived from https://godoc.org/golang.org/x/crypto/ssh#example-NewServerConn
  39. config := &ssh.ServerConfig{
  40. PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
  41. if c.User() == user && string(pass) == password {
  42. return nil, nil
  43. }
  44. return nil, fmt.Errorf("password rejected for %s", c.User())
  45. },
  46. PublicKeyCallback: func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
  47. result.Type = key.Type()
  48. result.Data = ssh.MarshalAuthorizedKey(key)
  49. return nil, nil
  50. },
  51. }
  52. privateKey, publicKey, err := GenerateKey(2048)
  53. if err != nil {
  54. return nil, err
  55. }
  56. privateBytes := EncodePrivateKey(privateKey)
  57. signer, err := ssh.ParsePrivateKey(privateBytes)
  58. if err != nil {
  59. return nil, err
  60. }
  61. config.AddHostKey(signer)
  62. result.PrivateKey = privateBytes
  63. publicBytes, err := EncodePublicKey(publicKey)
  64. if err != nil {
  65. return nil, err
  66. }
  67. result.PublicKey = publicBytes
  68. listener, err := net.Listen("tcp", "127.0.0.1:0")
  69. if err != nil {
  70. return nil, err
  71. }
  72. host, port, err := net.SplitHostPort(listener.Addr().String())
  73. if err != nil {
  74. return nil, err
  75. }
  76. result.Host = host
  77. result.Port = port
  78. go func() {
  79. // TODO: return this port.
  80. defer listener.Close()
  81. conn, err := listener.Accept()
  82. if err != nil {
  83. klog.Errorf("Failed to accept: %v", err)
  84. }
  85. _, chans, reqs, err := ssh.NewServerConn(conn, config)
  86. if err != nil {
  87. klog.Errorf("Failed handshake: %v", err)
  88. }
  89. go ssh.DiscardRequests(reqs)
  90. for newChannel := range chans {
  91. if newChannel.ChannelType() != "direct-tcpip" {
  92. newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
  93. continue
  94. }
  95. channel, requests, err := newChannel.Accept()
  96. if err != nil {
  97. klog.Errorf("Failed to accept channel: %v", err)
  98. }
  99. for req := range requests {
  100. klog.Infof("Got request: %v", req)
  101. }
  102. channel.Close()
  103. }
  104. }()
  105. return result, nil
  106. }
  107. func TestSSHTunnel(t *testing.T) {
  108. private, public, err := GenerateKey(2048)
  109. if err != nil {
  110. t.Errorf("unexpected error: %v", err)
  111. t.FailNow()
  112. }
  113. server, err := runTestSSHServer("foo", "bar")
  114. if err != nil {
  115. t.Errorf("unexpected error: %v", err)
  116. t.FailNow()
  117. }
  118. privateData := EncodePrivateKey(private)
  119. tunnel, err := newSSHTunnelFromBytes("foo", privateData, server.Host)
  120. if err != nil {
  121. t.Errorf("unexpected error: %v", err)
  122. t.FailNow()
  123. }
  124. tunnel.SSHPort = server.Port
  125. if err := tunnel.Open(); err != nil {
  126. t.Errorf("unexpected error: %v", err)
  127. t.FailNow()
  128. }
  129. _, err = tunnel.Dial(context.Background(), "tcp", "127.0.0.1:8080")
  130. if err != nil {
  131. t.Errorf("unexpected error: %v", err)
  132. }
  133. if server.Type != "ssh-rsa" {
  134. t.Errorf("expected %s, got %s", "ssh-rsa", server.Type)
  135. }
  136. publicData, err := EncodeSSHKey(public)
  137. if err != nil {
  138. t.Errorf("unexpected error: %v", err)
  139. }
  140. if !reflect.DeepEqual(server.Data, publicData) {
  141. t.Errorf("expected %s, got %s", string(server.Data), string(privateData))
  142. }
  143. if err := tunnel.Close(); err != nil {
  144. t.Errorf("unexpected error: %v", err)
  145. }
  146. }
  147. type fakeTunnel struct{}
  148. func (*fakeTunnel) Open() error {
  149. return nil
  150. }
  151. func (*fakeTunnel) Close() error {
  152. return nil
  153. }
  154. func (*fakeTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
  155. return nil, nil
  156. }
  157. type fakeTunnelCreator struct{}
  158. func (*fakeTunnelCreator) newSSHTunnel(string, string, string) (tunnel, error) {
  159. return &fakeTunnel{}, nil
  160. }
  161. func TestSSHTunnelListUpdate(t *testing.T) {
  162. // Start with an empty tunnel list.
  163. l := &SSHTunnelList{
  164. adding: make(map[string]bool),
  165. tunnelCreator: &fakeTunnelCreator{},
  166. }
  167. // Start with 2 tunnels.
  168. addressStrings := []string{"1.2.3.4", "5.6.7.8"}
  169. l.Update(addressStrings)
  170. checkTunnelsCorrect(t, l, addressStrings)
  171. // Add another tunnel.
  172. addressStrings = append(addressStrings, "9.10.11.12")
  173. l.Update(addressStrings)
  174. checkTunnelsCorrect(t, l, addressStrings)
  175. // Go down to a single tunnel.
  176. addressStrings = []string{"1.2.3.4"}
  177. l.Update(addressStrings)
  178. checkTunnelsCorrect(t, l, addressStrings)
  179. // Replace w/ all new tunnels.
  180. addressStrings = []string{"21.22.23.24", "25.26.27.28"}
  181. l.Update(addressStrings)
  182. checkTunnelsCorrect(t, l, addressStrings)
  183. // Call update with the same tunnels.
  184. l.Update(addressStrings)
  185. checkTunnelsCorrect(t, l, addressStrings)
  186. }
  187. func checkTunnelsCorrect(t *testing.T, tunnelList *SSHTunnelList, addresses []string) {
  188. if err := wait.Poll(100*time.Millisecond, 2*time.Second, func() (bool, error) {
  189. return hasCorrectTunnels(tunnelList, addresses), nil
  190. }); err != nil {
  191. t.Errorf("Error waiting for tunnels to reach expected state: %v. Expected %v, had %v", err, addresses, tunnelList)
  192. }
  193. }
  194. func hasCorrectTunnels(tunnelList *SSHTunnelList, addresses []string) bool {
  195. tunnelList.tunnelsLock.Lock()
  196. defer tunnelList.tunnelsLock.Unlock()
  197. wantMap := make(map[string]bool)
  198. for _, addr := range addresses {
  199. wantMap[addr] = true
  200. }
  201. haveMap := make(map[string]bool)
  202. for _, entry := range tunnelList.entries {
  203. if wantMap[entry.Address] == false {
  204. return false
  205. }
  206. haveMap[entry.Address] = true
  207. }
  208. for _, addr := range addresses {
  209. if haveMap[addr] == false {
  210. return false
  211. }
  212. }
  213. return true
  214. }
  215. type mockSSHDialer struct {
  216. network string
  217. addr string
  218. config *ssh.ClientConfig
  219. }
  220. func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  221. d.network = network
  222. d.addr = addr
  223. d.config = config
  224. return nil, fmt.Errorf("mock error from Dial")
  225. }
  226. type mockSigner struct {
  227. }
  228. func (s *mockSigner) PublicKey() ssh.PublicKey {
  229. panic("mockSigner.PublicKey not implemented")
  230. }
  231. func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
  232. panic("mockSigner.Sign not implemented")
  233. }
  234. func TestSSHUser(t *testing.T) {
  235. signer := &mockSigner{}
  236. table := []struct {
  237. title string
  238. user string
  239. host string
  240. signer ssh.Signer
  241. command string
  242. expectUser string
  243. }{
  244. {
  245. title: "all values provided",
  246. user: "testuser",
  247. host: "testhost",
  248. signer: signer,
  249. command: "uptime",
  250. expectUser: "testuser",
  251. },
  252. {
  253. title: "empty user defaults to GetEnv(USER)",
  254. user: "",
  255. host: "testhost",
  256. signer: signer,
  257. command: "uptime",
  258. expectUser: os.Getenv("USER"),
  259. },
  260. }
  261. for _, item := range table {
  262. dialer := &mockSSHDialer{}
  263. _, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer, false)
  264. if err == nil {
  265. t.Errorf("expected error (as mock returns error); did not get one")
  266. }
  267. errString := err.Error()
  268. if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) {
  269. t.Errorf("unexpected error: %v", errString)
  270. }
  271. if dialer.network != "tcp" {
  272. t.Errorf("unexpected network: %v", dialer.network)
  273. }
  274. if dialer.config.User != item.expectUser {
  275. t.Errorf("unexpected user: %v", dialer.config.User)
  276. }
  277. if len(dialer.config.Auth) != 1 {
  278. t.Errorf("unexpected auth: %v", dialer.config.Auth)
  279. }
  280. // (No way to test Auth - nothing exported?)
  281. }
  282. }
  283. func TestTimeoutDialer(t *testing.T) {
  284. listener, err := net.Listen("tcp", "127.0.0.1:0")
  285. if err != nil {
  286. t.Errorf("unexpected error: %v", err)
  287. t.FailNow()
  288. }
  289. testCases := []struct {
  290. timeout time.Duration
  291. expectedErrString string
  292. }{
  293. // delay > timeout should cause ssh.Dial to timeout.
  294. {1, "i/o timeout"},
  295. }
  296. for _, tc := range testCases {
  297. dialer := &timeoutDialer{&realSSHDialer{}, tc.timeout}
  298. _, err := dialer.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{})
  299. if len(tc.expectedErrString) == 0 && err != nil ||
  300. !strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
  301. t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
  302. }
  303. }
  304. listener.Close()
  305. }
  306. func newSSHTunnelFromBytes(user string, privateKey []byte, host string) (*sshTunnel, error) {
  307. signer, err := MakePrivateKeySignerFromBytes(privateKey)
  308. if err != nil {
  309. return nil, err
  310. }
  311. return makeSSHTunnel(user, signer, host)
  312. }