diff options
Diffstat (limited to 'utils/dns')
| -rw-r--r-- | utils/dns/defaults.go | 7 | ||||
| -rw-r--r-- | utils/dns/messages.go | 11 | ||||
| -rw-r--r-- | utils/dns/server.go | 153 | ||||
| -rw-r--r-- | utils/dns/upstream.go | 153 |
4 files changed, 324 insertions, 0 deletions
diff --git a/utils/dns/defaults.go b/utils/dns/defaults.go new file mode 100644 index 0000000..dc9824a --- /dev/null +++ b/utils/dns/defaults.go @@ -0,0 +1,7 @@ +package dns + +const ( + LogPrefix = "DNS" + UpstreamTimeoutSeconds = 5 + UpstreamResponseTTL = 60 +) diff --git a/utils/dns/messages.go b/utils/dns/messages.go new file mode 100644 index 0000000..a73c8de --- /dev/null +++ b/utils/dns/messages.go @@ -0,0 +1,11 @@ +package dns + +const ( + ForwardFailed = "Failed to forward query for %s to upstream: %v" + ForwardSuccess = "Forwarded query for %s to upstream." + ListenFailed = "Failed to start DNS listener: %v" + QueryReceived = "Query: %s %s" + ServerStarting = "DNS server started on %s (UDP)." + ShutdownComplete = "DNS server stopped." + ShutdownFailed = "Failed to shutdown DNS server: %v" +) diff --git a/utils/dns/server.go b/utils/dns/server.go new file mode 100644 index 0000000..7ae8770 --- /dev/null +++ b/utils/dns/server.go @@ -0,0 +1,153 @@ +package dns + +import ( + "fmt" + "net" + + "dove/config" + dnsRepo "dove/repositories/dns" + "dove/utils/logger" + + mdns "github.com/miekg/dns" +) + +var activeServer *mdns.Server + +func Start() { + address := fmt.Sprintf("%s:%d", config.DNS.Host, config.DNS.Port) + + mdns.HandleFunc(".", handleQuery) + + activeServer = &mdns.Server{ + Addr: address, + Net: "udp", + } + + go func() { + logger.Successf(LogPrefix, ServerStarting, address) + + if listenError := activeServer.ListenAndServe(); listenError != nil { + logger.Fatalf(LogPrefix, ListenFailed, listenError) + } + }() +} + +func Shutdown() { + if activeServer == nil { + return + } + + if shutdownError := activeServer.Shutdown(); shutdownError != nil { + logger.Errorf(LogPrefix, ShutdownFailed, shutdownError) + } + + logger.Infof(LogPrefix, ShutdownComplete) +} + +func handleQuery(writer mdns.ResponseWriter, request *mdns.Msg) { + if len(request.Question) == 0 { + return + } + + question := request.Question[0] + queryName := question.Name + queryType := mdns.TypeToString[question.Qtype] + + logger.Debugf(LogPrefix, QueryReceived, queryType, queryName) + + if dnsRepo.IsLocalDomain(queryName) { + response := resolveLocal(request) + writer.WriteMsg(response) + return + } + + upstreamResponse := forwardToUpstream(request) + if upstreamResponse != nil { + writer.WriteMsg(upstreamResponse) + return + } + + refusedResponse := &mdns.Msg{} + refusedResponse.SetRcode(request, mdns.RcodeServerFailure) + writer.WriteMsg(refusedResponse) +} + +func resolveLocal(request *mdns.Msg) *mdns.Msg { + question := request.Question[0] + queryName := question.Name + + response := &mdns.Msg{} + response.SetReply(request) + response.Authoritative = true + response.RecursionAvailable = true + + switch question.Qtype { + case mdns.TypeA: + for _, record := range dnsRepo.ResolveA(queryName) { + parsedIP := net.ParseIP(record.Address) + if parsedIP == nil || parsedIP.To4() == nil { + continue + } + + response.Answer = append(response.Answer, &mdns.A{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: record.TTL}, + A: parsedIP.To4(), + }) + } + + case mdns.TypeAAAA: + for _, record := range dnsRepo.ResolveAAAA(queryName) { + parsedIP := net.ParseIP(record.Address) + if parsedIP == nil || parsedIP.To16() == nil { + continue + } + + response.Answer = append(response.Answer, &mdns.AAAA{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: record.TTL}, + AAAA: parsedIP.To16(), + }) + } + + case mdns.TypeCNAME: + for _, record := range dnsRepo.ResolveCNAME(queryName) { + response.Answer = append(response.Answer, &mdns.CNAME{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: record.TTL}, + Target: mdns.Fqdn(record.Target), + }) + } + + case mdns.TypeMX: + for _, record := range dnsRepo.ResolveMX(queryName) { + response.Answer = append(response.Answer, &mdns.MX{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: record.TTL}, + Preference: record.Priority, + Mx: mdns.Fqdn(record.Target), + }) + } + + case mdns.TypeTXT: + for _, record := range dnsRepo.ResolveTXT(queryName) { + response.Answer = append(response.Answer, &mdns.TXT{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: record.TTL}, + Txt: []string{record.Content}, + }) + } + + case mdns.TypeSRV: + for _, record := range dnsRepo.ResolveSRV(queryName) { + response.Answer = append(response.Answer, &mdns.SRV{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: record.TTL}, + Priority: record.Priority, + Weight: record.Weight, + Port: record.Port, + Target: mdns.Fqdn(record.Target), + }) + } + } + + if len(response.Answer) == 0 { + response.Rcode = mdns.RcodeNameError + } + + return response +} diff --git a/utils/dns/upstream.go b/utils/dns/upstream.go new file mode 100644 index 0000000..e2af583 --- /dev/null +++ b/utils/dns/upstream.go @@ -0,0 +1,153 @@ +package dns + +import ( + "context" + "net" + "time" + + "dove/utils/logger" + + mdns "github.com/miekg/dns" +) + +func forwardToUpstream(request *mdns.Msg) *mdns.Msg { + if len(request.Question) == 0 { + return nil + } + + question := request.Question[0] + queryName := question.Name + + resolver := &net.Resolver{PreferGo: false} + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(UpstreamTimeoutSeconds)*time.Second) + defer cancel() + + response := &mdns.Msg{} + response.SetReply(request) + response.Authoritative = false + response.RecursionAvailable = true + + switch question.Qtype { + case mdns.TypeA: + addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + for _, address := range addresses { + if ipv4 := address.IP.To4(); ipv4 != nil { + response.Answer = append(response.Answer, &mdns.A{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + A: ipv4, + }) + } + } + + case mdns.TypeAAAA: + addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + for _, address := range addresses { + if address.IP.To4() == nil && address.IP.To16() != nil { + response.Answer = append(response.Answer, &mdns.AAAA{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + AAAA: address.IP.To16(), + }) + } + } + + case mdns.TypeCNAME: + canonicalName, lookupError := resolver.LookupCNAME(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + response.Answer = append(response.Answer, &mdns.CNAME{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + Target: mdns.Fqdn(canonicalName), + }) + + case mdns.TypeMX: + mxRecords, lookupError := resolver.LookupMX(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + for _, mxRecord := range mxRecords { + response.Answer = append(response.Answer, &mdns.MX{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + Preference: mxRecord.Pref, + Mx: mdns.Fqdn(mxRecord.Host), + }) + } + + case mdns.TypeTXT: + txtRecords, lookupError := resolver.LookupTXT(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + if len(txtRecords) > 0 { + response.Answer = append(response.Answer, &mdns.TXT{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + Txt: txtRecords, + }) + } + + case mdns.TypeSRV: + _, srvRecords, lookupError := resolver.LookupSRV(ctx, "", "", mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + for _, srvRecord := range srvRecords { + response.Answer = append(response.Answer, &mdns.SRV{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + Priority: srvRecord.Priority, + Weight: srvRecord.Weight, + Port: srvRecord.Port, + Target: mdns.Fqdn(srvRecord.Target), + }) + } + + case mdns.TypeNS: + nsRecords, lookupError := resolver.LookupNS(ctx, mdns.Fqdn(queryName)) + if lookupError != nil { + logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) + response.Rcode = mdns.RcodeNameError + return response + } + + for _, nsRecord := range nsRecords { + response.Answer = append(response.Answer, &mdns.NS{ + Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeNS, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, + Ns: mdns.Fqdn(nsRecord.Host), + }) + } + + default: + response.Rcode = mdns.RcodeNotImplemented + return response + } + + if len(response.Answer) == 0 { + response.Rcode = mdns.RcodeNameError + } + + logger.Debugf(LogPrefix, ForwardSuccess, queryName) + return response +} |
