Skip to content

Commit

Permalink
Feature: get_program_price (#143)
Browse files Browse the repository at this point in the history
* Feature: get_program_price functions

* Fix: add also to abstract AlehpClient

* Fix: black issue

* Fix: add superfuild to pyproject.toml

* Revert "Fix: add superfuild to pyproject.toml"

This reverts commit d206c0c.

* Fix: isort issue

* Fix: type

* Fix: unit test

* Fix: style issue

* Update src/aleph/sdk/client/http.py

Co-authored-by: Hugo Herter <[email protected]>

---------

Co-authored-by: Hugo Herter <[email protected]>
  • Loading branch information
1yam and hoh authored Aug 16, 2024
1 parent 0de453d commit 2173274
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 3 deletions.
15 changes: 14 additions & 1 deletion src/aleph/sdk/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import (
Any,
AsyncIterable,
Coroutine,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -40,7 +41,7 @@
from aleph.sdk.utils import extended_json_encoder

from ..query.filters import MessageFilter, PostFilter
from ..query.responses import PostsResponse
from ..query.responses import PostsResponse, PriceResponse
from ..types import GenericMessage, StorageEnum
from ..utils import Writable, compute_sha256

Expand Down Expand Up @@ -241,6 +242,18 @@ def watch_messages(
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

@abstractmethod
def get_program_price(
self,
item_hash: str,
) -> Coroutine[Any, Any, PriceResponse]:
"""
Get Program message Price
:param item_hash: item_hash of executable message
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")


class AuthenticatedAlephClient(AlephClient):
account: Account
Expand Down
23 changes: 21 additions & 2 deletions src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
from pydantic import ValidationError

from ..conf import settings
from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError
from ..exceptions import (
FileTooLarge,
ForgottenMessageError,
InvalidHashError,
MessageNotFoundError,
)
from ..query.filters import MessageFilter, PostFilter
from ..query.responses import MessagesResponse, Post, PostsResponse
from ..query.responses import MessagesResponse, Post, PostsResponse, PriceResponse
from ..types import GenericMessage
from ..utils import (
Writable,
Expand Down Expand Up @@ -409,3 +414,17 @@ async def watch_messages(
yield parse_message(data)
elif msg.type == aiohttp.WSMsgType.ERROR:
break

async def get_program_price(self, item_hash: str) -> PriceResponse:
async with self.http_session.get(f"/api/v0/price/{item_hash}") as resp:
try:
resp.raise_for_status()
response_json = await resp.json()
return PriceResponse(
required_tokens=response_json["required_tokens"],
payment_type=response_json["payment_type"],
)
except aiohttp.ClientResponseError as e:
if e.status == 400:
raise InvalidHashError(f"Bad request or no such hash {item_hash}")
raise e
6 changes: 6 additions & 0 deletions src/aleph/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,9 @@ def __init__(self, required_funds: float, available_funds: float):
super().__init__(
f"Insufficient funds: required {required_funds}, available {available_funds}"
)


class InvalidHashError(QueryError):
"""The Hash is not valid"""

pass
7 changes: 7 additions & 0 deletions src/aleph/sdk/query/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ class MessagesResponse(PaginationResponse):

messages: List[AlephMessage]
pagination_item = "messages"


class PriceResponse(BaseModel):
"""Response from an aleph.im node API on the path /api/v0/price/{item_hash}"""

required_tokens: float
payment_type: str
20 changes: 20 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import AsyncMock, MagicMock

import pytest as pytest
from aiohttp import ClientResponseError
from aleph_message.models import AggregateMessage, AlephMessage, PostMessage

import aleph.sdk.chains.ethereum as ethereum
Expand Down Expand Up @@ -230,6 +231,10 @@ class CustomMockResponse(MockResponse):
async def json(self):
return resp

def raise_for_status(self):
if status >= 400:
raise ClientResponseError(None, None, status=status)

@property
def status(self):
return status
Expand Down Expand Up @@ -259,6 +264,21 @@ def get(self, *_args, **_kwargs):
return client


def make_mock_get_session_400(
get_return_value: Union[Dict[str, Any], bytes]
) -> AlephHttpClient:
class MockHttpSession(AsyncMock):
def get(self, *_args, **_kwargs):
return make_custom_mock_response(get_return_value, 400)

http_session = MockHttpSession()

client = AlephHttpClient(api_server="http://localhost")
client.http_session = http_session

return client


@pytest.fixture
def mock_session_with_rejected_message(
ethereum_account, rejected_message
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_price.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from aleph.sdk.exceptions import InvalidHashError
from aleph.sdk.query.responses import PriceResponse
from tests.unit.conftest import make_mock_get_session, make_mock_get_session_400


@pytest.mark.asyncio
async def test_get_program_price_valid():
"""
Test that the get_program_price method returns the correct PriceResponse
when given a valid item hash.
"""
expected_response = {
"required_tokens": 3.0555555555555556e-06,
"payment_type": "superfluid",
}
mock_session = make_mock_get_session(expected_response)
async with mock_session:
response = await mock_session.get_program_price("cacacacacacaca")
assert response == PriceResponse(**expected_response)


@pytest.mark.asyncio
async def test_get_program_price_invalid():
"""
Test that the get_program_price method raises an InvalidHashError
when given an invalid item hash.
"""
mock_session = make_mock_get_session_400({"error": "Invalid hash"})
async with mock_session:
with pytest.raises(InvalidHashError):
await mock_session.get_program_price("invalid_item_hash")

0 comments on commit 2173274

Please sign in to comment.