Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (config): add support for a http.RoundTripper #137

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions oidc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"hash"
"hash/fnv"
"net/http"
"net/url"
"reflect"
"runtime"
Expand Down Expand Up @@ -89,9 +90,15 @@ type Config struct {

// ProviderCA is an optional CA certs (PEM encoded) to use when sending
// requests to the provider. If you have a list of *x509.Certificates, then
// see EncodeCertificates(...) to PEM encode them.
// see EncodeCertificates(...) to PEM encode them. Note: specifying both
// ProviderCA and RoundTripper is an error.
ProviderCA string

// RoundTripper is an optional http.RoundTripper to use when sending requests
// to the provider. Note: specifying both ProviderCA and RoundTripper is an
// error.
RoundTripper http.RoundTripper

// NowFunc is a time func that returns the current time.
NowFunc func() time.Time `json:"-"`

Expand All @@ -118,6 +125,7 @@ func NewConfig(issuer string, clientID string, clientSecret ClientSecret, suppor
SupportedSigningAlgs: supported,
Scopes: opts.withScopes,
ProviderCA: opts.withProviderCA,
RoundTripper: opts.withRoundTripper,
Audiences: opts.withAudiences,
NowFunc: opts.withNowFunc,
AllowedRedirectURLs: allowedRedirectURLs,
Expand Down Expand Up @@ -168,6 +176,16 @@ func (c *Config) Hash() (uint64, error) {
args = append(args, audiences...)
args = append(args, redirects...)

if c.RoundTripper != nil {
v := reflect.ValueOf(c.RoundTripper)
switch {
case v.CanAddr():
args = append(args, v.Addr().String())
default:
args = append(args, v.String())
}
}

if c.ProviderConfig != nil {
args = append(
args,
Expand Down Expand Up @@ -269,6 +287,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}
}
if c.ProviderCA != "" && c.RoundTripper != nil {
return fmt.Errorf("%s: you cannot specify both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter)
}

if c.ProviderConfig != nil {
switch {
Expand Down Expand Up @@ -300,6 +321,7 @@ type configOptions struct {
withProviderCA string
withNowFunc func() time.Time
withProviderConfig *ProviderConfig
withRoundTripper http.RoundTripper
}

// configDefaults is a handy way to get the defaults at runtime and
Expand All @@ -319,12 +341,14 @@ func getConfigOpts(opt ...Option) configOptions {
}

// WithProviderCA provides optional CA certs (PEM encoded) for the provider's
// config. These certs will can be used when making http requests to the
// config. These certs will be used when making http requests to the
// provider.
//
// Valid for: Config
//
// See EncodeCertificates(...) to PEM encode a number of certs.
//
// Note: specifying both WithProviderCA and WithRoundTripper is a error.
func WithProviderCA(cert string) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
Expand All @@ -333,6 +357,18 @@ func WithProviderCA(cert string) Option {
}
}

// WithRoundTripper provides and optional RoundTripper for the provider's
// config. This RoundTripper will be used when making http requests to the
// provider. Note: specifying both WithProviderCA and WithRoundTripper is a
// error.
func WithRoundTripper(rt http.RoundTripper) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withRoundTripper = rt
}
}
}

// EncodeCertificates will encode a number of x509 certificates to PEM. It will
// help encode certs for use with the WithProviderCA(...) option.
func EncodeCertificates(certs ...*x509.Certificate) (string, error) {
Expand Down
141 changes: 140 additions & 1 deletion oidc/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -44,6 +45,8 @@ func TestNewConfig(t *testing.T) {
return time.Now().Add(-1 * time.Minute)
}

testRt := newTestRoundTripper(t)

type args struct {
issuer string
clientID string
Expand All @@ -61,7 +64,7 @@ func TestNewConfig(t *testing.T) {
wantErrContains string
}{
{
name: "valid-with-all-valid-opts",
name: "valid-with-all-valid-opts-except-with-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
Expand Down Expand Up @@ -103,6 +106,49 @@ func TestNewConfig(t *testing.T) {
},
},
},
{
name: "with-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
clientSecret: "your_client_secret",
supported: []Alg{RS512},
allowedRedirectURLs: []string{"http://your_redirect_url", "http://redirect_url_two", "http://redirect_url_three"},
opt: []Option{
WithAudiences("your_aud1", "your_aud2"),
WithScopes("email", "profile"),
WithRoundTripper(testRt),
WithNow(testNow),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
},
},
want: &Config{
Issuer: "http://your_issuer/",
ClientID: "your_client_id",
ClientSecret: "your_client_secret",
SupportedSigningAlgs: []Alg{RS512},
Audiences: []string{"your_aud1", "your_aud2"},
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
RoundTripper: testRt,
NowFunc: testNow,
AllowedRedirectURLs: []string{
"http://your_redirect_url",
"http://redirect_url_two",
"http://redirect_url_three",
},
ProviderConfig: &ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
},
},
},
{
name: "missing-provider-config-auth-url",
args: args{
Expand Down Expand Up @@ -282,6 +328,22 @@ func TestNewConfig(t *testing.T) {
wantErr: true,
wantIsErr: ErrInvalidCACert,
},
{
name: "invalid-both-cert-and-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
clientSecret: "your_client_secret",
supported: []Alg{RS512},
allowedRedirectURLs: []string{"http://your_redirect_url"},
opt: []Option{
WithProviderCA(testCaPem),
WithRoundTripper(testRt),
},
},
wantErr: true,
wantIsErr: ErrInvalidParameter,
},
{
name: "invalid-alg",
args: args{
Expand Down Expand Up @@ -430,6 +492,7 @@ func TestConfig_Hash(t *testing.T) {
require.NoError(t, err)
return c
}
testRt := newTestRoundTripper(t)
tests := []struct {
name string
c1 *Config
Expand Down Expand Up @@ -473,6 +536,42 @@ func TestConfig_Hash(t *testing.T) {
),
wantEqual: true,
},
{
name: "equal-with-round-tripper",
c1: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback", "www.bob.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithRoundTripper(testRt),
WithNow(time.Now),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
),
c2: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.bob.com/callback", "www.alice.com/callback"},
WithScopes("profile", "email"),
WithAudiences("bob.com", "alice.com"),
WithRoundTripper(testRt),
WithNow(time.Now),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
),
wantEqual: true,
},
{
name: "diff-issuer",
c1: newCfg(
Expand Down Expand Up @@ -664,6 +763,29 @@ func TestConfig_Hash(t *testing.T) {
),
wantEqual: false,
},
{
name: "diff-round-trippers",
c1: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithRoundTripper(newTestRoundTripper(t)),
WithNow(time.Now),
),
c2: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithNow(time.Now),
),
wantEqual: false,
},
{
name: "diff-now-func",
c1: newCfg(
Expand Down Expand Up @@ -855,3 +977,20 @@ func TestConfig_Hash(t *testing.T) {
})
}
}

type testRoundTripper struct {
transport http.RoundTripper
called int
}

func newTestRoundTripper(t *testing.T) *testRoundTripper {
t.Helper()
return &testRoundTripper{
transport: http.DefaultTransport,
}
}

func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.called++
return rt.transport.RoundTrip(req)
}
4 changes: 2 additions & 2 deletions oidc/docs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func ExampleNewConfig() {
fmt.Println(pc)

// Output:
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil>}
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil> <nil>}
}

