mirror of
https://github.com/ansible-collections/community.general.git
synced 2026-03-26 21:33:12 +00:00
Add type hints to action and test plugins and to plugin utils; fix some bugs, and improve input validation (#11167)
* Add type hints to action and test plugins and to plugin utils. Also fix some bugs and add proper input validation. * Combine lines. Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com> * Extend changelog fragment. * Move task_vars initialization up. --------- Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
This commit is contained in:
@@ -4,9 +4,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
import typing as t
|
||||
from collections.abc import Mapping
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
|
||||
try:
|
||||
# Introduced with Data Tagging (https://github.com/ansible/ansible/pull/84621):
|
||||
from ansible.module_utils.datatag import native_type_name as _native_type_name
|
||||
@@ -16,7 +18,7 @@ except ImportError:
|
||||
HAS_NATIVE_TYPE_NAME = False
|
||||
|
||||
|
||||
def _atype(data, alias, *, use_native_type: bool = False):
|
||||
def _atype(data: t.Any, alias: Mapping, *, use_native_type: bool = False) -> str:
|
||||
"""
|
||||
Returns the name of the type class.
|
||||
"""
|
||||
@@ -30,10 +32,10 @@ def _atype(data, alias, *, use_native_type: bool = False):
|
||||
data_type = "dict"
|
||||
elif data_type == "_AnsibleLazyTemplateList":
|
||||
data_type = "list"
|
||||
return alias.get(data_type, data_type)
|
||||
return str(alias.get(data_type, data_type))
|
||||
|
||||
|
||||
def _ansible_type(data, alias, *, use_native_type: bool = False):
|
||||
def _ansible_type(data: t.Any, alias: t.Any, *, use_native_type: bool = False) -> str:
|
||||
"""
|
||||
Returns the Ansible data type.
|
||||
"""
|
||||
@@ -42,21 +44,20 @@ def _ansible_type(data, alias, *, use_native_type: bool = False):
|
||||
alias = {}
|
||||
|
||||
if not isinstance(alias, Mapping):
|
||||
msg = "The argument alias must be a dictionary. %s is %s"
|
||||
raise AnsibleFilterError(msg % (alias, type(alias)))
|
||||
raise AnsibleFilterError(f"The argument alias must be a dictionary. {alias!r} is {type(alias)}")
|
||||
|
||||
data_type = _atype(data, alias, use_native_type=use_native_type)
|
||||
|
||||
if data_type == "list" and len(data) > 0:
|
||||
items = [_atype(i, alias, use_native_type=use_native_type) for i in data]
|
||||
items_type = "|".join(sorted(set(items)))
|
||||
items = {_atype(i, alias, use_native_type=use_native_type) for i in data}
|
||||
items_type = "|".join(sorted(items))
|
||||
return f"{data_type}[{items_type}]"
|
||||
|
||||
if data_type == "dict" and len(data) > 0:
|
||||
keys = [_atype(i, alias, use_native_type=use_native_type) for i in data.keys()]
|
||||
vals = [_atype(i, alias, use_native_type=use_native_type) for i in data.values()]
|
||||
keys_type = "|".join(sorted(set(keys)))
|
||||
vals_type = "|".join(sorted(set(vals)))
|
||||
keys = {_atype(i, alias, use_native_type=use_native_type) for i in data.keys()}
|
||||
vals = {_atype(i, alias, use_native_type=use_native_type) for i in data.values()}
|
||||
keys_type = "|".join(sorted(keys))
|
||||
vals_type = "|".join(sorted(vals))
|
||||
return f"{data_type}[{keys_type}, {vals_type}]"
|
||||
|
||||
return data_type
|
||||
|
||||
@@ -6,12 +6,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.collections import is_sequence
|
||||
|
||||
def _keys_filter_params(data, matching_parameter):
|
||||
|
||||
def _keys_filter_params(
|
||||
data: t.Any, matching_parameter: t.Any
|
||||
) -> tuple[Sequence[Mapping[str, t.Any]], t.Literal["equal", "starts_with", "ends_with", "regex"]]:
|
||||
"""test parameters:
|
||||
* data must be a list of dictionaries. All keys must be strings.
|
||||
* matching_parameter is member of a list.
|
||||
@@ -21,27 +26,27 @@ def _keys_filter_params(data, matching_parameter):
|
||||
ml = ["equal", "starts_with", "ends_with", "regex"]
|
||||
|
||||
if not isinstance(data, Sequence):
|
||||
msg = "First argument must be a list. %s is %s"
|
||||
raise AnsibleFilterError(msg % (data, type(data)))
|
||||
msg = f"First argument must be a list. {data!r} is {type(data)}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
for elem in data:
|
||||
if not isinstance(elem, Mapping):
|
||||
msg = "The data items must be dictionaries. %s is %s"
|
||||
raise AnsibleFilterError(msg % (elem, type(elem)))
|
||||
msg = f"The data items must be dictionaries. {elem} is {type(elem)}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
for elem in data:
|
||||
if not all(isinstance(item, str) for item in elem.keys()):
|
||||
msg = "Top level keys must be strings. keys: %s"
|
||||
raise AnsibleFilterError(msg % elem.keys())
|
||||
msg = f"Top level keys must be strings. keys: {list(elem.keys())}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
if mp not in ml:
|
||||
msg = "The matching_parameter must be one of %s. matching_parameter=%s"
|
||||
raise AnsibleFilterError(msg % (ml, mp))
|
||||
msg = f"The matching_parameter must be one of {ml}. matching_parameter={mp!r}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
return
|
||||
return data, mp
|
||||
|
||||
|
||||
def _keys_filter_target_str(target, matching_parameter):
|
||||
def _keys_filter_target_str(target: t.Any, matching_parameter: t.Any) -> tuple[str, ...] | re.Pattern:
|
||||
"""
|
||||
Test:
|
||||
* target is a non-empty string or list.
|
||||
@@ -54,18 +59,18 @@ def _keys_filter_target_str(target, matching_parameter):
|
||||
"""
|
||||
|
||||
if not isinstance(target, Sequence):
|
||||
msg = "The target must be a string or a list. target is %s."
|
||||
raise AnsibleFilterError(msg % type(target))
|
||||
msg = f"The target must be a string or a list. target is {type(target)}."
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
if len(target) == 0:
|
||||
msg = "The target can't be empty."
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
if isinstance(target, list):
|
||||
if is_sequence(target):
|
||||
for elem in target:
|
||||
if not isinstance(elem, str):
|
||||
msg = "The target items must be strings. %s is %s"
|
||||
raise AnsibleFilterError(msg % (elem, type(elem)))
|
||||
msg = f"The target items must be strings. {elem!r} is {type(elem)}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
if matching_parameter == "regex":
|
||||
if isinstance(target, str):
|
||||
@@ -77,19 +82,19 @@ def _keys_filter_target_str(target, matching_parameter):
|
||||
else:
|
||||
r = target[0]
|
||||
try:
|
||||
tt = re.compile(r)
|
||||
return re.compile(r)
|
||||
except re.error as e:
|
||||
msg = "The target must be a valid regex if matching_parameter=regex. target is %s"
|
||||
raise AnsibleFilterError(msg % r) from e
|
||||
msg = f"The target must be a valid regex if matching_parameter=regex. target is {r}"
|
||||
raise AnsibleFilterError(msg) from e
|
||||
elif isinstance(target, str):
|
||||
tt = (target,)
|
||||
return (target,)
|
||||
else:
|
||||
tt = tuple(set(target))
|
||||
|
||||
return tt
|
||||
return tuple(set(target))
|
||||
|
||||
|
||||
def _keys_filter_target_dict(target, matching_parameter):
|
||||
def _keys_filter_target_dict(
|
||||
target: t.Any, matching_parameter: t.Any
|
||||
) -> list[tuple[str, str]] | list[tuple[re.Pattern, str]]:
|
||||
"""
|
||||
Test:
|
||||
* target is a list of dictionaries with attributes 'after' and 'before'.
|
||||
@@ -101,8 +106,8 @@ def _keys_filter_target_dict(target, matching_parameter):
|
||||
"""
|
||||
|
||||
if not isinstance(target, list):
|
||||
msg = "The target must be a list. target is %s."
|
||||
raise AnsibleFilterError(msg % (target, type(target)))
|
||||
msg = f"The target must be a list. target is {target!r} of type {type(target)}."
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
if len(target) == 0:
|
||||
msg = "The target can't be empty."
|
||||
@@ -110,25 +115,25 @@ def _keys_filter_target_dict(target, matching_parameter):
|
||||
|
||||
for elem in target:
|
||||
if not isinstance(elem, Mapping):
|
||||
msg = "The target items must be dictionaries. %s is %s"
|
||||
raise AnsibleFilterError(msg % (elem, type(elem)))
|
||||
msg = f"The target items must be dictionaries. {elem!r}%s is {type(elem)}"
|
||||
raise AnsibleFilterError(msg)
|
||||
if not all(k in elem for k in ("before", "after")):
|
||||
msg = "All dictionaries in target must include attributes: after, before."
|
||||
raise AnsibleFilterError(msg)
|
||||
if not isinstance(elem["before"], str):
|
||||
msg = "The attributes before must be strings. %s is %s"
|
||||
raise AnsibleFilterError(msg % (elem["before"], type(elem["before"])))
|
||||
msg = f"The attributes before must be strings. {elem['before']!r} is {type(elem['before'])}"
|
||||
raise AnsibleFilterError(msg)
|
||||
if not isinstance(elem["after"], str):
|
||||
msg = "The attributes after must be strings. %s is %s"
|
||||
raise AnsibleFilterError(msg % (elem["after"], type(elem["after"])))
|
||||
msg = f"The attributes after must be strings. {elem['after']!r} is {type(elem['after'])}"
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
before = [d["before"] for d in target]
|
||||
after = [d["after"] for d in target]
|
||||
before: list[str] = [d["before"] for d in target]
|
||||
after: list[str] = [d["after"] for d in target]
|
||||
|
||||
if matching_parameter == "regex":
|
||||
try:
|
||||
tr = map(re.compile, before)
|
||||
tz = list(zip(tr, after))
|
||||
return list(zip(tr, after))
|
||||
except re.error as e:
|
||||
msg = (
|
||||
"The attributes before must be valid regex if matching_parameter=regex."
|
||||
@@ -136,6 +141,4 @@ def _keys_filter_target_dict(target, matching_parameter):
|
||||
)
|
||||
raise AnsibleFilterError(msg % before) from e
|
||||
else:
|
||||
tz = list(zip(before, after))
|
||||
|
||||
return tz
|
||||
return list(zip(before, after))
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from collections.abc import Mapping, Set
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from ansible.module_utils.common.collections import is_sequence
|
||||
from ansible.utils.unsafe_proxy import (
|
||||
AnsibleUnsafe,
|
||||
@@ -17,14 +18,54 @@ _RE_TEMPLATE_CHARS = re.compile("[{}]")
|
||||
_RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]")
|
||||
|
||||
|
||||
def make_unsafe(value):
|
||||
@t.overload
|
||||
def make_unsafe(value: None) -> None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: Mapping) -> dict: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: Set) -> set: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: tuple) -> tuple: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: list) -> list: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: Sequence) -> Sequence: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: str) -> str: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: bool) -> bool: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: int) -> int: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def make_unsafe(value: float) -> float: ...
|
||||
|
||||
|
||||
def make_unsafe(value: t.Any) -> t.Any:
|
||||
if value is None or isinstance(value, AnsibleUnsafe):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return {make_unsafe(key): make_unsafe(val) for key, val in value.items()}
|
||||
elif isinstance(value, Set):
|
||||
return set(make_unsafe(elt) for elt in value)
|
||||
return {make_unsafe(elt) for elt in value}
|
||||
elif is_sequence(value):
|
||||
return type(value)(make_unsafe(elt) for elt in value)
|
||||
elif isinstance(value, bytes):
|
||||
|
||||
Reference in New Issue
Block a user