diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 994aa52..83909b5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,9 +23,10 @@ jobs: python -m pip install --upgrade pip pip install -r requirements-dev.txt pip install -e . - - name: Lint with flake8 + - name: Lint and check formatting with ruff run: | - flake8 . --count --show-source --statistics + ruff check src/grpc_requests/*.py src/tests/*.py --statistics + ruff format src/grpc_requests/*.py src/tests/*.py --check - name: Test with pytest run: | pytest --cov-report=xml --cov=src/grpc_requests diff --git a/README.md b/README.md index 9303f18..7204a53 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # grpc_requests +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![PyPI](https://img.shields.io/pypi/v/grpc-requests?style=flat-square)](https://pypi.org/project/grpc-requests) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/grpc-requests?style=flat-square)](https://pypi.org/project/grpc-requests) [![PyPI download month](https://img.shields.io/pypi/dm/grpc-requests?style=flat-square)](https://pypi.org/project/grpc-requests) diff --git a/requirements-dev.txt b/requirements-dev.txt index 98025cd..8b622b1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ pytest-cov>=4.0.0 pytest-asyncio>=0.15.1 aiounittest>=1.4.2 grpc-interceptor>=0.15.4 +ruff>=0.1.7 \ No newline at end of file diff --git a/src/grpc_requests/__init__.py b/src/grpc_requests/__init__.py index de38559..ae143a1 100644 --- a/src/grpc_requests/__init__.py +++ b/src/grpc_requests/__init__.py @@ -1,4 +1,10 @@ -from .aio import AsyncClient, ReflectionAsyncClient, StubAsyncClient, get_by_endpoint as async_get_by_endpoint +# ruff: noqa: F401 +from .aio import ( + AsyncClient, + ReflectionAsyncClient, + StubAsyncClient, + get_by_endpoint as async_get_by_endpoint, +) from .client import Client, ReflectionClient, StubClient, get_by_endpoint __version__ = "0.1.13" diff --git a/src/grpc_requests/aio.py b/src/grpc_requests/aio.py index e8e7c04..d1063b4 100644 --- a/src/grpc_requests/aio.py +++ b/src/grpc_requests/aio.py @@ -2,10 +2,25 @@ import sys from enum import Enum from functools import partial -from typing import Any, AsyncIterable, Dict, Iterable, List, NamedTuple, Optional, Tuple, TypeVar +from typing import ( + Any, + AsyncIterable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, +) import grpc -from google.protobuf import descriptor_pb2, descriptor_pool as _descriptor_pool, symbol_database as _symbol_database, message_factory # noqa: E501 +from google.protobuf import ( + descriptor_pb2, + descriptor_pool as _descriptor_pool, + symbol_database as _symbol_database, + message_factory, +) # noqa: E501 from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from google.protobuf.descriptor_pb2 import ServiceDescriptorProto from google.protobuf.json_format import MessageToDict, ParseDict @@ -27,15 +42,20 @@ def get_metadata(package_name: str): def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version + # Import GetMessageClass if protobuf version supports it -protobuf_version = get_metadata('protobuf').split('.') -get_message_class_supported = int(protobuf_version[0]) >= 4 and int(protobuf_version[1]) >= 22 +protobuf_version = get_metadata("protobuf").split(".") +get_message_class_supported = ( + int(protobuf_version[0]) >= 4 and int(protobuf_version[1]) >= 22 +) if get_message_class_supported: from google.protobuf.message_factory import GetMessageClass class DescriptorImport: - def __init__(self, ): + def __init__( + self, + ): pass @@ -55,8 +75,18 @@ def reflection_request(channel, requests): class BaseAsyncClient: - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_options=None, ssl=False, - compression=None, credentials: Optional[CredentialsInfo] = None, interceptors=None, **kwargs): + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + channel_options=None, + ssl=False, + compression=None, + credentials: Optional[CredentialsInfo] = None, + interceptors=None, + **kwargs, + ): self.endpoint = endpoint self._symbol_db = symbol_db or _symbol_database.Default() self._desc_pool = descriptor_pool or _descriptor_pool.Default() @@ -70,15 +100,21 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_optio for k, v in credentials.items() } - self._channel = grpc.aio.secure_channel(endpoint, grpc.ssl_channel_credentials(**_credentials), - options=self.channel_options, - compression=self.compression, - interceptors=interceptors) + self._channel = grpc.aio.secure_channel( + endpoint, + grpc.ssl_channel_credentials(**_credentials), + options=self.channel_options, + compression=self.compression, + interceptors=interceptors, + ) else: - self._channel = grpc.aio.insecure_channel(endpoint, options=self.channel_options, - compression=self.compression, - interceptors=interceptors) + self._channel = grpc.aio.insecure_channel( + endpoint, + options=self.channel_options, + compression=self.compression, + interceptors=interceptors, + ) @property def channel(self): @@ -133,14 +169,14 @@ async def parse_stream_responses(responses: AsyncIterable): class MethodType(Enum): - UNARY_UNARY = 'unary_unary' - STREAM_UNARY = 'stream_unary' - UNARY_STREAM = 'unary_stream' - STREAM_STREAM = 'stream_stream' + UNARY_UNARY = "unary_unary" + STREAM_UNARY = "stream_unary" + UNARY_STREAM = "unary_stream" + STREAM_STREAM = "stream_stream" @property def is_unary_request(self): - return 'unary_' in self.value + return "unary_" in self.value @property def request_parser(self): @@ -148,7 +184,7 @@ def request_parser(self): @property def is_unary_response(self): - return '_unary' in self.value + return "_unary" in self.value @property def response_parser(self): @@ -174,10 +210,23 @@ class MethodMetaData(NamedTuple): class BaseAsyncGrpcClient(BaseAsyncClient): - - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, compression=compression, **kwargs) + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + compression=compression, + **kwargs, + ) self._service_names: list = None self.has_server_registered = False self._services_module_name = {} @@ -191,31 +240,44 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, ssl=False, co async def _get_service_names(self): raise NotImplementedError() - async def check_method_available(self, service, method, method_type: MethodType = None): + async def check_method_available( + self, service, method, method_type: MethodType = None + ): if not self.has_server_registered: await self.register_all_service() methods_meta = self._service_methods_meta.get(service) if not methods_meta: service_names = await self.service_names() raise ValueError( - self.endpoint + " server doesn't support " + service + ". Available services " + str(service_names)) + self.endpoint + + " server doesn't support " + + service + + ". Available services " + + str(service_names) + ) if method not in methods_meta: raise ValueError( - f"{service} doesn't support {method} method. Available methods {methods_meta.keys()}") + f"{service} doesn't support {method} method. Available methods {methods_meta.keys()}" + ) if method_type and method_type != methods_meta[method].method_type: raise ValueError( - f"{method} is {methods_meta[method].method_type.value} not {method_type.value}") + f"{method} is {methods_meta[method].method_type.value} not {method_type.value}" + ) return True - def _register_methods(self, service_descriptor: ServiceDescriptor) -> Dict[str, MethodMetaData]: + def _register_methods( + self, service_descriptor: ServiceDescriptor + ) -> Dict[str, MethodMetaData]: svc_desc_proto = ServiceDescriptorProto() service_descriptor.CopyToProto(svc_desc_proto) service_full_name = service_descriptor.full_name metadata: Dict[str, MethodMetaData] = {} for method_proto in svc_desc_proto.method: method_name = method_proto.name - method_desc: MethodDescriptor = service_descriptor.methods_by_name[method_name] + method_desc: MethodDescriptor = service_descriptor.methods_by_name[ + method_name + ] if get_message_class_supported: input_type = GetMessageClass(method_desc.input_type) @@ -225,19 +287,21 @@ def _register_methods(self, service_descriptor: ServiceDescriptor) -> Dict[str, input_type = msg_factory.GetPrototype(method_desc.input_type) output_type = msg_factory.GetPrototype(method_desc.output_type) - method_type = MethodTypeMatch[(method_proto.client_streaming, method_proto.server_streaming)] + method_type = MethodTypeMatch[ + (method_proto.client_streaming, method_proto.server_streaming) + ] method_register_func = getattr(self.channel, method_type.value) handler = method_register_func( method=self._make_method_full_name(service_full_name, method_name), request_serializer=input_type.SerializeToString, - response_deserializer=output_type.FromString + response_deserializer=output_type.FromString, ) metadata[method_name] = MethodMetaData( method_type=method_type, input_type=input_type, output_type=output_type, - handler=handler + handler=handler, ) return metadata @@ -258,7 +322,10 @@ async def service_names(self): return self._service_names async def get_methods_meta(self, service_name: str): - if service_name in await self.service_names() and service_name not in self._service_methods_meta: + if ( + service_name in await self.service_names() + and service_name not in self._service_methods_meta + ): await self.register_service(service_name) try: @@ -274,7 +341,9 @@ async def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser(request, method_meta.input_type) + _request = method_meta.method_type.request_parser( + request, method_meta.input_type + ) if method_meta.method_type.is_unary_response: result = await method_meta.handler(_request, **kwargs) @@ -290,11 +359,15 @@ async def request(self, service, method, request=None, raw_output=False, **kwarg await self.check_method_available(service, method) return await self._request(service, method, request, raw_output, **kwargs) - async def unary_unary(self, service, method, request=None, raw_output=False, **kwargs): + async def unary_unary( + self, service, method, request=None, raw_output=False, **kwargs + ): await self.check_method_available(service, method, MethodType.UNARY_UNARY) return await self._request(service, method, request, raw_output, **kwargs) - async def unary_stream(self, service, method, request=None, raw_output=False, **kwargs): + async def unary_stream( + self, service, method, request=None, raw_output=False, **kwargs + ): await self.check_method_available(service, method, MethodType.UNARY_STREAM) return await self._request(service, method, request, raw_output, **kwargs) @@ -302,7 +375,9 @@ async def stream_unary(self, service, method, requests, raw_output=False, **kwar await self.check_method_available(service, method, MethodType.STREAM_UNARY) return await self._request(service, method, requests, raw_output, **kwargs) - async def stream_stream(self, service, method, requests, raw_output=False, **kwargs): + async def stream_stream( + self, service, method, requests, raw_output=False, **kwargs + ): await self.check_method_available(service, method, MethodType.STREAM_STREAM) return await self._request(service, method, requests, raw_output, **kwargs) @@ -320,9 +395,9 @@ def get_method_meta(self, service: str, method: str) -> MethodMetaData: def make_handler_argument(self, service: str, method: str): data_type = self.get_method_meta(service, method) return { - 'method': self._make_method_full_name(service, method), - 'request_serializer': data_type.input_type.SerializeToString, - 'response_deserializer': data_type.output_type.FromString, + "method": self._make_method_full_name(service, method), + "request_serializer": data_type.input_type.SerializeToString, + "response_deserializer": data_type.output_type.FromString, } async def service(self, name): @@ -330,14 +405,31 @@ async def service(self, name): if name in available_services: return await ServiceClient.create(client=self, service_name=name) else: - raise ValueError(name + " is not supported. Available services are: " + str(available_services)) + raise ValueError( + name + + " is not supported. Available services are: " + + str(available_services) + ) class ReflectionAsyncClient(BaseAsyncGrpcClient): - - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, compression=compression, **kwargs) + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + compression=compression, + **kwargs, + ) self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel) def _reflection_request(self, *requests): @@ -369,7 +461,7 @@ async def _get_file_descriptor_by_symbol(self, symbol): def _is_descriptor_registered(self, filename): try: self._desc_pool.FindFileByName(filename) - logger.debug(f'{filename} already registered') + logger.debug(f"{filename} already registered") return True except KeyError: return False @@ -378,7 +470,9 @@ async def _register_file_descriptor(self, file_descriptor): if not self._is_descriptor_registered(file_descriptor.name): logger.debug(f"start {file_descriptor.name} register") dependencies = list(file_descriptor.dependency) - logger.debug(f"found {len(dependencies)} dependencies for {file_descriptor.name}") + logger.debug( + f"found {len(dependencies)} dependencies for {file_descriptor.name}" + ) for dep_file_name in dependencies: if not self._is_descriptor_registered(dep_file_name): dep_desc = await self._get_file_descriptor_by_name(dep_file_name) @@ -386,13 +480,15 @@ async def _register_file_descriptor(self, file_descriptor): try: self._desc_pool.Add(file_descriptor) except TypeError: - logger.debug(f"{file_descriptor.name} already present in pool. Skipping.") + logger.debug( + f"{file_descriptor.name} already present in pool. Skipping." + ) logger.debug(f"{file_descriptor.name} registration complete") def _is_service_registered(self, service_name): try: self._desc_pool.FindServiceByName(service_name) - logger.debug(f'{service_name} already registered') + logger.debug(f"{service_name} already registered") return True except KeyError: return False @@ -407,11 +503,24 @@ async def register_service(self, service_name): class StubAsyncClient(BaseAsyncGrpcClient): - - def __init__(self, endpoint, service_descriptors: List[ServiceDescriptor], symbol_db=None, - descriptor_pool=None, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, compression=compression, **kwargs) + def __init__( + self, + endpoint, + service_descriptors: List[ServiceDescriptor], + symbol_db=None, + descriptor_pool=None, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + compression=compression, + **kwargs, + ) self.service_descriptors = service_descriptors async def _get_service_names(self): @@ -462,7 +571,9 @@ def get_by_endpoint(endpoint, service_descriptors=None, **kwargs) -> AsyncClient global _cached_clients if endpoint not in _cached_clients: if service_descriptors: - _cached_clients[endpoint] = StubAsyncClient(endpoint, service_descriptors=service_descriptors, **kwargs) + _cached_clients[endpoint] = StubAsyncClient( + endpoint, service_descriptors=service_descriptors, **kwargs + ) else: _cached_clients[endpoint] = AsyncClient(endpoint, **kwargs) return _cached_clients[endpoint] diff --git a/src/grpc_requests/client.py b/src/grpc_requests/client.py index 9b237f2..30afd06 100644 --- a/src/grpc_requests/client.py +++ b/src/grpc_requests/client.py @@ -3,8 +3,17 @@ import warnings from enum import Enum from functools import partial -from typing import (Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, - TypeVar, Union) +from typing import ( + Any, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) import grpc from google.protobuf import descriptor_pb2 @@ -30,9 +39,12 @@ def get_metadata(package_name: str): def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version + # Import GetMessageClass if protobuf version supports it -protobuf_version = get_metadata('protobuf').split('.') -get_message_class_supported = int(protobuf_version[0]) >= 4 and int(protobuf_version[1]) >= 22 +protobuf_version = get_metadata("protobuf").split(".") +get_message_class_supported = ( + int(protobuf_version[0]) >= 4 and int(protobuf_version[1]) >= 22 +) if get_message_class_supported: from google.protobuf.message_factory import GetMessageClass @@ -40,7 +52,9 @@ def get_metadata(package_name: str): class DescriptorImport: - def __init__(self, ): + def __init__( + self, + ): pass @@ -69,8 +83,18 @@ class CredentialsInfo(TypedDict): class BaseClient: - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_options=None, ssl=False, - compression=None, credentials: Optional[CredentialsInfo] = None, interceptors=None, **kwargs): + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + channel_options=None, + ssl=False, + compression=None, + credentials: Optional[CredentialsInfo] = None, + interceptors=None, + **kwargs, + ): self.endpoint = endpoint self._desc_pool = descriptor_pool or _descriptor_pool.Default() self.compression = compression @@ -83,11 +107,16 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_optio for k, v in credentials.items() } - self._channel = grpc.secure_channel(endpoint, grpc.ssl_channel_credentials(**_credentials), - options=self.channel_options, - compression=self.compression) + self._channel = grpc.secure_channel( + endpoint, + grpc.ssl_channel_credentials(**_credentials), + options=self.channel_options, + compression=self.compression, + ) else: - self._channel = grpc.insecure_channel(endpoint, options=self.channel_options, compression=self.compression) + self._channel = grpc.insecure_channel( + endpoint, options=self.channel_options, compression=self.compression + ) if interceptors: self._channel = grpc.intercept_channel(self._channel, *interceptors) @@ -107,7 +136,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): try: self._channel._close() except Exception as e: # pylint: disable=bare-except - logger.warning('can not close channel', exc_info=e) + logger.warning("can not close channel", exc_info=e) return False def __del__(self): @@ -115,7 +144,7 @@ def __del__(self): try: del self._channel except Exception as e: # pylint: disable=bare-except - logger.warning('can not delete channel', exc_info=e) + logger.warning("can not delete channel", exc_info=e) def parse_request_data(request_data, input_type): @@ -139,14 +168,14 @@ def parse_stream_responses(responses: Iterable): class MethodType(Enum): - UNARY_UNARY = 'unary_unary' - STREAM_UNARY = 'stream_unary' - UNARY_STREAM = 'unary_stream' - STREAM_STREAM = 'stream_stream' + UNARY_UNARY = "unary_unary" + STREAM_UNARY = "stream_unary" + UNARY_STREAM = "unary_stream" + STREAM_STREAM = "stream_stream" @property def is_unary_request(self): - return 'unary_' in self.value + return "unary_" in self.value @property def request_parser(self): @@ -154,7 +183,7 @@ def request_parser(self): @property def is_unary_response(self): - return '_unary' in self.value + return "_unary" in self.value @property def response_parser(self): @@ -180,10 +209,24 @@ class MethodMetaData(NamedTuple): class BaseGrpcClient(BaseClient): - - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, lazy=False, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, compression=compression, **kwargs) + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + lazy=False, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + compression=compression, + **kwargs, + ) self._service_names: list = None self._lazy = lazy self.has_server_registered = False @@ -206,24 +249,31 @@ def check_method_available(self, service, method, method_type: MethodType = None print(methods_meta) if not methods_meta: raise ValueError( - f"{self.endpoint} server doesn't support {service}. Available services {self.service_names}") + f"{self.endpoint} server doesn't support {service}. Available services {self.service_names}" + ) if method not in methods_meta: raise ValueError( - f"{service} doesn't support {method} method. Available methods {methods_meta.keys()}") + f"{service} doesn't support {method} method. Available methods {methods_meta.keys()}" + ) if method_type and method_type != methods_meta[method].method_type: raise ValueError( - f"{method} is {methods_meta[method].method_type.value} not {method_type.value}") + f"{method} is {methods_meta[method].method_type.value} not {method_type.value}" + ) return True - def _register_methods(self, service_descriptor: ServiceDescriptor) -> Dict[str, MethodMetaData]: + def _register_methods( + self, service_descriptor: ServiceDescriptor + ) -> Dict[str, MethodMetaData]: svc_desc_proto = ServiceDescriptorProto() service_descriptor.CopyToProto(svc_desc_proto) service_full_name = service_descriptor.full_name metadata: Dict[str, MethodMetaData] = {} for method_proto in svc_desc_proto.method: method_name = method_proto.name - method_desc: MethodDescriptor = service_descriptor.methods_by_name[method_name] + method_desc: MethodDescriptor = service_descriptor.methods_by_name[ + method_name + ] if get_message_class_supported: input_type = GetMessageClass(method_desc.input_type) @@ -233,19 +283,21 @@ def _register_methods(self, service_descriptor: ServiceDescriptor) -> Dict[str, input_type = msg_factory.GetPrototype(method_desc.input_type) output_type = msg_factory.GetPrototype(method_desc.output_type) - method_type = MethodTypeMatch[(method_proto.client_streaming, method_proto.server_streaming)] + method_type = MethodTypeMatch[ + (method_proto.client_streaming, method_proto.server_streaming) + ] method_register_func = getattr(self.channel, method_type.value) handler = method_register_func( method=self._make_method_full_name(service_full_name, method_name), request_serializer=input_type.SerializeToString, - response_deserializer=output_type.FromString + response_deserializer=output_type.FromString, ) metadata[method_name] = MethodMetaData( method_type=method_type, input_type=input_type, output_type=output_type, - handler=handler + handler=handler, ) return metadata @@ -255,7 +307,9 @@ def register_service(self, service_name): svc_desc = self._desc_pool.FindServiceByName(service_name) self._service_methods_meta[service_name] = self._register_methods(svc_desc) except KeyError: - logger.debug(f"{service_name} not found in descriptor pool, methods will not be registered") + logger.debug( + f"{service_name} not found in descriptor pool, methods will not be registered" + ) logger.debug(f"end {service_name} registration") def register_all_service(self): @@ -270,8 +324,11 @@ def service_names(self): return self._service_names def get_methods_meta(self, service_name: str): - - if self._lazy and service_name in self.service_names and service_name not in self._service_methods_meta: + if ( + self._lazy + and service_name in self.service_names + and service_name not in self._service_methods_meta + ): self.register_service(service_name) try: @@ -287,7 +344,9 @@ def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser(request, method_meta.input_type) + _request = method_meta.method_type.request_parser( + request, method_meta.input_type + ) result = method_meta.handler(_request, **kwargs) if raw_output: @@ -321,15 +380,19 @@ def get_service_descriptor(self, service): def describe_method_request(self, service, method): warnings.warn( "This function is deprecated, and will be removed in a future release. Use describe_request() instead.", - DeprecationWarning + DeprecationWarning, ) return describe_request(self.get_method_descriptor(service, method)) def describe_request(self, service, method): - return describe_descriptor(self.get_method_descriptor(service, method).input_type) + return describe_descriptor( + self.get_method_descriptor(service, method).input_type + ) def describe_response(self, service, method): - return describe_descriptor(self.get_method_descriptor(service, method).output_type) + return describe_descriptor( + self.get_method_descriptor(service, method).output_type + ) def get_method_descriptor(self, service, method): svc_desc = self.get_service_descriptor(service) @@ -342,23 +405,40 @@ def get_method_meta(self, service: str, method: str) -> MethodMetaData: def make_handler_argument(self, service: str, method: str): data_type = self.get_method_meta(service, method) return { - 'method': self._make_method_full_name(service, method), - 'request_serializer': data_type.input_type.SerializeToString, - 'response_deserializer': data_type.output_type.FromString, + "method": self._make_method_full_name(service, method), + "request_serializer": data_type.input_type.SerializeToString, + "response_deserializer": data_type.output_type.FromString, } def service(self, name): if name in self.service_names: return ServiceClient(client=self, service_name=name) else: - raise ValueError(f"{name} is not a supported service. Available services are {self.service_names}") + raise ValueError( + f"{name} is not a supported service. Available services are {self.service_names}" + ) class ReflectionClient(BaseGrpcClient): - - def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, lazy=False, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, lazy=lazy, compression=compression, **kwargs) + def __init__( + self, + endpoint, + symbol_db=None, + descriptor_pool=None, + lazy=False, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + lazy=lazy, + compression=compression, + **kwargs, + ) self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel) if not self._lazy: self.register_all_service() @@ -370,7 +450,7 @@ def _reflection_request(self, *requests): def _reflection_single_request(self, request): results = list(self._reflection_request(request)) if len(results) > 1: - raise ValueError('response has more than one result') + raise ValueError("response has more than one result") return results[0] def _get_service_names(self): @@ -394,7 +474,7 @@ def _get_file_descriptor_by_symbol(self, symbol): def _is_descriptor_registered(self, filename): try: self._desc_pool.FindFileByName(filename) - logger.debug(f'{filename} already registered') + logger.debug(f"{filename} already registered") return True except KeyError: return False @@ -403,20 +483,24 @@ def _register_file_descriptor(self, file_descriptor): if not self._is_descriptor_registered(file_descriptor.name): logger.debug(f"start {file_descriptor.name} register") dependencies = list(file_descriptor.dependency) - logger.debug(f"found {len(dependencies)} dependencies for {file_descriptor.name}") + logger.debug( + f"found {len(dependencies)} dependencies for {file_descriptor.name}" + ) for dep_file_name in dependencies: dep_desc = self._get_file_descriptor_by_name(dep_file_name) self._register_file_descriptor(dep_desc) try: self._desc_pool.Add(file_descriptor) except TypeError: - logger.debug(f"{file_descriptor.name} already present in pool. Skipping.") + logger.debug( + f"{file_descriptor.name} already present in pool. Skipping." + ) logger.debug(f"end {file_descriptor.name} registration complete") def _is_service_registered(self, service_name): try: self._desc_pool.FindServiceByName(service_name) - logger.debug(f'{service_name} already registered') + logger.debug(f"{service_name} already registered") return True except KeyError: return False @@ -431,11 +515,26 @@ def register_service(self, service_name): class StubClient(BaseGrpcClient): - - def __init__(self, endpoint, service_descriptors: List[ServiceDescriptor], symbol_db=None, lazy=False, - descriptor_pool=None, ssl=False, compression=None, - **kwargs): - super().__init__(endpoint, symbol_db, descriptor_pool, ssl=ssl, compression=compression, lazy=lazy, **kwargs) + def __init__( + self, + endpoint, + service_descriptors: List[ServiceDescriptor], + symbol_db=None, + lazy=False, + descriptor_pool=None, + ssl=False, + compression=None, + **kwargs, + ): + super().__init__( + endpoint, + symbol_db, + descriptor_pool, + ssl=ssl, + compression=compression, + lazy=lazy, + **kwargs, + ) self.service_descriptors = service_descriptors if not self._lazy: @@ -481,7 +580,9 @@ def get_by_endpoint(endpoint, service_descriptors=None, **kwargs) -> Client: global _cached_clients if endpoint not in _cached_clients: if service_descriptors: - _cached_clients[endpoint] = StubClient(endpoint, service_descriptors=service_descriptors, **kwargs) + _cached_clients[endpoint] = StubClient( + endpoint, service_descriptors=service_descriptors, **kwargs + ) else: _cached_clients[endpoint] = Client(endpoint, **kwargs) return _cached_clients[endpoint] diff --git a/src/grpc_requests/utils.py b/src/grpc_requests/utils.py index 41e3607..2502741 100644 --- a/src/grpc_requests/utils.py +++ b/src/grpc_requests/utils.py @@ -1,35 +1,42 @@ from pathlib import Path -from google.protobuf.descriptor import Descriptor, EnumDescriptor, MethodDescriptor, OneofDescriptor +from google.protobuf.descriptor import ( + Descriptor, + EnumDescriptor, + MethodDescriptor, + OneofDescriptor, +) import warnings # String descriptions of protobuf field types FIELD_TYPES = [ - 'DOUBLE', - 'FLOAT', - 'INT64', - 'UINT64', - 'INT32', - 'FIXED64', - 'FIXED32', - 'BOOL', - 'STRING', - 'GROUP', - 'MESSAGE', - 'BYTES', - 'UINT32', - 'ENUM', - 'SFIXED32', - 'SFIXED64', - 'SINT32', - 'SINT64' + "DOUBLE", + "FLOAT", + "INT64", + "UINT64", + "INT32", + "FIXED64", + "FIXED32", + "BOOL", + "STRING", + "GROUP", + "MESSAGE", + "BYTES", + "UINT32", + "ENUM", + "SFIXED32", + "SFIXED64", + "SINT32", + "SINT64", ] + def load_data(_path): - with open(Path(_path).expanduser(), 'rb') as f: + with open(Path(_path).expanduser(), "rb") as f: data = f.read() return data + def describe_request(method_descriptor: MethodDescriptor) -> dict: """ Provide a dictionary that describes the fields of a Method request @@ -39,13 +46,14 @@ def describe_request(method_descriptor: MethodDescriptor) -> dict: """ warnings.warn( "This function is deprecated, and will be removed in a future release. Use describe_descriptor() instead.", - DeprecationWarning + DeprecationWarning, ) description = {} for field in method_descriptor.input_type.fields: - description[field.name] = FIELD_TYPES[field.type-1] + description[field.name] = FIELD_TYPES[field.type - 1] return description + def describe_descriptor(descriptor: Descriptor, indent: int = 0) -> str: """ Prints a human readable description of a protobuf descriptor. @@ -58,7 +66,7 @@ def describe_descriptor(descriptor: Descriptor, indent: int = 0) -> str: if descriptor.enum_types: description += f"\n{padding}Enums:" for enum in descriptor.enum_types: - description += describe_enum_descriptor(enum, indent+1) + description += describe_enum_descriptor(enum, indent + 1) if descriptor.fields: description += f"\n{padding}Fields:" @@ -68,10 +76,11 @@ def describe_descriptor(descriptor: Descriptor, indent: int = 0) -> str: if descriptor.oneofs: description += f"\n{padding}Oneofs:" for oneof in descriptor.oneofs: - description += describe_oneof_descriptor(oneof, indent+1) + description += describe_oneof_descriptor(oneof, indent + 1) return description + def describe_enum_descriptor(enum_descriptor: EnumDescriptor, indent: int = 0) -> str: """ Prints a human readable description of a protobuf enum descriptor. @@ -84,7 +93,10 @@ def describe_enum_descriptor(enum_descriptor: EnumDescriptor, indent: int = 0) - description += f"\n{padding}{value.name} = {value.number}" return description -def describe_oneof_descriptor(oneof_descriptor: OneofDescriptor, indent: int = 0) -> str: + +def describe_oneof_descriptor( + oneof_descriptor: OneofDescriptor, indent: int = 0 +) -> str: """ Prints a human readable description of a protobuf oneof descriptor. :param oneof_descriptor: OneofDescriptor - a protobuf oneof descriptor diff --git a/src/tests/async_reflection_client_test.py b/src/tests/async_reflection_client_test.py index eb12bd9..83398f8 100644 --- a/src/tests/async_reflection_client_test.py +++ b/src/tests/async_reflection_client_test.py @@ -10,92 +10,124 @@ Test cases for async reflection based client """ -logger = logging.getLogger('name') +logger = logging.getLogger("name") + @pytest.mark.asyncio async def test_unary_unary(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") response = await greeter_service.SayHello({"name": "sinsky"}) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky!"} + @pytest.mark.asyncio async def test_unary_unary_interceptor(): - client = AsyncClient('localhost:50051', interceptors=[AsyncMetadataClientInterceptor()]) - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient( + "localhost:50051", interceptors=[AsyncMetadataClientInterceptor()] + ) + greeter_service = await client.service("helloworld.Greeter") response = await greeter_service.SayHello({"name": "sinsky"}) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky, interceptor accepted!"} + @pytest.mark.asyncio async def test_methods_meta(): - client = AsyncClient('localhost:50051', interceptors=[AsyncMetadataClientInterceptor()]) - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient( + "localhost:50051", interceptors=[AsyncMetadataClientInterceptor()] + ) + greeter_service = await client.service("helloworld.Greeter") meta = greeter_service.methods_meta - assert meta['HelloEveryone'].method_type == MethodType.STREAM_UNARY + assert meta["HelloEveryone"].method_type == MethodType.STREAM_UNARY + @pytest.mark.asyncio async def test_empty_body_request(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") response = await greeter_service.SayHello({}) assert isinstance(response, dict) + @pytest.mark.asyncio async def test_nonexistent_method(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") with pytest.raises(AttributeError): await greeter_service.SayGoodbye({}) + @pytest.mark.asyncio async def test_unsupported_argument(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") with pytest.raises(ParseError): await greeter_service.SayHello({"foo": "bar"}) + @pytest.mark.asyncio async def test_unary_stream(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") name_list = ["sinsky", "viridianforge", "jack", "harry"] - responses = [x async for x in await greeter_service.SayHelloGroup( - [{"name": name} for name in name_list] - )] + responses = [ + x + async for x in await greeter_service.SayHelloGroup( + [{"name": name} for name in name_list] + ) + ] assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): assert response == {"message": f"Hello, {name}!"} + @pytest.mark.asyncio async def test_stream_unary(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") name_list = ["sinsky", "viridianforge", "jack", "harry"] - response = await greeter_service.HelloEveryone([{"name": name} for name in name_list]) + response = await greeter_service.HelloEveryone( + [{"name": name} for name in name_list] + ) assert isinstance(response, dict) - assert response == {'message': f'Hello, {" ".join(["sinsky", "viridianforge", "jack", "harry"])}!'} + assert response == { + "message": f'Hello, {" ".join(["sinsky", "viridianforge", "jack", "harry"])}!' + } + @pytest.mark.asyncio async def test_stream_stream(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") name_list = ["sinsky", "viridianforge", "jack", "harry"] - responses = [x async for x in await greeter_service.SayHelloOneByOne([{"name": name} for name in name_list])] + responses = [ + x + async for x in await greeter_service.SayHelloOneByOne( + [{"name": name} for name in name_list] + ) + ] assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): assert response == {"message": f"Hello {name}"} + @pytest.mark.asyncio async def test_reflection_service_client(): - client = AsyncClient('localhost:50051') - greeter_service = await client.service('helloworld.Greeter') + client = AsyncClient("localhost:50051") + greeter_service = await client.service("helloworld.Greeter") method_names = greeter_service.method_names - assert method_names == ('SayHello', 'SayHelloGroup', 'HelloEveryone', 'SayHelloOneByOne') + assert method_names == ( + "SayHello", + "SayHelloGroup", + "HelloEveryone", + "SayHelloOneByOne", + ) + @pytest.mark.asyncio async def test_reflection_service_client_invalid_service(): - client = AsyncClient('localhost:50051') + client = AsyncClient("localhost:50051") with pytest.raises(ValueError): - await client.service('helloWorld.Singer') + await client.service("helloWorld.Singer") diff --git a/src/tests/common.py b/src/tests/common.py index b011763..8dff9e5 100644 --- a/src/tests/common.py +++ b/src/tests/common.py @@ -23,8 +23,8 @@ def intercept_unary_unary( return continuation(new_details, request) -class AsyncMetadataClientInterceptor(grpc.aio.UnaryUnaryClientInterceptor): +class AsyncMetadataClientInterceptor(grpc.aio.UnaryUnaryClientInterceptor): async def intercept_unary_unary(self, continuation, client_call_details, request): new_details = grpc.aio.ClientCallDetails( client_call_details.method, diff --git a/src/tests/conftest.py b/src/tests/conftest.py index efc6c81..de0e7ed 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -7,24 +7,31 @@ def helloworld_server_starter(): - server = HelloWorldServer('50051') + server = HelloWorldServer("50051") server.serve() + def client_tester_server_starter(): - server = ClientTesterServer('50052') + server = ClientTesterServer("50052") server.serve() + @pytest.fixture(scope="session", autouse=True) def helloworld_server(): - helloworld_server_process = multiprocessing.Process(target=helloworld_server_starter) + helloworld_server_process = multiprocessing.Process( + target=helloworld_server_starter + ) helloworld_server_process.start() time.sleep(1) yield helloworld_server_process.terminate() + @pytest.fixture(scope="session", autouse=True) def client_tester_server(): - client_tester_server_process = multiprocessing.Process(target=client_tester_server_starter) + client_tester_server_process = multiprocessing.Process( + target=client_tester_server_starter + ) client_tester_server_process.start() time.sleep(1) yield diff --git a/src/tests/reflection_client_test.py b/src/tests/reflection_client_test.py index 89cfcd7..3b3b7e3 100644 --- a/src/tests/reflection_client_test.py +++ b/src/tests/reflection_client_test.py @@ -10,78 +10,93 @@ Test cases for reflection based client """ -logger = logging.getLogger('name') +logger = logging.getLogger("name") + @pytest.fixture(scope="module") def helloworld_reflection_client(): try: - client = Client.get_by_endpoint('localhost:50051') + client = Client.get_by_endpoint("localhost:50051") yield client except: # noqa: E722 pytest.fail("Could not connect to local HelloWorld server") + @pytest.fixture(scope="module") def helloworld_reflection_client_with_interceptor(): try: # Don't use get_by_endpoint here, because interceptors are not cached. Consider caching kwargs too - client = Client('localhost:50051', interceptors=[MetadataClientInterceptor()]) + client = Client("localhost:50051", interceptors=[MetadataClientInterceptor()]) yield client except: # noqa: E722 pytest.fail("Could not connect to local HelloWorld server") + @pytest.fixture(scope="module") def client_tester_reflection_client(): try: - client = Client.get_by_endpoint('localhost:50051') + client = Client.get_by_endpoint("localhost:50051") yield client except: # noqa: E722 pytest.fail("Could not connect to local Test server") + def test_metadata_usage(helloworld_reflection_client): response = helloworld_reflection_client.request( - 'helloworld.Greeter', 'SayHello', + "helloworld.Greeter", + "SayHello", {"name": "sinsky"}, - metadata=[('password', '12345')] + metadata=[("password", "12345")], ) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky, password accepted!"} + def test_interceptor_usage(helloworld_reflection_client_with_interceptor): response = helloworld_reflection_client_with_interceptor.request( - 'helloworld.Greeter', 'SayHello', + "helloworld.Greeter", + "SayHello", {"name": "sinsky"}, ) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky, interceptor accepted!"} + def test_methods_meta(helloworld_reflection_client): - service = helloworld_reflection_client.service('helloworld.Greeter') + service = helloworld_reflection_client.service("helloworld.Greeter") meta = service.methods_meta - assert meta['HelloEveryone'].method_type == MethodType.STREAM_UNARY + assert meta["HelloEveryone"].method_type == MethodType.STREAM_UNARY + def test_unary_unary(helloworld_reflection_client): - response = helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {"name": "sinsky"}) + response = helloworld_reflection_client.request( + "helloworld.Greeter", "SayHello", {"name": "sinsky"} + ) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky!"} + def test_describe_method_request(client_tester_reflection_client): - request_description = \ - client_tester_reflection_client.describe_method_request('client_tester.ClientTester', 'TestUnaryUnary') + request_description = client_tester_reflection_client.describe_method_request( + "client_tester.ClientTester", "TestUnaryUnary" + ) expected_request_description = { - 'factor': 'INT32', - 'readings': 'FLOAT', - 'uuid': 'UINT64', - 'sample_flag': 'BOOL', - 'request_name': 'STRING', - 'extra_data': 'BYTES' + "factor": "INT32", + "readings": "FLOAT", + "uuid": "UINT64", + "sample_flag": "BOOL", + "request_name": "STRING", + "extra_data": "BYTES", } assert ( request_description == expected_request_description ), f"Expected: {expected_request_description}, Actual: {request_description}" + def test_describe_request(client_tester_reflection_client): - request_description = \ - client_tester_reflection_client.describe_request('client_tester.ClientTester', 'TestUnaryUnary') + request_description = client_tester_reflection_client.describe_request( + "client_tester.ClientTester", "TestUnaryUnary" + ) expected_request_description = """TestRequest Fields: \tfactor: INT32 @@ -92,68 +107,82 @@ def test_describe_request(client_tester_reflection_client): \textra_data: BYTES""" assert request_description == expected_request_description + def test_describe_response(client_tester_reflection_client): - request_description = \ - client_tester_reflection_client.describe_response('client_tester.ClientTester', 'TestUnaryUnary') + request_description = client_tester_reflection_client.describe_response( + "client_tester.ClientTester", "TestUnaryUnary" + ) expected_response_description = """TestResponse Fields: \taverage: DOUBLE \tfeedback: STRING""" assert request_description == expected_response_description + def test_empty_body_request(helloworld_reflection_client): - response = helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {}) + response = helloworld_reflection_client.request( + "helloworld.Greeter", "SayHello", {} + ) assert isinstance(response, dict) + def test_nonexistent_service(helloworld_reflection_client): with pytest.raises(ValueError): - helloworld_reflection_client.request('helloworld.Speaker', 'SingHello', {}) + helloworld_reflection_client.request("helloworld.Speaker", "SingHello", {}) + def test_nonexistent_method(helloworld_reflection_client): with pytest.raises(ValueError): - helloworld_reflection_client.request('helloworld.Greeter', 'SayGoodbye', {}) + helloworld_reflection_client.request("helloworld.Greeter", "SayGoodbye", {}) + def test_unsupported_argument(helloworld_reflection_client): with pytest.raises(ParseError): - helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {"foo": "bar"}) + helloworld_reflection_client.request( + "helloworld.Greeter", "SayHello", {"foo": "bar"} + ) + def test_unary_stream(helloworld_reflection_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] responses = helloworld_reflection_client.request( - 'helloworld.Greeter', - 'SayHelloGroup', - {"name": "".join(name_list)} + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} ) assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): assert response == {"message": f"Hello, {name}!"} + def test_stream_unary(helloworld_reflection_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] response = helloworld_reflection_client.request( - 'helloworld.Greeter', - 'HelloEveryone', - [{"name": name} for name in name_list] + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] ) assert isinstance(response, dict) - assert response == {'message': f'Hello, {" ".join(name_list)}!'} + assert response == {"message": f'Hello, {" ".join(name_list)}!'} + def test_stream_stream(helloworld_reflection_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] responses = helloworld_reflection_client.request( - 'helloworld.Greeter', - 'SayHelloOneByOne', - [{"name": name} for name in name_list] + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] ) assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): assert response == {"message": f"Hello, {name}!"} + def test_reflection_service_client(helloworld_reflection_client): - svc_client = helloworld_reflection_client.service('helloworld.Greeter') + svc_client = helloworld_reflection_client.service("helloworld.Greeter") method_names = svc_client.method_names - assert method_names == ('SayHello', 'SayHelloGroup', 'HelloEveryone', 'SayHelloOneByOne') + assert method_names == ( + "SayHello", + "SayHelloGroup", + "HelloEveryone", + "SayHelloOneByOne", + ) + def test_reflection_service_client_invalid_service(helloworld_reflection_client): with pytest.raises(ValueError): - helloworld_reflection_client.service('helloWorld.Singer') + helloworld_reflection_client.service("helloWorld.Singer") diff --git a/src/tests/service_client_test.py b/src/tests/service_client_test.py index f0257b6..a7204a8 100644 --- a/src/tests/service_client_test.py +++ b/src/tests/service_client_test.py @@ -7,16 +7,23 @@ Test cases for ServiceClient """ -logger = logging.getLogger('name') +logger = logging.getLogger("name") + @pytest.fixture(scope="module") def helloworld_service_client(): try: - client = ServiceClient(Client('localhost:50051'), "helloworld.Greeter") + client = ServiceClient(Client("localhost:50051"), "helloworld.Greeter") yield client except: # noqa: E722 pytest.fail("Could not connect to local HelloWorld server") + def test_method_names(helloworld_service_client): method_names = helloworld_service_client.method_names - assert method_names == ('SayHello', 'SayHelloGroup', 'HelloEveryone', 'SayHelloOneByOne') + assert method_names == ( + "SayHello", + "SayHelloGroup", + "HelloEveryone", + "SayHelloOneByOne", + ) diff --git a/src/tests/stub_client_test.py b/src/tests/stub_client_test.py index 68e8dbb..0c37390 100644 --- a/src/tests/stub_client_test.py +++ b/src/tests/stub_client_test.py @@ -9,66 +9,72 @@ Test cases for reflection based client """ -logger = logging.getLogger('name') +logger = logging.getLogger("name") + @pytest.fixture(scope="module") def helloworld_stub_client(): try: - client = StubClient('localhost:50051', [_GREETER]) + client = StubClient("localhost:50051", [_GREETER]) yield client except: # noqa: E722 pytest.fail("Could not connect to local HelloWorld server") def test_unary_unary(helloworld_stub_client): - response = helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {"name": "sinsky"}) + response = helloworld_stub_client.unary_unary( + "helloworld.Greeter", "SayHello", {"name": "sinsky"} + ) assert isinstance(response, dict) assert response == {"message": "Hello, sinsky!"} + def test_empty_body_request(helloworld_stub_client): - response = helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {}) + response = helloworld_stub_client.unary_unary("helloworld.Greeter", "SayHello", {}) logger.warning(f"Response: {response}") assert isinstance(response, dict) + def test_nonexistent_service(helloworld_stub_client): with pytest.raises(ValueError): - helloworld_stub_client.unary_unary('helloworld.Speaker', 'SingHello', {}) + helloworld_stub_client.unary_unary("helloworld.Speaker", "SingHello", {}) + def test_nonexistent_method(helloworld_stub_client): with pytest.raises(ValueError): - helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayGoodbye', {}) + helloworld_stub_client.unary_unary("helloworld.Greeter", "SayGoodbye", {}) + def test_unsupported_argument(helloworld_stub_client): with pytest.raises(ParseError): - helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {"foo": "bar"}) + helloworld_stub_client.unary_unary( + "helloworld.Greeter", "SayHello", {"foo": "bar"} + ) + def test_unary_stream(helloworld_stub_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] responses = helloworld_stub_client.unary_stream( - 'helloworld.Greeter', - 'SayHelloGroup', - {"name": "".join(name_list)} + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} ) assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): assert response == {"message": f"Hello, {name}!"} + def test_stream_unary(helloworld_stub_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] response = helloworld_stub_client.stream_unary( - 'helloworld.Greeter', - 'HelloEveryone', - [{"name": name} for name in name_list] + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] ) assert isinstance(response, dict) - assert response == {'message': f'Hello, {" ".join(name_list)}!'} + assert response == {"message": f'Hello, {" ".join(name_list)}!'} + def test_stream_stream(helloworld_stub_client): name_list = ["sinsky", "viridianforge", "jack", "harry"] responses = helloworld_stub_client.stream_stream( - 'helloworld.Greeter', - 'SayHelloOneByOne', - [{"name": name} for name in name_list] + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] ) assert all(isinstance(response, dict) for response in responses) for response, name in zip(responses, name_list): diff --git a/tests.sh b/tests.sh index 9b88bc9..c51241b 100644 --- a/tests.sh +++ b/tests.sh @@ -3,6 +3,8 @@ # Run this script before commits to count the number of flake8 errors and # and ensure tests are passing. -flake8 . --count --show-source --statistics +ruff check src/grpc_requests/*.py src/tests/*.py --statistics + +ruff format src/grpc_requests/*.py src/tests/*.py --check pytest --cov-report=xml --cov=src/grpc_requests \ No newline at end of file