Skip to content

Commit

Permalink
feat: improve errors and logs related to DNS call (#2109)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored Feb 11, 2024
1 parent 7fe1796 commit ba67a26
Show file tree
Hide file tree
Showing 87 changed files with 314 additions and 182 deletions.
3 changes: 2 additions & 1 deletion challenge/dns01/dns_challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"os"
"strconv"
"strings"
"time"

"github.com/go-acme/lego/v4/acme"
Expand Down Expand Up @@ -124,7 +125,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
}

log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
log.Infof("[%s] acme: Checking DNS record propagation. [nameservers=%s]", domain, strings.Join(recursiveNameservers, ","))

time.Sleep(interval)

Expand Down
9 changes: 6 additions & 3 deletions challenge/dns01/dns_challenge_manual.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ func (*DNSProviderManual) Present(domain, token, keyAuth string) error {

authZone, err := FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("manual: could not find zone: %w", err)
}

fmt.Printf("lego: Please create the following TXT record in your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", info.EffectiveFQDN, DefaultTTL, info.Value)
fmt.Printf("lego: Press 'Enter' when you are done\n")

_, err = bufio.NewReader(os.Stdin).ReadBytes('\n')
if err != nil {
return fmt.Errorf("manual: %w", err)
}

return err
return nil
}

// CleanUp prints instructions for manually removing the TXT record.
Expand All @@ -43,7 +46,7 @@ func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {

authZone, err := FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("manual: could not find zone: %w", err)
}

fmt.Printf("lego: You can now remove this TXT record from your %s zone:\n", authZone)
Expand Down
124 changes: 91 additions & 33 deletions challenge/dns01/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ func lookupNameservers(fqdn string) ([]string, error) {

zone, err := FindZoneByFqdn(fqdn)
if err != nil {
return nil, fmt.Errorf("could not determine the zone: %w", err)
return nil, fmt.Errorf("could not find zone: %w", err)
}

r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
if err != nil {
return nil, err
return nil, fmt.Errorf("NS call failed: %w", err)
}

for _, rr := range r.Answer {
Expand All @@ -116,7 +116,8 @@ func lookupNameservers(fqdn string) ([]string, error) {
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, errors.New("could not determine authoritative nameservers")

return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
}

// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
Expand All @@ -130,7 +131,7 @@ func FindPrimaryNsByFqdn(fqdn string) (string, error) {
func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil {
return "", err
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
}
return soa.primaryNs, nil
}
Expand All @@ -146,7 +147,7 @@ func FindZoneByFqdn(fqdn string) (string, error) {
func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil {
return "", err
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
}
return soa.zone, nil
}
Expand All @@ -171,35 +172,35 @@ func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error)

func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error
var in *dns.Msg
var r *dns.Msg

labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]

in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}

if in == nil {
if r == nil {
continue
}

switch in.Rcode {
switch r.Rcode {
case dns.RcodeSuccess:
// Check if we got a SOA RR in the answer section
if len(in.Answer) == 0 {
if len(r.Answer) == 0 {
continue
}

// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(in) {
if dnsMsgContainsCNAME(r) {
continue
}

for _, ans := range in.Answer {
for _, ans := range r.Answer {
if soa, ok := ans.(*dns.SOA); ok {
return newSoaCacheEntry(soa), nil
}
Expand All @@ -208,11 +209,11 @@ func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
}
}

return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
}

// dnsMsgContainsCNAME checks for a CNAME answer in msg.
Expand All @@ -226,16 +227,28 @@ func dnsMsgContainsCNAME(msg *dns.Msg) bool {
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)

var in *dns.Msg
if len(nameservers) == 0 {
return nil, &DNSError{Message: "empty list of nameservers"}
}

var r *dns.Msg
var err error
var errAll error

for _, ns := range nameservers {
in, err = sendDNSQuery(m, ns)
if err == nil && len(in.Answer) > 0 {
r, err = sendDNSQuery(m, ns)
if err == nil && len(r.Answer) > 0 {
break
}

errAll = errors.Join(errAll, err)
}

if err != nil {
return r, errAll
}
return in, err

return r, nil
}

func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
Expand All @@ -253,37 +266,82 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
in, _, err := tcp.Exchange(m, ns)
r, _, err := tcp.Exchange(m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}

return in, err
return r, nil
}

udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
in, _, err := udp.Exchange(m, ns)
r, _, err := udp.Exchange(m, ns)

if in != nil && in.Truncated {
if r != nil && r.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// If the TCP request succeeds, the "err" will reset to nil
in, _, err = tcp.Exchange(m, ns)
r, _, err = tcp.Exchange(m, ns)
}

if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}

return in, err
return r, nil
}

func formatDNSError(msg *dns.Msg, err error) string {
var parts []string
// DNSError error related to DNS calls.
type DNSError struct {
Message string
NS string
MsgIn *dns.Msg
MsgOut *dns.Msg
Err error
}

if msg != nil {
parts = append(parts, dns.RcodeToString[msg.Rcode])
func (d *DNSError) Error() string {
var details []string
if d.NS != "" {
details = append(details, "ns="+d.NS)
}

if err != nil {
parts = append(parts, err.Error())
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
}

if d.MsgOut != nil {
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
}

details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
}

msg := "DNS error"
if d.Message != "" {
msg = d.Message
}

if d.Err != nil {
msg += ": " + d.Err.Error()
}

if len(details) > 0 {
msg += " [" + strings.Join(details, ", ") + "]"
}

if len(parts) > 0 {
return ": " + strings.Join(parts, " ")
return msg
}

func (d *DNSError) Unwrap() error {
return d.Err
}

func formatQuestions(questions []dns.Question) string {
var parts []string
for _, question := range questions {
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
}

return ""
return strings.Join(parts, ";")
}
Loading

0 comments on commit ba67a26

Please sign in to comment.