Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support plugins defined as inner classes #1318

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog-fragments.d/1318.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support plugins defined as inner classes
52 changes: 38 additions & 14 deletions proxy/common/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import importlib
import itertools
from types import ModuleType
from typing import Any, Dict, List, Tuple, Union, Optional

from .utils import text_, bytes_
Expand Down Expand Up @@ -75,31 +76,54 @@ def load(
# this plugin_ is implementing
base_klass = None
for k in mro:
if bytes_(k.__name__) in p:
if bytes_(k.__qualname__) in p:
base_klass = k
break
if base_klass is None:
raise ValueError('%s is NOT a valid plugin' % text_(plugin_))
if klass not in p[bytes_(base_klass.__name__)]:
p[bytes_(base_klass.__name__)].append(klass)
logger.info('Loaded plugin %s.%s', module_name, klass.__name__)
if klass not in p[bytes_(base_klass.__qualname__)]:
p[bytes_(base_klass.__qualname__)].append(klass)
logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__)
# print(p)
return p

@staticmethod
def importer(plugin: Union[bytes, type]) -> Tuple[type, str]:
"""Import and returns the plugin."""
if isinstance(plugin, type):
return (plugin, '__main__')
if inspect.isclass(plugin):
return (plugin, plugin.__module__ or '__main__')
raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin))
plugin_ = text_(plugin.strip())
assert plugin_ != ''
module_name, klass_name = plugin_.rsplit(text_(DOT), 1)
klass = getattr(
importlib.import_module(
module_name.replace(
os.path.sep, text_(DOT),
),
),
klass_name,
)
path = plugin_.split(text_(DOT))
klass = None

def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]:
klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT))
try:
klass_module = importlib.import_module(klass_module_name)
except ModuleNotFoundError:
return None
klass_container: Union[ModuleType, type] = klass_module
for klass_path_part in klass_path:
try:
klass_container = getattr(klass_container, klass_path_part)
except AttributeError:
return None
if not isinstance(klass_container, type) or not inspect.isclass(klass_container):
return None
return klass_container

