Skip to content

Commit

Permalink
Fix mypy issues in missed packages
Browse files Browse the repository at this point in the history
  • Loading branch information
ulope committed May 15, 2019
1 parent f4d1d65 commit 4fa9a13
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 34 deletions.
2 changes: 2 additions & 0 deletions raiden/network/resolver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from eth_utils import to_bytes, to_hex

from raiden.raiden_service import RaidenService
from raiden.storage.wal import WriteAheadLog
from raiden.transfer.mediated_transfer.events import SendSecretRequest
from raiden.transfer.mediated_transfer.state_change import ReceiveSecretReveal

Expand All @@ -15,6 +16,7 @@ def reveal_secret_with_resolver(
if "resolver_endpoint" not in raiden.config:
return False

assert isinstance(raiden.wal, WriteAheadLog), "RaidenService has not been started"
current_state = raiden.wal.state_manager.current_state
task = current_state.payment_mapping.secrethashes_to_task[secret_request_event.secrethash]
token = task.target_state.transfer.token
Expand Down
8 changes: 6 additions & 2 deletions raiden/storage/migrations/v17_to_v18.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_token_network_by_identifier(
return None


def _transform_snapshot(raw_snapshot: Dict[Any, Any]) -> str:
def _transform_snapshot(raw_snapshot: str) -> str:
"""
This migration upgrades the object:
- `MediatorTransferState` such that a list of routes is added
Expand Down Expand Up @@ -48,7 +48,11 @@ def _transform_snapshot(raw_snapshot: Dict[Any, Any]) -> str:
token_network_identifier = transfer["balance_proof"]["token_network_identifier"]
token_network = get_token_network_by_identifier(snapshot, token_network_identifier)
channel_identifier = transfer["balance_proof"]["channel_identifier"]
channel = token_network.get("channelidentifiers_to_channels").get(channel_identifier)
channel = None
if token_network is not None:
channel = token_network.get("channelidentifiers_to_channels", {}).get(
channel_identifier
)
if not channel:
raise ChannelNotFound(
f"Upgrading to v18 failed. "
Expand Down
12 changes: 6 additions & 6 deletions raiden/storage/migrations/v18_to_v19.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BlockHashCache:

def __init__(self, web3: Web3):
self.web3 = web3
self.mapping = {}
self.mapping: Dict[BlockNumber, str] = {}

def get(self, block_number: BlockNumber) -> str:
"""Given a block number returns the hex representation of the blockhash"""
Expand All @@ -47,7 +47,7 @@ def _query_blocknumber_and_update_statechange_data(
) -> Tuple[str, int]:
data = record.data
data["block_hash"] = record.cache.get(record.block_number)
return (json.dumps(data), record.state_change_identifier)
return json.dumps(data), record.state_change_identifier


def _add_blockhash_to_state_changes(storage: SQLiteStorage, cache: BlockHashCache) -> None:
Expand All @@ -69,7 +69,7 @@ def _add_blockhash_to_state_changes(storage: SQLiteStorage, cache: BlockHashCach
data = json.loads(state_change.data)
assert "block_hash" not in data, "v18 state changes cant contain blockhash"
record = BlockQueryAndUpdateRecord(
block_number=int(data["block_number"]),
block_number=BlockNumber(int(data["block_number"])),
data=data,
state_change_identifier=state_change.state_change_identifier,
cache=cache,
Expand Down Expand Up @@ -112,7 +112,7 @@ def _add_blockhash_to_events(storage: SQLiteStorage, cache: BlockHashCache) -> N
if "block_hash" in statechange_data:
data["triggered_by_block_hash"] = statechange_data["block_hash"]
elif "block_number" in statechange_data:
block_number = int(statechange_data["block_number"])
block_number = BlockNumber(int(statechange_data["block_number"]))
data["triggered_by_block_hash"] = cache.get(block_number)

updated_events.append((json.dumps(data), event.event_identifier))
Expand All @@ -123,7 +123,7 @@ def _add_blockhash_to_events(storage: SQLiteStorage, cache: BlockHashCache) -> N
def _transform_snapshot(raw_snapshot: str, storage: SQLiteStorage, cache: BlockHashCache) -> str:
"""Upgrades a single snapshot by adding the blockhash to it and to any pending transactions"""
snapshot = json.loads(raw_snapshot)
block_number = int(snapshot["block_number"])
block_number = BlockNumber(int(snapshot["block_number"]))
snapshot["block_hash"] = cache.get(block_number)

pending_transactions = snapshot["pending_transactions"]
Expand Down Expand Up @@ -158,7 +158,7 @@ class TransformSnapshotRecord(NamedTuple):
cache: BlockHashCache


def _do_transform_snapshot(record: TransformSnapshotRecord) -> Tuple[Dict[str, Any], int]:
def _do_transform_snapshot(record: TransformSnapshotRecord) -> Tuple[str, int]:
new_snapshot = _transform_snapshot(
raw_snapshot=record.data, storage=record.storage, cache=record.cache
)
Expand Down
45 changes: 29 additions & 16 deletions raiden/storage/migrations/v19_to_v20.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
from eth_utils import to_canonical_address
from gevent.pool import Pool

from raiden import raiden_service # pylint: disable=unused-import
from raiden.constants import EMPTY_MERKLE_ROOT
from raiden.exceptions import RaidenUnrecoverableError
from raiden.network.proxies.utils import get_onchain_locksroots
from raiden.storage.sqlite import SQLiteStorage, StateChangeRecord
from raiden.transfer.identifiers import CanonicalIdentifier
from raiden.utils.serialization import serialize_bytes
from raiden.utils.typing import Any, Dict, Locksroot, Tuple

RaidenService = "RaidenService"
from raiden.utils.typing import (
Any,
ChainID,
ChannelID,
Dict,
Locksroot,
TokenNetworkAddress,
Tuple,
)

SOURCE_VERSION = 19
TARGET_VERSION = 20
Expand All @@ -31,7 +38,7 @@ def _find_channel_new_state_change(


def _get_onchain_locksroots(
raiden: RaidenService,
raiden: "raiden_service.RaidenService",
storage: SQLiteStorage,
token_network: Dict[str, Any],
channel: Dict[str, Any],
Expand All @@ -49,9 +56,9 @@ def _get_onchain_locksroots(
)

canonical_identifier = CanonicalIdentifier(
chain_identifier=-1,
token_network_address=to_canonical_address(token_network["address"]),
channel_identifier=int(channel["identifier"]),
chain_identifier=ChainID(-1),
token_network_address=TokenNetworkAddress(to_canonical_address(token_network["address"])),
channel_identifier=ChannelID(int(channel["identifier"])),
)

our_locksroot, partner_locksroot = get_onchain_locksroots(
Expand Down Expand Up @@ -96,7 +103,7 @@ def _add_onchain_locksroot_to_channel_new_state_changes(storage: SQLiteStorage,)


def _add_onchain_locksroot_to_channel_settled_state_changes(
raiden: RaidenService, storage: SQLiteStorage
raiden: "raiden_service.RaidenService", storage: SQLiteStorage
) -> None:
""" Adds `our_onchain_locksroot` and `partner_onchain_locksroot` to
ContractReceiveChannelSettled. """
Expand Down Expand Up @@ -134,9 +141,11 @@ def _add_onchain_locksroot_to_channel_settled_state_changes(
new_channel_state = channel_state_data["channel_state"]

canonical_identifier = CanonicalIdentifier(
chain_identifier=-1,
token_network_address=to_canonical_address(token_network_identifier),
channel_identifier=int(channel_identifier),
chain_identifier=ChainID(-1),
token_network_address=TokenNetworkAddress(
to_canonical_address(token_network_identifier)
),
channel_identifier=ChannelID(int(channel_identifier)),
)
our_locksroot, partner_locksroot = get_onchain_locksroots(
chain=raiden.chain,
Expand All @@ -156,8 +165,10 @@ def _add_onchain_locksroot_to_channel_settled_state_changes(


def _add_onchain_locksroot_to_snapshot(
raiden: RaidenService, storage: SQLiteStorage, snapshot_record: StateChangeRecord
) -> str:
raiden: "raiden_service.RaidenService",
storage: SQLiteStorage,
snapshot_record: StateChangeRecord,
) -> Tuple[str, int]:
"""
Add `onchain_locksroot` to each NettingChannelEndState
"""
Expand All @@ -175,10 +186,12 @@ def _add_onchain_locksroot_to_snapshot(
channel["our_state"]["onchain_locksroot"] = serialize_bytes(our_locksroot)
channel["partner_state"]["onchain_locksroot"] = serialize_bytes(partner_locksroot)

return json.dumps(snapshot, indent=4), snapshot_record.identifier
return json.dumps(snapshot, indent=4), snapshot_record.state_change_identifier


def _add_onchain_locksroot_to_snapshots(raiden: RaidenService, storage: SQLiteStorage) -> None:
def _add_onchain_locksroot_to_snapshots(
raiden: "raiden_service.RaidenService", storage: SQLiteStorage
) -> None:
snapshots = storage.get_snapshots()

transform_func = partial(_add_onchain_locksroot_to_snapshot, raiden, storage)
Expand All @@ -196,7 +209,7 @@ def upgrade_v19_to_v20( # pylint: disable=unused-argument
storage: SQLiteStorage,
old_version: int,
current_version: int, # pylint: disable=unused-argument
raiden: "RaidenService",
raiden: "raiden_service.RaidenService",
**kwargs, # pylint: disable=unused-argument
) -> int:
if old_version == SOURCE_VERSION:
Expand Down
20 changes: 10 additions & 10 deletions raiden/storage/migrations/v21_to_v22.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Tuple, TypeVar

from eth_utils import to_checksum_address

Expand All @@ -16,15 +16,15 @@

BATCH_UNLOCK = "raiden.transfer.state_change.ContractReceiveChannelBatchUnlock"

SPELLING_VARS_TOKEN_NETWORK = (
SPELLING_VARS_TOKEN_NETWORK = [
"token_network_address",
"token_network_id",
"token_network_identifier",
)
]

SPELLING_VARS_CHANNEL = ("channel_identifier", "channel_id", "identifier")
SPELLING_VARS_CHANNEL = ["channel_identifier", "channel_id", "identifier"]

SPELLING_VARS_CHAIN = ("chain_id", "chain_identifier")
SPELLING_VARS_CHAIN = ["chain_id", "chain_identifier"]


# these are missing the chain-id
Expand Down Expand Up @@ -224,8 +224,8 @@ def _add_canonical_identifier_to_statechanges(
our_address = str(to_checksum_address(raiden.address)).lower()

for state_change_batch in storage.batch_query_state_changes(batch_size=500):
updated_state_changes = list()
delete_state_changes = list()
updated_state_changes: List[Tuple[str, int]] = list()
delete_state_changes: List[int] = list()

for state_change_record in state_change_batch:
state_change_obj = json.loads(state_change_record.data)
Expand All @@ -236,16 +236,16 @@ def _add_canonical_identifier_to_statechanges(
)

if should_delete:
delete_state_changes.append(state_change_record.identifier)
delete_state_changes.append(state_change_record.state_change_identifier)
else:
channel_id = None
channel_id: Optional[int] = None
if is_unlock:
channel_id = resolve_channel_id_for_unlock(
storage, state_change_obj, our_address
)
walk_dicts(
state_change_obj,
lambda obj, channel_id=channel_id: upgrade_object(obj, chain_id, channel_id),
lambda obj, channel_id_=channel_id: upgrade_object(obj, chain_id, channel_id_),
)

walk_dicts(state_change_obj, constraint_has_canonical_identifier_or_values_removed)
Expand Down

0 comments on commit 4fa9a13

Please sign in to comment.