Skip to content

Commit

Permalink
fix(core): consistently generated session ids (#531)
Browse files Browse the repository at this point in the history
Co-authored-by: Archento <[email protected]>
  • Loading branch information
ejfitzgerald and Archento authored Oct 8, 2024
1 parent f8117b4 commit 82e6e5b
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 87 deletions.
11 changes: 11 additions & 0 deletions python/docs/api/uagents/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ An agent that interacts within a communication environment.
corresponding protocols.
- `_ctx` _Context_ - The context for agent interactions.
- `_test` _bool_ - True if the agent will register and transact on the testnet.
- `_enable_agent_inspector` _bool_ - Enable the agent inspector REST endpoints.
Properties:
- `name` _str_ - The name of the agent.
Expand Down Expand Up @@ -732,6 +733,16 @@ def start_message_receivers()
Start message receiving tasks for the agent.
<a id="src.uagents.agent.Agent.start_server"></a>
#### start`_`server
```python
async def start_server()
```
Start the agent's server.
<a id="src.uagents.agent.Agent.run_async"></a>
#### run`_`async
Expand Down
8 changes: 3 additions & 5 deletions python/docs/api/uagents/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ agent (AgentRepresentation): The agent representation associated with the contex
storage (KeyValueStore): The key-value store for storage operations.
ledger (LedgerClient): The client for interacting with the blockchain ledger.
logger (logging.Logger): The logger instance.
session (uuid.UUID): The session UUID associated with the context.

**Methods**:

Expand Down Expand Up @@ -101,7 +102,7 @@ Get the logger instance associated with the context.
```python
@property
@abstractmethod
def session() -> Union[uuid.UUID, None]
def session() -> uuid.UUID
```

Get the session UUID associated with the context.
Expand Down Expand Up @@ -267,7 +268,7 @@ Represents the agent internal context for proactive behaviour.

```python
@property
def session() -> Union[uuid.UUID, None]
def session() -> uuid.UUID
```

Get the session UUID associated with the context.
Expand Down Expand Up @@ -339,7 +340,6 @@ Represents the reactive context in which messages are handled and processed.

- `_queries` _Dict[str, asyncio.Future]_ - Dictionary mapping query senders to their
response Futures.
- `_session` _Optional[uuid.UUID]_ - The session UUID.
- `_replies` _Optional[Dict[str, Dict[str, Type[Model]]]]_ - Dictionary of allowed reply digests
for each type of incoming message.
- `_message_received` _Optional[MsgDigest]_ - The message digest received.
Expand All @@ -352,7 +352,6 @@ Represents the reactive context in which messages are handled and processed.

```python
def __init__(message_received: MsgDigest,
session: Optional[uuid.UUID] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
protocol: Optional[Tuple[str, Protocol]] = None,
Expand All @@ -366,7 +365,6 @@ Initialize the ExternalContext instance and attributes needed from the InternalC
- `message_received` _MsgDigest_ - The optional message digest received.
- `queries` _Dict[str, asyncio.Future]_ - Dictionary mapping query senders to their
response Futures.
- `session` _Optional[uuid.UUID]_ - The optional session UUID.
- `replies` _Optional[Dict[str, Dict[str, Type[Model]]]]_ - Dictionary of allowed replies
for each type of incoming message.
- `protocol` _Optional[Tuple[str, Protocol]]_ - The optional Tuple of protocols.
Expand Down
86 changes: 59 additions & 27 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parse_agentverse_config,
parse_endpoint_config,
)
from uagents.context import Context, ExternalContext, InternalContext
from uagents.context import Context, ContextFactory, ExternalContext, InternalContext
from uagents.crypto import Identity, derive_key_from_seed, is_user_address
from uagents.dispatch import Sink, dispatcher
from uagents.envelope import EnvelopeHistory, EnvelopeHistoryEntry
Expand Down Expand Up @@ -71,24 +71,31 @@
from uagents.utils import get_logger


async def _run_interval(func: IntervalCallback, ctx: Context, period: float):
async def _run_interval(
func: IntervalCallback,
logger: logging.Logger,
context_factory: ContextFactory,
period: float,
):
"""
Run the provided interval callback function at a specified period.
Args:
func (IntervalCallback): The interval callback function to run.
ctx (Context): The context for the agent.
logger (logging.Logger): The logger instance for logging interval handler activities.
context_factory (ContextFactory): The factory function for creating the context.
period (float): The time period at which to run the callback function.
"""
while True:
try:
ctx = context_factory()
await func(ctx)
except OSError as ex:
ctx.logger.exception(f"OS Error in interval handler: {ex}")
logger.exception(f"OS Error in interval handler: {ex}")
except RuntimeError as ex:
ctx.logger.exception(f"Runtime Error in interval handler: {ex}")
logger.exception(f"Runtime Error in interval handler: {ex}")
except Exception as ex:
ctx.logger.exception(f"Exception in interval handler: {ex}")
logger.exception(f"Exception in interval handler: {ex}")

await asyncio.sleep(period)

Expand Down Expand Up @@ -377,21 +384,6 @@ def __init__(
# keep track of supported protocols
self.protocols: Dict[str, Protocol] = {}

self._ctx = InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

# register with the dispatcher
self._dispatcher.register(self.address, self)

Expand Down Expand Up @@ -426,6 +418,28 @@ async def _handle_get_messages(_ctx: Context):

self._init_done = True

def _build_context(self) -> InternalContext:
"""
An internal method to build the context for the agent.
Returns:
InternalContext: The internal context for the agent.
"""
return InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

def _initialize_wallet_and_identity(self, seed, name, wallet_key_derivation_index):
"""
Initialize the wallet and identity for the agent.
Expand Down Expand Up @@ -997,7 +1011,10 @@ async def handle_rest(
if not handler:
return None

args = (self._ctx, message) if message else (self._ctx,)
args = []
args.append(self._build_context())
if message:
args.append(message)

return await handler(*args) # type: ignore

Expand All @@ -1015,7 +1032,8 @@ async def _startup(self):
)
for handler in self._on_startup:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in startup handler: {ex}")
except RuntimeError as ex:
Expand All @@ -1030,7 +1048,8 @@ async def _shutdown(self):
"""
for handler in self._on_shutdown:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in shutdown handler: {ex}")
except RuntimeError as ex:
Expand Down Expand Up @@ -1061,7 +1080,9 @@ def start_interval_tasks(self):
"""
for func, period in self._interval_handlers:
self._loop.create_task(_run_interval(func, self._ctx, period))
self._loop.create_task(
_run_interval(func, self._logger, self._build_context, period)
)

def start_message_receivers(self):
"""
Expand All @@ -1075,7 +1096,9 @@ def start_message_receivers(self):
if self._wallet_messaging_client is not None:
for task in [
self._wallet_messaging_client.poll_server(),
self._wallet_messaging_client.process_message_queue(self._ctx),
self._wallet_messaging_client.process_message_queue(
self._build_context
),
]:
self._loop.create_task(task)

Expand Down Expand Up @@ -1163,7 +1186,11 @@ async def _process_message_queue(self):
)

context = ExternalContext(
agent=self._ctx.agent,
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
Expand All @@ -1179,6 +1206,11 @@ async def _process_message_queue(self):
protocol=protocol_info,
)

# sanity check
assert (
context.session == session
), "Context object should always have message session"

# parse the received message
try:
recovered = model_class.parse_raw(message)
Expand Down
19 changes: 9 additions & 10 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -59,6 +60,7 @@ class Context(ABC):
storage (KeyValueStore): The key-value store for storage operations.
ledger (LedgerClient): The client for interacting with the blockchain ledger.
logger (logging.Logger): The logger instance.
session (uuid.UUID): The session UUID associated with the context.
Methods:
get_agents_by_protocol(protocol_digest, limit, logger): Retrieve a list of agent addresses
Expand Down Expand Up @@ -116,7 +118,7 @@ def logger(self) -> logging.Logger:

@property
@abstractmethod
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.
Expand Down Expand Up @@ -256,6 +258,7 @@ def __init__(
ledger: LedgerClient,
resolver: Resolver,
dispenser: Dispenser,
session: Optional[uuid.UUID] = None,
interval_messages: Optional[Set[str]] = None,
wallet_messaging_client: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
Expand All @@ -266,7 +269,7 @@ def __init__(
self._resolver = resolver
self._dispenser = dispenser
self._logger = logger
self._session: Optional[uuid.UUID] = None
self._session = session or uuid.uuid4()
self._interval_messages = interval_messages
self._wallet_messaging_client = wallet_messaging_client
self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {}
Expand All @@ -288,7 +291,7 @@ def logger(self) -> Union[logging.Logger, None]:
return self._logger

@property
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.
Expand Down Expand Up @@ -408,7 +411,6 @@ async def send(
we don't have access properties that are only necessary in re-active
contexts, like 'replies', 'message_received', or 'protocol'.
"""
self._session = None
schema_digest = Model.build_schema_digest(message)
message_body = message.model_dump_json()

Expand Down Expand Up @@ -440,8 +442,6 @@ async def send_raw(
protocol_digest: Optional[str] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
) -> MsgStatus:
self._session = self._session or uuid.uuid4()

# Extract address from destination agent identifier if present
_, parsed_name, parsed_address = parse_identifier(destination)

Expand Down Expand Up @@ -564,7 +564,6 @@ class ExternalContext(InternalContext):
Attributes:
_queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their
response Futures.
_session (Optional[uuid.UUID]): The session UUID.
_replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed reply digests
for each type of incoming message.
_message_received (Optional[MsgDigest]): The message digest received.
Expand All @@ -575,7 +574,6 @@ class ExternalContext(InternalContext):
def __init__(
self,
message_received: MsgDigest,
session: Optional[uuid.UUID] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
protocol: Optional[Tuple[str, Protocol]] = None,
Expand All @@ -588,13 +586,11 @@ def __init__(
message_received (MsgDigest): The optional message digest received.
queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their
response Futures.
session (Optional[uuid.UUID]): The optional session UUID.
replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed replies
for each type of incoming message.
protocol (Optional[Tuple[str, Protocol]]): The optional Tuple of protocols.
"""
super().__init__(**kwargs)
self._session = session or None
self._queries = queries or {}
self._replies = replies
self._message_received = message_received
Expand Down Expand Up @@ -674,3 +670,6 @@ async def send(
protocol_digest=self._protocol[0],
queries=self._queries,
)


ContextFactory = Callable[[], Context]
11 changes: 1 addition & 10 deletions python/src/uagents/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from uagents.crypto import Identity
from uagents.network import AlmanacContract, InsufficientFundsError, add_testnet_funds
from uagents.types import AgentEndpoint
from uagents.types import AgentEndpoint, AgentGeoLocation


class AgentRegistrationPolicy(ABC):
Expand All @@ -32,15 +32,6 @@ async def register(
pass


class AgentGeoLocation(BaseModel):
# Latitude and longitude of the agent
latitude: float
longitude: float

# Radius around the agent location, expressed in meters
radius: float


class AgentRegistrationAttestation(BaseModel):
agent_address: str
protocols: List[str]
Expand Down
9 changes: 9 additions & 0 deletions python/src/uagents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ class AgentEndpoint(BaseModel):
weight: int


class AgentGeoLocation(BaseModel):
# Latitude and longitude of the agent
latitude: float
longitude: float

# Radius around the agent location, expressed in meters
radius: float


class AgentInfo(BaseModel):
agent_address: str
endpoints: List[AgentEndpoint]
Expand Down
Loading

0 comments on commit 82e6e5b

Please sign in to comment.