Skip to content

Commit

Permalink
Improve documentation for address parameters
Browse files Browse the repository at this point in the history
Improve internal comments.
Add more hostport address unit test cases.
Simplify some symmetric auth code paths.
  • Loading branch information
beevik committed Jul 26, 2023
1 parent 196f2d4 commit 4137e12
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 43 deletions.
12 changes: 12 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ func xor(dst, src []byte) {
}

func decodeAuthKey(opt AuthOptions) ([]byte, error) {
if opt.Type == AuthNone {
return nil, nil
}

var key []byte
if len(opt.Key) > 20 {
var err error
Expand All @@ -177,6 +181,10 @@ func decodeAuthKey(opt AuthOptions) ([]byte, error) {
}

func appendMAC(buf *bytes.Buffer, opt AuthOptions, key []byte) {
if opt.Type == AuthNone {
return
}

a := algorithms[opt.Type]
payload := buf.Bytes()
digest := a.CalcDigest(payload, key)
Expand All @@ -185,6 +193,10 @@ func appendMAC(buf *bytes.Buffer, opt AuthOptions, key []byte) {
}

func verifyMAC(buf []byte, opt AuthOptions, key []byte) error {
if opt.Type == AuthNone {
return nil
}

// Validate that there are enough bytes at the end of the message to
// contain a MAC.
const headerSize = 48
Expand Down
80 changes: 37 additions & 43 deletions ntp.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,21 @@ func (r *Response) Validate() error {
return nil
}

// Query requests time data from a remote NTP server. The server address is of
// the form "host", "host:port", "host%zone:port", "[host]:port" or
// "[host%zone]:port". If no port is included, NTP default port 123 is used.
// The response contains information from which an accurate local time can be
// determined.
// Query requests time data from a remote NTP server. The response contains
// information from which a more accurate local time can be inferred.
//
// The server address is of the form "host", "host:port", "host%zone:port",
// "[host]:port" or "[host%zone]:port". The host may contain an IPv4, IPv6 or
// domain name address. When specifying both a port and an IPv6 address, one
// of the bracket formats must be used. If no port is included, NTP default
// port 123 is used.
func Query(address string) (*Response, error) {
return QueryWithOptions(address, QueryOptions{})
}

// QueryWithOptions performs the same function as Query but allows for the
// customization of certain query behaviors. See the comment for Query for
// more information.
// customization of certain query behaviors. See the comments for Query and
// QueryOptions for further details.
func QueryWithOptions(address string, opt QueryOptions) (*Response, error) {
h, now, err := getTime(address, &opt)
if err != nil && err != ErrAuthFailed {
Expand All @@ -427,11 +430,15 @@ func QueryWithOptions(address string, opt QueryOptions) (*Response, error) {
return generateResponse(h, now, err), nil
}

// Time returns the current local time using information returned from the
// remote NTP server. The server address is of the form "host", "host:port",
// "host%zone:port", "[host]:port" or "[host%zone]:port". If no port is
// included, NTP default port 123 is used. On error, Time returns the local
// Time returns the current, corrected local time using information returned
// from the remote NTP server. On error, Time returns the uncorrected local
// system time.
//
// The server address is of the form "host", "host:port", "host%zone:port",
// "[host]:port" or "[host%zone]:port". The host may contain an IPv4, IPv6 or
// domain name address. When specifying both a port and an IPv6 address, one
// of the bracket formats must be used. If no port is included, NTP default
// port 123 is used.
func Time(address string) (time.Time, error) {
r, err := Query(address)
if err != nil {
Expand Down Expand Up @@ -472,8 +479,8 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
opt.Dialer = defaultDialer
}

// Compose a remote "host:port" address string if the address string
// doesn't already contain a port.
// Compose a conforming host:port remote address string if the address
// string doesn't already contain a port.
remoteAddress, err := fixHostPort(address, opt.Port)
if err != nil {
return nil, 0, err
Expand All @@ -495,16 +502,6 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
}
}

// If using symmetric key authentication, decode and validate the auth key
// string.
var decodedAuthKey []byte
if opt.Auth.Type != AuthNone {
decodedAuthKey, err = decodeAuthKey(opt.Auth)
if err != nil {
return nil, 0, err
}
}

// Set a timeout on the connection.
con.SetDeadline(time.Now().Add(opt.Timeout))

Expand Down Expand Up @@ -541,11 +538,16 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
}
}

// Append an authentication MAC if requested.
if opt.Auth.Type != AuthNone {
appendMAC(&xmitBuf, opt.Auth, decodedAuthKey)
// If using symmetric key authentication, decode and validate the auth key
// string.
authKey, err := decodeAuthKey(opt.Auth)
if err != nil {
return nil, 0, err
}

// Append a MAC if authentication is being used.
appendMAC(&xmitBuf, opt.Auth, authKey)

// Transmit the query and keep track of when it was transmitted.
xmitTime := time.Now()
_, err = con.Write(xmitBuf.Bytes())
Expand All @@ -568,7 +570,7 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
}
recvTime := xmitTime.Add(delta)

// Deserialize the response header.
// Parse the response header.
recvBuf = recvBuf[:recvBytes]
recvReader := bytes.NewReader(recvBuf)
err = binary.Read(recvReader, binary.BigEndian, recvHdr)
Expand All @@ -584,12 +586,6 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
}
}

// Perform authentication of the server response.
var authErr error
if opt.Auth.Type != AuthNone {
authErr = verifyMAC(recvBuf, opt.Auth, decodedAuthKey)
}

// Check for invalid fields.
if recvHdr.getMode() != server {
return nil, 0, ErrInvalidMode
Expand All @@ -608,6 +604,9 @@ func getTime(address string, opt *QueryOptions) (*header, ntpTime, error) {
// transmit time.
recvHdr.OriginTime = toNtpTime(xmitTime)

// Perform authentication of the server response.
authErr := verifyMAC(recvBuf, opt.Auth, authKey)

return recvHdr, toNtpTime(recvTime), authErr
}

Expand All @@ -633,12 +632,6 @@ func defaultDialer(localAddress, remoteAddress string) (net.Conn, error) {
// dialWrapper is used to wrap the deprecated Dial callback in QueryOptions.
func dialWrapper(la, ra string,
dial func(la string, lp int, ra string, rp int) (net.Conn, error)) (net.Conn, error) {
var err error
ra, err = fixHostPort(ra, defaultNtpPort)
if err != nil {
return nil, err
}

rhost, rport, err := net.SplitHostPort(ra)
if err != nil {
return nil, err
Expand All @@ -655,7 +648,7 @@ func dialWrapper(la, ra string,
// fixHostPort examines an address in one of the accepted forms and fixes it
// to include a port number if necessary.
func fixHostPort(address string, defaultPort int) (fixed string, err error) {
// If address is wrapped in brackets, parse out the port (if any).
// If the address is wrapped in brackets, append a port if necessary.
if address[0] == '[' {
end := strings.IndexByte(address, ']')
switch {
Expand All @@ -670,14 +663,15 @@ func fixHostPort(address string, defaultPort int) (fixed string, err error) {
}
}

// No colons? Must be an IPv4 or domain address without a port.
// No colons? Must be a port-less IPv4 or domain address.
last := strings.LastIndexByte(address, ':')
if last < 0 {
return fmt.Sprintf("%s:%d", address, defaultPort), nil
}

// Exactly one colon? Must be an IPv4 or domain address with a port. (IPv6
// addresses are guaranteed to have more than one colon.)
// Exactly one colon? A port have been included along with an IPv4 or
// domain address. (IPv6 addresses are guaranteed to have more than one
// colon.)
prev := strings.LastIndexByte(address[:last], ':')
if prev < 0 {
return address, nil
Expand Down
10 changes: 10 additions & 0 deletions ntp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ func TestOfflineFixHostPort(t *testing.T) {
errMsg string
}{
{"192.168.1.1", "192.168.1.1:123", ""},
{"192.168.1.1:123", "192.168.1.1:123", ""},
{"192.168.1.1:1000", "192.168.1.1:1000", ""},
{"[192.168.1.1]:1000", "[192.168.1.1]:1000", ""},
{"www.example.com", "www.example.com:123", ""},
{"www.example.com:123", "www.example.com:123", ""},
{"www.example.com:1000", "www.example.com:1000", ""},
{"[www.example.com]:1000", "[www.example.com]:1000", ""},
{"::1", "[::1]:123", ""},
Expand All @@ -249,6 +251,14 @@ func TestOfflineFixHostPort(t *testing.T) {
{"[fe80::1]:1000", "[fe80::1]:1000", ""},
{"[fe80::", "", "missing ']' in address"},
{"[fe80::]@", "", "unexpected character following ']' in address"},
{"ff06:0:0:0:0:0:0:c3", "[ff06:0:0:0:0:0:0:c3]:123", ""},
{"[ff06:0:0:0:0:0:0:c3]", "[ff06:0:0:0:0:0:0:c3]:123", ""},
{"[ff06:0:0:0:0:0:0:c3]:123", "[ff06:0:0:0:0:0:0:c3]:123", ""},
{"[ff06:0:0:0:0:0:0:c3]:1000", "[ff06:0:0:0:0:0:0:c3]:1000", ""},
{"::ffff:192.168.1.1", "[::ffff:192.168.1.1]:123", ""},
{"[::ffff:192.168.1.1]", "[::ffff:192.168.1.1]:123", ""},
{"[::ffff:192.168.1.1]:123", "[::ffff:192.168.1.1]:123", ""},
{"[::ffff:192.168.1.1]:1000", "[::ffff:192.168.1.1]:1000", ""},
}
for _, c := range cases {
fixed, err := fixHostPort(c.address, defaultPort)
Expand Down

0 comments on commit 4137e12

Please sign in to comment.