func ExampleWithProviderConfig() {
Expand All @@ -120,7 +120,7 @@ func ExampleWithProviderConfig() {
fmt.Println(string(val))

// Output:
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","RoundTripper":null,"ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
}

func ExampleNewProvider() {
Expand Down
13 changes: 9 additions & 4 deletions oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,17 +635,22 @@ func (p *Provider) HTTPClient() (*http.Client, error) {
// to the same host. On the downside, this transport can leak file
// descriptors over time, so we'll be sure to call
// client.CloseIdleConnections() in the Provider.Done() to stave that off.
tr := cleanhttp.DefaultPooledTransport()
var tr http.RoundTripper

if p.config.ProviderCA != "" {
switch {
case p.config.RoundTripper != nil && p.config.ProviderCA != "":
return nil, fmt.Errorf("%s: you cannot specify config for both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter)
case p.config.ProviderCA != "":
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok {
return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}

tr.TLSClientConfig = &tls.Config{
tr = cleanhttp.DefaultPooledTransport()
tr.(*http.Transport).TLSClientConfig = &tls.Config{
RootCAs: certPool,
}
case p.config.RoundTripper != nil:
tr = p.config.RoundTripper
}

c := &http.Client{
Expand Down
25 changes: 25 additions & 0 deletions oidc/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,31 @@ func TestHTTPClient(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, c.Transport, p.client.Transport)
})
t.Run("check-transport-with-round-tripper", func(t *testing.T) {
testRt := newTestRoundTripper(t)
p := &Provider{
config: &Config{
RoundTripper: testRt,
},
}
c, err := p.HTTPClient()
require.NoError(t, err)
assert.Equal(t, c.Transport, p.client.Transport)
})
t.Run("err-both-ca-and-round-trippe", func(t *testing.T) {
_, testCaPem := TestGenerateCA(t, []string{"localhost"})

p := &Provider{
config: &Config{
ProviderCA: testCaPem,
RoundTripper: newTestRoundTripper(t),
},
}
_, err := p.HTTPClient()
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "you cannot specify config for both a ProviderCA and RoundTripper")
})
}

func TestProvider_UserInfo(t *testing.T) {
Expand Down
Loading