Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Await socket operations + some other minor cleanup #391

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"proecss as the model servinge engine.")
"process as the model serving engine.")

parser = AsyncEngineArgs.add_cli_args(parser)

Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
socket.connect(self.path)

# Ping RPC Server with request.
socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))
await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))

# Await acknowledgement from RPCServer.
response = pickle.loads(await socket.recv())

if (not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR):
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
socket.close()
raise ValueError(error_message)

Expand Down Expand Up @@ -80,7 +80,7 @@ async def get_model_config(self) -> ModelConfig:
socket.connect(self.path)

# Ping RPCServer with GET_MODEL_CONFIG request.
socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))

# Await the MODEL_CONFIG from the Server.
model_config = pickle.loads(await socket.recv())
Expand Down Expand Up @@ -126,7 +126,7 @@ async def generate(
socket.connect(self.path)

# Send RPCGenerateRequest to the RPCServer.
socket.send_multipart([
await socket.send_multipart([
pickle.dumps(
RPCGenerateRequest(
inputs=inputs,
Expand Down
17 changes: 7 additions & 10 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

class RPCServer:

# TODO: check if opening all these sockets is an antipattern.
# Alternative, use a smaller number of sockets with conditioning on the
# data that is passed through the socket.
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
# Initialize engine first.
Expand All @@ -41,7 +38,7 @@ def cleanup(self):

async def _send_success_message(self, identity):
"""Send message to client indicating an action was successful."""
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -50,20 +47,20 @@ async def get_model_config(self, identity):
"""Send the ModelConfig """
model_config = await self.engine.get_model_config()

self.socket.send_multipart(
await self.socket.send_multipart(
[identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
await self.engine.do_log_stats()

self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])

async def is_server_ready(self, identity):
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -73,7 +70,7 @@ async def abort(self, identity, request: RPCAbortRequest):
await self.engine.abort(request.request_id)

# Send confirmation to the client.
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -86,14 +83,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):
request_id=generate_request.request_id)

async for request_output in results_generator:
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)
])

except Exception as e:
### Notify client of all failures
self.socket.send_multipart(
await self.socket.send_multipart(
[identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)])

def _make_handler_coro(self, identity,
Expand Down
7 changes: 4 additions & 3 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,14 @@ def merge_async_iterators(
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()

finished = [False] * len(iterators)
producers = len(iterators)

async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())

Expand All @@ -320,13 +319,15 @@ async def producer(i: int, iterator: AsyncIterator[T]):
]

async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()

if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue

if isinstance(item, Exception):
Expand Down
Loading