package dns import ( "context" "net" "time" "dove/utils/logger" mdns "github.com/miekg/dns" ) func forwardToUpstream(request *mdns.Msg) *mdns.Msg { if len(request.Question) == 0 { return nil } question := request.Question[0] queryName := question.Name resolver := &net.Resolver{PreferGo: false} ctx, cancel := context.WithTimeout(context.Background(), time.Duration(UpstreamTimeoutSeconds)*time.Second) defer cancel() response := &mdns.Msg{} response.SetReply(request) response.Authoritative = false response.RecursionAvailable = true switch question.Qtype { case mdns.TypeA: addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } for _, address := range addresses { if ipv4 := address.IP.To4(); ipv4 != nil { response.Answer = append(response.Answer, &mdns.A{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, A: ipv4, }) } } case mdns.TypeAAAA: addresses, lookupError := resolver.LookupIPAddr(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } for _, address := range addresses { if address.IP.To4() == nil && address.IP.To16() != nil { response.Answer = append(response.Answer, &mdns.AAAA{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeAAAA, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, AAAA: address.IP.To16(), }) } } case mdns.TypeCNAME: canonicalName, lookupError := resolver.LookupCNAME(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } response.Answer = append(response.Answer, &mdns.CNAME{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeCNAME, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, Target: mdns.Fqdn(canonicalName), }) case mdns.TypeMX: mxRecords, lookupError := resolver.LookupMX(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } for _, mxRecord := range mxRecords { response.Answer = append(response.Answer, &mdns.MX{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeMX, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, Preference: mxRecord.Pref, Mx: mdns.Fqdn(mxRecord.Host), }) } case mdns.TypeTXT: txtRecords, lookupError := resolver.LookupTXT(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } if len(txtRecords) > 0 { response.Answer = append(response.Answer, &mdns.TXT{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeTXT, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, Txt: txtRecords, }) } case mdns.TypeSRV: _, srvRecords, lookupError := resolver.LookupSRV(ctx, "", "", mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } for _, srvRecord := range srvRecords { response.Answer = append(response.Answer, &mdns.SRV{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeSRV, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, Priority: srvRecord.Priority, Weight: srvRecord.Weight, Port: srvRecord.Port, Target: mdns.Fqdn(srvRecord.Target), }) } case mdns.TypeNS: nsRecords, lookupError := resolver.LookupNS(ctx, mdns.Fqdn(queryName)) if lookupError != nil { logger.Debugf(LogPrefix, ForwardFailed, queryName, lookupError) response.Rcode = mdns.RcodeNameError return response } for _, nsRecord := range nsRecords { response.Answer = append(response.Answer, &mdns.NS{ Hdr: mdns.RR_Header{Name: queryName, Rrtype: mdns.TypeNS, Class: mdns.ClassINET, Ttl: UpstreamResponseTTL}, Ns: mdns.Fqdn(nsRecord.Host), }) } default: response.Rcode = mdns.RcodeNotImplemented return response } if len(response.Answer) == 0 { response.Rcode = mdns.RcodeNameError } logger.Debugf(LogPrefix, ForwardSuccess, queryName) return response }