aboutsummaryrefslogtreecommitdiff
path: root/utils/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'utils/dns/server.go')
-rw-r--r--utils/dns/server.go153
1 files changed, 153 insertions, 0 deletions
diff --git a/utils/dns/server.go b/utils/dns/server.go
new file mode 100644
index 0000000..7ae8770
--- /dev/null
+++ b/utils/dns/server.go
@@ -0,0 +1,153 @@
+package dns
+
+import (
+ "fmt"
+ "net"
+
+ "dove/config"
+ dnsRepo "dove/repositories/dns"
+ "dove/utils/logger"
+
+ mdns "github.com/miekg/dns"
+)
+
+var activeServer *mdns.Server
+
+func Start() {
+ address := fmt.Sprintf("%s:%d", config.DNS.Host, config.DNS.Port)
+
+ mdns.HandleFunc(".", handleQuery)
+
+ activeServer = &mdns.Server{
+ Addr: address,
+ Net: "udp",
+ }
+
+ go func() {
+ logger.Successf(LogPrefix, ServerStarting, address)
+
+ if listenError := activeServer.ListenAndServe(); listenError != nil {
+ logger.Fatalf(LogPrefix, ListenFailed, listenError)
+ }
+ }()
+}
+
+func Shutdown() {
+ if activeServer == nil {
+ return
+ }
+
+ if shutdownError := activeServer.Shutdown(); shutdownError != nil {
+ logger.Errorf(LogPrefix, ShutdownFailed, shutdownError)
+ }
+
+ logger.Infof(LogPrefix, ShutdownComplete)
+}
+
+func handleQuery(writer mdns.ResponseWriter, request *mdns.Msg) {
+ if len(request.Question) == 0 {
+ return
+ }
+
+ question := request.Question[0]
+ queryName := question.Name
+ queryType := mdns.TypeToString[question.Qtype]
+
+ logger.Debugf(LogPrefix, QueryReceived, queryType, queryName)
+
+ if dnsRepo.IsLocalDomain(queryName) {
+ response := resolveLocal(request)
+ writer.WriteMsg(response)
+ return
+ }
+
+ upstreamResponse := forwardToUpstream(request)
+ if upstreamResponse != nil {
+ writer.WriteMsg(upstreamResponse)
+ return
+ }
+
+ refusedResponse := &mdns.Msg{}
+ refusedResponse.SetRcode(request, mdns.RcodeServerFailure)
+ writer.WriteMsg(refusedResponse)
+}
+
+func resolveLocal(request *mdns.Msg) *mdns.Msg {
+ question := request.Question[0]
+ queryName := question.Name
+
+ response := &mdns.Msg{}
+ response.SetReply(request)
+ response.Authoritative = true
+ response.RecursionAvailable = true
+
+ switch question.Qtype {
+ case mdns.TypeA:
+ for _, record := range dnsRepo.ResolveA(queryName) {
+ parsedIP := net.ParseIP(record.Address)
+ if parsedIP == nil || parsedIP.To4() == nil {
+ continue
+ }
+
+ response.Answer = append(response.Answer, &mdns.A{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: record.TTL},
+ A: parsedIP.To4(),
+ })
+ }
+
+ case mdns.TypeAAAA:
+ for _, record := range dnsRepo.ResolveAAAA(queryName) {
+ parsedIP := net.ParseIP(record.Address)
+ if parsedIP == nil || parsedIP.To16() == nil {
+ continue
+ }
+
+ response.Answer = append(response.Answer, &mdns.AAAA{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: record.TTL},
+ AAAA: parsedIP.To16(),
+ })
+ }
+
+ case mdns.TypeCNAME:
+ for _, record := range dnsRepo.ResolveCNAME(queryName) {
+ response.Answer = append(response.Answer, &mdns.CNAME{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: record.TTL},
+ Target: mdns.Fqdn(record.Target),
+ })
+ }
+
+ case mdns.TypeMX:
+ for _, record := range dnsRepo.ResolveMX(queryName) {
+ response.Answer = append(response.Answer, &mdns.MX{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: record.TTL},
+ Preference: record.Priority,
+ Mx: mdns.Fqdn(record.Target),
+ })
+ }
+
+ case mdns.TypeTXT:
+ for _, record := range dnsRepo.ResolveTXT(queryName) {
+ response.Answer = append(response.Answer, &mdns.TXT{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: record.TTL},
+ Txt: []string{record.Content},
+ })
+ }
+
+ case mdns.TypeSRV:
+ for _, record := range dnsRepo.ResolveSRV(queryName) {
+ response.Answer = append(response.Answer, &mdns.SRV{
+ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: record.TTL},
+ Priority: record.Priority,
+ Weight: record.Weight,
+ Port: record.Port,
+ Target: mdns.Fqdn(record.Target),
+ })
+ }
+ }
+
+ if len(response.Answer) == 0 {
+ response.Rcode = mdns.RcodeNameError
+ }
+
+ return response
+}