Skip to content

Commit

Permalink
Additional headers for WS accept message.
Browse files Browse the repository at this point in the history
  • Loading branch information
matiuszka committed Dec 15, 2021
1 parent c8b9581 commit 128e35f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
3 changes: 1 addition & 2 deletions docs/websockets.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

Starlette includes a `WebSocket` class that fulfils a similar role
to the HTTP request, but that allows sending and receiving data on a websocket.

Expand Down Expand Up @@ -51,7 +50,7 @@ For example: `websocket.path_params['username']`

### Accepting the connection

* `await websocket.accept(subprotocol=None)`
* `await websocket.accept(subprotocol=None, headers=None)`

### Sending data

Expand Down
2 changes: 2 additions & 0 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.additional_headers = None
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
Expand All @@ -313,6 +314,7 @@ def __enter__(self) -> "WebSocketTestSession":
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.additional_headers = message.get("headers", None)
return self

def __exit__(self, *args: typing.Any) -> None:
Expand Down
10 changes: 8 additions & 2 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,17 @@ async def send(self, message: Message) -> None:
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')

async def accept(self, subprotocol: str = None) -> None:
async def accept(
self,
subprotocol: str = None,
headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
) -> None:
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)

def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
Expand Down
14 changes: 14 additions & 0 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,20 @@ async def asgi(receive, send):
assert websocket.accepted_subprotocol == "wamp"


def test_additional_headers(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept(headers=[(b"additional", b"header")])
await websocket.close()

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.additional_headers = [(b"additional", b"header")]


def test_websocket_exception(test_client_factory):
def app(scope):
async def asgi(receive, send):
Expand Down

0 comments on commit 128e35f

Please sign in to comment.