server.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. // DNS server implementation.
  2. package dns
  3. import (
  4. "bytes"
  5. "crypto/tls"
  6. "encoding/binary"
  7. "io"
  8. "net"
  9. "sync"
  10. "time"
  11. )
  12. // Maximum number of TCP queries before we close the socket.
  13. const maxTCPQueries = 128
  14. // Handler is implemented by any value that implements ServeDNS.
  15. type Handler interface {
  16. ServeDNS(w ResponseWriter, r *Msg)
  17. }
  18. // A ResponseWriter interface is used by an DNS handler to
  19. // construct an DNS response.
  20. type ResponseWriter interface {
  21. // LocalAddr returns the net.Addr of the server
  22. LocalAddr() net.Addr
  23. // RemoteAddr returns the net.Addr of the client that sent the current request.
  24. RemoteAddr() net.Addr
  25. // WriteMsg writes a reply back to the client.
  26. WriteMsg(*Msg) error
  27. // Write writes a raw buffer back to the client.
  28. Write([]byte) (int, error)
  29. // Close closes the connection.
  30. Close() error
  31. // TsigStatus returns the status of the Tsig.
  32. TsigStatus() error
  33. // TsigTimersOnly sets the tsig timers only boolean.
  34. TsigTimersOnly(bool)
  35. // Hijack lets the caller take over the connection.
  36. // After a call to Hijack(), the DNS package will not do anything with the connection.
  37. Hijack()
  38. }
  39. type response struct {
  40. hijacked bool // connection has been hijacked by handler
  41. tsigStatus error
  42. tsigTimersOnly bool
  43. tsigRequestMAC string
  44. tsigSecret map[string]string // the tsig secrets
  45. udp *net.UDPConn // i/o connection if UDP was used
  46. tcp net.Conn // i/o connection if TCP was used
  47. udpSession *SessionUDP // oob data to get egress interface right
  48. remoteAddr net.Addr // address of the client
  49. writer Writer // writer to output the raw DNS bits
  50. }
  51. // ServeMux is an DNS request multiplexer. It matches the
  52. // zone name of each incoming request against a list of
  53. // registered patterns add calls the handler for the pattern
  54. // that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
  55. // that queries for the DS record are redirected to the parent zone (if that
  56. // is also registered), otherwise the child gets the query.
  57. // ServeMux is also safe for concurrent access from multiple goroutines.
  58. type ServeMux struct {
  59. z map[string]Handler
  60. m *sync.RWMutex
  61. }
  62. // NewServeMux allocates and returns a new ServeMux.
  63. func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
  64. // DefaultServeMux is the default ServeMux used by Serve.
  65. var DefaultServeMux = NewServeMux()
  66. // The HandlerFunc type is an adapter to allow the use of
  67. // ordinary functions as DNS handlers. If f is a function
  68. // with the appropriate signature, HandlerFunc(f) is a
  69. // Handler object that calls f.
  70. type HandlerFunc func(ResponseWriter, *Msg)
  71. // ServeDNS calls f(w, r).
  72. func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
  73. f(w, r)
  74. }
  75. // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
  76. func HandleFailed(w ResponseWriter, r *Msg) {
  77. m := new(Msg)
  78. m.SetRcode(r, RcodeServerFailure)
  79. // does not matter if this write fails
  80. w.WriteMsg(m)
  81. }
  82. func failedHandler() Handler { return HandlerFunc(HandleFailed) }
  83. // ListenAndServe Starts a server on address and network specified Invoke handler
  84. // for incoming queries.
  85. func ListenAndServe(addr string, network string, handler Handler) error {
  86. server := &Server{Addr: addr, Net: network, Handler: handler}
  87. return server.ListenAndServe()
  88. }
  89. // ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
  90. // http://golang.org/pkg/net/http/#ListenAndServeTLS
  91. func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
  92. cert, err := tls.LoadX509KeyPair(certFile, keyFile)
  93. if err != nil {
  94. return err
  95. }
  96. config := tls.Config{
  97. Certificates: []tls.Certificate{cert},
  98. }
  99. server := &Server{
  100. Addr: addr,
  101. Net: "tcp-tls",
  102. TLSConfig: &config,
  103. Handler: handler,
  104. }
  105. return server.ListenAndServe()
  106. }
  107. // ActivateAndServe activates a server with a listener from systemd,
  108. // l and p should not both be non-nil.
  109. // If both l and p are not nil only p will be used.
  110. // Invoke handler for incoming queries.
  111. func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
  112. server := &Server{Listener: l, PacketConn: p, Handler: handler}
  113. return server.ActivateAndServe()
  114. }
  115. func (mux *ServeMux) match(q string, t uint16) Handler {
  116. mux.m.RLock()
  117. defer mux.m.RUnlock()
  118. var handler Handler
  119. b := make([]byte, len(q)) // worst case, one label of length q
  120. off := 0
  121. end := false
  122. for {
  123. l := len(q[off:])
  124. for i := 0; i < l; i++ {
  125. b[i] = q[off+i]
  126. if b[i] >= 'A' && b[i] <= 'Z' {
  127. b[i] |= ('a' - 'A')
  128. }
  129. }
  130. if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key
  131. if t != TypeDS {
  132. return h
  133. }
  134. // Continue for DS to see if we have a parent too, if so delegeate to the parent
  135. handler = h
  136. }
  137. off, end = NextLabel(q, off)
  138. if end {
  139. break
  140. }
  141. }
  142. // Wildcard match, if we have found nothing try the root zone as a last resort.
  143. if h, ok := mux.z["."]; ok {
  144. return h
  145. }
  146. return handler
  147. }
  148. // Handle adds a handler to the ServeMux for pattern.
  149. func (mux *ServeMux) Handle(pattern string, handler Handler) {
  150. if pattern == "" {
  151. panic("dns: invalid pattern " + pattern)
  152. }
  153. mux.m.Lock()
  154. mux.z[Fqdn(pattern)] = handler
  155. mux.m.Unlock()
  156. }
  157. // HandleFunc adds a handler function to the ServeMux for pattern.
  158. func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
  159. mux.Handle(pattern, HandlerFunc(handler))
  160. }
  161. // HandleRemove deregistrars the handler specific for pattern from the ServeMux.
  162. func (mux *ServeMux) HandleRemove(pattern string) {
  163. if pattern == "" {
  164. panic("dns: invalid pattern " + pattern)
  165. }
  166. mux.m.Lock()
  167. delete(mux.z, Fqdn(pattern))
  168. mux.m.Unlock()
  169. }
  170. // ServeDNS dispatches the request to the handler whose
  171. // pattern most closely matches the request message. If DefaultServeMux
  172. // is used the correct thing for DS queries is done: a possible parent
  173. // is sought.
  174. // If no handler is found a standard SERVFAIL message is returned
  175. // If the request message does not have exactly one question in the
  176. // question section a SERVFAIL is returned, unlesss Unsafe is true.
  177. func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
  178. var h Handler
  179. if len(request.Question) < 1 { // allow more than one question
  180. h = failedHandler()
  181. } else {
  182. if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
  183. h = failedHandler()
  184. }
  185. }
  186. h.ServeDNS(w, request)
  187. }
  188. // Handle registers the handler with the given pattern
  189. // in the DefaultServeMux. The documentation for
  190. // ServeMux explains how patterns are matched.
  191. func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
  192. // HandleRemove deregisters the handle with the given pattern
  193. // in the DefaultServeMux.
  194. func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
  195. // HandleFunc registers the handler function with the given pattern
  196. // in the DefaultServeMux.
  197. func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
  198. DefaultServeMux.HandleFunc(pattern, handler)
  199. }
  200. // Writer writes raw DNS messages; each call to Write should send an entire message.
  201. type Writer interface {
  202. io.Writer
  203. }
  204. // Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
  205. type Reader interface {
  206. // ReadTCP reads a raw message from a TCP connection. Implementations may alter
  207. // connection properties, for example the read-deadline.
  208. ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
  209. // ReadUDP reads a raw message from a UDP connection. Implementations may alter
  210. // connection properties, for example the read-deadline.
  211. ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
  212. }
  213. // defaultReader is an adapter for the Server struct that implements the Reader interface
  214. // using the readTCP and readUDP func of the embedded Server.
  215. type defaultReader struct {
  216. *Server
  217. }
  218. func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
  219. return dr.readTCP(conn, timeout)
  220. }
  221. func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
  222. return dr.readUDP(conn, timeout)
  223. }
  224. // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
  225. // Implementations should never return a nil Reader.
  226. type DecorateReader func(Reader) Reader
  227. // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
  228. // Implementations should never return a nil Writer.
  229. type DecorateWriter func(Writer) Writer
  230. // A Server defines parameters for running an DNS server.
  231. type Server struct {
  232. // Address to listen on, ":dns" if empty.
  233. Addr string
  234. // if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
  235. Net string
  236. // TCP Listener to use, this is to aid in systemd's socket activation.
  237. Listener net.Listener
  238. // TLS connection configuration
  239. TLSConfig *tls.Config
  240. // UDP "Listener" to use, this is to aid in systemd's socket activation.
  241. PacketConn net.PacketConn
  242. // Handler to invoke, dns.DefaultServeMux if nil.
  243. Handler Handler
  244. // Default buffer size to use to read incoming UDP messages. If not set
  245. // it defaults to MinMsgSize (512 B).
  246. UDPSize int
  247. // The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
  248. ReadTimeout time.Duration
  249. // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
  250. WriteTimeout time.Duration
  251. // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
  252. IdleTimeout func() time.Duration
  253. // Secret(s) for Tsig map[<zonename>]<base64 secret>.
  254. TsigSecret map[string]string
  255. // Unsafe instructs the server to disregard any sanity checks and directly hand the message to
  256. // the handler. It will specifically not check if the query has the QR bit not set.
  257. Unsafe bool
  258. // If NotifyStartedFunc is set it is called once the server has started listening.
  259. NotifyStartedFunc func()
  260. // DecorateReader is optional, allows customization of the process that reads raw DNS messages.
  261. DecorateReader DecorateReader
  262. // DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
  263. DecorateWriter DecorateWriter
  264. // Graceful shutdown handling
  265. inFlight sync.WaitGroup
  266. lock sync.RWMutex
  267. started bool
  268. }
  269. // ListenAndServe starts a nameserver on the configured address in *Server.
  270. func (srv *Server) ListenAndServe() error {
  271. srv.lock.Lock()
  272. defer srv.lock.Unlock()
  273. if srv.started {
  274. return &Error{err: "server already started"}
  275. }
  276. addr := srv.Addr
  277. if addr == "" {
  278. addr = ":domain"
  279. }
  280. if srv.UDPSize == 0 {
  281. srv.UDPSize = MinMsgSize
  282. }
  283. switch srv.Net {
  284. case "tcp", "tcp4", "tcp6":
  285. a, err := net.ResolveTCPAddr(srv.Net, addr)
  286. if err != nil {
  287. return err
  288. }
  289. l, err := net.ListenTCP(srv.Net, a)
  290. if err != nil {
  291. return err
  292. }
  293. srv.Listener = l
  294. srv.started = true
  295. srv.lock.Unlock()
  296. err = srv.serveTCP(l)
  297. srv.lock.Lock() // to satisfy the defer at the top
  298. return err
  299. case "tcp-tls", "tcp4-tls", "tcp6-tls":
  300. network := "tcp"
  301. if srv.Net == "tcp4-tls" {
  302. network = "tcp4"
  303. } else if srv.Net == "tcp6" {
  304. network = "tcp6"
  305. }
  306. l, err := tls.Listen(network, addr, srv.TLSConfig)
  307. if err != nil {
  308. return err
  309. }
  310. srv.Listener = l
  311. srv.started = true
  312. srv.lock.Unlock()
  313. err = srv.serveTCP(l)
  314. srv.lock.Lock() // to satisfy the defer at the top
  315. return err
  316. case "udp", "udp4", "udp6":
  317. a, err := net.ResolveUDPAddr(srv.Net, addr)
  318. if err != nil {
  319. return err
  320. }
  321. l, err := net.ListenUDP(srv.Net, a)
  322. if err != nil {
  323. return err
  324. }
  325. if e := setUDPSocketOptions(l); e != nil {
  326. return e
  327. }
  328. srv.PacketConn = l
  329. srv.started = true
  330. srv.lock.Unlock()
  331. err = srv.serveUDP(l)
  332. srv.lock.Lock() // to satisfy the defer at the top
  333. return err
  334. }
  335. return &Error{err: "bad network"}
  336. }
  337. // ActivateAndServe starts a nameserver with the PacketConn or Listener
  338. // configured in *Server. Its main use is to start a server from systemd.
  339. func (srv *Server) ActivateAndServe() error {
  340. srv.lock.Lock()
  341. defer srv.lock.Unlock()
  342. if srv.started {
  343. return &Error{err: "server already started"}
  344. }
  345. pConn := srv.PacketConn
  346. l := srv.Listener
  347. if pConn != nil {
  348. if srv.UDPSize == 0 {
  349. srv.UDPSize = MinMsgSize
  350. }
  351. if t, ok := pConn.(*net.UDPConn); ok {
  352. if e := setUDPSocketOptions(t); e != nil {
  353. return e
  354. }
  355. srv.started = true
  356. srv.lock.Unlock()
  357. e := srv.serveUDP(t)
  358. srv.lock.Lock() // to satisfy the defer at the top
  359. return e
  360. }
  361. }
  362. if l != nil {
  363. srv.started = true
  364. srv.lock.Unlock()
  365. e := srv.serveTCP(l)
  366. srv.lock.Lock() // to satisfy the defer at the top
  367. return e
  368. }
  369. return &Error{err: "bad listeners"}
  370. }
  371. // Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
  372. // ActivateAndServe will return. All in progress queries are completed before the server
  373. // is taken down. If the Shutdown is taking longer than the reading timeout an error
  374. // is returned.
  375. func (srv *Server) Shutdown() error {
  376. srv.lock.Lock()
  377. if !srv.started {
  378. srv.lock.Unlock()
  379. return &Error{err: "server not started"}
  380. }
  381. srv.started = false
  382. srv.lock.Unlock()
  383. if srv.PacketConn != nil {
  384. srv.PacketConn.Close()
  385. }
  386. if srv.Listener != nil {
  387. srv.Listener.Close()
  388. }
  389. fin := make(chan bool)
  390. go func() {
  391. srv.inFlight.Wait()
  392. fin <- true
  393. }()
  394. select {
  395. case <-time.After(srv.getReadTimeout()):
  396. return &Error{err: "server shutdown is pending"}
  397. case <-fin:
  398. return nil
  399. }
  400. }
  401. // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
  402. func (srv *Server) getReadTimeout() time.Duration {
  403. rtimeout := dnsTimeout
  404. if srv.ReadTimeout != 0 {
  405. rtimeout = srv.ReadTimeout
  406. }
  407. return rtimeout
  408. }
  409. // serveTCP starts a TCP listener for the server.
  410. // Each request is handled in a separate goroutine.
  411. func (srv *Server) serveTCP(l net.Listener) error {
  412. defer l.Close()
  413. if srv.NotifyStartedFunc != nil {
  414. srv.NotifyStartedFunc()
  415. }
  416. reader := Reader(&defaultReader{srv})
  417. if srv.DecorateReader != nil {
  418. reader = srv.DecorateReader(reader)
  419. }
  420. handler := srv.Handler
  421. if handler == nil {
  422. handler = DefaultServeMux
  423. }
  424. rtimeout := srv.getReadTimeout()
  425. // deadline is not used here
  426. for {
  427. rw, err := l.Accept()
  428. if err != nil {
  429. if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
  430. continue
  431. }
  432. return err
  433. }
  434. m, err := reader.ReadTCP(rw, rtimeout)
  435. srv.lock.RLock()
  436. if !srv.started {
  437. srv.lock.RUnlock()
  438. return nil
  439. }
  440. srv.lock.RUnlock()
  441. if err != nil {
  442. continue
  443. }
  444. srv.inFlight.Add(1)
  445. go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
  446. }
  447. }
  448. // serveUDP starts a UDP listener for the server.
  449. // Each request is handled in a separate goroutine.
  450. func (srv *Server) serveUDP(l *net.UDPConn) error {
  451. defer l.Close()
  452. if srv.NotifyStartedFunc != nil {
  453. srv.NotifyStartedFunc()
  454. }
  455. reader := Reader(&defaultReader{srv})
  456. if srv.DecorateReader != nil {
  457. reader = srv.DecorateReader(reader)
  458. }
  459. handler := srv.Handler
  460. if handler == nil {
  461. handler = DefaultServeMux
  462. }
  463. rtimeout := srv.getReadTimeout()
  464. // deadline is not used here
  465. for {
  466. m, s, err := reader.ReadUDP(l, rtimeout)
  467. srv.lock.RLock()
  468. if !srv.started {
  469. srv.lock.RUnlock()
  470. return nil
  471. }
  472. srv.lock.RUnlock()
  473. if err != nil {
  474. continue
  475. }
  476. srv.inFlight.Add(1)
  477. go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
  478. }
  479. }
  480. // Serve a new connection.
  481. func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
  482. defer srv.inFlight.Done()
  483. w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
  484. if srv.DecorateWriter != nil {
  485. w.writer = srv.DecorateWriter(w)
  486. } else {
  487. w.writer = w
  488. }
  489. q := 0 // counter for the amount of TCP queries we get
  490. reader := Reader(&defaultReader{srv})
  491. if srv.DecorateReader != nil {
  492. reader = srv.DecorateReader(reader)
  493. }
  494. Redo:
  495. req := new(Msg)
  496. err := req.Unpack(m)
  497. if err != nil { // Send a FormatError back
  498. x := new(Msg)
  499. x.SetRcodeFormatError(req)
  500. w.WriteMsg(x)
  501. goto Exit
  502. }
  503. if !srv.Unsafe && req.Response {
  504. goto Exit
  505. }
  506. w.tsigStatus = nil
  507. if w.tsigSecret != nil {
  508. if t := req.IsTsig(); t != nil {
  509. secret := t.Hdr.Name
  510. if _, ok := w.tsigSecret[secret]; !ok {
  511. w.tsigStatus = ErrKeyAlg
  512. }
  513. w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
  514. w.tsigTimersOnly = false
  515. w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
  516. }
  517. }
  518. h.ServeDNS(w, req) // Writes back to the client
  519. Exit:
  520. if w.tcp == nil {
  521. return
  522. }
  523. // TODO(miek): make this number configurable?
  524. if q > maxTCPQueries { // close socket after this many queries
  525. w.Close()
  526. return
  527. }
  528. if w.hijacked {
  529. return // client calls Close()
  530. }
  531. if u != nil { // UDP, "close" and return
  532. w.Close()
  533. return
  534. }
  535. idleTimeout := tcpIdleTimeout
  536. if srv.IdleTimeout != nil {
  537. idleTimeout = srv.IdleTimeout()
  538. }
  539. m, err = reader.ReadTCP(w.tcp, idleTimeout)
  540. if err == nil {
  541. q++
  542. goto Redo
  543. }
  544. w.Close()
  545. return
  546. }
  547. func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
  548. conn.SetReadDeadline(time.Now().Add(timeout))
  549. l := make([]byte, 2)
  550. n, err := conn.Read(l)
  551. if err != nil || n != 2 {
  552. if err != nil {
  553. return nil, err
  554. }
  555. return nil, ErrShortRead
  556. }
  557. length := binary.BigEndian.Uint16(l)
  558. if length == 0 {
  559. return nil, ErrShortRead
  560. }
  561. m := make([]byte, int(length))
  562. n, err = conn.Read(m[:int(length)])
  563. if err != nil || n == 0 {
  564. if err != nil {
  565. return nil, err
  566. }
  567. return nil, ErrShortRead
  568. }
  569. i := n
  570. for i < int(length) {
  571. j, err := conn.Read(m[i:int(length)])
  572. if err != nil {
  573. return nil, err
  574. }
  575. i += j
  576. }
  577. n = i
  578. m = m[:n]
  579. return m, nil
  580. }
  581. func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
  582. conn.SetReadDeadline(time.Now().Add(timeout))
  583. m := make([]byte, srv.UDPSize)
  584. n, s, err := ReadFromSessionUDP(conn, m)
  585. if err != nil || n == 0 {
  586. if err != nil {
  587. return nil, nil, err
  588. }
  589. return nil, nil, ErrShortRead
  590. }
  591. m = m[:n]
  592. return m, s, nil
  593. }
  594. // WriteMsg implements the ResponseWriter.WriteMsg method.
  595. func (w *response) WriteMsg(m *Msg) (err error) {
  596. var data []byte
  597. if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
  598. if t := m.IsTsig(); t != nil {
  599. data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
  600. if err != nil {
  601. return err
  602. }
  603. _, err = w.writer.Write(data)
  604. return err
  605. }
  606. }
  607. data, err = m.Pack()
  608. if err != nil {
  609. return err
  610. }
  611. _, err = w.writer.Write(data)
  612. return err
  613. }
  614. // Write implements the ResponseWriter.Write method.
  615. func (w *response) Write(m []byte) (int, error) {
  616. switch {
  617. case w.udp != nil:
  618. n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
  619. return n, err
  620. case w.tcp != nil:
  621. lm := len(m)
  622. if lm < 2 {
  623. return 0, io.ErrShortBuffer
  624. }
  625. if lm > MaxMsgSize {
  626. return 0, &Error{err: "message too large"}
  627. }
  628. l := make([]byte, 2, 2+lm)
  629. binary.BigEndian.PutUint16(l, uint16(lm))
  630. m = append(l, m...)
  631. n, err := io.Copy(w.tcp, bytes.NewReader(m))
  632. return int(n), err
  633. }
  634. panic("not reached")
  635. }
  636. // LocalAddr implements the ResponseWriter.LocalAddr method.
  637. func (w *response) LocalAddr() net.Addr {
  638. if w.tcp != nil {
  639. return w.tcp.LocalAddr()
  640. }
  641. return w.udp.LocalAddr()
  642. }
  643. // RemoteAddr implements the ResponseWriter.RemoteAddr method.
  644. func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
  645. // TsigStatus implements the ResponseWriter.TsigStatus method.
  646. func (w *response) TsigStatus() error { return w.tsigStatus }
  647. // TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
  648. func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
  649. // Hijack implements the ResponseWriter.Hijack method.
  650. func (w *response) Hijack() { w.hijacked = true }
  651. // Close implements the ResponseWriter.Close method
  652. func (w *response) Close() error {
  653. // Can't close the udp conn, as that is actually the listener.
  654. if w.tcp != nil {
  655. e := w.tcp.Close()
  656. w.tcp = nil
  657. return e
  658. }
  659. return nil
  660. }