diff --git a/folium/features.py b/folium/features.py index 4d01492f5..dbe970f02 100644 --- a/folium/features.py +++ b/folium/features.py @@ -9,13 +9,14 @@ import warnings import functools import operator +from typing import Any, Dict, Callable, Sequence, Optional, Union, List, Tuple, Iterable from branca.colormap import LinearColormap, StepColormap from branca.element import (Element, Figure, JavascriptLink, MacroElement) from branca.utilities import color_brewer from folium.folium import Map -from folium.map import (FeatureGroup, Icon, Layer, Marker, Tooltip) +from folium.map import (FeatureGroup, Icon, Layer, Marker, Tooltip, Popup) from folium.utilities import ( validate_locations, _parse_size, @@ -25,7 +26,7 @@ none_min, get_obj_in_upper_tree, parse_options, - camelize + camelize, TypeJsonValue, TypePathOptions, ) from folium.vector_layers import PolyLine, path_options @@ -69,21 +70,28 @@ class RegularPolygonMarker(Marker): {% endmacro %} """) - def __init__(self, location, number_of_sides=4, rotation=0, radius=15, - popup=None, tooltip=None, **kwargs): + def __init__( + self, + location: Sequence[float], + number_of_sides: int = 4, + rotation: int = 0, + radius: int = 15, + popup: Optional[Union[str, Popup]] = None, + tooltip: Optional[Union[str, Tooltip]] = None, + **kwargs: TypePathOptions + ): super(RegularPolygonMarker, self).__init__( location, popup=popup, tooltip=tooltip ) self._name = 'RegularPolygonMarker' - self.options = path_options(**kwargs) + self.options = path_options(line=False, radius=radius, **kwargs) self.options.update(parse_options( number_of_sides=number_of_sides, rotation=rotation, - radius=radius, )) - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" super(RegularPolygonMarker, self).render() @@ -128,8 +136,15 @@ class Vega(Element): """ _template = Template(u'') - def __init__(self, data, width=None, height=None, - left='0%', top='0%', position='relative'): + def __init__( + self, + data: Any, + width: Optional[Union[int, str]] = None, + height: Optional[Union[int, str]] = None, + left: Union[int, str] = '0%', + top: Union[int, str] = '0%', + position: str = 'relative', + ): super(Vega, self).__init__() self._name = 'Vega' self.data = data.to_json() if hasattr(data, 'to_json') else data @@ -145,7 +160,7 @@ def __init__(self, data, width=None, height=None, self.top = _parse_size(top) self.position = position - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" self.json = json.dumps(self.data) @@ -220,8 +235,15 @@ class VegaLite(Element): """ _template = Template(u'') - def __init__(self, data, width=None, height=None, - left='0%', top='0%', position='relative'): + def __init__( + self, + data: Any, + width: Optional[Union[int, str]] = None, + height: Optional[Union[int, str]] = None, + left: Union[int, str] = '0%', + top: Union[int, str] = '0%', + position: str = 'relative', + ): super(self.__class__, self).__init__() self._name = 'VegaLite' self.data = data.to_json() if hasattr(data, 'to_json') else data @@ -239,7 +261,7 @@ def __init__(self, data, width=None, height=None, self.top = _parse_size(top) self.position = position - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" vegalite_major_version = self._get_vegalite_major_versions(self.data) @@ -271,7 +293,7 @@ def render(self, **kwargs): # Version 2 is assumed as the default, if no version is given in the schema. self._embed_vegalite_v2(figure) - def _get_vegalite_major_versions(self, spec): + def _get_vegalite_major_versions(self, spec: dict) -> Optional[str]: try: schema = spec['$schema'] except KeyError: @@ -281,28 +303,28 @@ def _get_vegalite_major_versions(self, spec): return major_version - def _embed_vegalite_v3(self, figure): + def _embed_vegalite_v3(self, figure: Figure) -> None: self._vega_embed() figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega@4'), name='vega') figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega-lite@3'), name='vega-lite') figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega-embed@3'), name='vega-embed') - def _embed_vegalite_v2(self, figure): + def _embed_vegalite_v2(self, figure: Figure) -> None: self._vega_embed() figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega@3'), name='vega') figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega-lite@2'), name='vega-lite') figure.header.add_child(JavascriptLink('https://cdn.jsdelivr.net/npm/vega-embed@3'), name='vega-embed') - def _vega_embed(self): + def _vega_embed(self) -> None: self._parent.script.add_child(Element(Template(""" vegaEmbed({{this.get_name()}}, {{this.json}}) .then(function(result) {}) .catch(console.error); """).render(this=self)), name=self.get_name()) - def _embed_vegalite_v1(self, figure): + def _embed_vegalite_v1(self, figure: Figure) -> None: self._parent.script.add_child(Element(Template(""" var embedSpec = { mode: "vega-lite", @@ -354,6 +376,8 @@ class GeoJson(Layer): embed: bool, default True Whether to embed the data in the html file or not. Note that disabling embedding is only supported if you provide a file link or URL. + popup: GeoJsonPopup, optional + Add popups to each feature based on the features content. Examples -------- @@ -436,15 +460,25 @@ class GeoJson(Layer): {% endmacro %} """) # noqa - def __init__(self, data, style_function=None, highlight_function=None, # noqa - name=None, overlay=True, control=True, show=True, - smooth_factor=None, tooltip=None, embed=True, popup=None): + def __init__( + self, + data: Any, + style_function: Optional[Callable] = None, + highlight_function: Optional[Callable] = None, + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + smooth_factor: Optional[float] = None, + tooltip: Optional[Union[str, Tooltip, 'GeoJsonTooltip']] = None, + embed: bool = True, + popup: Optional['GeoJsonPopup'] = None, + ): super(GeoJson, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'GeoJson' self.embed = embed - self.embed_link = None - self.json = None + self.embed_link: Optional[str] = None self.parent_map = None self.smooth_factor = smooth_factor self.style = style_function is not None @@ -454,14 +488,14 @@ def __init__(self, data, style_function=None, highlight_function=None, # noqa if self.style or self.highlight: self.convert_to_feature_collection() - if self.style: - self._validate_function(style_function, 'style_function') - self.style_function = style_function - self.style_map = {} - if self.highlight: - self._validate_function(highlight_function, 'highlight_function') - self.highlight_function = highlight_function - self.highlight_map = {} + if style_function is not None: + self.style_function: Callable = style_function + self._validate_function(self.style_function, 'style_function') + self.style_map: dict = {} + if highlight_function is not None: + self.highlight_function: Callable = highlight_function + self._validate_function(self.highlight_function, 'highlight_function') + self.highlight_map: dict = {} self.feature_identifier = self.find_identifier() if isinstance(tooltip, (GeoJsonTooltip, Tooltip)): @@ -471,7 +505,7 @@ def __init__(self, data, style_function=None, highlight_function=None, # noqa if isinstance(popup, (GeoJsonPopup)): self.add_child(popup) - def process_data(self, data): + def process_data(self, data: Any) -> dict: """Convert an unknown data input into a geojson dictionary.""" if isinstance(data, dict): self.embed = True @@ -498,7 +532,7 @@ def process_data(self, data): raise ValueError('Cannot render objects with any missing geometries' ': {!r}'.format(data)) - def convert_to_feature_collection(self): + def convert_to_feature_collection(self) -> None: """Convert data into a FeatureCollection if it is not already.""" if self.data['type'] == 'FeatureCollection': return @@ -514,7 +548,7 @@ def convert_to_feature_collection(self): self.data = {'type': 'Feature', 'geometry': self.data} self.data = {'type': 'FeatureCollection', 'features': [self.data]} - def _validate_function(self, func, name): + def _validate_function(self, func: Callable, name: str) -> None: """ Tests `self.style_function` and `self.highlight_function` to ensure they are functions returning dictionaries. @@ -525,7 +559,7 @@ def _validate_function(self, func, name): 'data[\'features\'] and returns a dictionary.' .format(name)) - def find_identifier(self): + def find_identifier(self) -> str: """Find a unique identifier for each feature, create it if needed. According to the GeoJSON specs a feature: @@ -533,6 +567,9 @@ def find_identifier(self): - MUST have a 'properties' field. The content can be any json object or even null. + In this implementation the returned field name will point + to a str or int value only. + """ feats = self.data['features'] # Each feature has an 'id' field with a unique value. @@ -559,7 +596,7 @@ def find_identifier(self): 'field to your geojson data or set `embed=True`. ' ) - def _get_self_bounds(self): + def _get_self_bounds(self) -> List[List[Optional[float]]]: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]]. @@ -567,16 +604,14 @@ def _get_self_bounds(self): """ return get_bounds(self.data, lonlat=True) - def render(self, **kwargs): + def render(self, **kwargs) -> None: self.parent_map = get_obj_in_upper_tree(self, Map) if self.style or self.highlight: - mapper = GeoJsonStyleMapper(self.data, self.feature_identifier, - self) + mapper = GeoJsonStyleMapper(self.data, self.feature_identifier, self) if self.style: self.style_map = mapper.get_style_map(self.style_function) if self.highlight: - self.highlight_map = mapper.get_highlight_map( - self.highlight_function) + self.highlight_map = mapper.get_highlight_map(self.highlight_function) super(GeoJson, self).render() @@ -585,23 +620,39 @@ class GeoJsonStyleMapper: Used in the GeoJson class. Users don't have to call this class directly. """ - - def __init__(self, data, feature_identifier, geojson_obj): + TypeStyleMapping = Dict[str, Union[List[Union[str, int]], str]] + + def __init__( + self, + data: dict, + feature_identifier: str, + geojson_obj: GeoJson, + ): self.data = data self.feature_identifier = feature_identifier self.geojson_obj = geojson_obj - def get_style_map(self, style_function): + def get_style_map( + self, + style_function: Callable, + ) -> TypeStyleMapping: """Return a dict that maps style parameters to features.""" return self._create_mapping(style_function, 'style') - def get_highlight_map(self, highlight_function): + def get_highlight_map( + self, + highlight_function: Callable, + ) -> TypeStyleMapping: """Return a dict that maps highlight parameters to features.""" return self._create_mapping(highlight_function, 'highlight') - def _create_mapping(self, func, switch): + def _create_mapping( + self, + func: Callable, + switch: str, + ) -> TypeStyleMapping: """Internal function to create the mapping.""" - mapping = {} + mapping: GeoJsonStyleMapper.TypeStyleMapping = {} for feature in self.data['features']: content = func(feature) if switch == 'style': @@ -614,26 +665,28 @@ def _create_mapping(self, func, switch): # Replace objects with their Javascript var names: content[key] = "{{'" + value.get_name() + "'}}" key = self._to_key(content) - mapping.setdefault(key, []).append(self.get_feature_id(feature)) + feature_id = self.get_feature_id(feature) + mapping.setdefault(key, []).append(feature_id) # type: ignore self._set_default_key(mapping) return mapping - def get_feature_id(self, feature): + def get_feature_id(self, feature: dict) -> Union[str, int]: """Return a value identifying the feature.""" fields = self.feature_identifier.split('.')[1:] - return functools.reduce(operator.getitem, fields, feature) + value = functools.reduce(operator.getitem, fields, feature) + assert isinstance(value, (str, int)) + return value @staticmethod - def _to_key(d): + def _to_key(d: dict) -> str: """Convert dict to str and enable Jinja2 template syntax.""" as_str = json.dumps(d, sort_keys=True) return as_str.replace('"{{', '{{').replace('}}"', '}}') @staticmethod - def _set_default_key(mapping): + def _set_default_key(mapping: TypeStyleMapping) -> None: """Replace the field with the most features with a 'default' field.""" - key_longest = sorted([(len(v), k) for k, v in mapping.items()], - reverse=True)[0][1] + key_longest = max(mapping, key=mapping.get) mapping['default'] = key_longest del (mapping[key_longest]) @@ -710,9 +763,18 @@ class TopoJson(Layer): {% endmacro %} """) # noqa - def __init__(self, data, object_path, style_function=None, - name=None, overlay=True, control=True, show=True, - smooth_factor=None, tooltip=None): + def __init__( + self, + data: Any, + object_path: str, + style_function: Optional[Callable] = None, + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + smooth_factor: Optional[float] = None, + tooltip: Optional[Union[str, Tooltip]] = None, + ): super(TopoJson, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'TopoJson' @@ -729,10 +791,7 @@ def __init__(self, data, object_path, style_function=None, self.object_path = object_path - if style_function is None: - def style_function(x): - return {} - self.style_function = style_function + self.style_function = style_function or (lambda x: {}) self.smooth_factor = smooth_factor @@ -741,7 +800,7 @@ def style_function(x): elif tooltip is not None: self.add_child(Tooltip(tooltip)) - def style_data(self): + def style_data(self) -> None: """Applies self.style_function to each feature of self.data.""" def recursive_get(data, keys): @@ -754,7 +813,7 @@ def recursive_get(data, keys): for feature in geometries: feature.setdefault('properties', {}).setdefault('style', {}).update(self.style_function(feature)) # noqa - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" self.style_data() super(TopoJson, self).render(**kwargs) @@ -767,7 +826,7 @@ def render(self, **kwargs): JavascriptLink('https://cdnjs.cloudflare.com/ajax/libs/topojson/1.6.9/topojson.min.js'), # noqa name='topojson') - def get_bounds(self): + def get_bounds(self) -> List[List[float]]: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]] @@ -829,8 +888,15 @@ class GeoJsonDetail(MacroElement): } """ - def __init__(self, fields, aliases=None, labels=True, localize=False, style=None, - class_name="geojsondetail"): + def __init__( + self, + fields: Sequence[str], + aliases: Optional[Sequence[str]] = None, + labels: bool = True, + localize: bool = False, + style: Optional[str] = None, + class_name: str = "geojsondetail", + ): super(GeoJsonDetail, self).__init__() assert isinstance(fields, (list, tuple)), 'Please pass a list or ' \ 'tuple to fields.' @@ -852,7 +918,7 @@ def __init__(self, fields, aliases=None, labels=True, localize=False, style=None # noqa outside of type checking. self.style = style - def warn_for_geometry_collections(self): + def warn_for_geometry_collections(self) -> None: """Checks for GeoJson GeometryCollection features to warn user about incompatibility.""" geom_collections = [ feature.get('properties') if feature.get('properties') is not None else key @@ -865,7 +931,7 @@ def warn_for_geometry_collections(self): "Please consider reworking these features: {} to MultiPolygon for full functionality.\n" "https://tools.ietf.org/html/rfc7946#page-9".format(self._name, geom_collections), UserWarning) - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" figure = self.get_root() if isinstance(self._parent, GeoJson): @@ -953,8 +1019,17 @@ class GeoJsonTooltip(GeoJsonDetail): {% endmacro %} """) - def __init__(self, fields, aliases=None, labels=True, localize=False, - style=None, class_name='foliumtooltip', sticky=True, **kwargs): + def __init__( + self, + fields: Sequence[str], + aliases: Optional[Sequence[str]] = None, + labels: bool = True, + localize: bool = False, + style: Optional[str] = None, + class_name: str = 'foliumtooltip', + sticky: bool = True, + **kwargs: TypeJsonValue + ): super(GeoJsonTooltip, self).__init__( fields=fields, aliases=aliases, labels=labels, localize=localize, style=style, class_name=class_name @@ -1007,9 +1082,16 @@ class GeoJsonPopup(GeoJsonDetail): {% endmacro %} """) - def __init__(self, fields=None, aliases=None, labels=True, - style="margin: auto;", class_name='foliumpopup', localize=True, - **kwargs): + def __init__( + self, + fields: Sequence[str], + aliases: Optional[Sequence[str]] = None, + labels: bool = True, + style: str = "margin: auto;", + class_name: str = 'foliumpopup', + localize: bool = True, + **kwargs: TypeJsonValue + ): super(GeoJsonPopup, self).__init__( fields=fields, aliases=aliases, labels=labels, localize=localize, class_name=class_name, style=style) @@ -1051,7 +1133,7 @@ class Choropleth(FeatureGroup): geometries data: Pandas DataFrame or Series, default None Data to bind to the GeoJSON. - columns: dict or tuple, default None + columns: tuple with two values, default None If the data is a Pandas DataFrame, the columns of data to be bound. Must pass column 1 as the key, and column 2 the values. key_on: string, default None @@ -1125,13 +1207,30 @@ class Choropleth(FeatureGroup): ... highlight=True) """ - def __init__(self, geo_data, data=None, columns=None, key_on=None, # noqa - bins=6, fill_color='blue', nan_fill_color='black', - fill_opacity=0.6, nan_fill_opacity=None, line_color='black', - line_weight=1, line_opacity=1, name=None, legend_name='', - overlay=True, control=True, show=True, - topojson=None, smooth_factor=None, highlight=None, - **kwargs): + def __init__( + self, + geo_data: Any, + data: Optional[Any] = None, + columns: Optional[Sequence[Any]] = None, + key_on: Optional[str] = None, + bins: Union[int, Sequence[float]] = 6, + fill_color: str = 'blue', + nan_fill_color: str = 'black', + fill_opacity: float = 0.6, + nan_fill_opacity: Optional[float] = None, + line_color: str = 'black', + line_weight: float = 1, + line_opacity: float = 1, + name: Optional[str] = None, + legend_name: str = '', + overlay: bool = True, + control: bool = True, + show: bool = True, + topojson: Optional[str] = None, + smooth_factor: Optional[float] = None, + highlight: bool = False, + **kwargs + ): super(Choropleth, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'Choropleth' @@ -1153,10 +1252,11 @@ def __init__(self, geo_data, data=None, columns=None, key_on=None, # noqa # Create color_data dict if hasattr(data, 'set_index'): # This is a pd.DataFrame - color_data = data.set_index(columns[0])[columns[1]].to_dict() + assert columns is not None + color_data = data.set_index(columns[0])[columns[1]].to_dict() # type: ignore elif hasattr(data, 'to_dict'): # This is a pd.Series - color_data = data.to_dict() + color_data = data.to_dict() # type: ignore elif data: color_data = dict(data) else: @@ -1252,7 +1352,7 @@ def highlight_function(x): if self.color_scale: self.add_child(self.color_scale) - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Render the GeoJson/TopoJson and color scale objects.""" if self.color_scale: # ColorMap needs Map as its parent @@ -1298,8 +1398,14 @@ class DivIcon(MacroElement): {% endmacro %} """) # noqa - def __init__(self, html=None, icon_size=None, icon_anchor=None, - popup_anchor=None, class_name='empty'): + def __init__( + self, + html: Optional[str] = None, + icon_size: Optional[Tuple[int, int]] = None, + icon_anchor: Optional[Tuple[int, int]] = None, + popup_anchor: Optional[Tuple[int, int]] = None, + class_name: str = 'empty', + ): super(DivIcon, self).__init__() self._name = 'DivIcon' self.options = parse_options( @@ -1362,7 +1468,7 @@ class ClickForMarker(MacroElement): {% endmacro %} """) # noqa - def __init__(self, popup=None): + def __init__(self, popup: Optional[str] = None): super(ClickForMarker, self).__init__() self._name = 'ClickForMarker' @@ -1412,9 +1518,16 @@ class CustomIcon(Icon): {% endmacro %} """) # noqa - def __init__(self, icon_image, icon_size=None, icon_anchor=None, - shadow_image=None, shadow_size=None, shadow_anchor=None, - popup_anchor=None): + def __init__( + self, + icon_image: Any, + icon_size: Optional[Tuple[int, int]] = None, + icon_anchor: Optional[Tuple[int, int]] = None, + shadow_image: Any = None, + shadow_size: Optional[Tuple[int, int]] = None, + shadow_anchor: Optional[Tuple[int, int]] = None, + popup_anchor: Optional[Tuple[int, int]] = None, + ): super(Icon, self).__init__() self._name = 'CustomIcon' self.options = parse_options( @@ -1434,17 +1547,16 @@ class ColorLine(FeatureGroup): Parameters ---------- - positions: tuple or list - The list of points latitude and longitude - colors: tuple or list - The list of segments colors. + positions: iterable of (lat, lon) pairs + The points on the line. Segments between points will be colored. + colors: iterable of float + Values that determine the color of a line segment. It must have length equal to `len(positions)-1`. colormap: branca.colormap.Colormap or list or tuple The colormap to use. If a list or tuple of colors is provided, a LinearColormap will be created from it. nb_steps: int, default 12 - To have lighter output the colormap will be discretized - to that number of colors. + The colormap will be discretized to this number of colors. opacity: float, default 1 Line opacity, scale 0-1 weight: int, default 2 @@ -1458,11 +1570,19 @@ class ColorLine(FeatureGroup): """ - def __init__(self, positions, colors, colormap=None, nb_steps=12, - weight=None, opacity=None, **kwargs): + def __init__( + self, + positions: Iterable[Sequence[float]], + colors: Iterable[float], + colormap: Optional[Union[LinearColormap, Sequence[Any]]] = None, + nb_steps: int = 12, + weight: Optional[int] = None, + opacity: Optional[float] = None, + **kwargs: Any + ): super(ColorLine, self).__init__(**kwargs) self._name = 'ColorLine' - positions = validate_locations(positions) + coords = validate_locations(positions) if colormap is None: cm = LinearColormap(['green', 'yellow', 'red'], @@ -1478,8 +1598,8 @@ def __init__(self, positions, colors, colormap=None, nb_steps=12, ).to_step(nb_steps) else: cm = colormap - out = {} - for (lat1, lng1), (lat2, lng2), color in zip(positions[:-1], positions[1:], colors): # noqa + out: Dict[str, List[List[List[float]]]] = {} + for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors): # noqa out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]]) for key, val in out.items(): self.add_child(PolyLine(val, color=key, weight=weight, opacity=opacity)) # noqa diff --git a/folium/folium.py b/folium/folium.py index b770f6f55..9968dd705 100644 --- a/folium/folium.py +++ b/folium/folium.py @@ -7,16 +7,17 @@ import time import warnings +from typing import Any, Optional, Sequence, Union, List from branca.element import CssLink, Element, Figure, JavascriptLink, MacroElement -from folium.map import FitBounds +from folium.map import FitBounds, Layer from folium.raster_layers import TileLayer from folium.utilities import ( _parse_size, _tmp_html, validate_location, - parse_options, + parse_options, TypeJsonValue, ) from jinja2 import Environment, PackageLoader, Template @@ -60,7 +61,7 @@ class GlobalSwitches(Element): """) - def __init__(self, no_touch=False, disable_3d=False): + def __init__(self, no_touch: bool = False, disable_3d: bool = False): super(GlobalSwitches, self).__init__() self._name = 'GlobalSwitches' self.no_touch = no_touch @@ -209,41 +210,41 @@ class Map(MacroElement): def __init__( self, - location=None, - width='100%', - height='100%', - left='0%', - top='0%', - position='relative', - tiles='OpenStreetMap', - attr=None, - min_zoom=0, - max_zoom=18, - zoom_start=10, - min_lat=-90, - max_lat=90, - min_lon=-180, - max_lon=180, - max_bounds=False, - crs='EPSG3857', - control_scale=False, - prefer_canvas=False, - no_touch=False, - disable_3d=False, - png_enabled=False, - zoom_control=True, - **kwargs + location: Optional[Sequence[float]] = None, + width: Union[str, int] = '100%', + height: Union[str, int] = '100%', + left: Union[str, int] = '0%', + top: Union[str, int] = '0%', + position: str = 'relative', + tiles: str = 'OpenStreetMap', + attr: Optional[str] = None, + min_zoom: int = 0, + max_zoom: int = 18, + zoom_start: int = 10, + min_lat: int = -90, + max_lat: int = 90, + min_lon: int = -180, + max_lon: int = 180, + max_bounds: bool = False, + crs: str = 'EPSG3857', + control_scale: bool = False, + prefer_canvas: bool = False, + no_touch: bool = False, + disable_3d: bool = False, + png_enabled: bool = False, + zoom_control: bool = True, + **kwargs: TypeJsonValue ): super(Map, self).__init__() self._name = 'Map' self._env = ENV # Undocumented for now b/c this will be subject to a re-factor soon. - self._png_image = None + self._png_image = '' self.png_enabled = png_enabled if location is None: # If location is not passed we center and zoom out. - self.location = [0, 0] + self.location = [0.0, 0.0] zoom_start = 1 else: self.location = validate_location(location) @@ -276,24 +277,25 @@ def __init__( disable_3d ) - self.objects_to_stay_in_front = [] + self.objects_to_stay_in_front: List[Layer] = [] if tiles: tile_layer = TileLayer(tiles=tiles, attr=attr, min_zoom=min_zoom, max_zoom=max_zoom) self.add_child(tile_layer, name=tile_layer.tile_name) - def _repr_html_(self, **kwargs): + def _repr_html_(self, **kwargs) -> str: """Displays the HTML Map in a Jupyter notebook.""" if self._parent is None: self.add_to(Figure()) + self._parent: Figure out = self._parent._repr_html_(**kwargs) self._parent = None else: out = self._parent._repr_html_(**kwargs) return out - def _to_png(self, delay=3): + def _to_png(self, delay: int = 3) -> str: """Export the HTML to byte representation of a PNG image. Uses selenium to render the HTML and record a PNG. You may need to @@ -323,7 +325,7 @@ def _to_png(self, delay=3): self._png_image = png return self._png_image - def _repr_png_(self): + def _repr_png_(self) -> Optional[str]: """Displays the PNG Map in a Jupyter notebook.""" # The notebook calls all _repr_*_ by default. # We don't want that here b/c this one is quite slow. @@ -331,7 +333,7 @@ def _repr_png_(self): return None return self._to_png() - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" figure = self.get_root() assert isinstance(figure, Figure), ('You cannot render this Element ' @@ -369,8 +371,14 @@ def render(self, **kwargs): super(Map, self).render(**kwargs) - def fit_bounds(self, bounds, padding_top_left=None, - padding_bottom_right=None, padding=None, max_zoom=None): + def fit_bounds( + self, + bounds: Sequence[Sequence[float]], + padding_top_left: Optional[Sequence[float]] = None, + padding_bottom_right: Optional[Sequence[float]] = None, + padding: Optional[Sequence[float]] = None, + max_zoom: Optional[int] = None, + ) -> None: """Fit the map to contain a bounding box with the maximum zoom level possible. @@ -403,7 +411,7 @@ def fit_bounds(self, bounds, padding_top_left=None, ) ) - def choropleth(self, *args, **kwargs): + def choropleth(self, *args, **kwargs) -> None: """Call the Choropleth class with the same arguments. This method may be deleted after a year from now (Nov 2018). @@ -417,7 +425,7 @@ def choropleth(self, *args, **kwargs): from folium.features import Choropleth self.add_child(Choropleth(*args, **kwargs)) - def keep_in_front(self, *args): + def keep_in_front(self, *args: Layer) -> None: """Pass one or multiple layers that must stay in front. The ordering matters, the last one is put on top. diff --git a/folium/map.py b/folium/map.py index 38abf71a1..e817a627a 100644 --- a/folium/map.py +++ b/folium/map.py @@ -6,12 +6,18 @@ """ from collections import OrderedDict +from typing import Dict, Sequence, Optional, Union, List, Tuple, Type import warnings from branca.element import Element, Figure, Html, MacroElement -from folium.utilities import validate_location, camelize, parse_options +from folium.utilities import ( + validate_location, + camelize, + parse_options, + TypeJsonValue, +) from jinja2 import Template @@ -32,7 +38,14 @@ class Layer(MacroElement): show: bool, default True Whether the layer will be shown on opening (only for overlays). """ - def __init__(self, name=None, overlay=False, control=True, show=True): + + def __init__( + self, + name: Optional[str] = None, + overlay: bool = False, + control: bool = True, + show: bool = True, + ): super(Layer, self).__init__() self.layer_name = name if name is not None else self.get_name() self.overlay = overlay @@ -72,8 +85,14 @@ class FeatureGroup(Layer): {% endmacro %} """) - def __init__(self, name=None, overlay=True, control=True, show=True, - **kwargs): + def __init__( + self, + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + **kwargs: TypeJsonValue + ): super(FeatureGroup, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'FeatureGroup' @@ -131,8 +150,13 @@ class LayerControl(MacroElement): {% endmacro %} """) - def __init__(self, position='topright', collapsed=True, autoZIndex=True, - **kwargs): + def __init__( + self, + position='topright', + collapsed=True, + autoZIndex=True, + **kwargs: TypeJsonValue + ): super(LayerControl, self).__init__() self._name = 'LayerControl' self.options = parse_options( @@ -141,16 +165,16 @@ def __init__(self, position='topright', collapsed=True, autoZIndex=True, autoZIndex=autoZIndex, **kwargs ) - self.base_layers = OrderedDict() - self.overlays = OrderedDict() - self.layers_untoggle = OrderedDict() + self.base_layers: OrderedDict[str, str] = OrderedDict() + self.overlays: OrderedDict[str, str] = OrderedDict() + self.layers_untoggle: OrderedDict[str, str] = OrderedDict() - def reset(self): + def reset(self) -> None: self.base_layers = OrderedDict() self.overlays = OrderedDict() self.layers_untoggle = OrderedDict() - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" for item in self._parent._children.values(): if not isinstance(item, Layer) or not item.control: @@ -207,14 +231,21 @@ class Icon(MacroElement): {{ this._parent.get_name() }}.setIcon({{ this.get_name() }}); {% endmacro %} """) - color_options = {'red', 'darkred', 'lightred', 'orange', 'beige', + color_options = {'red', 'darkred', 'lightred', 'orange', 'beige', 'green', 'darkgreen', 'lightgreen', 'blue', 'darkblue', 'cadetblue', 'lightblue', - 'purple', 'darkpurple', 'pink', + 'purple', 'darkpurple', 'pink', 'white', 'gray', 'lightgray', 'black'} - def __init__(self, color='blue', icon_color='white', icon='info-sign', - angle=0, prefix='glyphicon', **kwargs): + def __init__( + self, + color: str = 'blue', + icon_color: str = 'white', + icon: str = 'info-sign', + angle: int = 0, + prefix: str = 'glyphicon', + **kwargs: TypeJsonValue + ): super(Icon, self).__init__() self._name = 'Icon' if color not in self.color_options: @@ -270,8 +301,15 @@ class Marker(MacroElement): {% endmacro %} """) - def __init__(self, location, popup=None, tooltip=None, icon=None, - draggable=False, **kwargs): + def __init__( + self, + location: Sequence[float], + popup: Optional[Union[str, 'Popup']] = None, + tooltip: Optional[Union[str, 'Tooltip']] = None, + icon: Optional[Icon] = None, + draggable: bool = False, + **kwargs: TypeJsonValue + ): super(Marker, self).__init__() self._name = 'Marker' self.location = validate_location(location) @@ -289,7 +327,7 @@ def __init__(self, location, popup=None, tooltip=None, icon=None, self.add_child(tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))) - def _get_self_bounds(self): + def _get_self_bounds(self) -> List[List[float]]: """Computes the bounds of the object itself. Because a marker has only single coordinates, we repeat them. @@ -329,8 +367,15 @@ class Popup(Element): {% endfor %} """) # noqa - def __init__(self, html=None, parse_html=False, max_width='100%', - show=False, sticky=False, **kwargs): + def __init__( + self, + html: Optional[Union[str, Element]] = None, + parse_html: bool = False, + max_width: Union[str, int] = '100%', + show: bool = False, + sticky: bool = False, + **kwargs: TypeJsonValue + ): super(Popup, self).__init__() self._name = 'Popup' self.header = Element() @@ -356,7 +401,7 @@ def __init__(self, html=None, parse_html=False, max_width='100%', **kwargs ) - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Renders the HTML representation of the element.""" for name, child in self._children.items(): child.render(**kwargs) @@ -399,19 +444,25 @@ class Tooltip(MacroElement): ); {% endmacro %} """) - valid_options = { - 'pane': (str, ), - 'offset': (tuple, ), - 'direction': (str, ), - 'permanent': (bool, ), - 'sticky': (bool, ), - 'interactive': (bool, ), + valid_options: Dict[str, Tuple[Type, ...]] = { + 'pane': (str,), + 'offset': (tuple,), + 'direction': (str,), + 'permanent': (bool,), + 'sticky': (bool,), + 'interactive': (bool,), 'opacity': (float, int), - 'attribution': (str, ), - 'className': (str, ), + 'attribution': (str,), + 'className': (str,), } - def __init__(self, text, style=None, sticky=True, **kwargs): + def __init__( + self, + text: str, + style: Optional[str] = None, + sticky: bool = True, + **kwargs: TypeJsonValue + ): super(Tooltip, self).__init__() self._name = 'Tooltip' @@ -426,17 +477,20 @@ def __init__(self, text, style=None, sticky=True, **kwargs): # noqa outside of type checking. self.style = style - def parse_options(self, kwargs): + def parse_options( + self, + kwargs: Dict[str, TypeJsonValue], + ) -> Dict[str, TypeJsonValue]: """Validate the provided kwargs and return options as json string.""" kwargs = {camelize(key): value for key, value in kwargs.items()} for key in kwargs.keys(): assert key in self.valid_options, ( 'The option {} is not in the available options: {}.' - .format(key, ', '.join(self.valid_options)) + .format(key, ', '.join(self.valid_options)) ) assert isinstance(kwargs[key], self.valid_options[key]), ( 'The option {} must be one of the following types: {}.' - .format(key, self.valid_options[key]) + .format(key, self.valid_options[key]) ) return kwargs @@ -470,8 +524,14 @@ class FitBounds(MacroElement): {% endmacro %} """) - def __init__(self, bounds, padding_top_left=None, - padding_bottom_right=None, padding=None, max_zoom=None): + def __init__( + self, + bounds: Sequence[Sequence[float]], + padding_top_left: Optional[Sequence[float]] = None, + padding_bottom_right: Optional[Sequence[float]] = None, + padding: Optional[Sequence[float]] = None, + max_zoom: Optional[int] = None, + ): super(FitBounds, self).__init__() self._name = 'FitBounds' self.bounds = bounds @@ -517,7 +577,12 @@ class CustomPane(MacroElement): {% endmacro %} """) - def __init__(self, name, z_index=625, pointer_events=False): + def __init__( + self, + name: str, + z_index: Union[int, str] = 625, + pointer_events: bool = False, + ): super(CustomPane, self).__init__() self._name = 'Pane' self.name = name diff --git a/folium/plugins/beautify_icon.py b/folium/plugins/beautify_icon.py index a0c45a81e..7e102ee53 100644 --- a/folium/plugins/beautify_icon.py +++ b/folium/plugins/beautify_icon.py @@ -2,12 +2,13 @@ from branca.element import CssLink, Figure, JavascriptLink, MacroElement +from folium.map import Icon from folium.utilities import parse_options from jinja2 import Template -class BeautifyIcon(MacroElement): +class BeautifyIcon(Icon): """ Create a BeautifyIcon that can be added to a Marker diff --git a/folium/raster_layers.py b/folium/raster_layers.py index 5a00e3f77..8a789ba20 100644 --- a/folium/raster_layers.py +++ b/folium/raster_layers.py @@ -4,11 +4,17 @@ Wraps leaflet TileLayer, WmsTileLayer (TileLayer.WMS), ImageOverlay, and VideoOverlay """ +from typing import Optional, Any, Sequence, Callable from branca.element import Element, Figure from folium.map import Layer -from folium.utilities import image_to_url, mercator_transform, parse_options +from folium.utilities import ( + image_to_url, + mercator_transform, + parse_options, + TypeJsonValue, +) from jinja2 import Environment, PackageLoader, Template @@ -77,11 +83,25 @@ class TileLayer(Layer): {% endmacro %} """) - def __init__(self, tiles='OpenStreetMap', min_zoom=0, max_zoom=18, - max_native_zoom=None, attr=None, API_key=None, - detect_retina=False, name=None, overlay=False, - control=True, show=True, no_wrap=False, subdomains='abc', - tms=False, opacity=1, **kwargs): + def __init__( + self, + tiles: str = 'OpenStreetMap', + min_zoom: int = 0, + max_zoom: int = 18, + max_native_zoom: Optional[int] = None, + attr: Optional[str] = None, + API_key: Optional[str] = None, + detect_retina: bool = False, + name: Optional[str] = None, + overlay: bool = False, + control: bool = True, + show: bool = True, + no_wrap: bool = False, + subdomains: str = 'abc', + tms: bool = False, + opacity: float = 1, + **kwargs + ): self.tile_name = (name if name is not None else ''.join(tiles.lower().strip().split())) @@ -166,9 +186,21 @@ class WmsTileLayer(Layer): {% endmacro %} """) # noqa - def __init__(self, url, layers, styles='', fmt='image/jpeg', - transparent=False, version='1.1.1', attr='', - name=None, overlay=True, control=True, show=True, **kwargs): + def __init__( + self, + url: str, + layers: str, + styles: str = '', + fmt: str = 'image/jpeg', + transparent: bool = False, + version: str = '1.1.1', + attr: str = '', + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + **kwargs + ): super(WmsTileLayer, self).__init__(name=name, overlay=overlay, control=control, show=show) self.url = url @@ -238,9 +270,20 @@ class ImageOverlay(Layer): {% endmacro %} """) - def __init__(self, image, bounds, origin='upper', colormap=None, - mercator_project=False, pixelated=True, - name=None, overlay=True, control=True, show=True, **kwargs): + def __init__( + self, + image: Any, + bounds: Sequence[Sequence[float]], + origin: str = 'upper', + colormap: Optional[Callable] = None, + mercator_project: bool = False, + pixelated: bool = True, + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + **kwargs + ): super(ImageOverlay, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'ImageOverlay' @@ -250,13 +293,13 @@ def __init__(self, image, bounds, origin='upper', colormap=None, if mercator_project: image = mercator_transform( image, - [bounds[0][0], bounds[1][0]], + (bounds[0][0], bounds[1][0]), origin=origin ) self.url = image_to_url(image, origin=origin, colormap=colormap) - def render(self, **kwargs): + def render(self, **kwargs) -> None: super(ImageOverlay, self).render() figure = self.get_root() @@ -278,7 +321,7 @@ def render(self, **kwargs): """ figure.header.add_child(Element(pixelated), name='leaflet-image-layer') # noqa - def _get_self_bounds(self): + def _get_self_bounds(self) -> Sequence[Sequence[float]]: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]]. @@ -323,8 +366,18 @@ class VideoOverlay(Layer): {% endmacro %} """) - def __init__(self, video_url, bounds, autoplay=True, loop=True, - name=None, overlay=True, control=True, show=True, **kwargs): + def __init__( + self, + video_url: str, + bounds: Sequence[Sequence[float]], + autoplay: bool = True, + loop: bool = True, + name: Optional[str] = None, + overlay: bool = True, + control: bool = True, + show: bool = True, + **kwargs: TypeJsonValue + ): super(VideoOverlay, self).__init__(name=name, overlay=overlay, control=control, show=show) self._name = 'VideoOverlay' @@ -337,7 +390,7 @@ def __init__(self, video_url, bounds, autoplay=True, loop=True, **kwargs ) - def _get_self_bounds(self): + def _get_self_bounds(self) -> Sequence[Sequence[float]]: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]] diff --git a/folium/utilities.py b/folium/utilities.py index 6666dc8c0..ecf58671f 100644 --- a/folium/utilities.py +++ b/folium/utilities.py @@ -10,9 +10,24 @@ import copy import uuid import collections +from typing import ( + Iterable, + Sequence, + Union, + List, + Optional, + Callable, + Tuple, + Iterator, + Type, + Dict, Any, +) from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative import numpy as np + +from branca.element import Element + try: import pandas as pd except ImportError: @@ -22,8 +37,16 @@ _VALID_URLS = set(uses_relative + uses_netloc + uses_params) _VALID_URLS.discard('') +TypeLine = Iterable[Sequence[float]] +TypeMultiLine = Iterable[TypeLine] + +TypeJsonValueNoNone = Union[str, float, bool, Sequence, dict] +TypeJsonValue = Union[TypeJsonValueNoNone, None] -def validate_location(location): # noqa: C901 +TypePathOptions = Union[bool, str, float, None] + + +def validate_location(location: Sequence[float]) -> List[float]: # noqa: C901 """Validate a single lat/lon coordinate pair and convert to a list Validate that location: @@ -66,8 +89,31 @@ def validate_location(location): # noqa: C901 return [float(x) for x in coords] -def validate_locations(locations): - """Validate an iterable with multiple lat/lon coordinate pairs. +def validate_locations(locations: TypeLine) -> List[List[float]]: + """Validate an iterable with lat/lon coordinate pairs. + + Returns + ------- + list[list[float, float]] + + """ + locations = if_pandas_df_convert_to_numpy(locations) + try: + iter(locations) + except TypeError: + raise TypeError('Locations should be an iterable with coordinate pairs,' + ' but instead got {!r}.'.format(locations)) + try: + next(iter(locations)) + except StopIteration: + raise ValueError('Locations is empty.') + return [validate_location(coord_pair) for coord_pair in locations] + + +def validate_multi_locations( + locations: Union[TypeLine, TypeMultiLine] +) -> Union[List[List[float]], List[List[List[float]]]]: + """Validate an iterable with possibly nested lists of coordinate pairs. Returns ------- @@ -85,16 +131,17 @@ def validate_locations(locations): except StopIteration: raise ValueError('Locations is empty.') try: - float(next(iter(next(iter(next(iter(locations))))))) + float(next(iter(next(iter(next(iter(locations))))))) # type: ignore except (TypeError, StopIteration): # locations is a list of coordinate pairs - return [validate_location(coord_pair) for coord_pair in locations] + return [validate_location(coord_pair) # type: ignore + for coord_pair in locations] else: # locations is a list of a list of coordinate pairs, recurse - return [validate_locations(lst) for lst in locations] + return [validate_locations(lst) for lst in locations] # type: ignore -def if_pandas_df_convert_to_numpy(obj): +def if_pandas_df_convert_to_numpy(obj: Any) -> Any: """Return a Numpy array from a Pandas dataframe. Iterating over a DataFrame has weird side effects, such as the first @@ -106,7 +153,11 @@ def if_pandas_df_convert_to_numpy(obj): return obj -def image_to_url(image, colormap=None, origin='upper'): +def image_to_url( + image: Any, + colormap: Optional[Callable] = None, + origin: str = 'upper', +) -> str: """ Infers the type of an image argument and transforms it into a URL. @@ -144,7 +195,7 @@ def image_to_url(image, colormap=None, origin='upper'): return url.replace('\n', ' ') -def _is_url(url): +def _is_url(url: str) -> bool: """Check to see if `url` has a valid protocol.""" try: return urlparse(url).scheme in _VALID_URLS @@ -152,7 +203,11 @@ def _is_url(url): return False -def write_png(data, origin='upper', colormap=None): +def write_png( + data: Any, + origin: str = 'upper', + colormap: Optional[Callable] = None, +) -> bytes: """ Transform an array of data into a PNG string. This can be written to disk using binary I/O, or encoded using base64 @@ -184,9 +239,7 @@ def write_png(data, origin='upper', colormap=None): PNG formatted byte string """ - if colormap is None: - def colormap(x): - return (x, x, x, 1) + colormap = colormap or (lambda x: (x, x, x, 1)) arr = np.atleast_3d(data) height, width, nblayers = arr.shape @@ -239,7 +292,12 @@ def png_pack(png_tag, data): png_pack(b'IEND', b'')]) -def mercator_transform(data, lat_bounds, origin='upper', height_out=None): +def mercator_transform( + data: Any, + lat_bounds: Tuple[float, float], + origin: str = 'upper', + height_out: Optional[int] = None, +) -> np.ndarray: """ Transforms an image computed in (longitude,latitude) coordinates into the a Mercator projection image. @@ -266,7 +324,6 @@ def mercator_transform(data, lat_bounds, origin='upper', height_out=None): See https://en.wikipedia.org/wiki/Web_Mercator for more details. """ - import numpy as np def mercator(x): return np.arcsinh(np.tan(x*np.pi/180.))*180./np.pi @@ -300,7 +357,7 @@ def mercator(x): return out -def none_min(x, y): +def none_min(x: Optional[float], y: Optional[float]) -> Optional[float]: if x is None: return y elif y is None: @@ -309,7 +366,7 @@ def none_min(x, y): return min(x, y) -def none_max(x, y): +def none_max(x: Optional[float], y: Optional[float]) -> Optional[float]: if x is None: return y elif y is None: @@ -318,7 +375,7 @@ def none_max(x, y): return max(x, y) -def iter_coords(obj): +def iter_coords(obj: Any) -> Iterator[Tuple[float, ...]]: """ Returns all the coordinate tuples from a geometry or feature. @@ -340,7 +397,7 @@ def iter_coords(obj): yield f -def _locations_mirror(x): +def _locations_mirror(x: Any) -> Any: """ Mirrors the points in a list-of-list-of-...-of-list-of-points. For example: @@ -357,13 +414,16 @@ def _locations_mirror(x): return x -def get_bounds(locations, lonlat=False): +def get_bounds( + locations: Any, + lonlat: bool = False, +) -> List[List[Optional[float]]]: """ Computes the bounds of the object in the form [[lat_min, lon_min], [lat_max, lon_max]] """ - bounds = [[None, None], [None, None]] + bounds: List[List[Optional[float]]] = [[None, None], [None, None]] for point in iter_coords(locations): bounds = [ [ @@ -380,7 +440,7 @@ def get_bounds(locations, lonlat=False): return bounds -def camelize(key): +def camelize(key: str) -> str: """Convert a python_style_variable_name to lowerCamelCase. Examples @@ -394,7 +454,7 @@ def camelize(key): for i, x in enumerate(key.split('_'))) -def _parse_size(value): +def _parse_size(value: Union[str, float]) -> Tuple[float, str]: try: if isinstance(value, (int, float)): value_type = 'px' @@ -410,7 +470,7 @@ def _parse_size(value): return value, value_type -def compare_rendered(obj1, obj2): +def compare_rendered(obj1: str, obj2: str) -> bool: """ Return True/False if the normalized rendered version of two folium map objects are the equal or not. @@ -419,7 +479,7 @@ def compare_rendered(obj1, obj2): return normalize(obj1) == normalize(obj2) -def normalize(rendered): +def normalize(rendered: str) -> str: """Return the input string without non-functional spaces or newlines.""" out = ''.join([line.strip() for line in rendered.splitlines() @@ -429,7 +489,7 @@ def normalize(rendered): @contextmanager -def _tmp_html(data): +def _tmp_html(data: str) -> Iterator[str]: """Yields the path of a temporary HTML file containing data.""" filepath = '' try: @@ -442,7 +502,7 @@ def _tmp_html(data): os.remove(filepath) -def deep_copy(item_original): +def deep_copy(item_original: Element) -> Element: """Return a recursive deep-copy of item where each copy has a new ID.""" item = copy.copy(item_original) item._id = uuid.uuid4().hex @@ -456,18 +516,18 @@ def deep_copy(item_original): return item -def get_obj_in_upper_tree(element, cls): +def get_obj_in_upper_tree(element: Element, cls: Type) -> Element: """Return the first object in the parent tree of class `cls`.""" - if not hasattr(element, '_parent'): + parent = element._parent + if parent is None: raise ValueError('The top of the tree was reached without finding a {}' .format(cls)) - parent = element._parent if not isinstance(parent, cls): return get_obj_in_upper_tree(parent, cls) return parent -def parse_options(**kwargs): +def parse_options(**kwargs: TypeJsonValue) -> Dict[str, TypeJsonValueNoNone]: """Return a dict with lower-camelcase keys and non-None values..""" return {camelize(key): value for key, value in kwargs.items() diff --git a/folium/vector_layers.py b/folium/vector_layers.py index f50697337..86bafee58 100644 --- a/folium/vector_layers.py +++ b/folium/vector_layers.py @@ -4,16 +4,28 @@ Wraps leaflet Polyline, Polygon, Rectangle, Circle, and CircleMarker """ +from typing import Union, Sequence, Optional, List from branca.element import MacroElement from folium.map import Marker, Popup, Tooltip -from folium.utilities import validate_locations, get_bounds +from folium.utilities import ( + validate_locations, + validate_multi_locations, + get_bounds, + TypeLine, + TypeMultiLine, + TypePathOptions, +) from jinja2 import Template -def path_options(line=False, radius=False, **kwargs): +def path_options( + line: bool = False, + radius: Optional[float] = None, + **kwargs: TypePathOptions +): """ Contains options and constants shared between vector overlays (Polygon, Polyline, Circle, CircleMarker, and Rectangle). @@ -69,16 +81,16 @@ def path_options(line=False, radius=False, **kwargs): 'smoothFactor': kwargs.pop('smooth_factor', 1.0), 'noClip': kwargs.pop('no_clip', False), } - if radius: + if radius is not None: extra_options.update({'radius': radius}) color = kwargs.pop('color', '#3388ff') fill_color = kwargs.pop('fill_color', False) if fill_color: fill = True - elif not fill_color: + else: fill_color = color - fill = kwargs.pop('fill', False) + fill = kwargs.pop('fill', False) # type: ignore default = { 'stroke': kwargs.pop('stroke', True), @@ -106,9 +118,14 @@ class BaseMultiLocation(MacroElement): """ - def __init__(self, locations, popup=None, tooltip=None): + def __init__( + self, + locations: Union[TypeLine, TypeMultiLine], + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + ): super(BaseMultiLocation, self).__init__() - self.locations = validate_locations(locations) + self.locations = validate_multi_locations(locations) if popup is not None: self.add_child(popup if isinstance(popup, Popup) else Popup(str(popup))) @@ -116,7 +133,7 @@ def __init__(self, locations, popup=None, tooltip=None): self.add_child(tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))) - def _get_self_bounds(self): + def _get_self_bounds(self) -> List[List[Optional[float]]]: """Compute the bounds of the object itself.""" return get_bounds(self.locations) @@ -130,6 +147,7 @@ class PolyLine(BaseMultiLocation): ---------- locations: list of points (latitude, longitude) Latitude and Longitude of line (Northing, Easting) + Pass multiple sequences of coordinates for a multi-polyline. popup: str or folium.Popup, default None Input text or visualization for object displayed when clicking. tooltip: str or folium.Tooltip, default None @@ -155,10 +173,16 @@ class PolyLine(BaseMultiLocation): {% endmacro %} """) - def __init__(self, locations, popup=None, tooltip=None, **kwargs): + def __init__( + self, + locations: Union[TypeLine, TypeMultiLine], + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + **kwargs: TypePathOptions + ): super(PolyLine, self).__init__(locations, popup=popup, tooltip=tooltip) self._name = 'PolyLine' - self.options = path_options(line=True, **kwargs) + self.options = path_options(line=True, radius=None, **kwargs) class Polygon(BaseMultiLocation): @@ -169,7 +193,10 @@ class Polygon(BaseMultiLocation): Parameters ---------- locations: list of points (latitude, longitude) - Latitude and Longitude of line (Northing, Easting) + - One list of coordinate pairs to define a polygon. You don't have to + add a last point equal to the first point. + - If you pass a list with multiple of those it will make a multi- + polygon. popup: string or folium.Popup, default None Input text or visualization for object displayed when clicking. tooltip: str or folium.Tooltip, default None @@ -189,21 +216,27 @@ class Polygon(BaseMultiLocation): {% endmacro %} """) - def __init__(self, locations, popup=None, tooltip=None, **kwargs): + def __init__( + self, + locations: Union[TypeLine, TypeMultiLine], + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + **kwargs: TypePathOptions + ): super(Polygon, self).__init__(locations, popup=popup, tooltip=tooltip) self._name = 'Polygon' - self.options = path_options(line=True, **kwargs) + self.options = path_options(line=True, radius=None, **kwargs) -class Rectangle(BaseMultiLocation): +class Rectangle(MacroElement): """Draw rectangle overlays on a map. See :func:`folium.vector_layers.path_options` for the `Path` options. Parameters ---------- - bounds: list of points (latitude, longitude) - Latitude and Longitude of line (Northing, Easting) + bounds: [(lat1, lon1), (lat2, lon2)] + Two lat lon pairs marking the two corners of the rectangle. popup: string or folium.Popup, default None Input text or visualization for object displayed when clicking. tooltip: str or folium.Tooltip, default None @@ -223,10 +256,28 @@ class Rectangle(BaseMultiLocation): {% endmacro %} """) - def __init__(self, bounds, popup=None, tooltip=None, **kwargs): - super(Rectangle, self).__init__(bounds, popup=popup, tooltip=tooltip) + def __init__( + self, + bounds: Sequence[Sequence[float]], + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + **kwargs: TypePathOptions + ): + super(Rectangle, self).__init__() self._name = 'rectangle' - self.options = path_options(line=True, **kwargs) + self.options = path_options(line=True, radius=None, **kwargs) + self.locations = validate_locations(bounds) + assert len(self.locations) == 2, 'Need two lat/lon pairs' + if popup is not None: + self.add_child(popup if isinstance(popup, Popup) + else Popup(str(popup))) + if tooltip is not None: + self.add_child(tooltip if isinstance(tooltip, Tooltip) + else Tooltip(str(tooltip))) + + def _get_self_bounds(self) -> List[List[Optional[float]]]: + """Compute the bounds of the object itself.""" + return get_bounds(self.locations) class Circle(Marker): @@ -263,7 +314,13 @@ class Circle(Marker): {% endmacro %} """) - def __init__(self, location, radius, popup=None, tooltip=None, **kwargs): + def __init__(self, + location: Sequence[float], + radius: float, + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + **kwargs: TypePathOptions + ): super(Circle, self).__init__(location, popup=popup, tooltip=tooltip) self._name = 'circle' self.options = path_options(line=False, radius=radius, **kwargs) @@ -300,7 +357,14 @@ class CircleMarker(Marker): {% endmacro %} """) - def __init__(self, location, radius=10, popup=None, tooltip=None, **kwargs): + def __init__( + self, + location: Sequence[float], + radius: float = 10, + popup: Optional[Union[Popup, str]] = None, + tooltip: Optional[Union[Tooltip, str]] = None, + **kwargs: TypePathOptions + ): super(CircleMarker, self).__init__(location, popup=popup, tooltip=tooltip) self._name = 'CircleMarker' diff --git a/setup.cfg b/setup.cfg index 1cf13d099..b5b36feaf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,3 +32,9 @@ ignore = *.enc tests tests/* + +[mypy] +ignore_missing_imports = True + +[mypy-folium._version] +ignore_errors = True