transport_unix.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. //+build !windows,!solaris
  2. package dbus
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "io"
  8. "net"
  9. "syscall"
  10. )
  11. type oobReader struct {
  12. conn *net.UnixConn
  13. oob []byte
  14. buf [4096]byte
  15. }
  16. func (o *oobReader) Read(b []byte) (n int, err error) {
  17. n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
  18. if err != nil {
  19. return n, err
  20. }
  21. if flags&syscall.MSG_CTRUNC != 0 {
  22. return n, errors.New("dbus: control data truncated (too many fds received)")
  23. }
  24. o.oob = append(o.oob, o.buf[:oobn]...)
  25. return n, nil
  26. }
  27. type unixTransport struct {
  28. *net.UnixConn
  29. rdr *oobReader
  30. hasUnixFDs bool
  31. }
  32. func newUnixTransport(keys string) (transport, error) {
  33. var err error
  34. t := new(unixTransport)
  35. abstract := getKey(keys, "abstract")
  36. path := getKey(keys, "path")
  37. switch {
  38. case abstract == "" && path == "":
  39. return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
  40. case abstract != "" && path == "":
  41. t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
  42. if err != nil {
  43. return nil, err
  44. }
  45. return t, nil
  46. case abstract == "" && path != "":
  47. t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
  48. if err != nil {
  49. return nil, err
  50. }
  51. return t, nil
  52. default:
  53. return nil, errors.New("dbus: invalid address (both path and abstract set)")
  54. }
  55. }
  56. func init() {
  57. transports["unix"] = newUnixTransport
  58. }
  59. func (t *unixTransport) EnableUnixFDs() {
  60. t.hasUnixFDs = true
  61. }
  62. func (t *unixTransport) ReadMessage() (*Message, error) {
  63. var (
  64. blen, hlen uint32
  65. csheader [16]byte
  66. headers []header
  67. order binary.ByteOrder
  68. unixfds uint32
  69. )
  70. // To be sure that all bytes of out-of-band data are read, we use a special
  71. // reader that uses ReadUnix on the underlying connection instead of Read
  72. // and gathers the out-of-band data in a buffer.
  73. if t.rdr == nil {
  74. t.rdr = &oobReader{conn: t.UnixConn}
  75. } else {
  76. t.rdr.oob = nil
  77. }
  78. // read the first 16 bytes (the part of the header that has a constant size),
  79. // from which we can figure out the length of the rest of the message
  80. if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
  81. return nil, err
  82. }
  83. switch csheader[0] {
  84. case 'l':
  85. order = binary.LittleEndian
  86. case 'B':
  87. order = binary.BigEndian
  88. default:
  89. return nil, InvalidMessageError("invalid byte order")
  90. }
  91. // csheader[4:8] -> length of message body, csheader[12:16] -> length of
  92. // header fields (without alignment)
  93. binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
  94. binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
  95. if hlen%8 != 0 {
  96. hlen += 8 - (hlen % 8)
  97. }
  98. // decode headers and look for unix fds
  99. headerdata := make([]byte, hlen+4)
  100. copy(headerdata, csheader[12:])
  101. if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
  102. return nil, err
  103. }
  104. dec := newDecoder(bytes.NewBuffer(headerdata), order)
  105. dec.pos = 12
  106. vs, err := dec.Decode(Signature{"a(yv)"})
  107. if err != nil {
  108. return nil, err
  109. }
  110. Store(vs, &headers)
  111. for _, v := range headers {
  112. if v.Field == byte(FieldUnixFDs) {
  113. unixfds, _ = v.Variant.value.(uint32)
  114. }
  115. }
  116. all := make([]byte, 16+hlen+blen)
  117. copy(all, csheader[:])
  118. copy(all[16:], headerdata[4:])
  119. if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
  120. return nil, err
  121. }
  122. if unixfds != 0 {
  123. if !t.hasUnixFDs {
  124. return nil, errors.New("dbus: got unix fds on unsupported transport")
  125. }
  126. // read the fds from the OOB data
  127. scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
  128. if err != nil {
  129. return nil, err
  130. }
  131. if len(scms) != 1 {
  132. return nil, errors.New("dbus: received more than one socket control message")
  133. }
  134. fds, err := syscall.ParseUnixRights(&scms[0])
  135. if err != nil {
  136. return nil, err
  137. }
  138. msg, err := DecodeMessage(bytes.NewBuffer(all))
  139. if err != nil {
  140. return nil, err
  141. }
  142. // substitute the values in the message body (which are indices for the
  143. // array receiver via OOB) with the actual values
  144. for i, v := range msg.Body {
  145. switch v.(type) {
  146. case UnixFDIndex:
  147. j := v.(UnixFDIndex)
  148. if uint32(j) >= unixfds {
  149. return nil, InvalidMessageError("invalid index for unix fd")
  150. }
  151. msg.Body[i] = UnixFD(fds[j])
  152. case []UnixFDIndex:
  153. idxArray := v.([]UnixFDIndex)
  154. fdArray := make([]UnixFD, len(idxArray))
  155. for k, j := range idxArray {
  156. if uint32(j) >= unixfds {
  157. return nil, InvalidMessageError("invalid index for unix fd")
  158. }
  159. fdArray[k] = UnixFD(fds[j])
  160. }
  161. msg.Body[i] = fdArray
  162. }
  163. }
  164. return msg, nil
  165. }
  166. return DecodeMessage(bytes.NewBuffer(all))
  167. }
  168. func (t *unixTransport) SendMessage(msg *Message) error {
  169. fds := make([]int, 0)
  170. for i, v := range msg.Body {
  171. if fd, ok := v.(UnixFD); ok {
  172. msg.Body[i] = UnixFDIndex(len(fds))
  173. fds = append(fds, int(fd))
  174. }
  175. }
  176. if len(fds) != 0 {
  177. if !t.hasUnixFDs {
  178. return errors.New("dbus: unix fd passing not enabled")
  179. }
  180. msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
  181. oob := syscall.UnixRights(fds...)
  182. buf := new(bytes.Buffer)
  183. msg.EncodeTo(buf, nativeEndian)
  184. n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
  185. if err != nil {
  186. return err
  187. }
  188. if n != buf.Len() || oobn != len(oob) {
  189. return io.ErrShortWrite
  190. }
  191. } else {
  192. if err := msg.EncodeTo(t, nativeEndian); err != nil {
  193. return nil
  194. }
  195. }
  196. return nil
  197. }
  198. func (t *unixTransport) SupportsUnixFDs() bool {
  199. return true
  200. }