From 02f993330e002927b576b7b00ea8c6c5edc2aae8 Mon Sep 17 00:00:00 2001 From: Marcin Jaworski Date: Sat, 10 Mar 2018 12:33:12 +0100 Subject: [PATCH] EnumType type for Click Code provided by skycaptain in https://github.com/pallets/click/issues/605 --- miio/click_common.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/miio/click_common.py b/miio/click_common.py index 5685d8531..1de4c0f52 100644 --- a/miio/click_common.py +++ b/miio/click_common.py @@ -12,6 +12,7 @@ import miio import logging import json +import re from typing import Union from functools import wraps from functools import partial @@ -55,6 +56,43 @@ def __call__(self, *args, **kwargs): click.echo(click.style("Error: %s" % ex, fg='red', bold=True)) +class EnumType(click.Choice): + def __init__(self, enumcls, casesensitive=True): + choices = enumcls.__members__ + + if not casesensitive: + choices = (_.lower() for _ in choices) + + self._enumcls = enumcls + self._casesensitive = casesensitive + + super().__init__(list(sorted(set(choices)))) + + def convert(self, value, param, ctx): + if not self._casesensitive: + value = value.lower() + + value = super().convert(value, param, ctx) + + if not self._casesensitive: + return next(_ for _ in self._enumcls if _.name.lower() == value.lower()) + else: + return next(_ for _ in self._enumcls if _.name == value) + + def get_metavar(self, param): + word = self._enumcls.__name__ + + # Stolen from jpvanhal/inflection + word = re.sub(r"([A-Z]+)([A-Z][a-z])", r'\1_\2', word) + word = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', word) + word = word.replace("-", "_").lower().split("_") + + if word[-1] == "enum": + word.pop() + + return ("_".join(word)).upper() + + class GlobalContextObject: def __init__(self, debug: int=0, output: callable=None): self.debug = debug @@ -221,6 +259,7 @@ def wrap(*args, **kwargs): def json_output(pretty=False): indent = 2 if pretty else None + def decorator(func): @wraps(func) def wrap(*args, **kwargs):