module_name = None
for module_name_parts in range(len(path) - 1, 0, -1):
module_name = '.'.join(path[0:module_name_parts])
klass = locate_klass(module_name, path[module_name_parts:])
if klass:
break
if klass is None:
module_name = '__main__'
klass = locate_klass(module_name, path)
if klass is None or module_name is None:
raise ValueError('%s is not resolvable as a plugin class' % text_(plugin))
return (klass, module_name)
2 changes: 1 addition & 1 deletion proxy/core/acceptor/acceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None:
conn,
addr,
event_queue=self.event_queue,
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
# TODO: Move me into target method
logger.debug( # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion proxy/core/work/fd/fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def work(self, *args: Any) -> None:
self.works[fileno].publish_event(
event_name=eventNames.WORK_STARTED,
event_payload={'fileno': fileno, 'addr': addr},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
try:
self.works[fileno].initialize()
Expand Down
2 changes: 1 addition & 1 deletion proxy/core/work/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def shutdown(self) -> None:
self.publish_event(
event_name=eventNames.WORK_FINISHED,
event_payload={},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def run(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/exception/http_request_rejected.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.reason: Optional[bytes] = reason
self.headers: Optional[Dict[bytes, bytes]] = headers
self.body: Optional[bytes] = body
klass_name = self.__class__.__name__
klass_name = self.__class__.__qualname__
super().__init__(
message='%s %r' % (klass_name, reason)
if reason
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/exception/proxy_auth_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ProxyAuthenticationFailed(HttpProtocolException):
incoming request doesn't present necessary credentials."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(self.__class__.__name__, **kwargs)
super().__init__(self.__class__.__qualname__, **kwargs)

def response(self, _request: 'HttpParser') -> memoryview:
return PROXY_AUTH_FAILED_RESPONSE_PKT
2 changes: 1 addition & 1 deletion proxy/http/exception/proxy_conn_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, host: str, port: int, reason: str, **kwargs: Any):
self.host: str = host
self.port: int = port
self.reason: str = reason
super().__init__('%s %s' % (self.__class__.__name__, reason), **kwargs)
super().__init__('%s %s' % (self.__class__.__qualname__, reason), **kwargs)

def response(self, _request: 'HttpParser') -> memoryview:
return BAD_GATEWAY_RESPONSE_PKT
2 changes: 1 addition & 1 deletion proxy/http/proxy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def name(self) -> str:

Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__ # pragma: no cover
return self.__class__.__qualname__ # pragma: no cover

def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['HostPort']]:
"""Resolve upstream server host to an IP address.
Expand Down
8 changes: 4 additions & 4 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def emit_request_complete(self) -> None:
if self.request.method == httpMethods.POST
else None,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_events(self, chunk_size: int) -> None:
Expand Down Expand Up @@ -911,7 +911,7 @@ def emit_response_headers_complete(self) -> None:
for k, v in self.response.headers.items()
},
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_chunk_received(self, chunk_size: int) -> None:
Expand All @@ -925,7 +925,7 @@ def emit_response_chunk_received(self, chunk_size: int) -> None:
'chunk_size': chunk_size,
'encoded_chunk_size': chunk_size,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_complete(self) -> None:
Expand All @@ -938,7 +938,7 @@ def emit_response_complete(self) -> None:
event_payload={
'encoded_response_size': self.response.total_size,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

#
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def name(self) -> str:

Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__ # pragma: no cover
return self.__class__.__qualname__ # pragma: no cover

@abstractmethod
def routes(self) -> List[Tuple[int, str]]:
Expand Down
25 changes: 25 additions & 0 deletions tests/common/my_plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.

:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Any

from proxy.http.proxy import HttpProxyPlugin


class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class OuterClass:

class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
39 changes: 38 additions & 1 deletion tests/common/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Dict, List
from typing import Any, Dict, List

import unittest
from unittest import mock
Expand All @@ -19,6 +19,7 @@
from proxy.common.utils import bytes_
from proxy.common.version import __version__
from proxy.common.constants import PLUGIN_HTTP_PROXY, PY2_DEPRECATION_MESSAGE
from . import my_plugins


class TestFlags(unittest.TestCase):
Expand Down Expand Up @@ -140,6 +141,42 @@ def test_unique_plugin_from_class(self) -> None:
],
})

def test_plugin_from_inner_class_by_type(self) -> None:
self.flags = FlagParser.initialize(
[], plugins=[
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
)
self.assert_plugins({
'HttpProtocolHandlerPlugin': [
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
})

def test_plugin_from_inner_class_by_name(self) -> None:
self.flags = FlagParser.initialize(
[], plugins=[
b'tests.common.test_flags.TestFlags.MyHttpProxyPlugin',
b'tests.common.my_plugins.MyHttpProxyPlugin',
b'tests.common.my_plugins.OuterClass.MyHttpProxyPlugin',
],
)
self.assert_plugins({
'HttpProtocolHandlerPlugin': [
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
})

class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def test_basic_auth_flag_is_base64_encoded(self) -> None:
flags = FlagParser.initialize(['--basic-auth', 'user:pass'])
self.assertEqual(flags.auth_code, b'dXNlcjpwYXNz')
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_event_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_empties_queue(self) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
with self.assertRaises(queue.Empty):
Expand All @@ -64,7 +64,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
# consume
self.dispatcher.run_once()
Expand All @@ -79,7 +79,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)
return relay_recv
Expand All @@ -101,7 +101,7 @@ def test_unsubscribe(self) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
with self.assertRaises(EOFError):
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.assertEqual(
evq.queue.get(), {
Expand All @@ -44,7 +44,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)

Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_event_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
self.subscriber.unsubscribe()
Expand All @@ -69,6 +69,6 @@ def callback(self, ev: Dict[str, Any]) -> None:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)