Skip to content

Commit

Permalink
chore(middleware/cors): fix v2 merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sixcolors committed Mar 17, 2024
1 parent 334737d commit 89ae525
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
31 changes: 17 additions & 14 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
Next func(c *fiber.Ctx) bool

// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
// response header to the 'origin' request header when returned true. This allows for
Expand Down Expand Up @@ -110,7 +110,7 @@ func New(config ...Config) fiber.Handler {

// Validate CORS credentials configuration
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
log.Panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") //nolint:revive // we want to exit the program
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}

// allowOrigins is a slice of strings that contains the allowed origins
Expand All @@ -125,7 +125,8 @@ func New(config ...Config) fiber.Handler {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
log.Panicf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) //nolint:revive // we want to exit the program
log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin)
panic("[CORS] Invalid origin provided in configuration")
}
return normalizedOrigin, true
}
Expand Down Expand Up @@ -164,15 +165,16 @@ func New(config ...Config) fiber.Handler {
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
if cfg.Next != nil && cfg.Next(&c) {
return c.Next()
}

// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))

// If the request does not have an Origin header, the request is outside the scope of CORS
if originHeader == "" {
// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}

Expand Down Expand Up @@ -210,36 +212,37 @@ func New(config ...Config) fiber.Handler {
}

// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}

// Preflight request
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)

setCORSHeaders(ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)

// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}

// Function to set CORS headers
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if len(allowOrigin) > 0 {
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
Expand Down
28 changes: 25 additions & 3 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) {

ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)

Expand All @@ -48,6 +49,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand All @@ -58,6 +60,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand Down Expand Up @@ -86,6 +89,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)

// Perform request
handler(ctx)
Expand All @@ -100,6 +104,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)

require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
Expand Down Expand Up @@ -127,6 +132,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)

// Perform request
handler(ctx)
Expand All @@ -140,6 +146,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)

Expand Down Expand Up @@ -225,6 +232,7 @@ func Test_CORS_Subdomain(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")

// Perform request
Expand All @@ -239,6 +247,7 @@ func Test_CORS_Subdomain(t *testing.T) {
// Make request with domain only (disallowed)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")

handler(ctx)
Expand All @@ -251,6 +260,7 @@ func Test_CORS_Subdomain(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")

handler(ctx)
Expand Down Expand Up @@ -365,6 +375,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)

handler(ctx)
Expand Down Expand Up @@ -403,15 +414,15 @@ func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) {
headerExists = true
}
})
require.Equal(t, false, headerExists, "Access-Control-Allow-Origin header should not be set")
require.False(t, headerExists, "Access-Control-Allow-Origin header should not be set")
}

// go test -run Test_CORS_Next
func Test_CORS_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
Expand All @@ -426,7 +437,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{}))
app.Use(func(c *fiber.Ctx) error {
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

Expand All @@ -443,6 +454,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
handler := app.Handler()

t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
Expand All @@ -455,6 +467,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
})

t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
Expand All @@ -468,6 +481,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
})

t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
Expand All @@ -481,6 +495,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
})

t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
Expand All @@ -497,6 +512,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
})

t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
Expand Down Expand Up @@ -531,6 +547,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")

// Perform request
Expand All @@ -545,6 +562,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")

handler(ctx)
Expand All @@ -557,6 +575,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")

handler(ctx)
Expand Down Expand Up @@ -596,6 +615,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")

handler(ctx)
Expand Down Expand Up @@ -743,6 +763,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)
Expand Down Expand Up @@ -833,6 +854,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)
Expand Down

0 comments on commit 89ae525

Please sign in to comment.