Skip to content

Commit

Permalink
capability support
Browse files Browse the repository at this point in the history
  • Loading branch information
johnabass committed Aug 20, 2024
1 parent fb457ec commit e19d904
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 12 deletions.
243 changes: 243 additions & 0 deletions basculehttp/capabilities.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package basculehttp

import (
"context"
"errors"
"fmt"
"net/http"
"regexp"
"strings"

"github.com/xmidt-org/bascule"
"go.uber.org/multierr"
)

const (
// DefaultAllMethod is one of the default method strings that will match any HTTP method.
DefaultAllMethod = "all"

// DefaultWildcardMethod is one of the default method strings that will match any HTTP method.
DefaultWildcardMethod = "*"
)

var (
// ErrMissingCapabilities indicates that a token had no capabilities
// and thus is unauthorized.
ErrMissingCapabilities = errors.New("no capabilities in token")
)

// urlPathNormalization ensures that the given URL has a leading slash.
func urlPathNormalization(url string) string {
if url[0] == '/' {
return url
}

return "/" + url
}

// CapabilityUnauthorizedError indicates that a given capability was
// rejected and the token is unauthorized.
type CapabilityUnauthorizedError struct {
// Match is the match string in <prefix><url pattern>:<method> format
// that matched the capability but did not match the resource request.
Match string

// Capability is the capability string from the token that was rejected.
Capability string

// Err is any error that occurred. This will be returned from Unwrap.
Err error
}

func (cue *CapabilityUnauthorizedError) Unwrap() error {
return cue.Err
}

func (cue *CapabilityUnauthorizedError) StatusCode() int {
return http.StatusForbidden
}

func (cue *CapabilityUnauthorizedError) Error() string {
var o strings.Builder
o.WriteString(`Capability [`)
o.WriteString(cue.Capability)
o.WriteString(`] was rejected due to [`)
o.WriteString(cue.Match)
o.WriteRune(']')

if cue.Err != nil {
o.WriteString(`: `)
o.WriteString(cue.Err.Error())
}

return o.String()
}

// CapabilityApproverOption is a configurable option used to create a CapabilityApprover.
type CapabilityApproverOption interface {
apply(*CapabilityApprover) error
}

type capabilityApproverOptionFunc func(*CapabilityApprover) error

func (caof capabilityApproverOptionFunc) apply(ca *CapabilityApprover) error { return caof(ca) }

// WithCapabilityPrefixes adds several prefixes used to match capabilities, e.g. x1:webpa:foo:. Only
// the first prefix found during matching is considered for authorization. If no prefixes
// are set via this option, the resulting approver will not authorize any requests.
//
// Note that a prefix can itself be a regular expression, but may not have any subexpressions.
func WithCapabilityPrefixes(prefixes ...string) CapabilityApproverOption {
return capabilityApproverOptionFunc(func(ca *CapabilityApprover) error {
for _, p := range prefixes {
re, err := regexp.Compile("^" + p + "(.+):(.+?)$")
switch {
case err != nil:
return fmt.Errorf("Unable to compile capability prefix [%s]: %s", p, err)

case re.NumSubexp() != 2:
return fmt.Errorf("The prefix [%s] cannot have subexpressions", p)

default:
ca.matchers = append(ca.matchers, re)
}
}

return nil
})
}

// WithCapabilityAllMethods changes the values used to signal a match of all HTTP methods.
// By default, both DefaultAllMethod and DefaultWildcardMethod, if present in a capability,
// will match any HTTP method. This option overwrites the default, and is cumulative.
// However, a caller can add values to the default by using
// WithCapabilityAllMethods(DefaultAllMethod, DefaultWildcardMethod, "myvalue", ...).
func WithCapabilityAllMethods(v ...string) CapabilityApproverOption {
return capabilityApproverOptionFunc(func(ca *CapabilityApprover) error {
if ca.allMethods == nil {
ca.allMethods = make(map[string]bool, len(v))
}

for _, matchAll := range v {
ca.allMethods[matchAll] = true
}

return nil
})
}

// CapabilityApprover is a bascule HTTP approver that authorizes tokens
// with capabilities against requests.
//
// This approver expects capabilities in tokens to be of the form <prefix><endpoing regex>:<method>.
//
// The allowed prefixes must be set via one or more WithCapabilityPrefixes options. Prefixes
// may themselves contain colon delimiters and can be regular expressions without subexpressions.
type CapabilityApprover struct {
matchers []*regexp.Regexp
allMethods map[string]bool
}

// NewCapabilityApprover creates a CapabilityApprover using the supplied options.
// At least (1) of the configured prefixes must match an HTTP request's URL in
// ordered for a token to be authorized.
//
// If no prefixes are added via WithCapabilityPrefix, then the returned approver
// will not authorize any requests.
func NewCapabilityApprover(opts ...CapabilityApproverOption) (ca *CapabilityApprover, err error) {
ca = new(CapabilityApprover)
for _, o := range opts {
err = multierr.Append(err, o.apply(ca))
}

switch {
case err != nil:
ca = nil

default:
if len(ca.allMethods) == 0 {
// enforce the defaults
ca.allMethods = map[string]bool{
DefaultAllMethod: true,
DefaultWildcardMethod: true,
}
}
}

return
}

