diff --git a/server/events/command_runner.go b/server/events/command_runner.go index 24c717a01d..cb15e75eb7 100644 --- a/server/events/command_runner.go +++ b/server/events/command_runner.go @@ -22,6 +22,7 @@ import ( "github.com/pkg/errors" "github.com/runatlantis/atlantis/server/events/models" "github.com/runatlantis/atlantis/server/events/vcs" + "github.com/runatlantis/atlantis/server/events/yaml/valid" "github.com/runatlantis/atlantis/server/logging" "github.com/runatlantis/atlantis/server/recovery" gitlab "github.com/xanzy/go-gitlab" @@ -95,6 +96,7 @@ type DefaultCommandRunner struct { DisableAutoplan bool EventParser EventParsing Logger logging.SimpleLogging + GlobalCfg valid.GlobalCfg // AllowForkPRs controls whether we operate on pull requests from forks. AllowForkPRs bool // ParallelPoolSize controls the size of the wait group used to run @@ -320,6 +322,13 @@ func (c *DefaultCommandRunner) validateCtxAndComment(ctx *CommandContext) bool { } return false } + + repo := c.GlobalCfg.MatchingRepo(ctx.Pull.BaseRepo.ID()) + if !repo.BranchMatches(ctx.Pull.BaseBranch) { + ctx.Log.Info("command was run on a pull request which doesn't match base branches") + // just ignore it to allow us to use any git workflows without malicious intentions. + return false + } return true } diff --git a/server/events/command_runner_test.go b/server/events/command_runner_test.go index a236eb30d0..7c1fabb08c 100644 --- a/server/events/command_runner_test.go +++ b/server/events/command_runner_test.go @@ -16,6 +16,7 @@ package events_test import ( "errors" "fmt" + "regexp" "strings" "testing" @@ -182,6 +183,8 @@ func setup(t *testing.T) *vcsmocks.MockClient { When(preWorkflowHooksCommandRunner.RunPreHooks(matchers.AnyPtrToEventsCommandContext())).ThenReturn(nil) + gCfg := valid.NewGlobalCfgFromArgs(valid.GlobalCfgArgs{}) + ch = events.DefaultCommandRunner{ VCSClient: vcsClient, CommentCommandRunnerByCmd: commentCommandRunnerByCmd, @@ -190,6 +193,7 @@ func setup(t *testing.T) *vcsmocks.MockClient { GitlabMergeRequestGetter: gitlabGetter, AzureDevopsPullGetter: azuredevopsGetter, Logger: logger, + GlobalCfg: gCfg, AllowForkPRs: false, AllowForkPRsFlag: "allow-fork-prs-flag", Drainer: drainer, @@ -404,6 +408,40 @@ func TestRunCommentCommand_ClosedPull(t *testing.T) { vcsClient.VerifyWasCalledOnce().CreateComment(fixtures.GithubRepo, modelPull.Num, "Atlantis commands can't be run on closed pull requests", "") } +func TestRunCommentCommand_MatchedBranch(t *testing.T) { + t.Log("if a command is run on a pull request which matches base branches run plan successfully") + vcsClient := setup(t) + + ch.GlobalCfg.Repos = append(ch.GlobalCfg.Repos, valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + }) + var pull github.PullRequest + modelPull := models.PullRequest{BaseRepo: fixtures.GithubRepo, BaseBranch: "main"} + When(githubGetter.GetPullRequest(fixtures.GithubRepo, fixtures.Pull.Num)).ThenReturn(&pull, nil) + When(eventParsing.ParseGithubPull(&pull)).ThenReturn(modelPull, modelPull.BaseRepo, fixtures.GithubRepo, nil) + + ch.RunCommentCommand(fixtures.GithubRepo, nil, nil, fixtures.User, fixtures.Pull.Num, &events.CommentCommand{Name: models.PlanCommand}) + vcsClient.VerifyWasCalledOnce().CreateComment(fixtures.GithubRepo, modelPull.Num, "Ran Plan for 0 projects:\n\n\n\n", "plan") +} + +func TestRunCommentCommand_UnmatchedBranch(t *testing.T) { + t.Log("if a command is run on a pull request which doesn't match base branches do not comment with error") + vcsClient := setup(t) + + ch.GlobalCfg.Repos = append(ch.GlobalCfg.Repos, valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + }) + var pull github.PullRequest + modelPull := models.PullRequest{BaseRepo: fixtures.GithubRepo, BaseBranch: "foo"} + When(githubGetter.GetPullRequest(fixtures.GithubRepo, fixtures.Pull.Num)).ThenReturn(&pull, nil) + When(eventParsing.ParseGithubPull(&pull)).ThenReturn(modelPull, modelPull.BaseRepo, fixtures.GithubRepo, nil) + + ch.RunCommentCommand(fixtures.GithubRepo, nil, nil, fixtures.User, fixtures.Pull.Num, &events.CommentCommand{Name: models.PlanCommand}) + vcsClient.VerifyWasCalled(Never()).CreateComment(matchers.AnyModelsRepo(), AnyInt(), AnyString(), AnyString()) +} + func TestRunUnlockCommand_VCSComment(t *testing.T) { t.Log("if unlock PR command is run, atlantis should" + " invoke the delete command and comment on PR accordingly") diff --git a/server/events/pre_workflow_hooks_command_runner.go b/server/events/pre_workflow_hooks_command_runner.go index 1f110db1d4..7619fd9a82 100644 --- a/server/events/pre_workflow_hooks_command_runner.go +++ b/server/events/pre_workflow_hooks_command_runner.go @@ -34,7 +34,7 @@ func (w *DefaultPreWorkflowHooksCommandRunner) RunPreHooks( preWorkflowHooks := make([]*valid.PreWorkflowHook, 0) for _, repo := range w.GlobalCfg.Repos { - if repo.IDMatches(baseRepo.ID()) && repo.BranchMatches(pull.BaseBranch) && len(repo.PreWorkflowHooks) > 0 { + if repo.IDMatches(baseRepo.ID()) && len(repo.PreWorkflowHooks) > 0 { preWorkflowHooks = append(preWorkflowHooks, repo.PreWorkflowHooks...) } } diff --git a/server/events/yaml/valid/global_cfg.go b/server/events/yaml/valid/global_cfg.go index c2c4ced7db..81068d11ab 100644 --- a/server/events/yaml/valid/global_cfg.go +++ b/server/events/yaml/valid/global_cfg.go @@ -467,3 +467,15 @@ func (g GlobalCfg) getMatchingCfg(log logging.SimpleLogging, repoID string) (app } return } + +// MatchingRepo returns an instance of Repo which matches a given repoID. +// If multiple repos match, return the last one for consistency with getMatchingCfg. +func (g GlobalCfg) MatchingRepo(repoID string) *Repo { + for i := len(g.Repos) - 1; i >= 0; i-- { + repo := g.Repos[i] + if repo.IDMatches(repoID) { + return &repo + } + } + return nil +} diff --git a/server/events/yaml/valid/global_cfg_test.go b/server/events/yaml/valid/global_cfg_test.go index 475cf7dfa4..f317a81015 100644 --- a/server/events/yaml/valid/global_cfg_test.go +++ b/server/events/yaml/valid/global_cfg_test.go @@ -892,6 +892,69 @@ func TestRepo_BranchMatches(t *testing.T) { Equals(t, false, (valid.Repo{BranchRegex: regexp.MustCompile("release")}).BranchMatches("main")) } +func TestGlobalCfg_MatchingRepo(t *testing.T) { + defaultRepo := valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile(".*"), + ApplyRequirements: []string{}, + } + repo1 := valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + ApplyRequirements: []string{"approved"}, + } + repo2 := valid.Repo{ + ID: "github.com/owner/repo", + BranchRegex: regexp.MustCompile("^master$"), + ApplyRequirements: []string{"approved", "mergeable"}, + } + + cases := map[string]struct { + gCfg valid.GlobalCfg + repoID string + exp *valid.Repo + }{ + "matches to default": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo2, + }, + }, + repoID: "foo", + exp: &defaultRepo, + }, + "matches to IDRegex": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo1, + repo2, + }, + }, + repoID: "foo", + exp: &repo1, + }, + "matches to ID": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo1, + repo2, + }, + }, + repoID: "github.com/owner/repo", + exp: &repo2, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + Equals(t, c.exp, c.gCfg.MatchingRepo(c.repoID)) + }) + } +} + // String is a helper routine that allocates a new string value // to store v and returns a pointer to it. func String(v string) *string { return &v } diff --git a/server/server.go b/server/server.go index fd90e15daa..1a05582843 100644 --- a/server/server.go +++ b/server/server.go @@ -599,6 +599,7 @@ func NewServer(userConfig UserConfig, config Config) (*Server, error) { CommentCommandRunnerByCmd: commentCommandRunnerByCmd, EventParser: eventParser, Logger: logger, + GlobalCfg: globalCfg, AllowForkPRs: userConfig.AllowForkPRs, AllowForkPRsFlag: config.AllowForkPRsFlag, SilenceForkPRErrors: userConfig.SilenceForkPRErrors,