Skip to content

Commit

Permalink
feat: allow user-defined client context (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcieno authored Apr 29, 2024
1 parent fe11d78 commit ba56ed4
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
9 changes: 9 additions & 0 deletions cmd/aws-lambda-rie/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package main

import (
"bytes"
"encoding/base64"
"fmt"
"io/ioutil"
"math"
Expand Down Expand Up @@ -81,6 +82,13 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i
return
}

rawClientContext, err := base64.StdEncoding.DecodeString(r.Header.Get("X-Amz-Client-Context"))
if err != nil {
log.Errorf("Failed to decode X-Amz-Client-Context: %s", err)
w.WriteHeader(500)
return
}

initDuration := ""
inv := GetenvWithDefault("AWS_LAMBDA_FUNCTION_TIMEOUT", "300")
timeoutDuration, _ := time.ParseDuration(inv + "s")
Expand Down Expand Up @@ -114,6 +122,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i
TraceID: r.Header.Get("X-Amzn-Trace-Id"),
LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"),
Payload: bytes.NewReader(bodyBytes),
ClientContext: string(rawClientContext),
}
fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion)

Expand Down
30 changes: 26 additions & 4 deletions test/integration/local_lambda/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from subprocess import Popen, PIPE
from unittest import TestCase, main
from pathlib import Path
import base64
import json
import time
import os
import requests
Expand Down Expand Up @@ -62,12 +64,14 @@ def run_command(self, cmd):

def sleep_1s(self):
time.sleep(SLEEP_TIME)
def invoke_function(self):

def invoke_function(self, json={}, headers={}):
return requests.post(
f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", json={}
f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations",
json=json,
headers=headers,
)

@contextmanager
def create_container(self, param, image):
try:
Expand Down Expand Up @@ -234,6 +238,24 @@ def test_port_override(self):
self.assertEqual(b'"My lambda ran succesfully"', r.content)


def test_custom_client_context(self):
image, rie, image_name = self.tagged_name("custom_client_context")

params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.custom_client_context_handler"

with self.create_container(params, image):
r = self.invoke_function(headers={
"X-Amz-Client-Context": base64.b64encode(json.dumps({
"custom": {
"foo": "bar",
"baz": 123,
}
}).encode('utf8')).decode('utf8'),
})
content = json.loads(r.content)
self.assertEqual("bar", content["foo"])
self.assertEqual(123, content["baz"])


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions test/integration/testdata/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ def check_remaining_time_handler(event, context):
# Wait 1s to see if the remaining time changes
time.sleep(1)
return context.get_remaining_time_in_millis()


def custom_client_context_handler(event, context):
return context.client_context.custom

0 comments on commit ba56ed4

Please sign in to comment.