Skip to content

Commit

Permalink
Fix magic label tagged job not picked up by runner
Browse files Browse the repository at this point in the history
  • Loading branch information
Tereius committed Aug 6, 2024
1 parent 7829b77 commit 625e5bd
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 36 deletions.
2 changes: 1 addition & 1 deletion compute.tf
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ resource "google_compute_project_metadata_item" "startup_scripts_register_jit_ru
agent_name=$(hostname)
echo "Setup of agent '$agent_name' started"
apt-get update && apt-get -y install docker.io docker-buildx curl jq ${local.github_runner_package_install}
useradd -d /home/agent -u ${var.github_runner_uid} -g ${var.github_runner_gid} agent
useradd -d /home/agent -u ${var.github_runner_uid} agent
usermod -aG docker agent
newgrp docker
curl -s -o /tmp/agent.tar.gz -L '${var.github_runner_download_url}'
Expand Down
83 changes: 55 additions & 28 deletions runner-autoscaler/pkg/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"math/rand"
"net/http"
"net/url"
"regexp"
"strings"

cloudtasks "cloud.google.com/go/cloudtasks/apiv2"
Expand Down Expand Up @@ -92,25 +93,45 @@ func (j Job) hasLabel(label string) bool {
return false
}

func (j Job) getMagicLabel(key string) *string {
type MagicLabel string

labelKey := "@" + key + ":"
const (
MagicLabelMachine MagicLabel = "machine"
)

var magicLabels = []string{string(MagicLabelMachine)}
var matchMagicLabels = regexp.MustCompile(`@(` + strings.Join(magicLabels, "|") + `):`)

func IsMagicLabel(label string) bool {

if matches := matchMagicLabels.FindStringSubmatch(label); len(matches) >= 2 {
return true
}
return false
}

func (j Job) GetMagicLabelValue(key MagicLabel) *string {

matchMagicLabel := regexp.MustCompile("@(" + string(key) + "):(.+)")
for _, l := range j.Labels {
if strings.HasPrefix(l, labelKey) {
ret := l[len(labelKey):]
matches := matchMagicLabel.FindStringSubmatch(l)
if len(matches) >= 3 {
ret := matches[2]
return &ret
}
}
return nil
}

// returns true if all labels were found. false otherwise. Returns also all labels that were missing
func (j Job) hasAllLabels(labels []string) (bool, []string) {
// returns true if all labels were found (excluding magic labels). false otherwise. Returns also all labels that were missing
func (j Job) HasAllLabels(labels []string) (bool, []string) {

missingLabels := []string{}
for _, label := range labels {
if !j.hasLabel(label) {
missingLabels = append(missingLabels, label)
if !IsMagicLabel(label) {
if !j.hasLabel(label) {
missingLabels = append(missingLabels, label)
}
}
}
return len(missingLabels) <= 0, missingLabels
Expand Down Expand Up @@ -429,7 +450,8 @@ func (s *Autoscaler) GenerateRunnerRegistrationToken(ctx context.Context) (strin
}
}*/

func (s *Autoscaler) GenerateRunnerJitConfig(ctx context.Context, url string, runnerName string, runnerGroupId int64) (string, error) {
// A jit-config needs: RunnerName, RunnerGroupId, Labels, WorkFolder
func (s *Autoscaler) GenerateRunnerJitConfig(ctx context.Context, url string, runnerName string, runnerGroupId int64, labels []string) (string, error) {

log.Debugf("About to request GitHub runner %s jit config from %s (runner group %d) using PAT from secret version: %s", runnerName, url, runnerGroupId, s.conf.SecretVersion)
secretAccessClient := newSecretAccessClient(ctx)
Expand All @@ -440,7 +462,7 @@ func (s *Autoscaler) GenerateRunnerJitConfig(ctx context.Context, url string, ru
reqPayload := map[string]any{}
reqPayload["name"] = runnerName
reqPayload["runner_group_id"] = runnerGroupId
reqPayload["labels"] = s.conf.RunnerLabels
reqPayload["labels"] = labels
reqPayload["work_folder"] = "_work"
data, _ := json.Marshal(reqPayload)
if req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data)); err != nil {
Expand Down Expand Up @@ -476,9 +498,9 @@ func (s *Autoscaler) GenerateRunnerJitConfig(ctx context.Context, url string, ru
}
}

func (s *Autoscaler) createCallbackTaskWithToken(ctx context.Context, url string, secret string, settings VmSettings) error {
func (s *Autoscaler) createCallbackTaskWithToken(ctx context.Context, url string, secret string, job Job) error {

data, _ := json.Marshal(settings)
data, _ := json.Marshal(job)
now := timestamppb.Now()
now.Seconds += 1 // delay the callback a little bit
req := &taskspb.CreateTaskRequest{
Expand Down Expand Up @@ -545,9 +567,9 @@ func (s *Autoscaler) createVmWithRegistrationToken(ctx *gin.Context, instanceNam
}
}*/

func (s *Autoscaler) createVmWithJitConfig(ctx *gin.Context, url string, runnerGroupId int64, settings VmSettings) {
func (s *Autoscaler) createVmWithJitConfig(ctx *gin.Context, url string, runnerGroupId int64, settings VmSettings, labels []string) {

if jitConfig, err := s.GenerateRunnerJitConfig(ctx, url, settings.Name, runnerGroupId); err != nil {
if jitConfig, err := s.GenerateRunnerJitConfig(ctx, url, settings.Name, runnerGroupId, labels); err != nil {
ctx.AbortWithError(http.StatusInternalServerError, err)
} else {
jit_config_attr := fmt.Sprintf("%s_%s", RUNNER_JIT_CONFIG_ATTR, RandStringRunes(16))
Expand All @@ -569,19 +591,29 @@ func (s *Autoscaler) handleCreateVm(ctx *gin.Context) {

log.Info("Received create-vm cloud task callback")
if data, src, err := s.verifySignature(ctx); err == nil {
vmSettings := VmSettings{}
json.Unmarshal(data, &vmSettings)
job := Job{}
json.Unmarshal(data, &job)
// use jit config
switch src.SourceType {
case TypeEnterprise:
log.Infof("Using jit config for runner registration for enterprise: %s", src.Name)
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_ENTERPRISE_JIT_CONFIG_ENDPOINT, src.Name), s.conf.RunnerGroupId, vmSettings)
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_ENTERPRISE_JIT_CONFIG_ENDPOINT, src.Name), s.conf.RunnerGroupId, VmSettings{
Name: fmt.Sprintf("%s-%s", s.conf.RunnerPrefix, RandStringRunes(10)),
MachineType: job.GetMagicLabelValue(MagicLabelMachine),
}, job.Labels)
case TypeOrganization:
log.Infof("Using jit config for runner registration for organization: %s", src.Name)
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_ORG_JIT_CONFIG_ENDPOINT, src.Name), s.conf.RunnerGroupId, vmSettings)
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_ORG_JIT_CONFIG_ENDPOINT, src.Name), s.conf.RunnerGroupId, VmSettings{
Name: fmt.Sprintf("%s-%s", s.conf.RunnerPrefix, RandStringRunes(10)),
MachineType: job.GetMagicLabelValue(MagicLabelMachine),
}, job.Labels)
case TypeRepository:
log.Infof("Using jit config for runner registration for repository: %s", src.Name)
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_REPO_JIT_CONFIG_ENDPOINT, src.Name), 1, vmSettings) // For repositories there is an implicit runner group with id 1
// For repositories there is an implicit runner group with id 1
s.createVmWithJitConfig(ctx, fmt.Sprintf(RUNNER_REPO_JIT_CONFIG_ENDPOINT, src.Name), 1, VmSettings{
Name: fmt.Sprintf("%s-%s", s.conf.RunnerPrefix, RandStringRunes(10)),
MachineType: job.GetMagicLabelValue(MagicLabelMachine),
}, job.Labels)
default:
log.Errorf("Missing source type for %s", src.Name)
ctx.Status(http.StatusBadRequest)
Expand Down Expand Up @@ -619,12 +651,9 @@ func (s *Autoscaler) handleWebhook(ctx *gin.Context) {
ctx.AbortWithError(http.StatusBadRequest, err)
} else {
if payload.Action == QUEUED {
if ok, missingLabels := payload.Job.hasAllLabels(s.conf.RunnerLabels); ok {
if ok, missingLabels := payload.Job.HasAllLabels(s.conf.RunnerLabels); ok {
createUrl := createCallbackUrl(ctx, s.conf.RouteCreateVm, s.conf.SourceQueryParam, src.Name)
if err := s.createCallbackTaskWithToken(ctx, createUrl, src.Secret, VmSettings{
Name: fmt.Sprintf("%s-%s", s.conf.RunnerPrefix, RandStringRunes(10)),
MachineType: payload.Job.getMagicLabel("machine"),
}); err != nil {
if err := s.createCallbackTaskWithToken(ctx, createUrl, src.Secret, payload.Job); err != nil {
log.Errorf("Can not enqueue create-vm cloud task callback: %s", err.Error())
ctx.AbortWithError(http.StatusInternalServerError, err)
return
Expand All @@ -638,11 +667,9 @@ func (s *Autoscaler) handleWebhook(ctx *gin.Context) {
runnerGroupId = 1
}
if payload.Job.RunnerGroupId == runnerGroupId {
if ok, missingLabels := payload.Job.hasAllLabels(s.conf.RunnerLabels); ok {
if ok, missingLabels := payload.Job.HasAllLabels(s.conf.RunnerLabels); ok {
deleteUrl := createCallbackUrl(ctx, s.conf.RouteDeleteVm, s.conf.SourceQueryParam, src.Name)
if err := s.createCallbackTaskWithToken(ctx, deleteUrl, src.Secret, VmSettings{
Name: payload.Job.RunnerName,
}); err != nil {
if err := s.createCallbackTaskWithToken(ctx, deleteUrl, src.Secret, payload.Job); err != nil {
log.Errorf("Can not enqueue delete-vm cloud task callback: %s", err.Error())
ctx.AbortWithError(http.StatusInternalServerError, err)
return
Expand Down
26 changes: 25 additions & 1 deletion runner-autoscaler/test/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,31 @@ func TestGenerateRunnerJitConfig(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
jitConfig, err := scaler.GenerateRunnerJitConfig(ctx, fmt.Sprintf(pkg.RUNNER_REPO_JIT_CONFIG_ENDPOINT, TEST_REPO), "unit_test_runner_"+pkg.RandStringRunes(10), 1)
jitConfig, err := scaler.GenerateRunnerJitConfig(ctx, fmt.Sprintf(pkg.RUNNER_REPO_JIT_CONFIG_ENDPOINT, TEST_REPO), "unit_test_runner_"+pkg.RandStringRunes(10), 1, []string{"self-hosted"})
assert.Nil(t, err)
assert.NotEmpty(t, jitConfig)
}

func TestGetMagicLabelValue(t *testing.T) {

job := pkg.Job{
Labels: []string{"test", "@foo:bar", "@machine:test"},
}
result := job.GetMagicLabelValue(pkg.MagicLabelMachine)
assert.NotNil(t, result)
assert.Equal(t, "test", *result)
}

func TestHasAllLabels(t *testing.T) {

job := pkg.Job{
Labels: []string{"test", "@foo:bar", "@machine:test"},
}
result, missing := job.HasAllLabels([]string{"test"})
assert.True(t, result)
assert.Empty(t, missing)
result, missing = job.HasAllLabels([]string{"test", "foo"})
assert.False(t, result)
assert.NotEmpty(t, missing)
assert.Len(t, missing, 1)
}
6 changes: 0 additions & 6 deletions variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,6 @@ variable "github_runner_uid" {
default = 10000
}

variable "github_runner_gid" {
type = number
description = "The gid the runner will be run with."
default = 10000
}

variable "github_runner_packages" {
type = list(string)
description = "Additional packages that will be installed in the runner with apt."
Expand Down

0 comments on commit 625e5bd

Please sign in to comment.