Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem: Websocket were required to fetch logs #645

Merged
merged 14 commits into from
Jul 4, 2024
5 changes: 4 additions & 1 deletion src/aleph/vm/controllers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot
from aleph.vm.network.interfaces import TapInterface
from aleph.vm.utils.logs import make_logs_queue
from aleph.vm.utils.logs import get_past_vm_logs, make_logs_queue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,3 +118,6 @@
@property
def _journal_stderr_name(self) -> str:
return f"vm-{self.vm_hash}-stderr"

def past_logs(self):
yield from get_past_vm_logs(self._journal_stdout_name, self._journal_stderr_name)

Check warning on line 123 in src/aleph/vm/controllers/interface.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/controllers/interface.py#L123

Added line #L123 was not covered by tests
4 changes: 0 additions & 4 deletions src/aleph/vm/controllers/qemu/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,3 @@ async def teardown(self):
if self.tap_interface:
await self.tap_interface.delete()
await self.stop_guest_api()

def print_logs(self) -> None:
"""Print logs to our output for debugging"""
queue = self.get_log_queue()
4 changes: 3 additions & 1 deletion src/aleph/vm/orchestrator/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .views.operator import (
operate_erase,
operate_expire,
operate_logs,
operate_reboot,
operate_stop,
stream_logs,
Expand Down Expand Up @@ -100,7 +101,8 @@ def setup_webapp():
web.get("/about/config", about_config),
# /control APIs are used to control the VMs and access their logs
web.post("/control/allocation/notify", notify_allocation),
web.get("/control/machine/{ref}/logs", stream_logs),
web.get("/control/machine/{ref}/stream_logs", stream_logs),
web.get("/control/machine/{ref}/logs", operate_logs),
web.post("/control/machine/{ref}/expire", operate_expire),
web.post("/control/machine/{ref}/stop", operate_stop),
web.post("/control/machine/{ref}/erase", operate_erase),
Expand Down
30 changes: 29 additions & 1 deletion src/aleph/vm/orchestrator/views/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
queue = None
try:
ws = web.WebSocketResponse()
logger.info(f"starting websocket: {request.path}")
await ws.prepare(request)
try:
await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws)
Expand All @@ -75,6 +76,7 @@
while True:
log_type, message = await queue.get()
assert log_type in ("stdout", "stderr")
logger.debug(message)

await ws.send_json({"type": log_type, "message": message})

Expand All @@ -87,15 +89,41 @@
execution.vm.unregister_queue(queue)


@cors_allow_all
@require_jwk_authentication
async def operate_logs(request: web.Request, authenticated_sender: str) -> web.StreamResponse:
hoh marked this conversation as resolved.
Show resolved Hide resolved
"""Logs of a VM (not streaming)"""
vm_hash = get_itemhash_or_400(request.match_info)
pool: VmPool = request.app["vm_pool"]
execution = get_execution_or_404(vm_hash, pool=pool)
if not is_sender_authorized(authenticated_sender, execution.message):
return web.Response(status=403, body="Unauthorized sender")

Check warning on line 100 in src/aleph/vm/orchestrator/views/operator.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/orchestrator/views/operator.py#L100

Added line #L100 was not covered by tests

response = web.StreamResponse()
response.headers["Content-Type"] = "text/plain"
await response.prepare(request)

for entry in execution.vm.past_logs():
msg = f'{entry["__REALTIME_TIMESTAMP"].isoformat()}> {entry["MESSAGE"]}'
await response.write(msg.encode())
olethanh marked this conversation as resolved.
Show resolved Hide resolved
await response.write_eof()
return response


async def authenticate_websocket_for_vm_or_403(execution: VmExecution, vm_hash: ItemHash, ws: web.WebSocketResponse):
"""Authenticate a websocket connection.

Web browsers do not allow setting headers in WebSocket requests, so the authentication
relies on the first message sent by the client.
"""
first_message = await ws.receive_json()
try:
first_message = await ws.receive_json()
except TypeError as error:
logging.exception(error)
raise web.HTTPForbidden(body="Invalid auth package")
credentials = first_message["auth"]
authenticated_sender = await authenticate_websocket_message(credentials)

