aboutsummaryrefslogtreecommitdiff
path: root/utils/dns
diff options
context:
space:
mode:
Diffstat (limited to 'utils/dns')
-rw-r--r--utils/dns/defaults.go7
-rw-r--r--utils/dns/messages.go11
-rw-r--r--utils/dns/server.go153
-rw-r--r--utils/dns/upstream.go153
4 files changed, 324 insertions, 0 deletions
diff --git a/utils/dns/defaults.go b/utils/dns/defaults.go
new file mode 100644
index 0000000..dc9824a
--- /dev/null
+++ b/utils/dns/defaults.go
@@ -0,0 +1,7 @@
+package dns
+
+const (
+ LogPrefix = "DNS"
+ UpstreamTimeoutSeconds = 5
+ UpstreamResponseTTL = 60
+)
diff --git a/utils/dns/messages.go b/utils/dns/messages.go
new file mode 100644
index 0000000..a73c8de
--- /dev/null
+++ b/utils/dns/messages.go
@@ -0,0 +1,11 @@
+package dns
+
+const (
+ ForwardFailed = "Failed to forward query for %s to upstream: %v"
+ ForwardSuccess = "Forwarded query for %s to upstream."
+ ListenFailed = "Failed to start DNS listener: %v"
+ QueryReceived = "Query: %s %s"
+ ServerStarting = "DNS server started on %s (UDP)."
+ ShutdownComplete = "DNS server stopped."
+ ShutdownFailed = "Failed to shutdown DNS server: %v"
+)
diff --git a/utils/dns/server.go b/utils/dns/server.go
new file mode 100644
index 0000000..7ae8770
--- /dev/null
+++ b/utils/dns/server.go
@@ -0,0 +1,153 @@
+package dns
+
+import (
+ "fmt"
+ "net"
+
+ "dove/config"
+ dnsRepo "dove/repositories/dns"
+ "dove/utils/logger"
+
+ mdns "github.com/miekg/dns"
+)
+
+var activeServer *mdns.Server
+
+func Start() {
+ address := fmt.Sprintf("%s:%d", config.DNS.Host, config.DNS.Port)
+
+ mdns.HandleFunc(".", handleQuery)
+
+ activeServer = &mdns.Server{
+ Addr: address,
+ Net: "udp",
+ }
+
+ go func() {
+ logger.Successf(LogPrefix, ServerStarting, address)
+
+ if listenError := activeServer.ListenAndServe(); listenError != nil {
+ logger.Fatalf(LogPrefix, ListenFailed, listenError)
+ }
+ }()
+}
+
+func Shutdown() {
+ if activeServer == nil {
+ return
+ }
+
+ if shutdownError := activeServer.Shutdown(); shutdownError != nil {
+ logger.Errorf(LogPrefix, ShutdownFailed, shutdownError)
+ }
+
+ logger.Infof(LogPrefix, ShutdownComplete)
+}
+
+func handleQuery(writer mdns.ResponseWriter, request *mdns.Msg) {
+ if len(request.Question) == 0 {
+ return
+ }
+
+ question := request.Question[0]
+ queryName := question.Name
+ queryType := mdns.TypeToString[question.Qtype]
+
+ logger.Debugf(LogPrefix, QueryReceived, queryType, queryName)
+
+ if dnsRepo.IsLocalDomain(queryName) {
+ response := resolveLocal(request)
+ writer.WriteMsg(response)
+ return
+ }
+
+ upstreamResponse := forwardToUpstream(request)
+ if upstreamResponse != nil {
+ writer.WriteMsg(upstreamResponse)
+ return
+ }
+
+ refusedResponse := &mdns.Msg{}
+ refusedResponse.SetRcode(request, mdns.RcodeServerFailure)
+ writer.WriteMsg(refusedResponse)
+}
+
+func resolveLocal(request *mdns.Msg) *mdns.Msg {
+ question := request.Question[0]
+ queryName := question.Name
+
+ response := &mdns.Msg{}
+ response.SetReply(request)
+ response.Authoritative = true
+ response.RecursionAvailable = true
+
+ switch question.Qtype {
+ case mdns.TypeA:
+ for _, record := range dnsRepo.ResolveA(queryName) {
+ parsedIP := net.ParseIP(record.Address)
+ if parsedIP == nil || parsedIP.To4() == nil {
+ continue
+ }
+
+ response.Answer = append(response.Answer, &mdns.A{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: record.TTL},
+ A: parsedIP.To4(),
+ })
+ }
+
+ case mdns.TypeAAAA:
+ for _, record := range dnsRepo.ResolveAAAA(queryName) {
+ parsedIP := net.ParseIP(record.Address)
+ if parsedIP == nil || parsedIP.To16() == nil {
+ continue
+ }
+
+ response.Answer = append(response.Answer, &mdns.AAAA{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: record.TTL},
+ AAAA: parsedIP.To16(),
+ })
+ }
+
+ case mdns.TypeCNAME:
+ for _, record := range dnsRepo.ResolveCNAME(queryName) {
+ response.Answer = append(response.Answer, &mdns.CNAME{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: record.TTL},
+ Target: mdns.Fqdn(record.Target),
+ })
+ }
+
+ case mdns.TypeMX:
+ for _, record := range dnsRepo.ResolveMX(queryName) {
+ response.Answer = append(response.Answer, &mdns.MX{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: record.TTL},
+ Preference: record.Priority,
+ Mx: mdns.Fqdn(record.Target),
+ })
+ }
+
+ case mdns.TypeTXT:
+ for _, record := range dnsRepo.ResolveTXT(queryName) {
+ response.Answer = append(response.Answer, &mdns.TXT{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: record.TTL},
+ Txt: []string{record.Content},
+ })
+ }
+
+ case mdns.TypeSRV:
+ for _, record := range dnsRepo.ResolveSRV(queryName) {
+ response.Answer = append(response.Answer, &mdns.SRV{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: record.TTL},
+ Priority: record.Priority,
+ Weight: record.Weight,
+ Port: record.Port,
+ Target: mdns.Fqdn(record.Target),
+ })
+ }
+ }
+
+ if len(response.Answer) == 0 {
+ response.Rcode = mdns.RcodeNameError
+ }
+
+ return response
+}
diff --git a/utils/dns/upstream.go b/utils/dns/upstream.go
new file mode 100644
index 0000000..e2af583
--- /dev/null
+++ b/utils/dns/upstream.go
@@ -0,0 +1,153 @@
+package dns
+
+import (
+ "context"
+ "net"
+ "time"
+
+ "dove/utils/logger"
+
+ mdns "github.com/miekg/dns"
+)
+
+func forwardToUpstream(request *mdns.Msg) *mdns.Msg {
+ if len(request.Question) == 0 {
+ return nil
+ }
+
+ question := request.Question[0]
+ queryName := question.Name
+
+ resolver := &net.Resolver{PreferGo: false}
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(UpstreamTimeoutSeconds)*time.Second)
+ defer cancel()
+
+ response := &mdns.Msg{}
+ response.SetReply(request)
+ response.Authoritative = false
+ response.RecursionAvailable = true
+
+ switch question.Qtype {
+ case mdns.TypeA:
+ addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ for _, address := range addresses {
+ if ipv4 := address.IP.To4(); ipv4 != nil {
+ response.Answer = append(response.Answer, &mdns.A{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ A: ipv4,
+ })
+ }
+ }
+
+ case mdns.TypeAAAA:
+ addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ for _, address := range addresses {
+ if address.IP.To4() == nil && address.IP.To16() != nil {
+ response.Answer = append(response.Answer, &mdns.AAAA{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ AAAA: address.IP.To16(),
+ })
+ }
+ }
+
+ case mdns.TypeCNAME:
+ canonicalName, lookupError := resolver.LookupCNAME(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ response.Answer = append(response.Answer, &mdns.CNAME{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ Target: mdns.Fqdn(canonicalName),
+ })
+
+ case mdns.TypeMX:
+ mxRecords, lookupError := resolver.LookupMX(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ for _, mxRecord := range mxRecords {
+ response.Answer = append(response.Answer, &mdns.MX{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ Preference: mxRecord.Pref,
+ Mx: mdns.Fqdn(mxRecord.Host),
+ })
+ }
+
+ case mdns.TypeTXT:
+ txtRecords, lookupError := resolver.LookupTXT(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ if len(txtRecords) > 0 {
+ response.Answer = append(response.Answer, &mdns.TXT{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ Txt: txtRecords,
+ })
+ }
+
+ case mdns.TypeSRV:
+ _, srvRecords, lookupError := resolver.LookupSRV(ctx, "", "", mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ for _, srvRecord := range srvRecords {
+ response.Answer = append(response.Answer, &mdns.SRV{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ Priority: srvRecord.Priority,
+ Weight: srvRecord.Weight,
+ Port: srvRecord.Port,
+ Target: mdns.Fqdn(srvRecord.Target),
+ })
+ }
+
+ case mdns.TypeNS:
+ nsRecords, lookupError := resolver.LookupNS(ctx, mdns.Fqdn(queryName))
+ if lookupError != nil {
+ logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError)
+ response.Rcode = mdns.RcodeNameError
+ return response
+ }
+
+ for _, nsRecord := range nsRecords {
+ response.Answer = append(response.Answer, &mdns.NS{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeNS, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL},
+ Ns: mdns.Fqdn(nsRecord.Host),
+ })
+ }
+
+ default:
+ response.Rcode = mdns.RcodeNotImplemented
+ return response
+ }
+
+ if len(response.Answer) == 0 {
+ response.Rcode = mdns.RcodeNameError
+ }
+
+ logger.Debugf(LogPrefix, ForwardSuccess, queryName)
+ return response
+}