From 0268d7b8e422d4ff976a74aedbbf3d9f655c7966 Mon Sep 17 00:00:00 2001 From: 1yam Date: Tue, 15 Oct 2024 14:32:18 +0200 Subject: [PATCH] Feature: balance endpoints filter on chain --- src/aleph/schemas/api/accounts.py | 9 +++- src/aleph/web/controllers/accounts.py | 27 +++++++++--- tests/api/test_balance.py | 62 +++++++++++++++++++++++++++ tests/conftest.py | 24 +++++++++++ 4 files changed, 114 insertions(+), 8 deletions(-) diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index ca904d3af..01d5f53b0 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -1,7 +1,8 @@ import datetime as dt from decimal import Decimal -from typing import List +from typing import List, Optional +from aleph_message.models import Chain from pydantic import BaseModel, Field from aleph.types.files import FileType @@ -9,6 +10,12 @@ from aleph.web.controllers.utils import DEFAULT_PAGE +class GetAccountQueryParams(BaseModel): + chain: Optional[Chain] = Field( + default=None, description="Get Balance on a specific EVM Chain" + ) + + class GetAccountBalanceResponse(BaseModel): address: str balance: Decimal diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 141b08285..5bdd76792 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -6,7 +6,7 @@ from aleph_message.models import MessageType from pydantic import ValidationError, parse_obj_as -from aleph.db.accessors.balances import get_total_balance +from aleph.db.accessors.balances import get_balance_by_chain, get_total_balance from aleph.db.accessors.cost import get_total_cost_for_address from aleph.db.accessors.files import get_address_files_for_api, get_address_files_stats from aleph.db.accessors.messages import get_message_stats_by_address @@ -15,6 +15,7 @@ GetAccountFilesQueryParams, GetAccountFilesResponse, GetAccountFilesResponseItem, + GetAccountQueryParams, ) from aleph.types.db_session import DbSessionFactory from aleph.web.controllers.app_state_getters import get_session_factory_from_request @@ -62,15 +63,27 @@ def _get_address_from_request(request: web.Request) -> str: async def get_account_balance(request: web.Request): address = _get_address_from_request(request) + try: + query_params = GetAccountQueryParams.parse_obj(request.query) + except ValidationError as e: + raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + session_factory: DbSessionFactory = get_session_factory_from_request(request) with session_factory() as session: - balance = get_total_balance( - session=session, address=address, include_dapps=False - ) - total_cost = get_total_cost_for_address(session=session, address=address) + if query_params.chain is None: + balance = ( + get_total_balance(session=session, address=address, include_dapps=False) + or 0 + ) + else: + balance = ( + get_balance_by_chain( + session=session, address=address, chain=query_params.chain + ) + or 0 + ) - if balance is None: - raise web.HTTPNotFound() + total_cost = get_total_cost_for_address(session=session, address=address) return web.json_response( text=GetAccountBalanceResponse( diff --git a/tests/api/test_balance.py b/tests/api/test_balance.py index c37d1b44c..62ede9ff8 100644 --- a/tests/api/test_balance.py +++ b/tests/api/test_balance.py @@ -1,4 +1,5 @@ import pytest +from aleph_message.models import Chain from aleph.db.models import AlephBalanceDb from aleph.jobs.process_pending_messages import PendingMessageProcessor @@ -25,3 +26,64 @@ async def test_get_balance( data = await response.json() assert data["balance"] == user_balance.balance assert data["locked_amount"] == 2002.4666666666667 + + +@pytest.mark.asyncio +async def test_get_balance_with_chain( + ccn_api_client, + message_processor: PendingMessageProcessor, + instance_message_with_volumes_in_db, + fixture_instance_message, + user_balance_eth_avax: AlephBalanceDb, +): + pipeline = message_processor.make_pipeline() + _ = [message async for message in pipeline] + + assert fixture_instance_message.item_content + expected_locked_amount = 2002.4666666666667 + chain = Chain.AVAX.value + # Test Avax + avax_response = await ccn_api_client.get(f"{MESSAGES_URI}?chain={chain}") + + assert avax_response.status == 200, await avax_response.text() + avax_data = await avax_response.json() + avax_expected_balance = user_balance_eth_avax.balance + assert avax_data["balance"] == avax_expected_balance + assert avax_data["locked_amount"] == expected_locked_amount + + # Verify ETH Value + chain = Chain.ETH.value + eth_response = await ccn_api_client.get(f"{MESSAGES_URI}?chain={chain}") + assert eth_response.status == 200, await eth_response.text() + eth_data = await eth_response.json() + eth_expected_balance = user_balance_eth_avax.balance + assert eth_data["balance"] == eth_expected_balance + assert eth_data["locked_amount"] == expected_locked_amount + + # Verify All Chain + total_response = await ccn_api_client.get(f"{MESSAGES_URI}") + assert total_response.status == 200, await total_response.text() + total_data = await total_response.json() + total_expected_balance = user_balance_eth_avax.balance * 2 + assert total_data["balance"] == total_expected_balance + assert total_data["locked_amount"] == expected_locked_amount + + +@pytest.mark.asyncio +async def test_get_balance_with_no_balance( + ccn_api_client, +): + response = await ccn_api_client.get(f"{MESSAGES_URI}") + + assert response.status == 200, await response.text() + data = await response.json() + assert data["balance"] == 0 + assert data["locked_amount"] == 0 + + # Test Eth Case + response = await ccn_api_client.get(f"{MESSAGES_URI}?chain{Chain.ETH.value}") + + assert response.status == 200, await response.text() + data = await response.json() + assert data["balance"] == 0 + assert data["locked_amount"] == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 2a9370c96..3e238f3d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -352,3 +352,27 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: session.add(balance) session.commit() return balance + + +@pytest.fixture +def user_balance_eth_avax(session_factory: DbSessionFactory) -> AlephBalanceDb: + balance_eth = AlephBalanceDb( + address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", + chain=Chain.ETH, + balance=Decimal(22_192), + eth_height=0, + ) + + balance_avax = AlephBalanceDb( + address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", + chain=Chain.AVAX, + balance=Decimal(22_192), + eth_height=0, + ) + + with session_factory() as session: + session.add(balance_eth) + session.add(balance_avax) + + session.commit() + return balance_avax