// Approve attempts to match each capability to a configured prefix. Then, for any matched prefix,
// the URL regexp and method in the capability must match the resource. URLs are normalized
// with a leading '/'.
//
// This method returns success (i.e. a nil error) when the first matching capability is found.
func (ca *CapabilityApprover) Approve(_ context.Context, resource *http.Request, token bascule.Token) error {
capabilities, ok := bascule.GetCapabilities(token)
if len(capabilities) == 0 || !ok {
return ErrMissingCapabilities
}

for _, matcher := range ca.matchers {
for _, capability := range capabilities {
substrings := matcher.FindStringSubmatch(capability)
if len(substrings) < 2 {
// no match
continue
}

// the format of capabilities is <prefix><url pattern>:<method>
// <url pattern> and <method> will be substrings
err := ca.approveURL(resource, substrings[1])
if err == nil {
err = ca.approveMethod(resource, substrings[2])
}

if err != nil {
err = &CapabilityUnauthorizedError{
Match: matcher.String(),
Capability: capability,
Err: err,
}
}

// stop at the first match, regardless of result
return err
}
}

// none of the matchers matched any capability, OR there were no matchers configured
return bascule.ErrUnauthorized
}

func (ca *CapabilityApprover) approveMethod(resource *http.Request, capabilityMethod string) error {
switch {
case ca.allMethods[capabilityMethod]:
return nil

case capabilityMethod == strings.ToLower(resource.Method):
return nil

default:
return fmt.Errorf("method does not match request method [%s]", resource.Method)
}
}

func (ca *CapabilityApprover) approveURL(resource *http.Request, capabilityURL string) error {
resourcePath := resource.URL.EscapedPath()

re, err := regexp.Compile(urlPathNormalization(capabilityURL))
if err != nil {
return err
}

indices := re.FindStringIndex(urlPathNormalization(resourcePath))
if len(indices) < 1 || indices[0] != 0 {
return fmt.Errorf("url does not match request URL [%s]", resourcePath)
}

return nil
}
31 changes: 21 additions & 10 deletions basculehttp/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@ import (
type ErrorStatusCoder func(request *http.Request, err error) int

// DefaultErrorStatusCoder is the strategy used when no ErrorStatusCoder is supplied.
// The following tests are done in order:
//
// If err has bascule.ErrMissingCredentials in its chain, this function returns
// (1) First, if err is nil, this method returns 0.
//
// (2) If any error in the chain provides a 'StatusCode() int' method, the result
// from that method is returned.
//
// (3) If err has bascule.ErrMissingCredentials in its chain, this function returns
// http.StatusUnauthorized.
//
// If err has bascule.ErrInvalidCredentials in its chain, this function returns
// http.StatusBadRequest.
// (4) If err has bascule.ErrUnauthorized in its chain, this function returns
// http.StatusForbidden.
//
// Failing the previous two checks, if the error provides a StatusCode() method,
// the return value from that method is used.
// (5) If err has bascule.ErrInvalidCredentials in its chain, this function returns
// http.StatusBadRequest.
//
// Otherwise, this method returns 0 to indicate that it doesn't know how to
// (6) Otherwise, this method returns 0 to indicate that it doesn't know how to
// produce a status code from the error.
func DefaultErrorStatusCoder(_ *http.Request, err error) int {
type statusCoder interface {
Expand All @@ -37,19 +43,24 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int {
var sc statusCoder

switch {
// check if it's a status coder first, so that we can
// override status codes for built-in errors.
case err == nil:
return 0

case errors.As(err, &sc):
return sc.StatusCode()

case errors.Is(err, bascule.ErrMissingCredentials):
return http.StatusUnauthorized

case errors.Is(err, bascule.ErrUnauthorized):
return http.StatusForbidden

case errors.Is(err, bascule.ErrInvalidCredentials):
return http.StatusBadRequest
}

return 0
default:
return 0
}
}

// ErrorMarshaler is a strategy for marshaling an error's contents, particularly to
Expand Down
16 changes: 14 additions & 2 deletions basculehttp/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,26 @@ type ErrorTestSuite struct {
}

func (suite *ErrorTestSuite) TestDefaultErrorStatusCoder() {
suite.Run("Nil", func() {
suite.Zero(
DefaultErrorStatusCoder(nil, nil),
)
})

suite.Run("ErrMissingCredentials", func() {
suite.Equal(
http.StatusUnauthorized,
DefaultErrorStatusCoder(nil, bascule.ErrMissingCredentials),
)
})

suite.Run("ErrUnauthorized", func() {
suite.Equal(
http.StatusForbidden,
DefaultErrorStatusCoder(nil, bascule.ErrUnauthorized),
)
})

suite.Run("ErrInvalidCredentials", func() {
suite.Equal(
http.StatusBadRequest,
Expand Down Expand Up @@ -53,8 +66,7 @@ func (suite *ErrorTestSuite) TestDefaultErrorStatusCoder() {
})

suite.Run("Unrecognized", func() {
suite.Equal(
0,
suite.Zero(
DefaultErrorStatusCoder(nil, errors.New("unrecognized error")),
)
})
Expand Down

0 comments on commit e19d904

Please sign in to comment.