conn.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package sftp
  2. import (
  3. "encoding"
  4. "io"
  5. "sync"
  6. "github.com/pkg/errors"
  7. )
  8. // conn implements a bidirectional channel on which client and server
  9. // connections are multiplexed.
  10. type conn struct {
  11. io.Reader
  12. io.WriteCloser
  13. sync.Mutex // used to serialise writes to sendPacket
  14. }
  15. func (c *conn) recvPacket() (uint8, []byte, error) {
  16. return recvPacket(c)
  17. }
  18. func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
  19. c.Lock()
  20. defer c.Unlock()
  21. return sendPacket(c, m)
  22. }
  23. type clientConn struct {
  24. conn
  25. wg sync.WaitGroup
  26. sync.Mutex // protects inflight
  27. inflight map[uint32]chan<- result // outstanding requests
  28. }
  29. // Close closes the SFTP session.
  30. func (c *clientConn) Close() error {
  31. defer c.wg.Wait()
  32. return c.conn.Close()
  33. }
  34. func (c *clientConn) loop() {
  35. defer c.wg.Done()
  36. err := c.recv()
  37. if err != nil {
  38. c.broadcastErr(err)
  39. }
  40. }
  41. // recv continuously reads from the server and forwards responses to the
  42. // appropriate channel.
  43. func (c *clientConn) recv() error {
  44. defer c.conn.Close()
  45. for {
  46. typ, data, err := c.recvPacket()
  47. if err != nil {
  48. return err
  49. }
  50. sid, _ := unmarshalUint32(data)
  51. c.Lock()
  52. ch, ok := c.inflight[sid]
  53. delete(c.inflight, sid)
  54. c.Unlock()
  55. if !ok {
  56. // This is an unexpected occurrence. Send the error
  57. // back to all listeners so that they terminate
  58. // gracefully.
  59. return errors.Errorf("sid: %v not fond", sid)
  60. }
  61. ch <- result{typ: typ, data: data}
  62. }
  63. }
  64. // result captures the result of receiving the a packet from the server
  65. type result struct {
  66. typ byte
  67. data []byte
  68. err error
  69. }
  70. type idmarshaler interface {
  71. id() uint32
  72. encoding.BinaryMarshaler
  73. }
  74. func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) {
  75. ch := make(chan result, 1)
  76. c.dispatchRequest(ch, p)
  77. s := <-ch
  78. return s.typ, s.data, s.err
  79. }
  80. func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
  81. c.Lock()
  82. c.inflight[p.id()] = ch
  83. if err := c.conn.sendPacket(p); err != nil {
  84. delete(c.inflight, p.id())
  85. ch <- result{err: err}
  86. }
  87. c.Unlock()
  88. }
  89. // broadcastErr sends an error to all goroutines waiting for a response.
  90. func (c *clientConn) broadcastErr(err error) {
  91. c.Lock()
  92. listeners := make([]chan<- result, 0, len(c.inflight))
  93. for _, ch := range c.inflight {
  94. listeners = append(listeners, ch)
  95. }
  96. c.Unlock()
  97. for _, ch := range listeners {
  98. ch <- result{err: err}
  99. }
  100. }
  101. type serverConn struct {
  102. conn
  103. }
  104. func (s *serverConn) sendError(p id, err error) error {
  105. return s.sendPacket(statusFromError(p, err))
  106. }