diff --git a/plugins/module_utils/android_sdkmanager.py b/plugins/module_utils/android_sdkmanager.py index 9fd51d044c..8396179f2d 100644 --- a/plugins/module_utils/android_sdkmanager.py +++ b/plugins/module_utils/android_sdkmanager.py @@ -20,7 +20,7 @@ __state_map = {"present": "--install", "absent": "--uninstall"} __channel_map = {"stable": 0, "beta": 1, "dev": 2, "canary": 3} -def __map_channel(channel_name): +def __map_channel(channel_name: str) -> int: if channel_name not in __channel_map: raise ValueError(f"Unknown channel name '{channel_name}'") return __channel_map[channel_name] @@ -45,18 +45,18 @@ def sdkmanager_runner(module: AnsibleModule, **kwargs) -> CmdRunner: class Package: - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if not isinstance(other, Package): return True return self.name != other.name - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Package): return False @@ -86,17 +86,17 @@ class AndroidSdkManager: def __init__(self, module: AnsibleModule) -> None: self.runner = sdkmanager_runner(module) - def get_installed_packages(self): + def get_installed_packages(self) -> set[Package]: with self.runner("installed sdk_root channel") as ctx: rc, stdout, stderr = ctx.run() return self._parse_packages(stdout, self._RE_INSTALLED_PACKAGES_HEADER, self._RE_INSTALLED_PACKAGE) - def get_updatable_packages(self): + def get_updatable_packages(self) -> set[Package]: with self.runner("list newer sdk_root channel") as ctx: rc, stdout, stderr = ctx.run() return self._parse_packages(stdout, self._RE_UPDATABLE_PACKAGES_HEADER, self._RE_UPDATABLE_PACKAGE) - def apply_packages_changes(self, packages, accept_licenses=False): + def apply_packages_changes(self, packages: list[Package], accept_licenses: bool = False) -> tuple[int, str, str]: """Install or delete packages, depending on the `module.vars.state` parameter""" if len(packages) == 0: return 0, "", "" @@ -118,7 +118,7 @@ class AndroidSdkManager: return rc, stdout, stderr return 0, "", "" - def _try_parse_stderr(self, stderr): + def _try_parse_stderr(self, stderr: str) -> None: data = stderr.splitlines() for line in data: unknown_package_regex = self._RE_UNKNOWN_PACKAGE.match(line) @@ -127,15 +127,15 @@ class AndroidSdkManager: raise SdkManagerException(f"Unknown package {package}") @staticmethod - def _parse_packages(stdout, header_regexp, row_regexp): + def _parse_packages(stdout: str, header_regexp: re.Pattern, row_regexp: re.Pattern) -> set[Package]: data = stdout.splitlines() - section_found = False + section_found: bool = False packages = set() for line in data: if not section_found: - section_found = header_regexp.match(line) + section_found = bool(header_regexp.match(line)) continue else: p = row_regexp.match(line) diff --git a/plugins/module_utils/btrfs.py b/plugins/module_utils/btrfs.py index 612e912ef9..ef5f53c870 100644 --- a/plugins/module_utils/btrfs.py +++ b/plugins/module_utils/btrfs.py @@ -72,37 +72,37 @@ class BtrfsCommands: command = f"{self.__btrfs} subvolume list -tap {filesystem_path}" result = self.__module.run_command(command, check_rc=True) stdout = [x.split("\t") for x in result[1].splitlines()] - subvolumes = [{"id": 5, "parent": None, "path": "/"}] + subvolumes: list[dict[str, t.Any]] = [{"id": 5, "parent": None, "path": "/"}] if len(stdout) > 2: subvolumes.extend([self.__parse_subvolume_list_record(x) for x in stdout[2:]]) return subvolumes - def __parse_subvolume_list_record(self, item): + def __parse_subvolume_list_record(self, item: list[str]) -> dict[str, t.Any]: return { "id": int(item[0]), "parent": int(item[2]), "path": normalize_subvolume_path(item[5]), } - def subvolume_get_default(self, filesystem_path): + def subvolume_get_default(self, filesystem_path: str) -> int: command = [self.__btrfs, "subvolume", "get-default", to_bytes(filesystem_path)] result = self.__module.run_command(command, check_rc=True) # ID [n] ... return int(result[1].strip().split()[1]) - def subvolume_set_default(self, filesystem_path, subvolume_id): + def subvolume_set_default(self, filesystem_path: str, subvolume_id: int) -> None: command = [self.__btrfs, "subvolume", "set-default", str(subvolume_id), to_bytes(filesystem_path)] self.__module.run_command(command, check_rc=True) - def subvolume_create(self, subvolume_path): + def subvolume_create(self, subvolume_path: str) -> None: command = [self.__btrfs, "subvolume", "create", to_bytes(subvolume_path)] self.__module.run_command(command, check_rc=True) - def subvolume_snapshot(self, snapshot_source, snapshot_destination): + def subvolume_snapshot(self, snapshot_source: str, snapshot_destination: str) -> None: command = [self.__btrfs, "subvolume", "snapshot", to_bytes(snapshot_source), to_bytes(snapshot_destination)] self.__module.run_command(command, check_rc=True) - def subvolume_delete(self, subvolume_path): + def subvolume_delete(self, subvolume_path: str) -> None: command = [self.__btrfs, "subvolume", "delete", to_bytes(subvolume_path)] self.__module.run_command(command, check_rc=True) @@ -117,7 +117,7 @@ class BtrfsInfoProvider: self.__btrfs_api = BtrfsCommands(module) self.__findmnt_path: str = self.__module.get_bin_path("findmnt", required=True) - def get_filesystems(self): + def get_filesystems(self) -> list[dict[str, t.Any]]: filesystems = self.__btrfs_api.filesystem_show() mountpoints = self.__find_mountpoints() for filesystem in filesystems: @@ -132,20 +132,22 @@ class BtrfsInfoProvider: return filesystems - def get_mountpoints(self, filesystem_devices): + def get_mountpoints(self, filesystem_devices: list[str]) -> list[dict[str, t.Any]]: mountpoints = self.__find_mountpoints() return self.__filter_mountpoints_for_devices(mountpoints, filesystem_devices) - def get_subvolumes(self, filesystem_path): + def get_subvolumes(self, filesystem_path) -> list[dict[str, t.Any]]: return self.__btrfs_api.subvolumes_list(filesystem_path) - def get_default_subvolume_id(self, filesystem_path): + def get_default_subvolume_id(self, filesystem_path) -> int: return self.__btrfs_api.subvolume_get_default(filesystem_path) - def __filter_mountpoints_for_devices(self, mountpoints, devices): + def __filter_mountpoints_for_devices( + self, mountpoints: list[dict[str, t.Any]], devices: list[str] + ) -> list[dict[str, t.Any]]: return [m for m in mountpoints if (m["device"] in devices)] - def __find_mountpoints(self): + def __find_mountpoints(self) -> list[dict[str, t.Any]]: command = f"{self.__findmnt_path} -t btrfs -nvP" result = self.__module.run_command(command) mountpoints = [] @@ -156,7 +158,7 @@ class BtrfsInfoProvider: mountpoints.append(mountpoint) return mountpoints - def __parse_mountpoint_pairs(self, line): + def __parse_mountpoint_pairs(self, line) -> dict[str, t.Any]: pattern = re.compile( r'^TARGET="(?P.*)"\s+SOURCE="(?P.*)"\s+FSTYPE="(?P.*)"\s+OPTIONS="(?P.*)"\s*$' ) @@ -170,13 +172,13 @@ class BtrfsInfoProvider: "subvolid": self.__extract_mount_subvolid(groups["options"]), } else: - raise BtrfsModuleException(f"Failed to parse findmnt result for line: '{line}'") + raise BtrfsModuleException(f"Failed to parse findmnt result for line: {line!r}") - def __extract_mount_subvolid(self, mount_options): + def __extract_mount_subvolid(self, mount_options: str) -> int: for option in mount_options.split(","): if option.startswith("subvolid="): return int(option[len("subvolid=") :]) - raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options '{mount_options}'") + raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options {mount_options!r}") class BtrfsSubvolume: @@ -184,39 +186,38 @@ class BtrfsSubvolume: Wrapper class providing convenience methods for inspection of a btrfs subvolume """ - def __init__(self, filesystem, subvolume_id): + def __init__(self, filesystem: BtrfsFilesystem, subvolume_id: int): self.__filesystem = filesystem self.__subvolume_id = subvolume_id - def get_filesystem(self): + def get_filesystem(self) -> BtrfsFilesystem: return self.__filesystem - def is_mounted(self): + def is_mounted(self) -> bool: mountpoints = self.get_mountpoints() return mountpoints is not None and len(mountpoints) > 0 - def is_filesystem_root(self): + def is_filesystem_root(self) -> bool: return self.__subvolume_id == 5 - def is_filesystem_default(self): + def is_filesystem_default(self) -> bool: return self.__filesystem.default_subvolid == self.__subvolume_id - def get_mounted_path(self): + def get_mounted_path(self) -> str | None: mountpoints = self.get_mountpoints() if mountpoints is not None and len(mountpoints) > 0: return mountpoints[0] - elif self.parent is not None: + if self.parent is not None: parent = self.__filesystem.get_subvolume_by_id(self.parent) - parent_path = parent.get_mounted_path() + parent_path = parent.get_mounted_path() if parent else None if parent_path is not None: - return parent_path + os.path.sep + self.name - else: - return None + return f"{parent_path}{os.path.sep}{self.name}" + return None - def get_mountpoints(self): + def get_mountpoints(self) -> list[str]: return self.__filesystem.get_mountpoints_by_subvolume_id(self.__subvolume_id) - def get_child_relative_path(self, absolute_child_path): + def get_child_relative_path(self, absolute_child_path: str) -> str: """ Get the relative path from this subvolume to the named child subvolume. The provided parameter is expected to be normalized as by normalize_subvolume_path. @@ -228,19 +229,21 @@ class BtrfsSubvolume: else: raise BtrfsModuleException(f"Path '{absolute_child_path}' doesn't start with '{path}'") - def get_parent_subvolume(self): + def get_parent_subvolume(self) -> BtrfsSubvolume | None: parent_id = self.parent return self.__filesystem.get_subvolume_by_id(parent_id) if parent_id is not None else None - def get_child_subvolumes(self): + def get_child_subvolumes(self) -> list[BtrfsSubvolume]: return self.__filesystem.get_subvolume_children(self.__subvolume_id) @property - def __info(self): - return self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id) + def __info(self) -> dict[str, t.Any]: + result = self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id) + # assert result is not None + return result # type: ignore @property - def id(self): + def id(self) -> int: return self.__subvolume_id @property @@ -248,7 +251,7 @@ class BtrfsSubvolume: return self.path.split("/").pop() @property - def path(self): + def path(self) -> str: return self.__info["path"] @property @@ -261,105 +264,105 @@ class BtrfsFilesystem: Wrapper class providing convenience methods for inspection of a btrfs filesystem """ - def __init__(self, info, provider, module: AnsibleModule) -> None: + def __init__(self, info: dict[str, t.Any], provider: BtrfsInfoProvider, module: AnsibleModule) -> None: self.__provider = provider # constant for module execution - self.__uuid = info["uuid"] - self.__label = info["label"] - self.__devices = info["devices"] + self.__uuid: str = info["uuid"] + self.__label: str = info["label"] + self.__devices: list[str] = info["devices"] # refreshable - self.__default_subvolid = info["default_subvolid"] if "default_subvolid" in info else None + self.__default_subvolid: int | None = info["default_subvolid"] if "default_subvolid" in info else None self.__update_mountpoints(info["mountpoints"] if "mountpoints" in info else []) self.__update_subvolumes(info["subvolumes"] if "subvolumes" in info else []) @property - def uuid(self): + def uuid(self) -> str: return self.__uuid @property - def label(self): + def label(self) -> str: return self.__label @property - def default_subvolid(self): + def default_subvolid(self) -> int | None: return self.__default_subvolid @property - def devices(self): + def devices(self) -> list[str]: return list(self.__devices) - def refresh(self): + def refresh(self) -> None: self.refresh_mountpoints() self.refresh_subvolumes() self.refresh_default_subvolume() - def refresh_mountpoints(self): + def refresh_mountpoints(self) -> None: mountpoints = self.__provider.get_mountpoints(list(self.__devices)) self.__update_mountpoints(mountpoints) - def __update_mountpoints(self, mountpoints): - self.__mountpoints = dict() + def __update_mountpoints(self, mountpoints: list[dict[str, t.Any]]) -> None: + self.__mountpoints: dict[int, list[str]] = dict() for i in mountpoints: - subvolid = i["subvolid"] - mountpoint = i["mountpoint"] + subvolid: int = i["subvolid"] + mountpoint: str = i["mountpoint"] if subvolid not in self.__mountpoints: self.__mountpoints[subvolid] = [] self.__mountpoints[subvolid].append(mountpoint) - def refresh_subvolumes(self): + def refresh_subvolumes(self) -> None: filesystem_path = self.get_any_mountpoint() if filesystem_path is not None: subvolumes = self.__provider.get_subvolumes(filesystem_path) self.__update_subvolumes(subvolumes) - def __update_subvolumes(self, subvolumes): + def __update_subvolumes(self, subvolumes: list[dict[str, t.Any]]) -> None: # TODO strategy for retaining information on deleted subvolumes? - self.__subvolumes = dict() + self.__subvolumes: dict[int, dict[str, t.Any]] = dict() for subvolume in subvolumes: self.__subvolumes[subvolume["id"]] = subvolume - def refresh_default_subvolume(self): + def refresh_default_subvolume(self) -> None: filesystem_path = self.get_any_mountpoint() if filesystem_path is not None: self.__default_subvolid = self.__provider.get_default_subvolume_id(filesystem_path) - def contains_device(self, device): + def contains_device(self, device: str) -> bool: return device in self.__devices - def contains_subvolume(self, subvolume): + def contains_subvolume(self, subvolume: str) -> bool: return self.get_subvolume_by_name(subvolume) is not None - def get_subvolume_by_id(self, subvolume_id): + def get_subvolume_by_id(self, subvolume_id: int) -> BtrfsSubvolume | None: return BtrfsSubvolume(self, subvolume_id) if subvolume_id in self.__subvolumes else None - def get_subvolume_info_for_id(self, subvolume_id): + def get_subvolume_info_for_id(self, subvolume_id: int) -> dict[str, t.Any] | None: return self.__subvolumes[subvolume_id] if subvolume_id in self.__subvolumes else None - def get_subvolume_by_name(self, subvolume): + def get_subvolume_by_name(self, subvolume: str) -> BtrfsSubvolume | None: for subvolume_info in self.__subvolumes.values(): if subvolume_info["path"] == subvolume: return BtrfsSubvolume(self, subvolume_info["id"]) return None - def get_any_mountpoint(self): + def get_any_mountpoint(self) -> str | None: for subvol_mountpoints in self.__mountpoints.values(): if len(subvol_mountpoints) > 0: return subvol_mountpoints[0] # maybe error? return None - def get_any_mounted_subvolume(self): + def get_any_mounted_subvolume(self) -> BtrfsSubvolume | None: for subvolid, subvol_mountpoints in self.__mountpoints.items(): if len(subvol_mountpoints) > 0: return self.get_subvolume_by_id(subvolid) return None - def get_mountpoints_by_subvolume_id(self, subvolume_id): + def get_mountpoints_by_subvolume_id(self, subvolume_id: int) -> list[str]: return self.__mountpoints[subvolume_id] if subvolume_id in self.__mountpoints else [] - def get_nearest_subvolume(self, subvolume): + def get_nearest_subvolume(self, subvolume: str) -> BtrfsSubvolume: """Return the identified subvolume if existing, else the closest matching parent""" subvolumes_by_path = self.__get_subvolumes_by_path() while len(subvolume) > 1: @@ -370,30 +373,31 @@ class BtrfsFilesystem: return BtrfsSubvolume(self, 5) - def get_mountpath_as_child(self, subvolume_name): + def get_mountpath_as_child(self, subvolume_name: str) -> str: """Find a path to the target subvolume through a mounted ancestor""" nearest = self.get_nearest_subvolume(subvolume_name) + nearest_or_none: BtrfsSubvolume | None = nearest if nearest.path == subvolume_name: - nearest = nearest.get_parent_subvolume() - if nearest is None or nearest.get_mounted_path() is None: + nearest_or_none = nearest.get_parent_subvolume() + if nearest_or_none is None or nearest_or_none.get_mounted_path() is None: raise BtrfsModuleException(f"Failed to find a path '{subvolume_name}' through a mounted parent subvolume") else: - return nearest.get_mounted_path() + os.path.sep + nearest.get_child_relative_path(subvolume_name) + return f"{nearest_or_none.get_mounted_path()}{os.path.sep}{nearest_or_none.get_child_relative_path(subvolume_name)}" - def get_subvolume_children(self, subvolume_id): + def get_subvolume_children(self, subvolume_id: int) -> list[BtrfsSubvolume]: return [BtrfsSubvolume(self, x["id"]) for x in self.__subvolumes.values() if x["parent"] == subvolume_id] - def __get_subvolumes_by_path(self): + def __get_subvolumes_by_path(self) -> dict[str, dict[str, t.Any]]: result = {} for s in self.__subvolumes.values(): path = s["path"] result[path] = s return result - def is_mounted(self): + def is_mounted(self) -> bool: return self.__mountpoints is not None and len(self.__mountpoints) > 0 - def get_summary(self): + def get_summary(self) -> dict[str, t.Any]: subvolumes = [] sources = self.__subvolumes.values() if self.__subvolumes is not None else [] for subvolume in sources: @@ -424,14 +428,16 @@ class BtrfsFilesystemsProvider: def __init__(self, module: AnsibleModule) -> None: self.__module = module self.__provider = BtrfsInfoProvider(module) - self.__filesystems = None + self.__filesystems: dict[str, BtrfsFilesystem] | None = None - def get_matching_filesystem(self, criteria): + def get_matching_filesystem(self, criteria: dict[str, t.Any]) -> BtrfsFilesystem: if criteria["device"] is not None: criteria["device"] = os.path.realpath(criteria["device"]) self.__check_init() - matching = [f for f in self.__filesystems.values() if self.__filesystem_matches_criteria(f, criteria)] + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + matching = [f for f in self_filesystems.values() if self.__filesystem_matches_criteria(f, criteria)] if len(matching) == 1: return matching[0] else: @@ -439,26 +445,30 @@ class BtrfsFilesystemsProvider: f"Found {len(matching)} filesystems matching criteria uuid={criteria['uuid']} label={criteria['label']} device={criteria['device']}" ) - def __filesystem_matches_criteria(self, filesystem, criteria): + def __filesystem_matches_criteria(self, filesystem: BtrfsFilesystem, criteria: dict[str, t.Any]): return ( (criteria["uuid"] is None or filesystem.uuid == criteria["uuid"]) and (criteria["label"] is None or filesystem.label == criteria["label"]) and (criteria["device"] is None or filesystem.contains_device(criteria["device"])) ) - def get_filesystem_for_device(self, device): + def get_filesystem_for_device(self, device: str) -> BtrfsFilesystem | None: real_device = os.path.realpath(device) self.__check_init() - for fs in self.__filesystems.values(): + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + for fs in self_filesystems.values(): if fs.contains_device(real_device): return fs return None - def get_filesystems(self): + def get_filesystems(self) -> list[BtrfsFilesystem]: self.__check_init() - return list(self.__filesystems.values()) + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + return list(self_filesystems.values()) - def __check_init(self): + def __check_init(self) -> None: if self.__filesystems is None: self.__filesystems = dict() for f in self.__provider.get_filesystems(): diff --git a/plugins/module_utils/cmd_runner.py b/plugins/module_utils/cmd_runner.py index 419d6c276f..ead3c86f8b 100644 --- a/plugins/module_utils/cmd_runner.py +++ b/plugins/module_utils/cmd_runner.py @@ -12,7 +12,7 @@ from ansible.module_utils.common.locale import get_best_parsable_locale from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt if t.TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Mapping, Sequence from ansible.module_utils.basic import AnsibleModule @@ -84,7 +84,7 @@ class CmdRunner: self, module: AnsibleModule, command, - arg_formats: dict[str, Callable] | None = None, + arg_formats: Mapping[str, Callable | cmd_runner_fmt._ArgFormat] | None = None, default_args_order: str | Sequence[str] = (), check_rc: bool = False, force_lang: str = "C", diff --git a/plugins/module_utils/cmd_runner_fmt.py b/plugins/module_utils/cmd_runner_fmt.py index 535a012947..0000cc81a8 100644 --- a/plugins/module_utils/cmd_runner_fmt.py +++ b/plugins/module_utils/cmd_runner_fmt.py @@ -11,36 +11,46 @@ from functools import wraps from ansible.module_utils.common.collections import is_sequence if t.TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping, Sequence ArgFormatType = Callable[[t.Any], list[str]] + _T = t.TypeVar("_T") -def _ensure_list(value): - return list(value) if is_sequence(value) else [value] +def _ensure_list(value: _T | Sequence[_T]) -> list[_T]: + return list(value) if is_sequence(value) else [value] # type: ignore # TODO need type assertion for is_sequence class _ArgFormat: - def __init__(self, func, ignore_none=True, ignore_missing_value=False): + def __init__( + self, + func: Callable[[t.Any], Sequence[t.Any]], + ignore_none: bool | None = True, + ignore_missing_value: bool = False, + ) -> None: self.func = func self.ignore_none = ignore_none self.ignore_missing_value = ignore_missing_value - def __call__(self, value): + def __call__(self, value: t.Any | None) -> list[str]: ignore_none = self.ignore_none if self.ignore_none is not None else True if value is None and ignore_none: return [] f = self.func return [str(x) for x in f(value)] - def __str__(self): + def __str__(self) -> str: return f"" - def __repr__(self): + def __repr__(self) -> str: return str(self) -def as_bool(args_true, args_false=None, ignore_none=None): +def as_bool( + args_true: Sequence[t.Any] | t.Any, + args_false: Sequence[t.Any] | t.Any | None = None, + ignore_none: bool | None = None, +) -> _ArgFormat: if args_false is not None: if ignore_none is None: ignore_none = False @@ -51,24 +61,24 @@ def as_bool(args_true, args_false=None, ignore_none=None): ) -def as_bool_not(args): +def as_bool_not(args: Sequence[t.Any] | t.Any) -> _ArgFormat: return as_bool([], args, ignore_none=False) -def as_optval(arg, ignore_none=None): +def as_optval(arg, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [f"{arg}{value}"], ignore_none=ignore_none) -def as_opt_val(arg, ignore_none=None): +def as_opt_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [arg, value], ignore_none=ignore_none) -def as_opt_eq_val(arg, ignore_none=None): +def as_opt_eq_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [f"{arg}={value}"], ignore_none=ignore_none) -def as_list(ignore_none=None, min_len=0, max_len=None): - def func(value): +def as_list(ignore_none: bool | None = None, min_len: int = 0, max_len: int | None = None) -> _ArgFormat: + def func(value: t.Any) -> list[t.Any]: value = _ensure_list(value) if len(value) < min_len: raise ValueError(f"Parameter must have at least {min_len} element(s)") @@ -79,17 +89,21 @@ def as_list(ignore_none=None, min_len=0, max_len=None): return _ArgFormat(func, ignore_none=ignore_none) -def as_fixed(*args): +def as_fixed(*args: t.Any) -> _ArgFormat: if len(args) == 1 and is_sequence(args[0]): args = args[0] return _ArgFormat(lambda value: _ensure_list(args), ignore_none=False, ignore_missing_value=True) -def as_func(func, ignore_none=None): +def as_func(func: Callable[[t.Any], Sequence[t.Any]], ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(func, ignore_none=ignore_none) -def as_map(_map, default=None, ignore_none=None): +def as_map( + _map: Mapping[t.Any, Sequence[t.Any] | t.Any], + default: Sequence[t.Any] | t.Any | None = None, + ignore_none: bool | None = None, +) -> _ArgFormat: if default is None: default = [] return _ArgFormat(lambda value: _ensure_list(_map.get(value, default)), ignore_none=ignore_none) @@ -126,5 +140,5 @@ def stack(fmt): return wrapper -def is_argformat(fmt): +def is_argformat(fmt: object) -> t.TypeGuard[_ArgFormat]: return isinstance(fmt, _ArgFormat) diff --git a/plugins/module_utils/deps.py b/plugins/module_utils/deps.py index 81f2ce4d39..31e6dd7f2f 100644 --- a/plugins/module_utils/deps.py +++ b/plugins/module_utils/deps.py @@ -16,51 +16,51 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -_deps = dict() +_deps: dict[str, _Dependency] = dict() class _Dependency: _states = ["pending", "failure", "success"] - def __init__(self, name, reason=None, url=None, msg=None): + def __init__(self, name: str, reason: str | None = None, url: str | None = None, msg: str | None = None) -> None: self.name = name self.reason = reason self.url = url self.msg = msg self.state = 0 - self.trace = None - self.exc = None + self.trace: str | None = None + self.exc: Exception | None = None - def succeed(self): + def succeed(self) -> None: self.state = 2 - def fail(self, exc, trace): + def fail(self, exc: Exception, trace: str) -> None: self.state = 1 self.exc = exc self.trace = trace @property - def message(self): + def message(self) -> str: if self.msg: return str(self.msg) else: return missing_required_lib(self.name, reason=self.reason, url=self.url) @property - def failed(self): + def failed(self) -> bool: return self.state == 1 def validate(self, module: AnsibleModule) -> None: if self.failed: module.fail_json(msg=self.message, exception=self.trace) - def __str__(self): + def __str__(self) -> str: return f"" @contextmanager -def declare(name, *args, **kwargs): +def declare(name: str, *args, **kwargs) -> t.Generator[_Dependency]: dep = _Dependency(name, *args, **kwargs) try: yield dep @@ -72,7 +72,7 @@ def declare(name, *args, **kwargs): _deps[name] = dep -def _select_names(spec): +def _select_names(spec: str | None) -> list[str]: dep_names = sorted(_deps) if spec: @@ -90,12 +90,12 @@ def _select_names(spec): return dep_names -def validate(module: AnsibleModule, spec=None) -> None: +def validate(module: AnsibleModule, spec: str | None = None) -> None: for dep in _select_names(spec): _deps[dep].validate(module) -def failed(spec=None) -> bool: +def failed(spec: str | None = None) -> bool: return any(_deps[d].failed for d in _select_names(spec)) diff --git a/plugins/module_utils/gio_mime.py b/plugins/module_utils/gio_mime.py index 09dcc75903..932a10e5ec 100644 --- a/plugins/module_utils/gio_mime.py +++ b/plugins/module_utils/gio_mime.py @@ -26,7 +26,7 @@ def gio_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner: ) -def gio_mime_get(runner, mime_type): +def gio_mime_get(runner: CmdRunner, mime_type): def process(rc, out, err): if err.startswith("No default applications for"): return None diff --git a/plugins/module_utils/gitlab.py b/plugins/module_utils/gitlab.py index 1753b82787..16528dce57 100644 --- a/plugins/module_utils/gitlab.py +++ b/plugins/module_utils/gitlab.py @@ -17,7 +17,7 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -def _determine_list_all_kwargs(version) -> dict[str, t.Any]: +def _determine_list_all_kwargs(version: str) -> dict[str, t.Any]: gitlab_version = LooseVersion(version) if gitlab_version >= LooseVersion("4.0.0"): # 4.0.0 removed 'as_list' @@ -43,7 +43,7 @@ except Exception: list_all_kwargs = {} -def auth_argument_spec(spec=None): +def auth_argument_spec(spec: dict[str, t.Any] | None = None) -> dict[str, t.Any]: arg_spec = dict( ca_path=dict(type="str"), api_token=dict(type="str", no_log=True), @@ -138,7 +138,7 @@ def gitlab_authentication(module: AnsibleModule, min_version=None) -> gitlab.Git return gitlab_instance -def filter_returned_variables(gitlab_variables): +def filter_returned_variables(gitlab_variables) -> list[dict[str, t.Any]]: # pop properties we don't know existing_variables = [dict(x.attributes) for x in gitlab_variables] KNOWN = [ @@ -159,9 +159,11 @@ def filter_returned_variables(gitlab_variables): return existing_variables -def vars_to_variables(vars, module: AnsibleModule): +def vars_to_variables( + vars: dict[str, str | int | float | dict[str, t.Any]], module: AnsibleModule +) -> list[dict[str, t.Any]]: # transform old vars to new variables structure - variables = list() + variables = [] for item, value in vars.items(): if isinstance(value, (str, int, float)): variables.append( diff --git a/plugins/module_utils/heroku.py b/plugins/module_utils/heroku.py index ec5c50fdad..1ac213872e 100644 --- a/plugins/module_utils/heroku.py +++ b/plugins/module_utils/heroku.py @@ -28,12 +28,12 @@ class HerokuHelper: self.check_lib() self.api_key = module.params["api_key"] - def check_lib(self): + def check_lib(self) -> None: if not HAS_HEROKU: self.module.fail_json(msg=missing_required_lib("heroku3"), exception=HEROKU_IMP_ERR) @staticmethod - def heroku_argument_spec(): + def heroku_argument_spec() -> dict[str, t.Any]: return dict( api_key=dict(fallback=(env_fallback, ["HEROKU_API_KEY", "TF_VAR_HEROKU_API_KEY"]), type="str", no_log=True) ) diff --git a/plugins/module_utils/hwc_utils.py b/plugins/module_utils/hwc_utils.py index 9a21b7269b..64ce4b59ea 100644 --- a/plugins/module_utils/hwc_utils.py +++ b/plugins/module_utils/hwc_utils.py @@ -25,37 +25,37 @@ from ansible.module_utils.common.text.converters import to_text class HwcModuleException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__() self._message = message - def __str__(self): + def __str__(self) -> str: return f"[HwcClientException] message={self._message}" class HwcClientException(Exception): - def __init__(self, code, message): + def __init__(self, code: int, message: str) -> None: super().__init__() self._code = code self._message = message - def __str__(self): + def __str__(self) -> str: msg = f" code={self._code}," if self._code != 0 else "" return f"[HwcClientException]{msg} message={self._message}" class HwcClientException404(HwcClientException): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(404, message) - def __str__(self): + def __str__(self) -> str: return f"[HwcClientException404] message={self._message}" def session_method_wrapper(f): - def _wrap(self, url, *args, **kwargs): + def _wrap(self, url: str, *args, **kwargs): try: url = self.endpoint + url r = f(self, url, *args, **kwargs) @@ -92,7 +92,7 @@ def session_method_wrapper(f): class _ServiceClient: - def __init__(self, client, endpoint, product): + def __init__(self, client, endpoint: str, product): self._client = client self._endpoint = endpoint self._default_header = { @@ -101,30 +101,30 @@ class _ServiceClient: } @property - def endpoint(self): + def endpoint(self) -> str: return self._endpoint @endpoint.setter - def endpoint(self, e): + def endpoint(self, e: str) -> None: self._endpoint = e @session_method_wrapper - def get(self, url, body=None, header=None, timeout=None): + def get(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.get(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def post(self, url, body=None, header=None, timeout=None): + def post(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.post(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def delete(self, url, body=None, header=None, timeout=None): + def delete(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.delete(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def put(self, url, body=None, header=None, timeout=None): + def put(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.put(url, json=body, timeout=timeout, headers=self._header(header)) - def _header(self, header): + def _header(self, header: dict[str, t.Any] | None) -> dict[str, t.Any]: if header and isinstance(header, dict): for k, v in self._default_header.items(): if k not in header: @@ -136,7 +136,7 @@ class _ServiceClient: class Config: - def __init__(self, module: AnsibleModule, product): + def __init__(self, module: AnsibleModule, product) -> None: self._project_client = None self._domain_client = None self._module = module diff --git a/plugins/module_utils/ibm_sa_utils.py b/plugins/module_utils/ibm_sa_utils.py index a4108c7374..df339ec6fb 100644 --- a/plugins/module_utils/ibm_sa_utils.py +++ b/plugins/module_utils/ibm_sa_utils.py @@ -76,7 +76,7 @@ def connect_ssl(module: AnsibleModule): module.fail_json(msg=f"Connection with Spectrum Accelerate system has failed: {e}.") -def spectrum_accelerate_spec(): +def spectrum_accelerate_spec() -> dict[str, t.Any]: """Return arguments spec for AnsibleModule""" return dict( endpoints=dict(required=True), @@ -103,6 +103,6 @@ def build_pyxcli_command(fields): return pyxcli_args -def is_pyxcli_installed(module: AnsibleModule): +def is_pyxcli_installed(module: AnsibleModule) -> None: if not PYXCLI_INSTALLED: module.fail_json(msg=missing_required_lib("pyxcli"), exception=PYXCLI_IMP_ERR) diff --git a/plugins/module_utils/ilo_redfish_utils.py b/plugins/module_utils/ilo_redfish_utils.py index c76477d3e0..64f746989e 100644 --- a/plugins/module_utils/ilo_redfish_utils.py +++ b/plugins/module_utils/ilo_redfish_utils.py @@ -4,13 +4,15 @@ from __future__ import annotations -from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils import time +import typing as t + +from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils class iLORedfishUtils(RedfishUtils): - def get_ilo_sessions(self): - result = {} + def get_ilo_sessions(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} # listing all users has always been slower than other operations, why? session_list = [] sessions_results = [] @@ -48,8 +50,8 @@ class iLORedfishUtils(RedfishUtils): result["ret"] = True return result - def set_ntp_server(self, mgr_attributes): - result = {} + def set_ntp_server(self, mgr_attributes) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} setkey = mgr_attributes["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -60,7 +62,7 @@ class iLORedfishUtils(RedfishUtils): return response result["ret"] = True data = response["data"] - payload = {"DHCPv4": {"UseNTPServers": ""}} + payload: dict[str, t.Any] = {"DHCPv4": {"UseNTPServers": ""}} if data["DHCPv4"]["UseNTPServers"]: payload["DHCPv4"]["UseNTPServers"] = False @@ -97,7 +99,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {mgr_attributes['mgr_attr_name']}"} - def set_time_zone(self, attr): + def set_time_zone(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] uri = f"{self.manager_uri}DateTime/" @@ -124,7 +126,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_dns_server(self, attr): + def set_dns_server(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() uri = nic_info["nic_addr"] @@ -148,7 +150,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_domain_name(self, attr): + def set_domain_name(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -160,7 +162,7 @@ class iLORedfishUtils(RedfishUtils): data = response["data"] - payload = {"DHCPv4": {"UseDomainName": ""}} + payload: dict[str, t.Any] = {"DHCPv4": {"UseDomainName": ""}} if data["DHCPv4"]["UseDomainName"]: payload["DHCPv4"]["UseDomainName"] = False @@ -185,7 +187,7 @@ class iLORedfishUtils(RedfishUtils): return response return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_wins_registration(self, mgrattr): + def set_wins_registration(self, mgrattr) -> dict[str, t.Any]: Key = mgrattr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -198,7 +200,7 @@ class iLORedfishUtils(RedfishUtils): return response return {"ret": True, "changed": True, "msg": f"Modified {mgrattr['mgr_attr_name']}"} - def get_server_poststate(self): + def get_server_poststate(self) -> dict[str, t.Any]: # Get server details response = self.get_request(self.root_uri + self.systems_uri) if not response["ret"]: @@ -210,7 +212,7 @@ class iLORedfishUtils(RedfishUtils): else: return {"ret": True, "server_poststate": server_data["Oem"]["Hp"]["PostState"]} - def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800): + def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800) -> dict[str, t.Any]: # This method checks if OOB controller reboot is completed time.sleep(10) diff --git a/plugins/module_utils/influxdb.py b/plugins/module_utils/influxdb.py index a66c64af80..3a9677f7c4 100644 --- a/plugins/module_utils/influxdb.py +++ b/plugins/module_utils/influxdb.py @@ -47,7 +47,7 @@ class InfluxDb: self.password = self.params["password"] self.database_name = self.params.get("database_name") - def check_lib(self): + def check_lib(self) -> None: if not HAS_REQUESTS: self.module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR) @@ -55,7 +55,7 @@ class InfluxDb: self.module.fail_json(msg=missing_required_lib("influxdb"), exception=INFLUXDB_IMP_ERR) @staticmethod - def influxdb_argument_spec(): + def influxdb_argument_spec() -> dict[str, t.Any]: return dict( hostname=dict(type="str", default="localhost"), port=dict(type="int", default=8086), @@ -71,7 +71,7 @@ class InfluxDb: udp_port=dict(type="int", default=4444), ) - def connect_to_influxdb(self): + def connect_to_influxdb(self) -> InfluxDBClient: args = dict( host=self.hostname, port=self.port, diff --git a/plugins/module_utils/ipa.py b/plugins/module_utils/ipa.py index c7ef98604b..ed8c180f8c 100644 --- a/plugins/module_utils/ipa.py +++ b/plugins/module_utils/ipa.py @@ -27,7 +27,7 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -def _env_then_dns_fallback(*args, **kwargs): +def _env_then_dns_fallback(*args, **kwargs) -> str: """Load value from environment or DNS in that order""" try: result = env_fallback(*args, **kwargs) @@ -54,10 +54,10 @@ class IPAClient: self.timeout = module.params.get("ipa_timeout") self.use_gssapi = False - def get_base_url(self): + def get_base_url(self) -> str: return f"{self.protocol}://{self.host}/ipa" - def get_json_url(self): + def get_json_url(self) -> str: return f"{self.get_base_url()}/session/json" def login(self, username, password): @@ -102,7 +102,7 @@ class IPAClient: {"referer": self.get_base_url(), "Content-Type": "application/json", "Accept": "application/json"} ) - def _fail(self, msg, e): + def _fail(self, msg: str, e) -> t.NoReturn: if "message" in e: err_string = e.get("message") else: @@ -209,7 +209,7 @@ class IPAClient: return changed -def ipa_argument_spec(): +def ipa_argument_spec() -> dict[str, t.Any]: return dict( ipa_prot=dict(type="str", default="https", choices=["http", "https"], fallback=(env_fallback, ["IPA_PROT"])), ipa_host=dict(type="str", default="ipa.example.com", fallback=(_env_then_dns_fallback, ["IPA_HOST"])), diff --git a/plugins/module_utils/jenkins.py b/plugins/module_utils/jenkins.py index 9c9f16c969..810128dab8 100644 --- a/plugins/module_utils/jenkins.py +++ b/plugins/module_utils/jenkins.py @@ -10,7 +10,7 @@ import os import time -def download_updates_file(updates_expiration): +def download_updates_file(updates_expiration: int | float) -> tuple[str, bool]: updates_filename = "jenkins-plugin-cache.json" updates_dir = os.path.expanduser("~/.ansible/tmp") updates_file = os.path.join(updates_dir, updates_filename) diff --git a/plugins/module_utils/ldap.py b/plugins/module_utils/ldap.py index d226f9ab2b..41139d1882 100644 --- a/plugins/module_utils/ldap.py +++ b/plugins/module_utils/ldap.py @@ -30,7 +30,7 @@ except ImportError: HAS_LDAP = False -def gen_specs(**specs): +def gen_specs(**specs: t.Any) -> dict[str, t.Any]: specs.update( { "bind_dn": dict(), @@ -51,7 +51,7 @@ def gen_specs(**specs): return specs -def ldap_required_together(): +def ldap_required_together() -> list[list[str]]: return [["client_cert", "client_key"]] @@ -80,11 +80,11 @@ class LdapGeneric: else: self.dn = self.module.params["dn"] - def fail(self, msg, exn): + def fail(self, msg: str, exn: str | Exception) -> t.NoReturn: self.module.fail_json(msg=msg, details=f"{exn}", exception=traceback.format_exc()) - def _find_dn(self): - dn = self.module.params["dn"] + def _find_dn(self) -> str: + dn: str = self.module.params["dn"] explode_dn = ldap.dn.explode_dn(dn) @@ -134,7 +134,7 @@ class LdapGeneric: return connection - def _xorder_dn(self): + def _xorder_dn(self) -> bool: # match X_ORDERed DNs regex = r".+\{\d+\}.+" explode_dn = ldap.dn.explode_dn(self.module.params["dn"]) diff --git a/plugins/module_utils/lxd.py b/plugins/module_utils/lxd.py index 95644ff860..7513acd69e 100644 --- a/plugins/module_utils/lxd.py +++ b/plugins/module_utils/lxd.py @@ -6,10 +6,11 @@ from __future__ import annotations import http.client as http_client +import json import os import socket import ssl -import json +import typing as t from urllib.parse import urlparse from ansible.module_utils.urls import generic_urlparse @@ -20,7 +21,7 @@ HTTPSConnection = http_client.HTTPSConnection class UnixHTTPConnection(HTTPConnection): - def __init__(self, path): + def __init__(self, path: str) -> None: HTTPConnection.__init__(self, "localhost") self.path = path @@ -31,33 +32,34 @@ class UnixHTTPConnection(HTTPConnection): class LXDClientException(Exception): - def __init__(self, msg, **kwargs): + def __init__(self, msg: str, **kwargs) -> None: self.msg = msg self.kwargs = kwargs class LXDClient: def __init__( - self, url, key_file=None, cert_file=None, debug=False, server_cert_file=None, server_check_hostname=True - ): + self, + url: str, + key_file: str | None = None, + cert_file: str | None = None, + debug: bool = False, + server_cert_file: str | None = None, + server_check_hostname: bool = True, + ) -> None: """LXD Client. :param url: The URL of the LXD server. (e.g. unix:/var/lib/lxd/unix.socket or https://127.0.0.1) - :type url: ``str`` :param key_file: The path of the client certificate key file. - :type key_file: ``str`` :param cert_file: The path of the client certificate file. - :type cert_file: ``str`` :param debug: The debug flag. The request and response are stored in logs when debug is true. - :type debug: ``bool`` :param server_cert_file: The path of the server certificate file. - :type server_cert_file: ``str`` :param server_check_hostname: Whether to check the server's hostname as part of TLS verification. - :type debug: ``bool`` """ self.url = url self.debug = debug - self.logs = [] + self.logs: list[dict[str, t.Any]] = [] + self.connection: UnixHTTPConnection | HTTPSConnection if url.startswith("https:"): self.cert_file = cert_file self.key_file = key_file @@ -67,7 +69,7 @@ class LXDClient: # Check that the received cert is signed by the provided server_cert_file ctx.load_verify_locations(cafile=server_cert_file) ctx.check_hostname = server_check_hostname - ctx.load_cert_chain(cert_file, keyfile=key_file) + ctx.load_cert_chain(cert_file, keyfile=key_file) # type: ignore # TODO! self.connection = HTTPSConnection(parts.get("netloc"), context=ctx) elif url.startswith("unix:"): unix_socket_path = url[len("unix:") :] @@ -75,7 +77,7 @@ class LXDClient: else: raise LXDClientException("URL scheme must be unix: or https:") - def do(self, method, url, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None): + def do(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None): resp_json = self._send_request(method, url, body_json=body_json, ok_error_codes=ok_error_codes, timeout=timeout) if resp_json["type"] == "async": url = f"{resp_json['operation']}/wait" @@ -91,7 +93,7 @@ class LXDClient: body_json = {"type": "client", "password": trust_password} return self._send_request("POST", "/1.0/certificates", body_json=body_json) - def _send_request(self, method, url, body_json=None, ok_error_codes=None, timeout=None): + def _send_request(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None): try: body = json.dumps(body_json) self.connection.request(method, url, body=body) @@ -133,9 +135,9 @@ class LXDClient: return err -def default_key_file(): +def default_key_file() -> str: return os.path.expanduser("~/.config/lxc/client.key") -def default_cert_file(): +def default_cert_file() -> str: return os.path.expanduser("~/.config/lxc/client.crt") diff --git a/plugins/module_utils/manageiq.py b/plugins/module_utils/manageiq.py index 5214251429..3936079f21 100644 --- a/plugins/module_utils/manageiq.py +++ b/plugins/module_utils/manageiq.py @@ -32,7 +32,7 @@ except ImportError: HAS_CLIENT = False -def manageiq_argument_spec(): +def manageiq_argument_spec() -> dict[str, t.Any]: options = dict( url=dict(default=os.environ.get("MIQ_URL", None)), username=dict(default=os.environ.get("MIQ_USERNAME", None)), @@ -68,7 +68,7 @@ def validate_connection_params(module: AnsibleModule) -> dict[str, t.Any]: raise AssertionError("should be unreachable") -def manageiq_entities(): +def manageiq_entities() -> dict[str, str]: return { "provider": "providers", "host": "hosts", @@ -125,7 +125,7 @@ class ManageIQ: return self._module @property - def api_url(self): + def api_url(self) -> str: """Base ManageIQ API Returns: diff --git a/plugins/module_utils/online.py b/plugins/module_utils/online.py index 95351d494c..adf5f66ae4 100644 --- a/plugins/module_utils/online.py +++ b/plugins/module_utils/online.py @@ -15,7 +15,7 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -def online_argument_spec(): +def online_argument_spec() -> dict[str, t.Any]: return dict( api_token=dict( required=True, @@ -32,7 +32,7 @@ def online_argument_spec(): class OnlineException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message diff --git a/plugins/module_utils/pkg_req.py b/plugins/module_utils/pkg_req.py index 360fb36fe8..06a44cd5cf 100644 --- a/plugins/module_utils/pkg_req.py +++ b/plugins/module_utils/pkg_req.py @@ -18,11 +18,11 @@ with deps.declare("packaging"): class PackageRequirement: - def __init__(self, module: AnsibleModule, name) -> None: + def __init__(self, module: AnsibleModule, name: str) -> None: self.module = module self.parsed_name, self.requirement = self._parse_spec(name) - def _parse_spec(self, name): + def _parse_spec(self, name: str) -> tuple[str, Requirement | None]: """ Parse a package name that may include version specifiers using PEP 508. Returns a tuple of (name, requirement) where requirement is of type packaging.requirements.Requirement and it may be None. @@ -54,7 +54,7 @@ class PackageRequirement: except Exception as e: raise ValueError(f"Invalid package specification for '{name}': {e}") from e - def matches_version(self, version): + def matches_version(self, version: str): """ Check if a version string fulfills a version specifier. diff --git a/plugins/module_utils/python_runner.py b/plugins/module_utils/python_runner.py index 12826add1c..c8780174db 100644 --- a/plugins/module_utils/python_runner.py +++ b/plugins/module_utils/python_runner.py @@ -20,20 +20,20 @@ class PythonRunner(CmdRunner): command, arg_formats=None, default_args_order=(), - check_rc=False, - force_lang="C", - path_prefix=None, - environ_update=None, - python="python", - venv=None, - ): + check_rc: bool = False, + force_lang: str = "C", + path_prefix: list[str] | None = None, + environ_update: dict[str, str] | None = None, + python: str = "python", + venv: str | None = None, + ) -> None: self.python = python self.venv = venv self.has_venv = venv is not None if os.path.isabs(python) or "/" in python: self.python = python - elif self.has_venv: + elif venv is not None: if path_prefix is None: path_prefix = [] path_prefix.append(os.path.join(venv, "bin")) diff --git a/plugins/module_utils/redfish_utils.py b/plugins/module_utils/redfish_utils.py index 80653d4c99..5904aa9eb4 100644 --- a/plugins/module_utils/redfish_utils.py +++ b/plugins/module_utils/redfish_utils.py @@ -55,15 +55,15 @@ REDFISH_COMMON_ARGUMENT_SPEC = { class RedfishUtils: def __init__( self, - creds, - root_uri, + creds: dict[str, str], + root_uri: str, timeout, module: AnsibleModule, resource_id=None, - data_modification=False, - strip_etag_quotes=False, - ciphers=None, - ): + data_modification: bool = False, + strip_etag_quotes: bool = False, + ciphers: str | None = None, + ) -> None: self.root_uri = root_uri self.creds = creds self.timeout = timeout @@ -79,7 +79,7 @@ class RedfishUtils: self.validate_certs = module.params.get("validate_certs", False) self.ca_path = module.params.get("ca_path") - def _auth_params(self, headers): + def _auth_params(self, headers: dict[str, str]) -> tuple[str | None, str | None, bool]: """ Return tuple of required authentication params based on the presence of a token in the self.creds dict. If using a token, set the @@ -157,7 +157,7 @@ class RedfishUtils: resp["msg"] = f"Properties in {uri} are already set" return resp - def _request(self, uri, **kwargs): + def _request(self, uri: str, **kwargs): kwargs.setdefault("validate_certs", self.validate_certs) kwargs.setdefault("follow_redirects", "all") kwargs.setdefault("use_proxy", True) @@ -169,7 +169,9 @@ class RedfishUtils: return resp, headers # The following functions are to send GET/POST/PATCH/DELETE requests - def get_request(self, uri, override_headers=None, allow_no_resp=False, timeout=None): + def get_request( + self, uri: str, override_headers: dict[str, str] | None = None, allow_no_resp: bool = False, timeout=None + ): req_headers = dict(GET_HEADERS) if override_headers: req_headers.update(override_headers) @@ -212,7 +214,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed GET request to '{uri}': '{e}'"} return {"ret": True, "data": data, "headers": headers, "resp": resp} - def post_request(self, uri, pyld, multipart=False): + def post_request(self, uri: str, pyld, multipart: bool = False): req_headers = dict(POST_HEADERS) username, password, basic_auth = self._auth_params(req_headers) try: @@ -257,7 +259,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed POST request to '{uri}': '{e}'"} return {"ret": True, "data": data, "headers": headers, "resp": resp} - def patch_request(self, uri, pyld, check_pyld=False): + def patch_request(self, uri: str, pyld, check_pyld: bool = False): req_headers = dict(PATCH_HEADERS) r = self.get_request(uri) if r["ret"]: @@ -309,7 +311,7 @@ class RedfishUtils: return {"ret": False, "changed": False, "msg": f"Failed PATCH request to '{uri}': '{e}'"} return {"ret": True, "changed": True, "resp": resp, "msg": f"Modified {uri}"} - def put_request(self, uri, pyld): + def put_request(self, uri: str, pyld): req_headers = dict(PUT_HEADERS) r = self.get_request(uri) if r["ret"]: @@ -347,7 +349,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed PUT request to '{uri}': '{e}'"} return {"ret": True, "resp": resp} - def delete_request(self, uri, pyld=None): + def delete_request(self, uri: str, pyld=None): req_headers = dict(DELETE_HEADERS) username, password, basic_auth = self._auth_params(req_headers) try: diff --git a/plugins/module_utils/redis.py b/plugins/module_utils/redis.py index 7240cc320a..615dcd48cd 100644 --- a/plugins/module_utils/redis.py +++ b/plugins/module_utils/redis.py @@ -47,7 +47,7 @@ def fail_imports(module: AnsibleModule, needs_certifi: bool = True) -> None: module.fail_json(msg="\n".join(errors), traceback="\n".join(traceback)) -def redis_auth_argument_spec(tls_default=True): +def redis_auth_argument_spec(tls_default: bool = True) -> dict[str, t.Any]: return dict( login_host=dict( type="str", @@ -64,7 +64,7 @@ def redis_auth_argument_spec(tls_default=True): ) -def redis_auth_params(module: AnsibleModule): +def redis_auth_params(module: AnsibleModule) -> dict[str, t.Any]: login_host = module.params["login_host"] login_user = module.params["login_user"] login_password = module.params["login_password"] @@ -100,9 +100,8 @@ class RedisAnsible: self.module = module self.connection = self._connect() - def _connect(self): + def _connect(self) -> Redis: try: return Redis(**redis_auth_params(self.module)) except Exception as e: self.module.fail_json(msg=f"{e}") - return None diff --git a/plugins/module_utils/rundeck.py b/plugins/module_utils/rundeck.py index d089b77bf0..996ab698aa 100644 --- a/plugins/module_utils/rundeck.py +++ b/plugins/module_utils/rundeck.py @@ -14,7 +14,7 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -def api_argument_spec(): +def api_argument_spec() -> dict[str, t.Any]: """ Creates an argument spec that can be used with any module that will be requesting content via Rundeck API @@ -31,7 +31,13 @@ def api_argument_spec(): return api_argument_spec -def api_request(module: AnsibleModule, endpoint, data=None, method="GET", content_type="application/json"): +def api_request( + module: AnsibleModule, + endpoint: str, + data: t.Any | None = None, + method: str = "GET", + content_type: str = "application/json", +) -> tuple[t.Any, dict[str, t.Any]]: """Manages Rundeck API requests via HTTP(S) :arg module: The AnsibleModule (used to get url, api_version, api_token, etc). diff --git a/plugins/module_utils/scaleway.py b/plugins/module_utils/scaleway.py index 924ca9e537..c7663d7b16 100644 --- a/plugins/module_utils/scaleway.py +++ b/plugins/module_utils/scaleway.py @@ -21,6 +21,7 @@ from ansible_collections.community.general.plugins.module_utils.datetime import ) if t.TYPE_CHECKING: + from collections.abc import Iterable from ansible.module_utils.basic import AnsibleModule SCALEWAY_SECRET_IMP_ERR: str | None = None @@ -33,7 +34,7 @@ except Exception: HAS_SCALEWAY_SECRET_PACKAGE = False -def scaleway_argument_spec(): +def scaleway_argument_spec() -> dict[str, t.Any]: return dict( api_token=dict( required=True, @@ -63,7 +64,7 @@ def payload_from_object(scw_object): class ScalewayException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message @@ -74,7 +75,7 @@ R_LINK_HEADER = r"""<[^>]+>;\srel="(first|previous|next|last)" R_RELATION = r'[^>]+)>; rel="(?Pfirst|previous|next|last)"' -def parse_pagination_link(header): +def parse_pagination_link(header: str) -> dict[str, str]: if not re.match(R_LINK_HEADER, header, re.VERBOSE): raise ScalewayException("Scaleway API answered with an invalid Link pagination header") else: @@ -90,7 +91,7 @@ def parse_pagination_link(header): return parsed_relations -def filter_sensitive_attributes(container, attributes): +def filter_sensitive_attributes(container: dict[str, t.Any], attributes: Iterable[str]) -> dict[str, t.Any]: """ WARNING: This function is effectively private, **do not use it**! It will be removed or renamed once changing its name no longer triggers a pylint bug. @@ -103,7 +104,7 @@ def filter_sensitive_attributes(container, attributes): class SecretVariables: @staticmethod - def ensure_scaleway_secret_package(module): + def ensure_scaleway_secret_package(module: AnsibleModule) -> None: if not HAS_SCALEWAY_SECRET_PACKAGE: module.fail_json( msg=missing_required_lib("passlib[argon2]", url="https://passlib.readthedocs.io/en/stable/"), @@ -228,8 +229,9 @@ class Scaleway: return Response(resp, info) @staticmethod - def get_user_agent_string(module): - return f"ansible {module.ansible_version} Python {sys.version.split(' ', 1)[0]}" + def get_user_agent_string(module: AnsibleModule) -> str: + ansible_version = module.ansible_version # type: ignore # For some reason this isn't documented in AnsibleModule + return f"ansible {ansible_version} Python {sys.version.split(' ', 1)[0]}" def get(self, path, data=None, headers=None, params=None): return self.send(method="GET", path=path, data=data, headers=headers, params=params) @@ -249,7 +251,7 @@ class Scaleway: def update(self, path, data=None, headers=None, params=None): return self.send(method="UPDATE", path=path, data=data, headers=headers, params=params) - def warn(self, x): + def warn(self, x) -> None: self.module.warn(str(x)) def fetch_state(self, resource): diff --git a/plugins/module_utils/snap.py b/plugins/module_utils/snap.py index 31f869c08b..f14b7a5315 100644 --- a/plugins/module_utils/snap.py +++ b/plugins/module_utils/snap.py @@ -52,7 +52,7 @@ def snap_runner(module: AnsibleModule, **kwargs) -> CmdRunner: return runner -def get_version(runner: CmdRunner): +def get_version(runner: CmdRunner) -> dict[str, list[str]]: with runner("version") as ctx: rc, out, err = ctx.run() return dict(x.split() for x in out.splitlines() if len(x.split()) == 2) diff --git a/plugins/module_utils/utm_utils.py b/plugins/module_utils/utm_utils.py index 07c3bac75f..443ee1a178 100644 --- a/plugins/module_utils/utm_utils.py +++ b/plugins/module_utils/utm_utils.py @@ -12,18 +12,19 @@ from __future__ import annotations import json +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.urls import fetch_url class UTMModuleConfigurationError(Exception): - def __init__(self, msg, **args): + def __init__(self, msg: str, **args): super().__init__(self, msg) self.msg = msg self.module_fail_args = args - def do_fail(self, module): + def do_fail(self, module: AnsibleModule) -> t.NoReturn: module.fail_json(msg=self.msg, other=self.module_fail_args) diff --git a/plugins/module_utils/vardict.py b/plugins/module_utils/vardict.py index 195ec4d847..7ada2021e0 100644 --- a/plugins/module_utils/vardict.py +++ b/plugins/module_utils/vardict.py @@ -6,33 +6,41 @@ from __future__ import annotations import copy +import typing as t class _Variable: NOTHING = object() - def __init__(self, diff=False, output=True, change=None, fact=False, verbosity=0): + def __init__( + self, + diff: bool = False, + output: bool = True, + change: bool | None = None, + fact: bool = False, + verbosity: int = 0, + ): self.init = False - self.initial_value = None - self.value = None + self.initial_value: t.Any = None + self.value: t.Any = None - self.diff = None - self._change = None - self.output = None - self.fact = None - self._verbosity = None + self.diff: bool = None # type: ignore # will be changed in set_meta() call + self._change: bool | None = None + self.output: bool = None # type: ignore # will be changed in set_meta() call + self.fact: bool = None # type: ignore # will be changed in set_meta() call + self._verbosity: int = None # type: ignore # will be changed in set_meta() call self.set_meta(output=output, diff=diff, change=change, fact=fact, verbosity=verbosity) - def getchange(self): + def getchange(self) -> bool: return self.diff if self._change is None else self._change - def setchange(self, value): + def setchange(self, value: bool | None) -> None: self._change = value - def getverbosity(self): + def getverbosity(self) -> int: return self._verbosity - def setverbosity(self, v): + def setverbosity(self, v: int) -> None: if not (0 <= v <= 4): raise ValueError("verbosity must be an int in the range 0 to 4") self._verbosity = v @@ -40,7 +48,15 @@ class _Variable: change = property(getchange, setchange) verbosity = property(getverbosity, setverbosity) - def set_meta(self, output=None, diff=None, change=None, fact=None, initial_value=NOTHING, verbosity=None): + def set_meta( + self, + output: bool | None = None, + diff: bool | None = None, + change: bool | None = None, + fact: bool | None = None, + initial_value: t.Any = NOTHING, + verbosity: int | None = None, + ) -> None: """Set the metadata for the variable Args: @@ -64,7 +80,7 @@ class _Variable: if verbosity is not None: self.verbosity = verbosity - def as_dict(self, meta_only=False): + def as_dict(self, meta_only: bool = False) -> dict[str, t.Any]: d = { "diff": self.diff, "change": self.change, @@ -77,27 +93,27 @@ class _Variable: d["value"] = self.value return d - def set_value(self, value): + def set_value(self, value: t.Any) -> t.Self: if not self.init: self.initial_value = copy.deepcopy(value) self.init = True self.value = value return self - def is_visible(self, verbosity): + def is_visible(self, verbosity: int) -> bool: return self.verbosity <= verbosity @property - def has_changed(self): + def has_changed(self) -> bool: return self.change and (self.initial_value != self.value) @property - def diff_result(self): + def diff_result(self) -> dict[str, t.Any] | None: if self.diff and self.has_changed: return {"before": self.initial_value, "after": self.value} - return + return None - def __str__(self): + def __str__(self) -> str: return ( f"" @@ -119,34 +135,34 @@ class VarDict: "as_dict", ) - def __init__(self): - self.__vars__ = dict() + def __init__(self) -> None: + self.__vars__: dict[str, _Variable] = dict() - def __getitem__(self, item): + def __getitem__(self, item: str): return self.__vars__[item].value - def __setitem__(self, key, value): + def __setitem__(self, key: str, value) -> None: self.set(key, value) - def __getattr__(self, item): + def __getattr__(self, item: str): try: return self.__vars__[item].value except KeyError: return getattr(super(), item) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value) -> None: if key == "__vars__": super().__setattr__(key, value) else: self.set(key, value) - def _var(self, name): + def _var(self, name: str) -> _Variable: return self.__vars__[name] - def var(self, name): + def var(self, name: str) -> dict[str, t.Any]: return self._var(name).as_dict() - def set_meta(self, name, **kwargs): + def set_meta(self, name: str, **kwargs): """Set the metadata for the variable Args: @@ -160,10 +176,10 @@ class VarDict: """ self._var(name).set_meta(**kwargs) - def get_meta(self, name): + def get_meta(self, name: str) -> dict[str, t.Any]: return self._var(name).as_dict(meta_only=True) - def set(self, name, value, **kwargs): + def set(self, name: str, value, **kwargs) -> None: """Set the value and optionally metadata for a variable. The variable is not required to exist prior to calling `set`. For details on the accepted metada see the documentation for method `set_meta`. @@ -185,10 +201,10 @@ class VarDict: var.set_value(value) self.__vars__[name] = var - def output(self, verbosity=0): + def output(self, verbosity: int = 0) -> dict[str, t.Any]: return {n: v.value for n, v in self.__vars__.items() if v.output and v.is_visible(verbosity)} - def diff(self, verbosity=0): + def diff(self, verbosity: int = 0) -> dict[str, t.Any] | None: diff_results = [ (n, v.diff_result) for n, v in self.__vars__.items() if v.diff_result and v.is_visible(verbosity) ] @@ -198,13 +214,13 @@ class VarDict: return {"before": before, "after": after} return None - def facts(self, verbosity=0): + def facts(self, verbosity: int = 0) -> dict[str, t.Any] | None: facts_result = {n: v.value for n, v in self.__vars__.items() if v.fact and v.is_visible(verbosity)} return facts_result if facts_result else None @property - def has_changed(self): + def has_changed(self) -> bool: return any(var.has_changed for var in self.__vars__.values()) - def as_dict(self): + def as_dict(self) -> dict[str, t.Any]: return {name: var.value for name, var in self.__vars__.items()} diff --git a/plugins/module_utils/vexata.py b/plugins/module_utils/vexata.py index d6d81cf64e..b55b67f113 100644 --- a/plugins/module_utils/vexata.py +++ b/plugins/module_utils/vexata.py @@ -22,7 +22,7 @@ if t.TYPE_CHECKING: VXOS_VERSION = None -def get_version(iocs_json): +def get_version(iocs_json) -> tuple[int, ...]: if not iocs_json: raise Exception("Invalid IOC json") active = next((x for x in iocs_json if x["mgmtRole"]), None) @@ -65,7 +65,7 @@ def get_array(module: AnsibleModule): module.fail_json(msg=f"Vexata API access failed: {e}") -def argument_spec(): +def argument_spec() -> dict[str, t.Any]: """Return standard base dictionary used for the argument_spec argument in AnsibleModule""" return dict( array=dict(type="str", required=True), @@ -75,20 +75,20 @@ def argument_spec(): ) -def required_together(): +def required_together() -> list[list[str]]: """Return the default list used for the required_together argument to AnsibleModule""" return [["user", "password"]] -def size_to_MiB(size): +def size_to_MiB(size: str) -> int: """Convert a '[MGT]' string to MiB, return -1 on error.""" quant = size[:-1] exponent = size[-1] if not quant.isdigit() or exponent not in "MGT": return -1 - quant = int(quant) + quant_int = int(quant) if exponent == "G": - quant <<= 10 + quant_int <<= 10 elif exponent == "T": - quant <<= 20 - return quant + quant_int <<= 20 + return quant_int diff --git a/plugins/module_utils/wdc_redfish_utils.py b/plugins/module_utils/wdc_redfish_utils.py index 39a055652c..56ec7537d6 100644 --- a/plugins/module_utils/wdc_redfish_utils.py +++ b/plugins/module_utils/wdc_redfish_utils.py @@ -45,7 +45,7 @@ class WdcRedfishUtils(RedfishUtils): CHASSIS_LOCATE = "#Chassis.Locate" CHASSIS_POWER_MODE = "#Chassis.PowerMode" - def __init__(self, creds, root_uris, timeout, module: AnsibleModule, resource_id, data_modification): + def __init__(self, creds, root_uris, timeout, module: AnsibleModule, resource_id, data_modification) -> None: super().__init__( creds=creds, root_uri=root_uris[0], diff --git a/plugins/module_utils/xenserver.py b/plugins/module_utils/xenserver.py index 24368b10fd..bf9b4f0a9b 100644 --- a/plugins/module_utils/xenserver.py +++ b/plugins/module_utils/xenserver.py @@ -27,7 +27,7 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule -def xenserver_common_argument_spec(): +def xenserver_common_argument_spec() -> dict[str, t.Any]: return dict( hostname=dict( type="str", @@ -45,7 +45,7 @@ def xenserver_common_argument_spec(): ) -def xapi_to_module_vm_power_state(power_state): +def xapi_to_module_vm_power_state(power_state: str) -> str | None: """Maps XAPI VM power states to module VM power states.""" module_power_state_map = { "running": "poweredon", @@ -57,7 +57,7 @@ def xapi_to_module_vm_power_state(power_state): return module_power_state_map.get(power_state) -def module_to_xapi_vm_power_state(power_state): +def module_to_xapi_vm_power_state(power_state: str) -> str | None: """Maps module VM power states to XAPI VM power states.""" vm_power_state_map = { "poweredon": "running", @@ -71,7 +71,7 @@ def module_to_xapi_vm_power_state(power_state): return vm_power_state_map.get(power_state) -def is_valid_ip_addr(ip_addr): +def is_valid_ip_addr(ip_addr: str) -> bool: """Validates given string as IPv4 address for given string. Args: @@ -97,7 +97,7 @@ def is_valid_ip_addr(ip_addr): return True -def is_valid_ip_netmask(ip_netmask): +def is_valid_ip_netmask(ip_netmask: str) -> bool: """Validates given string as IPv4 netmask. Args: @@ -129,7 +129,7 @@ def is_valid_ip_netmask(ip_netmask): return True -def is_valid_ip_prefix(ip_prefix): +def is_valid_ip_prefix(ip_prefix: str) -> bool: """Validates given string as IPv4 prefix. Args: @@ -146,7 +146,7 @@ def is_valid_ip_prefix(ip_prefix): return not (ip_prefix_int < 0 or ip_prefix_int > 32) -def ip_prefix_to_netmask(ip_prefix, skip_check=False): +def ip_prefix_to_netmask(ip_prefix: str, skip_check: bool = False) -> str: """Converts IPv4 prefix to netmask. Args: @@ -169,7 +169,7 @@ def ip_prefix_to_netmask(ip_prefix, skip_check=False): return "" -def ip_netmask_to_prefix(ip_netmask, skip_check=False): +def ip_netmask_to_prefix(ip_netmask: str, skip_check: bool = False) -> str: """Converts IPv4 netmask to prefix. Args: @@ -192,7 +192,7 @@ def ip_netmask_to_prefix(ip_netmask, skip_check=False): return "" -def is_valid_ip6_addr(ip6_addr): +def is_valid_ip6_addr(ip6_addr: str) -> bool: """Validates given string as IPv6 address. Args: @@ -226,7 +226,7 @@ def is_valid_ip6_addr(ip6_addr): return all(ip6_addr_hextet_regex.match(ip6_addr_hextet) for ip6_addr_hextet in ip6_addr_split) -def is_valid_ip6_prefix(ip6_prefix): +def is_valid_ip6_prefix(ip6_prefix: str) -> bool: """Validates given string as IPv6 prefix. Args: