Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly identify infixed concats as potential SQL injections #987

Merged
merged 9 commits into from
Jul 25, 2023
56 changes: 52 additions & 4 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,48 @@ func GetChar(n ast.Node) (byte, error) {
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
}

// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return.
// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type,
// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's
// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type,
// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile.
//
// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below
//
// Do note that this will omit non-string values. So for example, if you were to use this node:
// ```go
// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1"
//
// ```
func GetStringRecursive(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}

if expr, ok := n.(*ast.BinaryExpr); ok {
var err error
x, xerr := GetStringRecursive(expr.X)
if xerr != nil {
err = fmt.Errorf("%w Error on X branch in recursion: %v", err, xerr)
audunmo marked this conversation as resolved.
Show resolved Hide resolved
}

y, yerr := GetStringRecursive(expr.Y)
if yerr != nil {
err = fmt.Errorf("%w Error on Y branch in recursion: %v", err, err)
audunmo marked this conversation as resolved.
Show resolved Hide resolved
}

return x + y, err
}

return "", nil
}

// GetString will read and return a string value from an ast.BasicLit
func GetString(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}

return "", fmt.Errorf("Unexpected AST node type: %T", n)
}

Expand Down Expand Up @@ -201,22 +238,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string {
return values
}

// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string {
values := []string{}
obj := ident.Obj
if obj != nil {
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
for _, v := range decl.Values {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
}
case *ast.AssignStmt:
for _, v := range decl.Rhs {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
Expand All @@ -226,6 +262,18 @@ func GetIdentStringValues(ident *ast.Ident) []string {
return values
}

// getIdentStringRecursive returns the string of values of an Ident if they can be resolved
// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively,
// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details
func GetIdentStringValuesRecursive(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetStringRecursive)
}

// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetString)
}

// GetBinaryExprOperands returns all operands of a binary expression by traversing
// the expression tree
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {
Expand Down
51 changes: 50 additions & 1 deletion rules/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string {
return s.MetaData.ID
}

// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
// This method assumes you've already verified that the branch contains SQL syntax
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
for _, node := range branch {
be, ok := node.(*ast.BinaryExpr)
if !ok {
continue
}

operands := gosec.GetBinaryExprOperands(be)

for _, op := range operands {
if _, ok := op.(*ast.BasicLit); ok {
continue
}

if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
continue
}

return be
}
}
return nil
}

// see if we can figure out what it is
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
if n.Obj != nil {
Expand Down Expand Up @@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu
}
}

// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
if id, ok := query.(*ast.Ident); ok {
var match bool
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
if s.MatchPatterns(str) {
match = true
break
}
}

if !match {
return nil, nil
}

switch decl := id.Obj.Decl.(type) {
case *ast.AssignStmt:
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
}

return nil, nil
}

Expand All @@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro
return s.checkQuery(sqlQueryCall, ctx)
}
}

return nil, nil
}

Expand All @@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
rule := &sqlStrConcat{
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"),
},
MetaData: issue.MetaData{
ID: id,
Expand Down
52 changes: 42 additions & 10 deletions testutils/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,28 @@ func main() {
// SampleCodeG202 - SQL query string building via string concatenation
SampleCodeG202 = []CodeSample{
{[]string{`
// infixed concatenation
package main

import (
"database/sql"
"os"
)

func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}

q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')"
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()},
{[]string{`
package main

import (
Expand All @@ -1613,7 +1635,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// case insensitive match
package main

Expand All @@ -1632,7 +1655,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// context match
package main

Expand All @@ -1652,7 +1676,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// DB transaction check
package main

Expand Down Expand Up @@ -1680,7 +1705,8 @@ func main(){
if err := tx.Commit(); err != nil {
panic(err)
}
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// multiple string concatenation
package main

Expand All @@ -1699,7 +1725,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// false positive
package main

Expand All @@ -1718,7 +1745,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 0, gosec.NewConfig()}, {[]string{`
}`}, 0, gosec.NewConfig()},
{[]string{`
package main

import (
Expand All @@ -1740,7 +1768,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
package main

const gender = "M"
Expand All @@ -1766,7 +1795,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
// ExecContext match
package main

Expand All @@ -1787,7 +1817,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// Exec match
package main

Expand All @@ -1807,7 +1838,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
audunmo marked this conversation as resolved.
Show resolved Hide resolved
{[]string{`
package main

import (
Expand Down
Loading