package dns import ( "fmt" "net" "dove/config" dnsRepo "dove/repositories/dns" "dove/utils/logger" mdns "github.com/miekg/dns" ) var activeServer *mdns.Server func Start() { address := fmt.Sprintf("%s:%d", config.BindAddress, config.DnsPort) mdns.HandleFunc(".", handleQuery) activeServer = &mdns.Server{ Addr: address, Net: "udp", } go func() { logger.Successf(LogPrefix, ServerStarting, address) if listenError := activeServer.ListenAndServe(); listenError != nil { logger.Fatalf(LogPrefix, ListenFailed, listenError) } }() } func Shutdown() { if activeServer == nil { return } if shutdownError := activeServer.Shutdown(); shutdownError != nil { logger.Errorf(LogPrefix, ShutdownFailed, shutdownError) } logger.Infof(LogPrefix, ShutdownComplete) } func handleQuery(writer mdns.ResponseWriter, request *mdns.Msg) { if len(request.Question) == 0 { return } question := request.Question[0] queryName := question.Name queryType := mdns.TypeToString[question.Qtype] logger.Debugf(LogPrefix, QueryReceived, queryType, queryName) if dnsRepo.IsLocalDomain(queryName) { response := resolveLocal(request) writer.WriteMsg(response) return } upstreamResponse := forwardToUpstream(request) if upstreamResponse != nil { writer.WriteMsg(upstreamResponse) return } refusedResponse := &mdns.Msg{} refusedResponse.SetRcode(request, mdns.RcodeServerFailure) writer.WriteMsg(refusedResponse) } func resolveLocal(request *mdns.Msg) *mdns.Msg { question := request.Question[0] queryName := question.Name response := &mdns.Msg{} response.SetReply(request) response.Authoritative = true response.RecursionAvailable = true switch question.Qtype { case mdns.TypeA: for _, record := range dnsRepo.ResolveA(queryName) { parsedIP := net.ParseIP(record.Address) if parsedIP == nil || parsedIP.To4() == nil { continue } response.Answer = append(response.Answer, &mdns.A{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: record.TTL}, A: parsedIP.To4(), }) } case mdns.TypeAAAA: for _, record := range dnsRepo.ResolveAAAA(queryName) { parsedIP := net.ParseIP(record.Address) if parsedIP == nil || parsedIP.To16() == nil { continue } response.Answer = append(response.Answer, &mdns.AAAA{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: record.TTL}, AAAA: parsedIP.To16(), }) } case mdns.TypeCNAME: for _, record := range dnsRepo.ResolveCNAME(queryName) { response.Answer = append(response.Answer, &mdns.CNAME{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: record.TTL}, Target: mdns.Fqdn(record.Target), }) } case mdns.TypeMX: for _, record := range dnsRepo.ResolveMX(queryName) { response.Answer = append(response.Answer, &mdns.MX{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: record.TTL}, Preference: record.Priority, Mx: mdns.Fqdn(record.Target), }) } case mdns.TypeTXT: for _, record := range dnsRepo.ResolveTXT(queryName) { response.Answer = append(response.Answer, &mdns.TXT{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: record.TTL}, Txt: []string{record.Content}, }) } case mdns.TypeSRV: for _, record := range dnsRepo.ResolveSRV(queryName) { response.Answer = append(response.Answer, &mdns.SRV{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: record.TTL}, Priority: record.Priority, Weight: record.Weight, Port: record.Port, Target: mdns.Fqdn(record.Target), }) } } if len(response.Answer) == 0 { response.Rcode = mdns.RcodeNameError } return response }