Skip to content

Commit

Permalink
Support SSE (#461)
Browse files Browse the repository at this point in the history
* Support SSE

* Fix merge

* Add sse everywhere

* async for

* Add accept headers

* Fix merge

* Revert poetry.lock

* Temporary .fernignore
  • Loading branch information
billytrend-cohere authored Apr 17, 2024
1 parent c56b135 commit 6954548
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ src/cohere/utils.py
src/cohere/overrides.py
src/cohere/config.py
src/cohere/manually_maintained

src/cohere/base_client.py
57 changes: 41 additions & 16 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from json.decoder import JSONDecodeError

import httpx
from httpx_sse import EventSource

from .connectors.client import AsyncConnectorsClient, ConnectorsClient
from .core.api_error import ApiError
Expand Down Expand Up @@ -345,6 +346,7 @@ def chat_stream(
headers=jsonable_encoder(
remove_none_from_dict(
{
"Accept": "*/*, text/event-stream, application/stream+json",
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
Expand All @@ -357,10 +359,15 @@ def chat_stream(
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
) as _response:
if 200 <= _response.status_code < 300:
for _text in _response.iter_lines():
if len(_text) == 0:
continue
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(_text))) # type: ignore
try:
event_source = EventSource(_response)
for sse in event_source.iter_sse():
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(sse.data.data))) # type: ignore
except Exception:
for _text in _response.iter_lines():
if len(_text) == 0:
continue
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(_text))) # type: ignore
return
_response.read()
if _response.status_code == 429:
Expand Down Expand Up @@ -763,6 +770,7 @@ def generate_stream(
headers=jsonable_encoder(
remove_none_from_dict(
{
"Accept": "*/*, text/event-stream, application/stream+json",
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
Expand All @@ -775,10 +783,15 @@ def generate_stream(
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
) as _response:
if 200 <= _response.status_code < 300:
for _text in _response.iter_lines():
if len(_text) == 0:
continue
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(_text))) # type: ignore
try:
event_source = EventSource(_response)
for sse in event_source.iter_sse():
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(sse.data))) # type: ignore
except Exception:
for _text in _response.iter_lines():
if len(_text) == 0:
continue
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(_text))) # type: ignore
return
_response.read()
if _response.status_code == 400:
Expand Down Expand Up @@ -1844,6 +1857,7 @@ async def chat_stream(
headers=jsonable_encoder(
remove_none_from_dict(
{
"Accept": "*/*, text/event-stream, application/stream+json",
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
Expand All @@ -1856,10 +1870,15 @@ async def chat_stream(
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
) as _response:
if 200 <= _response.status_code < 300:
async for _text in _response.aiter_lines():
if len(_text) == 0:
continue
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(_text))) # type: ignore
try:
event_source = EventSource(_response)
async for sse in event_source.aiter_sse():
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(sse.data))) # type: ignore
except Exception:
async for _text in _response.aiter_lines():
if len(_text) == 0:
continue
yield typing.cast(StreamedChatResponse, construct_type(type_=StreamedChatResponse, object_=json.loads(_text))) # type: ignore
return
await _response.aread()
if _response.status_code == 429:
Expand Down Expand Up @@ -2262,6 +2281,7 @@ async def generate_stream(
headers=jsonable_encoder(
remove_none_from_dict(
{
"Accept": "*/*, text/event-stream, application/stream+json",
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
Expand All @@ -2274,10 +2294,15 @@ async def generate_stream(
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
) as _response:
if 200 <= _response.status_code < 300:
async for _text in _response.aiter_lines():
if len(_text) == 0:
continue
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(_text))) # type: ignore
try:
event_source = EventSource(_response)
async for sse in event_source.aiter_sse():
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(sse.data))) # type: ignore
except Exception:
async for _text in _response.aiter_lines():
if len(_text) == 0:
continue
yield typing.cast(GenerateStreamedResponse, construct_type(type_=GenerateStreamedResponse, object_=json.loads(_text))) # type: ignore
return
await _response.aread()
if _response.status_code == 400:
Expand Down

0 comments on commit 6954548

Please sign in to comment.