Skip to content

Commit

Permalink
fix: update retrieveRequestParams
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay committed Mar 4, 2024
1 parent 7e70ead commit e4390a2
Show file tree
Hide file tree
Showing 20 changed files with 82 additions and 56 deletions.
10 changes: 5 additions & 5 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
}

func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {
params, err := retrieveRequestParams(r, &AdminUserParams{})
if err != nil {
params := &AdminUserParams{}
if err := retrieveRequestParams(r, params); err != nil {
return nil, err
}

Expand Down Expand Up @@ -558,12 +558,12 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
factor := getFactor(ctx)
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params, err := retrieveRequestParams(r, &adminUserUpdateFactorParams{})
if err != nil {
params := &adminUserUpdateFactorParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
if params.FriendlyName != "" {
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
return terr
Expand Down
4 changes: 2 additions & 2 deletions internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveRequestParams(r, &SignupParams{})
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
params.Aud = aud
Expand Down
4 changes: 2 additions & 2 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params, err := retrieveRequestParams(r, &SignupParams{})
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
if params.Email == "" && params.Phone == "" {
Expand Down
8 changes: 4 additions & 4 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ type RequestParams interface {
}

// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided
func retrieveRequestParams[A RequestParams](r *http.Request, params *A) (*A, error) {
func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error {
body, err := getBodyBytes(r)
if err != nil {
return nil, internalServerError("Could not read body into byte slice").WithInternalError(err)
return internalServerError("Could not read body into byte slice").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return nil, badRequestError("Could not read request body: %v", err)
return badRequestError("Could not read request body: %v", err)
}
return params, nil
return nil
}
23 changes: 23 additions & 0 deletions internal/api/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ func TestIsValidPKCEParmas(t *testing.T) {
})
}
}

// func BenchmarkRetrieveRequestParams(b *testing.B) {
// var buffer bytes.Buffer
// require.NoError(b, json.NewEncoder(&buffer).Encode(map[string]interface{}{}))
// req, err := http.NewRequest(http.MethodPost, "/", &buffer)
// require.NoError(b, err)

// b.Run("retrieveRequestParams", func(b *testing.B) {
// for n := 0; n < b.N; n++ {
// _, err := retrieveRequestParams(req, &SignupParams{})
// require.NoError(b, err)
// }
// })

// var params SignupParams
// b.Run("without generics", func(b *testing.B) {
// for n := 0; n < b.N; n++ {
// body, err := getBodyBytes(req)
// require.NoError(b, err)
// require.NoError(b, json.Unmarshal(body, &params))
// }
// })
// }
5 changes: 3 additions & 2 deletions internal/api/invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
adminUser := getAdminUser(ctx)
params, err := retrieveRequestParams(r, &InviteParams{})
if err != nil {
params := &InviteParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
config := a.config
mailer := a.Mailer(ctx)
adminUser := getAdminUser(ctx)
params, err := retrieveRequestParams(r, &GenerateLinkParams{})
if err != nil {
params := &GenerateLinkParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
8 changes: 4 additions & 4 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error {
session := getSession(ctx)
config := a.config

params, err := retrieveRequestParams(r, &EnrollFactorParams{})
if err != nil {
params := &EnrollFactorParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
issuer := ""
Expand Down Expand Up @@ -199,8 +199,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
factor := getFactor(ctx)
config := a.config

params, err := retrieveRequestParams(r, &VerifyFactorParams{})
if err != nil {
params := &VerifyFactorParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
currentIP := utilities.GetIPAddress(r)
Expand Down
7 changes: 3 additions & 4 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
Phone string `json:"phone"`
}

params, err := retrieveRequestParams(req, &requestBody)
if err != nil {
if err := retrieveRequestParams(req, &requestBody); err != nil {
return c, badRequestError("Error invalid request body").WithInternalError(err)
}

if shouldRateLimitEmail {
if params.Email != "" {
if requestBody.Email != "" {
if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
req.Context(),
Expand All @@ -119,7 +118,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
}

if shouldRateLimitPhone {
if params.Phone != "" {
if requestBody.Phone != "" {
if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil {
return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded")
}
Expand Down
8 changes: 3 additions & 5 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ func (p *SmsParams) Validate(smsProvider string) error {

// Otp returns the MagicLink or SmsOtp handler based on the request body params
func (a *API) Otp(w http.ResponseWriter, r *http.Request) error {
var err error
params := &OtpParams{
CreateUser: true,
}
if params.Data == nil {
params.Data = make(map[string]interface{})
}

params, err = retrieveRequestParams(r, params)
if err != nil {
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down Expand Up @@ -111,8 +109,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
}
var err error

params, err := retrieveRequestParams(r, &SmsParams{})
if err != nil {
params := &SmsParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down
5 changes: 3 additions & 2 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
params, err := retrieveRequestParams(r, &RecoverParams{})
if err != nil {
params := &RecoverParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -46,6 +46,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)

user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
5 changes: 3 additions & 2 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
params, err := retrieveRequestParams(r, &ResendConfirmationParams{})
if err != nil {
params := &ResendConfirmationParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -76,6 +76,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)
if params.Email != "" {
user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
5 changes: 3 additions & 2 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveRequestParams(r, &SignupParams{})
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -129,6 +129,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}

var codeChallengeMethod models.CodeChallengeMethod
var err error
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
Expand Down
5 changes: 3 additions & 2 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

params, err := retrieveRequestParams(r, &SingleSignOnParams{})
if err != nil {
params := &SingleSignOnParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
hasProviderID := false

if hasProviderID, err = params.validate(); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions internal/api/ssoadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
ctx := r.Context()
db := a.db.WithContext(ctx)

params, err := retrieveRequestParams(r, &CreateSSOProviderParams{})
if err != nil {
params := &CreateSSOProviderParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down Expand Up @@ -258,8 +258,8 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
ctx := r.Context()
db := a.db.WithContext(ctx)

params, err := retrieveRequestParams(r, &CreateSSOProviderParams{})
if err != nil {
params := &CreateSSOProviderParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down
10 changes: 6 additions & 4 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error {
func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)

params, err := retrieveRequestParams(r, &PasswordGrantParams{})
if err != nil {
params := &PasswordGrantParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -113,6 +113,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
var user *models.User
var grantParams models.GrantParams
var provider string
var err error

grantParams.FillGrantParams(r)

Expand Down Expand Up @@ -228,8 +229,9 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
// can be told to at least propagate the User-Agent header.
grantParams.FillGrantParams(r)

params, err := retrieveRequestParams(r, &PKCEGrantParams{})
if err != nil {
params := &PKCEGrantParams{}

if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
db := a.db.WithContext(ctx)
config := a.config

params, err := retrieveRequestParams(r, &IdTokenGrantParams{})
if err != nil {
params := &IdTokenGrantParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
db := a.db.WithContext(ctx)
config := a.config

params, err := retrieveRequestParams(r, &RefreshTokenGrantParams{})
if err != nil {
params := &RefreshTokenGrantParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down
6 changes: 3 additions & 3 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
config := a.config
aud := a.requestAud(ctx, r)

params, err := retrieveRequestParams(r, &UserUpdateParams{})
if err != nil {
params := &UserUpdateParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand Down Expand Up @@ -163,7 +163,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

err = db.Transaction(func(tx *storage.Connection) error {
err := db.Transaction(func(tx *storage.Connection) error {
var terr error
if params.Password != nil {
var sessionID *uuid.UUID
Expand Down
4 changes: 1 addition & 3 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error {
}
return a.verifyGet(w, r, params)
case http.MethodPost:
var err error
params, err = retrieveRequestParams(r, params)
if err != nil {
if err := retrieveRequestParams(r, params); err != nil {
return err
}
if err := params.Validate(r); err != nil {
Expand Down

0 comments on commit e4390a2

Please sign in to comment.