diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b73b61ae5..eb1308f7b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -144,7 +144,7 @@ repos: rev: 3.9.2 hooks: - id: flake8 - language_version: python3 + language_version: python3.10 additional_dependencies: - flake8-2020 >= 1.6.0 - flake8-docstrings >= 1.5.0 diff --git a/README.md b/README.md index b93c88af4b..83b68c6486 100644 --- a/README.md +++ b/README.md @@ -2341,25 +2341,25 @@ To run standalone benchmark for `proxy.py`, use the following command from repo ```console ❯ proxy -h -usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] - [--tunnel-username TUNNEL_USERNAME] +usage: -m [-h] [--threadless] [--threaded] [--num-workers NUM_WORKERS] + [--enable-events] [--local-executor LOCAL_EXECUTOR] + [--backlog BACKLOG] [--hostname HOSTNAME] + [--hostnames HOSTNAMES [HOSTNAMES ...]] [--port PORT] + [--ports PORTS [PORTS ...]] [--port-file PORT_FILE] + [--unix-socket-path UNIX_SOCKET_PATH] + [--num-acceptors NUM_ACCEPTORS] [--tunnel-hostname TUNNEL_HOSTNAME] + [--tunnel-port TUNNEL_PORT] [--tunnel-username TUNNEL_USERNAME] [--tunnel-ssh-key TUNNEL_SSH_KEY] [--tunnel-ssh-key-passphrase TUNNEL_SSH_KEY_PASSPHRASE] - [--tunnel-remote-port TUNNEL_REMOTE_PORT] [--threadless] - [--threaded] [--num-workers NUM_WORKERS] [--enable-events] - [--local-executor LOCAL_EXECUTOR] [--backlog BACKLOG] - [--hostname HOSTNAME] [--hostnames HOSTNAMES [HOSTNAMES ...]] - [--port PORT] [--ports PORTS [PORTS ...]] [--port-file PORT_FILE] - [--unix-socket-path UNIX_SOCKET_PATH] - [--num-acceptors NUM_ACCEPTORS] [--version] [--log-level LOG_LEVEL] - [--log-file LOG_FILE] [--log-format LOG_FORMAT] - [--open-file-limit OPEN_FILE_LIMIT] + [--tunnel-remote-port TUNNEL_REMOTE_PORT] [--version] + [--log-level LOG_LEVEL] [--log-file LOG_FILE] + [--log-format LOG_FORMAT] [--open-file-limit OPEN_FILE_LIMIT] [--plugins PLUGINS [PLUGINS ...]] [--enable-dashboard] [--basic-auth BASIC_AUTH] [--enable-ssh-tunnel] [--work-klass WORK_KLASS] [--pid-file PID_FILE] [--openssl OPENSSL] - [--data-dir DATA_DIR] [--enable-proxy-protocol] [--enable-conn-pool] - [--key-file KEY_FILE] [--cert-file CERT_FILE] - [--client-recvbuf-size CLIENT_RECVBUF_SIZE] + [--data-dir DATA_DIR] [--ssh-listener-klass SSH_LISTENER_KLASS] + [--enable-proxy-protocol] [--enable-conn-pool] [--key-file KEY_FILE] + [--cert-file CERT_FILE] [--client-recvbuf-size CLIENT_RECVBUF_SIZE] [--server-recvbuf-size SERVER_RECVBUF_SIZE] [--max-sendbuf-size MAX_SENDBUF_SIZE] [--timeout TIMEOUT] [--disable-http-proxy] [--disable-headers DISABLE_HEADERS] @@ -2379,25 +2379,10 @@ usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] [--filtered-client-ips FILTERED_CLIENT_IPS] [--filtered-url-regex-config FILTERED_URL_REGEX_CONFIG] -proxy.py v2.4.4rc6.dev85+g9335918b +proxy.py v2.4.4rc6.dev164+g73497f30 options: -h, --help show this help message and exit - --tunnel-hostname TUNNEL_HOSTNAME - Default: None. Remote hostname or IP address to which - SSH tunnel will be established. - --tunnel-port TUNNEL_PORT - Default: 22. SSH port of the remote host. - --tunnel-username TUNNEL_USERNAME - Default: None. Username to use for establishing SSH - tunnel. - --tunnel-ssh-key TUNNEL_SSH_KEY - Default: None. Private key path in pem format - --tunnel-ssh-key-passphrase TUNNEL_SSH_KEY_PASSPHRASE - Default: None. Private key passphrase - --tunnel-remote-port TUNNEL_REMOTE_PORT - Default: 8899. Remote port which will be forwarded - locally for proxy. --threadless Default: True. Enabled by default on Python 3.8+ (mac, linux). When disabled a new thread is spawned to handle each client connection. @@ -2434,6 +2419,21 @@ options: --host and --port flags are ignored --num-acceptors NUM_ACCEPTORS Defaults to number of CPU cores. + --tunnel-hostname TUNNEL_HOSTNAME + Default: None. Remote hostname or IP address to which + SSH tunnel will be established. + --tunnel-port TUNNEL_PORT + Default: 22. SSH port of the remote host. + --tunnel-username TUNNEL_USERNAME + Default: None. Username to use for establishing SSH + tunnel. + --tunnel-ssh-key TUNNEL_SSH_KEY + Default: None. Private key path in pem format + --tunnel-ssh-key-passphrase TUNNEL_SSH_KEY_PASSPHRASE + Default: None. Private key passphrase + --tunnel-remote-port TUNNEL_REMOTE_PORT + Default: 8899. Remote port which will be forwarded + locally for proxy. --version, -v Prints proxy.py version. --log-level LOG_LEVEL Valid options: DEBUG, INFO (default), WARNING, ERROR, @@ -2461,6 +2461,9 @@ options: --openssl OPENSSL Default: openssl. Path to openssl binary. By default, assumption is that openssl is in your PATH. --data-dir DATA_DIR Default: ~/.proxypy. Path to proxypy data directory. + --ssh-listener-klass SSH_LISTENER_KLASS + Default: proxy.core.ssh.listener.SshTunnelListener. An + implementation of BaseSshTunnelListener --enable-proxy-protocol Default: False. If used, will enable proxy protocol. Only version 1 is currently supported. diff --git a/docs/conf.py b/docs/conf.py index 863864b20f..da38c54ddc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -324,6 +324,7 @@ (_py_class_role, 're.Pattern'), (_py_class_role, 'proxy.core.base.tcp_server.T'), (_py_class_role, 'proxy.common.types.RePattern'), + (_py_class_role, 'BaseSshTunnelHandler'), (_py_obj_role, 'proxy.core.work.threadless.T'), (_py_obj_role, 'proxy.core.work.work.T'), (_py_obj_role, 'proxy.core.base.tcp_server.T'), diff --git a/proxy/core/ssh/listener.py b/proxy/core/ssh/listener.py index d851600fdd..72e0369a9f 100644 --- a/proxy/core/ssh/listener.py +++ b/proxy/core/ssh/listener.py @@ -8,21 +8,20 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import sys +import socket import logging import argparse -from typing import TYPE_CHECKING, Any, Set, Callable, Optional +from typing import TYPE_CHECKING, Any, Set, Optional, cast try: - from paramiko import SSHClient, AutoAddPolicy - from paramiko.transport import Transport - if TYPE_CHECKING: # pragma: no cover - from paramiko.channel import Channel - + if TYPE_CHECKING: # pragma: no cover from ...common.types import HostPort except ImportError: # pragma: no cover pass +from .base import BaseSshTunnelHandler, BaseSshTunnelListener from ...common.flag import flags @@ -72,18 +71,27 @@ ) -class SshTunnelListener: +class SshTunnelListener(BaseSshTunnelListener): """Connects over SSH and forwards a remote port to local host. Incoming connections are delegated to provided callback.""" def __init__( - self, - flags: argparse.Namespace, - on_connection_callback: Callable[['Channel', 'HostPort', 'HostPort'], None], + self, + flags: argparse.Namespace, + handler: BaseSshTunnelHandler, + *args: Any, + **kwargs: Any, ) -> None: + paramiko_logger = logging.getLogger('paramiko') + paramiko_logger.setLevel(logging.WARNING) + + # pylint: disable=import-outside-toplevel + from paramiko import SSHClient + from paramiko.transport import Transport + self.flags = flags - self.on_connection_callback = on_connection_callback + self.handler = handler self.ssh: Optional[SSHClient] = None self.transport: Optional[Transport] = None self.forwarded: Set['HostPort'] = set() @@ -92,24 +100,20 @@ def start_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.request_port_forward( *remote_addr, - handler=self.on_connection_callback, + handler=self.handler.on_connection, ) self.forwarded.add(remote_addr) - logger.info('%s:%d forwarding successful...' % remote_addr) + logger.debug('%s:%d forwarding successful...' % remote_addr) def stop_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.cancel_port_forward(*remote_addr) self.forwarded.remove(remote_addr) - def __enter__(self) -> 'SshTunnelListener': - self.setup() - return self - - def __exit__(self, *args: Any) -> None: - self.shutdown() - def setup(self) -> None: + # pylint: disable=import-outside-toplevel + from paramiko import SSHClient, AutoAddPolicy + self.ssh = SSHClient() self.ssh.load_system_host_keys() self.ssh.set_missing_host_key_policy(AutoAddPolicy()) @@ -119,14 +123,30 @@ def setup(self) -> None: username=self.flags.tunnel_username, key_filename=self.flags.tunnel_ssh_key, passphrase=self.flags.tunnel_ssh_key_passphrase, + compress=True, + timeout=10, + auth_timeout=7, ) - logger.info( - 'SSH connection established to %s:%d...' % ( + logger.debug( + 'SSH connection established to %s:%d...' + % ( self.flags.tunnel_hostname, self.flags.tunnel_port, ), ) self.transport = self.ssh.get_transport() + assert self.transport + sock = cast(socket.socket, self.transport.sock) # type: ignore[redundant-cast] + # Enable TCP keep-alive + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + # Keep-alive interval (in seconds) + if sys.platform != 'darwin': + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) + # Keep-alive probe interval (in seconds) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) + # Number of keep-alive probes before timeout + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) + self.start_port_forward(('', self.flags.tunnel_remote_port)) def shutdown(self) -> None: for remote_addr in list(self.forwarded): @@ -136,3 +156,10 @@ def shutdown(self) -> None: self.transport.close() if self.ssh is not None: self.ssh.close() + self.handler.shutdown() + + def is_alive(self) -> bool: + return self.transport.is_alive() if self.transport else False + + def is_active(self) -> bool: + return self.transport.is_active() if self.transport else False diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 70d3369ec4..4e2f44ac3f 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -726,9 +726,9 @@ def generate_upstream_certificate( ): raise HttpProtocolException( f'For certificate generation all the following flags are mandatory: ' - f'--ca-cert-file:{ self.flags.ca_cert_file }, ' - f'--ca-key-file:{ self.flags.ca_key_file }, ' - f'--ca-signing-key-file:{ self.flags.ca_signing_key_file }', + f'--ca-cert-file:{ self.flags.ca_cert_file}, ' + f'--ca-key-file:{ self.flags.ca_key_file}, ' + f'--ca-signing-key-file:{ self.flags.ca_signing_key_file}', ) cert_file_path = HttpProxyPlugin.generated_cert_file_path( self.flags.ca_cert_dir, text_(self.request.host), diff --git a/proxy/http/websocket/frame.py b/proxy/http/websocket/frame.py index af6e0c7e9e..08954f0725 100644 --- a/proxy/http/websocket/frame.py +++ b/proxy/http/websocket/frame.py @@ -128,8 +128,8 @@ def build(self) -> bytes: ) else: raise ValueError( - f'Invalid payload_length { self.payload_length },' - f'maximum allowed { 1 << 64 }', + f'Invalid payload_length { self.payload_length},' + f'maximum allowed { 1 << 64}', ) if self.masked and self.data: mask = secrets.token_bytes(4) if self.mask is None else self.mask diff --git a/proxy/proxy.py b/proxy/proxy.py index 4279c611d5..2350f4e88f 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -14,10 +14,10 @@ import pprint import signal import logging +import argparse import threading -from typing import TYPE_CHECKING, Any, List, Optional, cast +from typing import TYPE_CHECKING, Any, List, Type, Optional, cast -from .core.ssh import SshTunnelListener, SshHttpProtocolHandler from .core.work import ThreadlessPool from .core.event import EventManager from .common.flag import FlagParser, flags @@ -25,16 +25,19 @@ from .core.work.fd import RemoteFdExecutor from .core.acceptor import AcceptorPool from .core.listener import ListenerPool +from .core.ssh.base import BaseSshTunnelListener from .common.constants import ( IS_WINDOWS, DEFAULT_PLUGINS, DEFAULT_VERSION, DEFAULT_LOG_FILE, DEFAULT_PID_FILE, DEFAULT_LOG_LEVEL, DEFAULT_BASIC_AUTH, DEFAULT_LOG_FORMAT, DEFAULT_WORK_KLASS, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_ENABLE_DASHBOARD, DEFAULT_ENABLE_SSH_TUNNEL, + DEFAULT_SSH_LISTENER_KLASS, ) if TYPE_CHECKING: # pragma: no cover from .core.listener import TcpSocketListener + from .core.ssh.base import BaseSshTunnelHandler logger = logging.getLogger(__name__) @@ -152,6 +155,15 @@ help='Default: ~/.proxypy. Path to proxypy data directory.', ) +flags.add_argument( + '--ssh-listener-klass', + type=str, + default=DEFAULT_SSH_LISTENER_KLASS, + help='Default: ' + + DEFAULT_SSH_LISTENER_KLASS + + '. An implementation of BaseSshTunnelListener', +) + class Proxy: """Proxy is a context manager to control proxy.py library core. @@ -175,13 +187,13 @@ class Proxy: """ def __init__(self, input_args: Optional[List[str]] = None, **opts: Any) -> None: + self.opts = opts self.flags = FlagParser.initialize(input_args, **opts) self.listeners: Optional[ListenerPool] = None self.executors: Optional[ThreadlessPool] = None self.acceptors: Optional[AcceptorPool] = None self.event_manager: Optional[EventManager] = None - self.ssh_http_protocol_handler: Optional[SshHttpProtocolHandler] = None - self.ssh_tunnel_listener: Optional[SshTunnelListener] = None + self.ssh_tunnel_listener: Optional[BaseSshTunnelListener] = None def __enter__(self) -> 'Proxy': self.setup() @@ -261,21 +273,29 @@ def setup(self) -> None: self.acceptors.setup() # Start SSH tunnel acceptor if enabled if self.flags.enable_ssh_tunnel: - self.ssh_http_protocol_handler = SshHttpProtocolHandler( - flags=self.flags, - ) - self.ssh_tunnel_listener = SshTunnelListener( + self.ssh_tunnel_listener = self._setup_tunnel( flags=self.flags, - on_connection_callback=self.ssh_http_protocol_handler.on_connection, - ) - self.ssh_tunnel_listener.setup() - self.ssh_tunnel_listener.start_port_forward( - ('', self.flags.tunnel_remote_port), + **self.opts, ) # TODO: May be close listener fd as we don't need it now if threading.current_thread() == threading.main_thread(): self._register_signals() + @staticmethod + def _setup_tunnel( + flags: argparse.Namespace, + ssh_handler_klass: Type['BaseSshTunnelHandler'], + ssh_listener_klass: Any, + **kwargs: Any, + ) -> BaseSshTunnelListener: + tunnel = cast(Type[BaseSshTunnelListener], ssh_listener_klass)( + flags=flags, + handler=ssh_handler_klass(flags=flags), + **kwargs, + ) + tunnel.setup() + return tunnel + def shutdown(self) -> None: if self.flags.enable_ssh_tunnel: assert self.ssh_tunnel_listener is not None @@ -339,14 +359,14 @@ def _register_signals(self) -> None: @staticmethod def _handle_exit_signal(signum: int, _frame: Any) -> None: - logger.info('Received signal %d' % signum) + logger.debug('Received signal %d' % signum) sys.exit(0) def _handle_siginfo(self, _signum: int, _frame: Any) -> None: pprint.pprint(self.flags.__dict__) # pragma: no cover -def sleep_loop() -> None: +def sleep_loop(p: Optional[Proxy] = None) -> None: while True: try: time.sleep(1) @@ -355,8 +375,8 @@ def sleep_loop() -> None: def main(**opts: Any) -> None: - with Proxy(sys.argv[1:], **opts): - sleep_loop() + with Proxy(sys.argv[1:], **opts) as p: + sleep_loop(p) def entry_point() -> None: diff --git a/tests/test_main.py b/tests/test_main.py index d939273cb4..de9f69e5f6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -330,26 +330,27 @@ def test_enable_devtools( @mock.patch('proxy.proxy.AcceptorPool') @mock.patch('proxy.proxy.ThreadlessPool') @mock.patch('proxy.proxy.ListenerPool') - @mock.patch('proxy.proxy.SshHttpProtocolHandler') - @mock.patch('proxy.proxy.SshTunnelListener') def test_enable_ssh_tunnel( - self, - mock_ssh_tunnel_listener: mock.Mock, - mock_ssh_http_proto_handler: mock.Mock, - mock_listener_pool: mock.Mock, - mock_executor_pool: mock.Mock, - mock_acceptor_pool: mock.Mock, - mock_event_manager: mock.Mock, - mock_parse_args: mock.Mock, - mock_load_plugins: mock.Mock, - mock_sleep: mock.Mock, + self, + mock_listener_pool: mock.Mock, + mock_executor_pool: mock.Mock, + mock_acceptor_pool: mock.Mock, + mock_event_manager: mock.Mock, + mock_parse_args: mock.Mock, + mock_load_plugins: mock.Mock, + mock_sleep: mock.Mock, ) -> None: mock_sleep.side_effect = KeyboardInterrupt() mock_args = mock_parse_args.return_value self.mock_default_args(mock_args) mock_args.enable_ssh_tunnel = True mock_args.local_executor = 0 - main() + mock_ssh_tunnel_listener = mock.MagicMock() + mock_ssh_http_proto_handler = mock.MagicMock() + main( + ssh_listener_klass=mock_ssh_tunnel_listener, + ssh_handler_klass=mock_ssh_http_proto_handler, + ) mock_load_plugins.assert_called() self.assertEqual( mock_load_plugins.call_args_list[0][0][0], [ @@ -367,10 +368,7 @@ def test_enable_ssh_tunnel( mock_ssh_http_proto_handler.assert_called_once() mock_ssh_tunnel_listener.assert_called_once() mock_ssh_tunnel_listener.return_value.setup.assert_called_once() - mock_ssh_tunnel_listener.return_value.start_port_forward.assert_called_once() mock_ssh_tunnel_listener.return_value.shutdown.assert_called_once() - # shutdown will internally call stop port forward - mock_ssh_tunnel_listener.return_value.stop_port_forward.assert_not_called() class TestProxyContextManager(unittest.TestCase):