Skip to content

Commit

Permalink
fix: actually stream events from the async client (#319)
Browse files Browse the repository at this point in the history
* remove duplicated code

* fix streaming client

* use async client

* refresh the session object at each iteration
  • Loading branch information
masci authored Oct 15, 2024
1 parent 36adfa3 commit d53a748
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 57 deletions.
46 changes: 24 additions & 22 deletions llama_deploy/client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import httpx
import json
import time
from typing import Any, AsyncGenerator, List, Optional

import httpx

from llama_deploy.control_plane.server import ControlPlaneConfig
from llama_deploy.types import (
TaskDefinition,
ServiceDefinition,
TaskResult,
SessionDefinition,
TaskDefinition,
TaskResult,
)

DEFAULT_TIMEOUT = 120.0
Expand Down Expand Up @@ -126,26 +127,27 @@ async def get_task_result_stream(
AsyncGenerator[str, None, None]: A generator that yields the result of the task.
"""
start_time = time.time()
async with httpx.AsyncClient() as client:
while True:
try:
response = await client.get(
f"{self.control_plane_url}/sessions/{self.session_id}/tasks/{task_id}/result_stream"
while True:
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"{self.control_plane_url}/sessions/{self.session_id}/tasks/{task_id}/result_stream",
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
json_line = json.loads(line)
yield json_line
break # Exit the function if successful
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise # Re-raise if it's not a 404 error
if time.time() - start_time < self.timeout:
await asyncio.sleep(self.poll_interval)
else:
raise TimeoutError(
f"Task result not available after waiting for {self.timeout} seconds"
)
response.raise_for_status()
async for line in response.aiter_lines():
json_line = json.loads(line)
yield json_line
break # Exit the function if successful
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise # Re-raise if it's not a 404 error
if time.time() - start_time < self.timeout:
await asyncio.sleep(self.poll_interval)
else:
raise TimeoutError(
f"Task result not available after waiting for {self.timeout} seconds"
)


class AsyncLlamaDeployClient:
Expand Down
49 changes: 14 additions & 35 deletions llama_deploy/control_plane/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import asyncio
import json
import uuid
from logging import getLogger
from typing import Any, AsyncGenerator, Dict, List, Optional

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from logging import getLogger
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import AsyncGenerator, Any, Dict, List, Optional

from llama_index.core.storage.kvstore.types import BaseKVStore
from llama_index.core.storage.kvstore import SimpleKVStore
from llama_index.core.storage.kvstore.types import BaseKVStore
from pydantic_settings import BaseSettings, SettingsConfigDict

from llama_deploy.control_plane.base import BaseControlPlane
from llama_deploy.message_consumers.base import (
Expand All @@ -25,8 +25,8 @@
from llama_deploy.orchestrators.utils import get_result_key, get_stream_key
from llama_deploy.types import (
ActionTypes,
SessionDefinition,
ServiceDefinition,
SessionDefinition,
TaskDefinition,
TaskResult,
TaskStream,
Expand Down Expand Up @@ -549,41 +549,19 @@ async def event_generator(
session: SessionDefinition, stream_key: str
) -> AsyncGenerator[str, None]:
try:
stream_results = session.state[stream_key]
stream_results = sorted(stream_results, key=lambda x: x["index"])
for result in stream_results:
if not isinstance(result, TaskStream):
if isinstance(result, dict):
result = TaskStream(**result)
elif isinstance(result, str):
result = TaskStream(**json.loads(result))
else:
raise ValueError("Unexpected result type in stream")

yield json.dumps(result.data) + "\n"

# check if there is a final result
final_result = await self.get_task_result(task_id, session_id)
if final_result is not None:
return

# Continue to check for new results
last_index = 0
while True:
await asyncio.sleep(
self.step_interval
) # Small delay to prevent tight loop
session = await self.get_session(session_id)
new_results = session.state[stream_key][len(stream_results) :]
new_results = sorted(new_results, key=lambda x: x["index"])

for result in new_results:
stream_results = session.state[stream_key][last_index:]
stream_results = sorted(stream_results, key=lambda x: x["index"])
for result in stream_results:
if not isinstance(result, TaskStream):
if isinstance(result, dict):
result = TaskStream(**result)
elif isinstance(result, str):
result = TaskStream(**json.loads(result))
else:
raise ValueError("Unexpected result type")
raise ValueError("Unexpected result type in stream")

yield json.dumps(result.data) + "\n"

Expand All @@ -592,8 +570,9 @@ async def event_generator(
if final_result is not None:
return

# update results list used for indexing
stream_results.extend(new_results)
last_index += len(stream_results)
# Small delay to prevent tight loop
await asyncio.sleep(self.step_interval)
except Exception as e:
logger.error(
f"Error in event stream for session {session_id}, task {task_id}: {str(e)}"
Expand Down

0 comments on commit d53a748

Please sign in to comment.