diff --git a/pyroute2/__init__.py b/pyroute2/__init__.py index 4979c64be..89d6e1299 100644 --- a/pyroute2/__init__.py +++ b/pyroute2/__init__.py @@ -28,6 +28,7 @@ from pyroute2.iproute import ChaoticIPRoute, IPBatch, IPRoute, RawIPRoute from pyroute2.iproute.ipmock import IPRoute as IPMock from pyroute2.ipset import IPSet +from pyroute2.ipvs import IPVS, IPVSDest, IPVSService from pyroute2.iwutil import IW from pyroute2.ndb.main import NDB from pyroute2.ndb.noipdb import NoIPDB @@ -81,6 +82,9 @@ IPRoute, IPRSocket, IPSet, + IPVS, + IPVSDest, + IPVSService, IW, GenericNetlinkSocket, L2tp, diff --git a/pyroute2/iproute/linux.py b/pyroute2/iproute/linux.py index a26589643..b993274e1 100644 --- a/pyroute2/iproute/linux.py +++ b/pyroute2/iproute/linux.py @@ -13,13 +13,9 @@ from pyroute2.lab import LAB_API from pyroute2.netlink import ( NLM_F_ACK, - NLM_F_APPEND, NLM_F_ATOMIC, NLM_F_CREATE, NLM_F_DUMP, - NLM_F_ECHO, - NLM_F_EXCL, - NLM_F_REPLACE, NLM_F_REQUEST, NLM_F_ROOT, NLMSG_ERROR, @@ -186,35 +182,6 @@ def filter_messages(*argv, **kwarg): self._genmatch = self.filter_messages self.filter_messages = filter_messages - def make_request_type(self, command, command_map): - if isinstance(command, basestring): - return (lambda x: (x[0], self.make_request_flags(x[1])))( - command_map[command] - ) - elif isinstance(command, int): - return command, self.make_request_flags('create') - elif isinstance(command, (list, tuple)): - return command - else: - raise TypeError('allowed command types: int, str, list, tuple') - - def make_request_flags(self, mode): - flags = { - 'dump': NLM_F_REQUEST | NLM_F_DUMP, - 'get': NLM_F_REQUEST | NLM_F_ACK, - 'req': NLM_F_REQUEST | NLM_F_ACK, - } - flags['create'] = flags['req'] | NLM_F_CREATE | NLM_F_EXCL - flags['append'] = flags['req'] | NLM_F_CREATE | NLM_F_APPEND - flags['change'] = flags['req'] | NLM_F_REPLACE - flags['replace'] = flags['change'] | NLM_F_CREATE - - return flags[mode] | ( - NLM_F_ECHO - if (self.config['nlm_echo'] and mode not in ('get', 'dump')) - else 0 - ) - def filter_messages(self, dump_filter, msgs): ''' Filter messages using `dump_filter`. The filter might be a diff --git a/pyroute2/ipvs.py b/pyroute2/ipvs.py new file mode 100644 index 000000000..05dcc8c43 --- /dev/null +++ b/pyroute2/ipvs.py @@ -0,0 +1,195 @@ +''' +IPVS -- IP Virtual Server +------------------------- + +IPVS configuration is done via generic netlink protocol. +At the low level one can use it with a GenericNetlinkSocket, +binding it to "IPVS" generic netlink family. + +But for the convenience the library provides utility classes: + + * IPVS -- a socket class to access the API + * IPVSService -- a class to define IPVS service records + * IPVSDest -- a class to define real server records + +Dump all the records:: + + from pyroute2 import IPVS, IPVSDest, IPVSService + + # run the socket + ipvs = IPVS() + + # iterate all the IPVS services + for s in ipvs.service("dump"): + + # create a utility object from a netlink message + service = IPVSService.from_message(s) + print("Service: ", service) + + # iterate all the real servers for this service + for d in ipvs.dest("dump", service=service): + + # create and print a utility object + dest = IPVSDest.from_message(d) + print(" Real server: ", dest) + +Create a service and a real server record:: + + from socket import IPPROTO_TCP + from pyroute2 import IPVS, IPVSDest, IPVSService + + ipvs = IPVS() + + service = IPVSService(addr="192.168.122.1", port=80, protocol=IPPROTO_TCP) + real_server = IPVSDest(addr="10.0.2.20", port=80) + + ipvs.service("add", service=service) + ipvs.dest("add", service=service, dest=real_server) + +Delete a service:: + + from pyroute2 import IPVS, IPVSService + + ipvs = IPVS() + ipvs.service("del", + service=IPVSService( + addr="192.168.122.1", + port=80, + protocol=IPPROTO_TCP + ) + ) + +''' + +from socket import AF_INET + +from pyroute2.common import get_address_family +from pyroute2.netlink.generic import ipvs +from pyroute2.requests.common import NLAKeyTransform +from pyroute2.requests.main import RequestProcessor + + +class ServiceFieldFilter(NLAKeyTransform): + _nla_prefix = 'IPVS_SVC_ATTR_' + + def set_addr(self, context, value): + ret = {"addr": value} + if "af" in context.keys(): + family = context["af"] + else: + family = ret["af"] = get_address_family(value) + if family == AF_INET and "netmask" not in context.keys(): + ret["netmask"] = "255.255.255.255" + return ret + + +class DestFieldFilter(NLAKeyTransform): + _nla_prefix = 'IPVS_DEST_ATTR_' + + def set_addr(self, context, value): + ret = {"addr": value} + if "addr_family" not in context.keys(): + ret["addr_family"] = get_address_family(value) + return ret + + +class NLAFilter(RequestProcessor): + msg = None + keys = tuple() + field_filter = None + nla = None + default_values = {} + + def __init__(self, **kwarg): + dict.update(self, self.default_values) + super().__init__(prime=kwarg) + + @classmethod + def from_message(cls, msg): + obj = cls() + for key, value in msg.get(cls.nla)["attrs"]: + obj[key] = value + obj.pop("stats", None) + obj.pop("stats64", None) + return obj + + def dump_nla(self, items=None): + if items is None: + items = self.items() + self.update(self) + self.finalize() + return { + "attrs": list( + map(lambda x: (self.msg.name2nla(x[0]), x[1]), items) + ) + } + + def dump_key(self): + return self.dump_nla( + items=filter(lambda x: x[0] in self.key_fields, self.items()) + ) + + +class IPVSService(NLAFilter): + field_filter = ServiceFieldFilter() + msg = ipvs.ipvsmsg.service + key_fields = ("af", "protocol", "addr", "port") + nla = "IPVS_CMD_ATTR_SERVICE" + default_values = { + "timeout": 0, + "sched_name": "wlc", + "flags": {"flags": 0, "mask": 0xFFFF}, + } + + +class IPVSDest(NLAFilter): + field_filter = DestFieldFilter() + msg = ipvs.ipvsmsg.dest + nla = "IPVS_CMD_ATTR_DEST" + default_values = { + "fwd_method": 3, + "weight": 1, + "tun_type": 0, + "tun_port": 0, + "tun_flags": 0, + "u_thresh": 0, + "l_thresh": 0, + } + + +class IPVS(ipvs.IPVSSocket): + + def service(self, command, service=None): + command_map = { + "add": (ipvs.IPVS_CMD_NEW_SERVICE, "create"), + "set": (ipvs.IPVS_CMD_SET_SERVICE, "change"), + "update": (ipvs.IPVS_CMD_DEL_SERVICE, "change"), + "del": (ipvs.IPVS_CMD_DEL_SERVICE, "req"), + "get": (ipvs.IPVS_CMD_GET_SERVICE, "get"), + "dump": (ipvs.IPVS_CMD_GET_SERVICE, "dump"), + } + cmd, flags = self.make_request_type(command, command_map) + msg = ipvs.ipvsmsg() + msg["cmd"] = cmd + msg["version"] = ipvs.GENL_VERSION + if service is not None: + msg["attrs"] = [("IPVS_CMD_ATTR_SERVICE", service.dump_nla())] + return self.nlm_request(msg, msg_type=self.prid, msg_flags=flags) + + def dest(self, command, service, dest=None): + command_map = { + "add": (ipvs.IPVS_CMD_NEW_DEST, "create"), + "set": (ipvs.IPVS_CMD_SET_DEST, "change"), + "update": (ipvs.IPVS_CMD_DEL_DEST, "change"), + "del": (ipvs.IPVS_CMD_DEL_DEST, "req"), + "get": (ipvs.IPVS_CMD_GET_DEST, "get"), + "dump": (ipvs.IPVS_CMD_GET_DEST, "dump"), + } + cmd, flags = self.make_request_type(command, command_map) + msg = ipvs.ipvsmsg() + msg["cmd"] = cmd + msg["version"] = 0x1 + msg["attrs"] = [("IPVS_CMD_ATTR_SERVICE", service.dump_key())] + if dest is not None: + msg["attrs"].append(("IPVS_CMD_ATTR_DEST", dest.dump_nla())) + return self.nlm_request(msg, msg_type=self.prid, msg_flags=flags) diff --git a/pyroute2/netlink/__init__.py b/pyroute2/netlink/__init__.py index 4ea5bd659..6e60a9ee5 100644 --- a/pyroute2/netlink/__init__.py +++ b/pyroute2/netlink/__init__.py @@ -473,6 +473,7 @@ class my_msg(nlmsg): import io import logging +import socket import struct import sys import threading @@ -894,7 +895,7 @@ def __init__( self.value = NotInitialized # work only on non-empty mappings if self.nla_map and not self.__class__.__compiled_nla: - self.compile_nla() + self.compile_nla_table() if self.header: self['header'] = {} @@ -1434,7 +1435,7 @@ def getvalue(self): return self - def compile_nla(self): + def compile_nla_table(self): # Bug-Url: https://github.com/svinota/pyroute2/issues/980 # Bug-Url: https://github.com/svinota/pyroute2/pull/981 if isinstance(self.nla_map, NlaMapAdapter): @@ -2051,15 +2052,37 @@ class target(nla_base_string): __slots__ = () sql_type = 'TEXT' family = None + family_attr = None own_parent = True + def __init__(self, *argv, **kwarg): + init = kwarg.get('init', None) + if init is not None: + key, value = init.split(',') + if key == 'family' and value.startswith('AF_'): + self.family = getattr(socket, value) + elif key == 'nla': + self.family_attr = value + super().__init__(*argv, **kwarg) + def get_family(self): if self.family is not None: return self.family pointer = self + if self.family_attr is not None: + nla = self.family_attr + else: + nla = 'family' while pointer.parent is not None: pointer = pointer.parent - return pointer.get('family', AF_UNSPEC) + family = pointer.get(nla) + if family is not None: + return family + return AF_UNSPEC + + @staticmethod + def get_addrlen(family): + return {AF_INET: 4, AF_INET6: 16, AF_MPLS: 4}.get(family, 4) def encode(self): family = self.get_family() @@ -2096,14 +2119,17 @@ def encode(self): def decode(self): nla_base_string.decode(self) family = self.get_family() + data = self['value'] if family in (AF_INET, AF_INET6): - self.value = inet_ntop(family, self['value']) + if family == AF_INET: + data = data[:4] + elif family == AF_INET6: + data = data[:16] + self.value = inet_ntop(family, data) elif family == AF_MPLS: self.value = [] - for i in range(len(self['value']) // 4): - label = struct.unpack( - '>I', self['value'][i * 4 : i * 4 + 4] - )[0] + for i in range(len(data) // 4): + label = struct.unpack('>I', data[i * 4 : i * 4 + 4])[0] record = { 'label': (label & 0xFFFFF000) >> 12, 'tc': (label & 0x00000E00) >> 9, diff --git a/pyroute2/netlink/generic/ipvs.py b/pyroute2/netlink/generic/ipvs.py new file mode 100644 index 000000000..38260c0cd --- /dev/null +++ b/pyroute2/netlink/generic/ipvs.py @@ -0,0 +1,123 @@ +from pyroute2.netlink import genlmsg, nla +from pyroute2.netlink.generic import GenericNetlinkSocket + +GENL_NAME = "IPVS" +GENL_VERSION = 0x1 + +IPVS_CMD_UNSPEC = 0 + +IPVS_CMD_NEW_SERVICE = 1 +IPVS_CMD_SET_SERVICE = 2 +IPVS_CMD_DEL_SERVICE = 3 +IPVS_CMD_GET_SERVICE = 4 + +IPVS_CMD_NEW_DEST = 5 +IPVS_CMD_SET_DEST = 6 +IPVS_CMD_DEL_DEST = 7 +IPVS_CMD_GET_DEST = 8 + +IPVS_CMD_NEW_DAEMON = 9 +IPVS_CMD_DEL_DAEMON = 10 +IPVS_CMD_GET_DAEMON = 11 + +IPVS_CMD_SET_CONFIG = 12 +IPVS_CMD_GET_CONFIG = 13 + +IPVS_CMD_SET_INFO = 14 +IPVS_CMD_GET_INFO = 15 + +IPVS_CMD_ZERO = 16 +IPVS_CMD_FLUSH = 17 + + +class ipvsstats: + class stats(nla): + nla_map = ( + ("IPVS_STATS_ATTR_UNSPEC", "none"), + ("IPVS_STATS_ATTR_CONNS", "uint32"), + ("IPVS_STATS_ATTR_INPKTS", "uint32"), + ("IPVS_STATS_ATTR_OUTPKTS", "uint32"), + ("IPVS_STATS_ATTR_INBYTES", "uint64"), + ("IPVS_STATS_ATTR_OUTBYTES", "uint64"), + ("IPVS_STATS_ATTR_CPS", "uint32"), + ("IPVS_STATS_ATTR_INPPS", "uint32"), + ("IPVS_STATS_ATTR_OUTPPS", "uint32"), + ("IPVS_STATS_ATTR_INBPS", "uint32"), + ("IPVS_STATS_ATTR_OUTBPS", "uint32"), + ) + + class stats64(nla): + nla_map = ( + ("IPVS_STATS_ATTR_UNSPEC", "none"), + ("IPVS_STATS_ATTR_CONNS", "uint64"), + ("IPVS_STATS_ATTR_INPKTS", "uint64"), + ("IPVS_STATS_ATTR_OUTPKTS", "uint64"), + ("IPVS_STATS_ATTR_INBYTES", "uint64"), + ("IPVS_STATS_ATTR_OUTBYTES", "uint64"), + ("IPVS_STATS_ATTR_CPS", "uint64"), + ("IPVS_STATS_ATTR_INPPS", "uint64"), + ("IPVS_STATS_ATTR_OUTPPS", "uint64"), + ("IPVS_STATS_ATTR_INBPS", "uint64"), + ("IPVS_STATS_ATTR_OUTBPS", "uint64"), + ) + + +class ipvsmsg(genlmsg): + prefix = "IPVS_CMD_ATTR_" + nla_map = ( + ("IPVS_CMD_ATTR_UNSPEC", "none"), + ("IPVS_CMD_ATTR_SERVICE", "service"), + ("IPVS_CMD_ATTR_DEST", "dest"), + ("IPVS_CMD_ATTR_DAEMON", "hex"), + ("IPVS_CMD_ATTR_TIMEOUT_TCP", "hex"), + ("IPVS_CMD_ATTR_TIMEOUT_TCP_FIN", "hex"), + ("IPVS_CMD_ATTR_TIMEOUT_UDP", "hex"), + ) + + class service(nla, ipvsstats): + prefix = "IPVS_SVC_ATTR_" + nla_map = ( + ("IPVS_SVC_ATTR_UNSPEC", "none"), + ("IPVS_SVC_ATTR_AF", "uint16"), + ("IPVS_SVC_ATTR_PROTOCOL", "uint16"), + ("IPVS_SVC_ATTR_ADDR", "target(nla,IPVS_SVC_ATTR_AF)"), + ("IPVS_SVC_ATTR_PORT", "be16"), + ("IPVS_SVC_ATTR_FWMARK", "uint32"), + ("IPVS_SVC_ATTR_SCHED_NAME", "asciiz"), + ("IPVS_SVC_ATTR_FLAGS", "flags"), + ("IPVS_SVC_ATTR_TIMEOUT", "uint32"), + ("IPVS_SVC_ATTR_NETMASK", "ip4addr"), + ("IPVS_SVC_ATTR_STATS", "stats"), + ("IPVS_SVC_ATTR_PE_NAME", "asciiz"), + ("IPVS_SVC_ATTR_STATS64", "stats64"), + ) + + class flags(nla): + fields = (("flags", "I"), ("mask", "I")) + + class dest(nla, ipvsstats): + prefix = "IPVS_DEST_ATTR_" + nla_map = ( + ("IPVS_DEST_ATTR_UNSPEC", "none"), + ("IPVS_DEST_ATTR_ADDR", "target(nla,IPVS_DEST_ATTR_ADDR_FAMILY)"), + ("IPVS_DEST_ATTR_PORT", "be16"), + ("IPVS_DEST_ATTR_FWD_METHOD", "uint32"), + ("IPVS_DEST_ATTR_WEIGHT", "uint32"), + ("IPVS_DEST_ATTR_U_THRESH", "uint32"), + ("IPVS_DEST_ATTR_L_THRESH", "uint32"), + ("IPVS_DEST_ATTR_ACTIVE_CONNS", "uint32"), + ("IPVS_DEST_ATTR_INACT_CONNS", "uint32"), + ("IPVS_DEST_ATTR_PERSIST_CONNS", "uint32"), + ("IPVS_DEST_ATTR_STATS", "stats"), + ("IPVS_DEST_ATTR_ADDR_FAMILY", "uint16"), + ("IPVS_DEST_ATTR_STATS64", "stats64"), + ("IPVS_DEST_ATTR_TUN_TYPE", "uint8"), + ("IPVS_DEST_ATTR_TUN_PORT", "uint16"), + ("IPVS_DEST_ATTR_TUN_FLAGS", "uint16"), + ) + + +class IPVSSocket(GenericNetlinkSocket): + def __init__(self, *argv, **kwargs): + super().__init__(*argv, **kwargs) + self.bind(GENL_NAME, ipvsmsg) diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index 837a17bfa..0e4f10cbe 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -103,7 +103,7 @@ ) from pyroute2 import config -from pyroute2.common import DEFAULT_RCVBUF, AddrPool +from pyroute2.common import DEFAULT_RCVBUF, AddrPool, basestring from pyroute2.config import AF_NETLINK from pyroute2.netlink import ( NETLINK_ADD_MEMBERSHIP, @@ -114,9 +114,14 @@ NETLINK_LISTEN_ALL_NSID, NLM_F_ACK, NLM_F_ACK_TLVS, + NLM_F_APPEND, + NLM_F_CREATE, NLM_F_DUMP, NLM_F_DUMP_INTR, + NLM_F_ECHO, + NLM_F_EXCL, NLM_F_MULTI, + NLM_F_REPLACE, NLM_F_REQUEST, NLMSG_DONE, NLMSG_ERROR, @@ -894,6 +899,35 @@ def post_init(self): def clone(self): return type(self)(**self.config) + def make_request_type(self, command, command_map): + if isinstance(command, basestring): + return (lambda x: (x[0], self.make_request_flags(x[1])))( + command_map[command] + ) + elif isinstance(command, int): + return command, self.make_request_flags('create') + elif isinstance(command, (list, tuple)): + return command + else: + raise TypeError('allowed command types: int, str, list, tuple') + + def make_request_flags(self, mode): + flags = { + 'dump': NLM_F_REQUEST | NLM_F_DUMP, + 'get': NLM_F_REQUEST | NLM_F_ACK, + 'req': NLM_F_REQUEST | NLM_F_ACK, + } + flags['create'] = flags['req'] | NLM_F_CREATE | NLM_F_EXCL + flags['append'] = flags['req'] | NLM_F_CREATE | NLM_F_APPEND + flags['change'] = flags['req'] | NLM_F_REPLACE + flags['replace'] = flags['change'] | NLM_F_CREATE + + return flags[mode] | ( + NLM_F_ECHO + if (self.config['nlm_echo'] and mode not in ('get', 'dump')) + else 0 + ) + def put( self, msg, diff --git a/pyroute2/requests/main.py b/pyroute2/requests/main.py index f1ea1512b..1962d3cf9 100644 --- a/pyroute2/requests/main.py +++ b/pyroute2/requests/main.py @@ -7,8 +7,12 @@ class RequestProcessor(dict): + field_filter = None + context = None + def __init__(self, field_filter=None, context=None, prime=None): - self.field_filter = field_filter + if field_filter is not None: + self.field_filter = field_filter self.context = ( context if isinstance(context, (dict, weakref.ProxyType)) else {} ) diff --git a/tests/test_linux/test_ipvs.py b/tests/test_linux/test_ipvs.py new file mode 100644 index 000000000..e37d62b53 --- /dev/null +++ b/tests/test_linux/test_ipvs.py @@ -0,0 +1,57 @@ +from socket import IPPROTO_TCP + +import pytest + +from pyroute2 import IPVS, IPVSService + + +class Context: + def __init__(self, request, tmpdir): + self.ipvs = IPVS() + self.services = [] + + def new_service(self, addr, port, protocol): + service = IPVSService(addr=addr, port=port, protocol=protocol) + self.ipvs.service("add", service=service) + self.services.append(service) + return service + + def teardown(self): + for service in self.services: + self.ipvs.service("del", service=service) + self.services = [] + + def service(self, command, service=None): + return self.ipvs.service(command, service) + + def dest(self, command, service, dest=None): + return self.ipvs.dest(command, service, dest) + + +@pytest.fixture +def ipvsadm(request, tmpdir): + ctx = Context(request, tmpdir) + yield ctx + ctx.teardown() + + +def test_basic(ipvsadm, context): + ipaddr = context.new_ipaddr + ( + context.ndb.interfaces[context.default_interface.ifname] + .add_ip(f"{ipaddr}/24") + .commit() + ) + ipvsadm.new_service(addr=ipaddr, port=6000, protocol=IPPROTO_TCP) + buffer = [] + for service in ipvsadm.service("dump"): + if ( + service.get(('service', 'addr')) == ipaddr + and service.get(('service', 'port')) == 6000 + and service.get(('service', 'protocol')) == IPPROTO_TCP + ): + break + buffer.append(service) + else: + raise KeyError('service not found') + print(buffer)