if is_sender_authorized(authenticated_sender, execution.message):
logger.debug(f"Accepted request to access logs by {authenticated_sender} on {vm_hash}")
return True
Expand Down
26 changes: 25 additions & 1 deletion src/aleph/vm/utils/logs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from typing import Callable, TypedDict
from datetime import datetime
from typing import Callable, Generator, TypedDict

from systemd import journal

Expand All @@ -10,6 +11,7 @@
class EntryDict(TypedDict):
SYSLOG_IDENTIFIER: str
MESSAGE: str
__REALTIME_TIMESTAMP: datetime


def make_logs_queue(stdout_identifier, stderr_identifier, skip_past=False) -> tuple[asyncio.Queue, Callable[[], None]]:
Expand Down Expand Up @@ -56,3 +58,25 @@
r.close()

return queue, do_cancel


def get_past_vm_logs(stdout_identifier, stderr_identifier) -> Generator[EntryDict, None, None]:
"""Get existing log for the VM identifiers.

@param stdout_identifier: journald identifier for process stdout
@param stderr_identifier: journald identifier for process stderr
@return: an iterator of log entry

Works by creating a journald reader, and using `add_reader` to call a callback when
data is available for reading.

For more information refer to the sd-journal(3) manpage
and systemd.journal module documentation.
"""
r = journal.Reader()
r.add_match(SYSLOG_IDENTIFIER=stdout_identifier)
r.add_match(SYSLOG_IDENTIFIER=stderr_identifier)

Check warning on line 78 in src/aleph/vm/utils/logs.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/utils/logs.py#L76-L78

Added lines #L76 - L78 were not covered by tests

r.seek_head()

Check warning on line 80 in src/aleph/vm/utils/logs.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/utils/logs.py#L80

Added line #L80 was not covered by tests
for entry in r:
yield entry

Check warning on line 82 in src/aleph/vm/utils/logs.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/utils/logs.py#L82

Added line #L82 was not covered by tests
86 changes: 86 additions & 0 deletions src/aleph/vm/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import datetime
import json

import eth_account.messages
import pytest
from eth_account.datastructures import SignedMessage
from eth_account.signers.local import LocalAccount
from jwcrypto import jwk
from jwcrypto.jwa import JWA


@pytest.fixture
def patch_datetime_now(monkeypatch):
"""Fixture for patching the datetime.now() and datetime.utcnow() methods
to return a fixed datetime object.
This fixture creates a subclass of `datetime.datetime` called `mydatetime`,
which overrides the `now()` and `utcnow()` class methods to return a fixed
datetime object specified by `FAKE_TIME`.
"""

class MockDateTime(datetime.datetime):
FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55)

@classmethod
def now(cls, tz=None, *args, **kwargs):
return cls.FAKE_TIME.replace(tzinfo=tz)

@classmethod
def utcnow(cls, *args, **kwargs):
return cls.FAKE_TIME

monkeypatch.setattr(datetime, "datetime", MockDateTime)
return MockDateTime


async def generate_signer_and_signed_headers_for_operation(
patch_datetime_now, operation_payload: dict
) -> tuple[LocalAccount, dict]:
"""Generate a temporary eth_account for testing and sign the operation with it"""
account = eth_account.Account()
signer_account = account.create()
key = jwk.JWK.generate(
kty="EC",
crv="P-256",
# key_ops=["verify"],
)
pubkey = {
"pubkey": json.loads(key.export_public()),
"alg": "ECDSA",
"domain": "localhost",
"address": signer_account.address,
"expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z",
}
pubkey_payload = json.dumps(pubkey).encode("utf-8").hex()
signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload)
signed_message: SignedMessage = signer_account.sign_message(signable_message)
pubkey_signature = to_0x_hex(signed_message.signature)
pubkey_signature_header = json.dumps(
{
"payload": pubkey_payload,
"signature": pubkey_signature,
}
)
payload_as_bytes = json.dumps(operation_payload).encode("utf-8")

payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes)
headers = {
"X-SignedPubKey": pubkey_signature_header,
"X-SignedOperation": json.dumps(
{
"payload": payload_as_bytes.hex(),
"signature": payload_signature.hex(),
}
),
}
return signer_account, headers


def to_0x_hex(b: bytes) -> str:
"""
Convert the bytes to a 0x-prefixed hex string
"""

# force this for compat between different hexbytes versions which behave differenty
# and conflict with other package don't allow us to have the version we want
return "0x" + bytes.hex(b)
81 changes: 7 additions & 74 deletions tests/supervisor/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import json
from typing import Any

Expand All @@ -8,22 +7,16 @@
from eth_account.datastructures import SignedMessage
from jwcrypto import jwk, jws
from jwcrypto.common import base64url_decode
from jwcrypto.jwa import JWA

from aleph.vm.orchestrator.views.authentication import (
authenticate_jwk,
require_jwk_authentication,
)


def to_0x_hex(b: bytes) -> str:
"""
Convert the bytes to a 0x-prefixed hex string
"""

# force this for compat between different hexbytes versions which behave differenty
# and conflict with other package don't allow us to have the version we want
return "0x" + bytes.hex(b)
from aleph.vm.utils.test_helpers import (
generate_signer_and_signed_headers_for_operation,
patch_datetime_now,
to_0x_hex,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -67,30 +60,6 @@ async def view(request, authenticated_sender):
assert {"error": "Invalid X-SignedPubKey format"} == r


@pytest.fixture
def patch_datetime_now(monkeypatch):
"""Fixture for patching the datetime.now() and datetime.utcnow() methods
to return a fixed datetime object.
This fixture creates a subclass of `datetime.datetime` called `mydatetime`,
which overrides the `now()` and `utcnow()` class methods to return a fixed
datetime object specified by `FAKE_TIME`.
"""

class MockDateTime(datetime.datetime):
FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55)

@classmethod
def now(cls, tz=None, *args, **kwargs):
return cls.FAKE_TIME.replace(tzinfo=tz)

@classmethod
def utcnow(cls, *args, **kwargs):
return cls.FAKE_TIME

monkeypatch.setattr(datetime, "datetime", MockDateTime)
return MockDateTime


@pytest.mark.asyncio
async def test_require_jwk_authentication_expired(aiohttp_client):
app = web.Application()
Expand Down Expand Up @@ -257,32 +226,8 @@ async def test_require_jwk_authentication_good_key(aiohttp_client, patch_datetim
"""An HTTP request to a view decorated by `@require_jwk_authentication`
auth correctly a temporary key signed by a wallet and an operation signed by that key"""
app = web.Application()

account = eth_account.Account()
signer_account = account.create()
key = jwk.JWK.generate(
kty="EC",
crv="P-256",
# key_ops=["verify"],
)

pubkey = {
"pubkey": json.loads(key.export_public()),
"alg": "ECDSA",
"domain": "localhost",
"address": signer_account.address,
"expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z",
}
pubkey_payload = json.dumps(pubkey).encode("utf-8").hex()
signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload)
signed_message: SignedMessage = signer_account.sign_message(signable_message)
pubkey_signature = to_0x_hex(signed_message.signature)
pubkey_signature_header = json.dumps(
{
"payload": pubkey_payload,
"signature": pubkey_signature,
}
)
payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/"}
signer_account, headers = await generate_signer_and_signed_headers_for_operation(patch_datetime_now, payload)

@require_jwk_authentication
async def view(request, authenticated_sender):
Expand All @@ -292,18 +237,6 @@ async def view(request, authenticated_sender):
app.router.add_get("", view)
client = await aiohttp_client(app)

payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/"}

payload_as_bytes = json.dumps(payload).encode("utf-8")
headers = {"X-SignedPubKey": pubkey_signature_header}
payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes)
headers["X-SignedOperation"] = json.dumps(
{
"payload": payload_as_bytes.hex(),
"signature": payload_signature.hex(),
}
)

resp = await client.get("/", headers=headers)
assert resp.status == 200, await resp.text()

Expand Down
Loading