Make all module_utils and plugin_utils private (#887)

* Add leading underscore. Remove deprecated module utils.

* Document module and plugin utils as private. Add changelog fragment.

* Convert relative to absolute imports.

* Remove unnecessary imports.
This commit is contained in:
Felix Fontein
2025-05-11 19:17:58 +02:00
committed by GitHub
parent f758d94fba
commit a5a4e022ba
146 changed files with 678 additions and 465 deletions

View File

@@ -0,0 +1,456 @@
# Copyright (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import os
import stat
import traceback
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._openssh.utils import (
parse_openssh_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._openssh.certificate import (
OpensshCertificateTimeParameters,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
Param = t.ParamSpec("Param")
def restore_on_failure(
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
def backup_and_restore(
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
) -> None:
backup_file = module.backup_local(path) if os.path.exists(path) else None
try:
f(module, path, *args, **kwargs)
except Exception:
if backup_file is not None:
module.atomic_move(os.path.abspath(backup_file), os.path.abspath(path))
raise
else:
module.add_cleanup_file(backup_file)
return backup_and_restore
@restore_on_failure
def safe_atomic_move(
module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike
) -> None:
module.atomic_move(os.path.abspath(path), os.path.abspath(destination))
def _restore_all_on_failure(
f: t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
],
) -> t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
]:
def backup_and_restore(
self: OpensshModule,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
*args,
**kwargs,
) -> None:
backups = [
(d, self.module.backup_local(d))
for s, d in sources_and_destinations
if os.path.exists(d)
]
try:
f(self, sources_and_destinations, *args, **kwargs)
except Exception:
for destination, backup in backups:
self.module.atomic_move(
os.path.abspath(backup), os.path.abspath(destination)
)
raise
else:
for destination, backup in backups:
self.module.add_cleanup_file(backup)
return backup_and_restore
class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.changed: bool = False
self.check_mode: bool = self.module.check_mode
def execute(self) -> t.NoReturn:
try:
self._execute()
except Exception as e:
self.module.fail_json(
msg=f"unexpected error occurred: {e}",
exception=traceback.format_exc(),
)
self.module.exit_json(**self.result)
@abc.abstractmethod
def _execute(self) -> None:
pass
@property
def result(self) -> dict[str, t.Any]:
result = self._result
result["changed"] = self.changed
if self.module._diff:
result["diff"] = self.diff
return result
@property
@abc.abstractmethod
def _result(self) -> dict[str, t.Any]:
pass
@property
@abc.abstractmethod
def diff(self) -> dict[str, t.Any]:
pass
@staticmethod
def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
if not self.check_mode:
f(self, *args, **kwargs)
return wrapper # type: ignore
@staticmethod
def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
f(self, *args, **kwargs)
self.changed = True
return wrapper # type: ignore
def _check_if_base_dir(self, path: str | os.PathLike) -> None:
base_dir = os.path.dirname(path) or "."
if not os.path.isdir(base_dir):
self.module.fail_json(
name=base_dir,
msg=f"The directory {base_dir} does not exist or the file is not a directory",
)
def _get_ssh_version(self) -> str | None:
ssh_bin = self.module.get_bin_path("ssh")
if not ssh_bin:
return None
return parse_openssh_version(
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
)
@_restore_all_on_failure
def _safe_secure_move(
self,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
) -> None:
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
permissions if 'destination' does not already exists).
"""
for source, destination in sources_and_destinations:
if os.path.exists(destination):
self.module.atomic_move(
os.path.abspath(source), os.path.abspath(destination)
)
else:
self.module.preserved_copy(source, destination)
def _update_permissions(self, path: str | os.PathLike) -> None:
file_args = self.module.load_file_common_arguments(self.module.params)
file_args["path"] = path
if not self.module.check_file_absent_if_check_mode(path):
self.changed = self.module.set_fs_attributes_if_different(
file_args, self.changed
)
else:
self.changed = True
class KeygenCommand:
def __init__(self, module: AnsibleModule) -> None:
self._bin_path = module.get_bin_path("ssh-keygen", True)
self._run_command = module.run_command
def generate_certificate(
self,
certificate_path: str,
identifier: str,
options: list[str] | None,
pkcs11_provider: str | None,
principals: list[str] | None,
serial_number: int | None,
signature_algorithm: str | None,
signing_key_path: str,
type: t.Literal["host", "user"] | None,
time_parameters: OpensshCertificateTimeParameters,
use_agent: bool,
**kwargs,
) -> tuple[int, str, str]:
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
if options:
for option in options:
args.extend(["-O", option])
if pkcs11_provider:
args.extend(["-D", pkcs11_provider])
if principals:
args.extend(["-n", ",".join(principals)])
if serial_number is not None:
args.extend(["-z", str(serial_number)])
if type == "host":
args.extend(["-h"])
if use_agent:
args.extend(["-U"])
if time_parameters.validity_string:
args.extend(["-V", time_parameters.validity_string])
if signature_algorithm:
args.extend(["-t", signature_algorithm])
args.append(certificate_path)
return self._run_command(args, **kwargs)
def generate_keypair(
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
) -> tuple[int, str, str]:
args = [
self._bin_path,
"-q",
"-N",
"",
"-b",
str(size),
"-t",
type,
"-f",
private_key_path,
"-C",
comment or "",
]
# "y" must be entered in response to the "overwrite" prompt
data = "y" if os.path.exists(private_key_path) else None
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(
self, certificate_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(
self, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(
self,
private_key_path: str,
comment: str,
force_new_format: bool = True,
**kwargs,
) -> tuple[int, str, str]:
if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK
):
try:
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
except (IOError, OSError) as e:
raise ValueError(
f"The private key at {private_key_path} is not writeable preventing a comment update ({e})"
)
command = [self._bin_path, "-q"]
if force_new_format:
command.append("-o")
command.extend(["-c", "-C", comment, "-f", private_key_path])
return self._run_command(command, **kwargs)
_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
class PrivateKey:
def __init__(
self, size: int, key_type: str, fingerprint: str, format: str = ""
) -> None:
self._size = size
self._type = key_type
self._fingerprint = fingerprint
self._format = format
@property
def size(self) -> int:
return self._size
@property
def type(self) -> str:
return self._type
@property
def fingerprint(self) -> str:
return self._fingerprint
@property
def format(self) -> str:
return self._format
@classmethod
def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey:
properties = string.split()
return cls(
size=int(properties[0]),
key_type=properties[-1][1:-1].lower(),
fingerprint=properties[1],
)
def to_dict(self) -> dict[str, t.Any]:
return {
"size": self._size,
"type": self._type,
"fingerprint": self._fingerprint,
"format": self._format,
}
_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
class PublicKey:
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
self._type_string = type_string
self._data = data
self._comment = comment
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return all(
[
self._type_string == other._type_string,
self._data == other._data,
(
(self._comment == other._comment)
if self._comment is not None and other._comment is not None
else True
),
]
)
def __ne__(self, other: object) -> bool:
return not self == other
def __str__(self) -> str:
return f"{self._type_string} {self._data}"
@property
def comment(self) -> str | None:
return self._comment
@comment.setter
def comment(self, value: str | None) -> None:
self._comment = value
@property
def data(self) -> str:
return self._data
@property
def type_string(self) -> str:
return self._type_string
@classmethod
def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey:
properties = string.strip("\n").split(" ", 2)
return cls(
type_string=properties[0],
data=properties[1],
comment=properties[2] if len(properties) > 2 else "",
)
@classmethod
def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None:
try:
with open(path, "r") as f:
properties = f.read().strip(" \n").split(" ", 2)
except (IOError, OSError):
raise
if len(properties) < 2:
return None
return cls(
type_string=properties[0],
data=properties[1],
comment="" if len(properties) <= 2 else properties[2],
)
def to_dict(self) -> dict[str, t.Any]:
return {
"comment": self._comment,
"public_key": self._data,
}
def parse_private_key_format(
path: str | os.PathLike,
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
with open(path, "r") as file:
header = file.readline().strip()
if header == "-----BEGIN OPENSSH PRIVATE KEY-----":
return "SSH"
elif header == "-----BEGIN PRIVATE KEY-----":
return "PKCS8"
elif header == "-----BEGIN RSA PRIVATE KEY-----":
return "PKCS1"
return ""

View File

@@ -0,0 +1,585 @@
# Copyright (c) 2018, David Kainz <dkainz@mgit.at> <dave.jokain@gmx.at>
# Copyright (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import os
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
)
from ansible_collections.community.crypto.plugins.module_utils._openssh.backends.common import (
KeygenCommand,
OpensshModule,
PrivateKey,
PublicKey,
parse_private_key_format,
)
from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import (
CRYPTOGRAPHY_VERSION,
HAS_OPENSSH_SUPPORT,
InvalidCommentError,
InvalidPassphraseError,
InvalidPrivateKeyFileError,
OpenSSHError,
OpensshKeypair,
)
from ansible_collections.community.crypto.plugins.module_utils._openssh.utils import (
any_in,
file_mode,
secure_write,
)
from ansible_collections.community.crypto.plugins.module_utils._version import (
LooseVersion,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module)
self.comment: str | None = self.module.params["comment"]
self.private_key_path: str = self.module.params["path"]
self.public_key_path = self.private_key_path + ".pub"
self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.state: t.Literal["present", "absent"] = self.module.params["state"]
self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = (
self.module.params["type"]
)
self.size: int = self._get_size(self.module.params["size"])
self._validate_path()
self.original_private_key: PrivateKey | None = None
self.original_public_key: PublicKey | None = None
self.private_key: PrivateKey | None = None
self.public_key: PublicKey | None = None
def _get_size(self, size: int | None) -> int:
if self.type in ("rsa", "rsa1"):
result = 4096 if size is None else size
if result < 1024:
return self.module.fail_json(
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. "
+ "Attempting to use bit lengths under 1024 will cause the module to fail."
)
elif self.type == "dsa":
result = 1024 if size is None else size
if result != 1024:
return self.module.fail_json(
msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2."
)
elif self.type == "ecdsa":
result = 256 if size is None else size
if result not in (256, 384, 521):
return self.module.fail_json(
msg="For ECDSA keys, size determines the key length by selecting from one of "
+ "three elliptic curve sizes: 256, 384 or 521 bits. "
+ "Attempting to use bit lengths other than these three values for ECDSA keys will "
+ "cause this module to fail."
)
elif self.type == "ed25519":
# User input is ignored for `key size` when `key type` is ed25519
result = 256
else:
return self.module.fail_json(
msg=f"{self.type} is not a valid value for key type"
)
return result
def _validate_path(self) -> None:
self._check_if_base_dir(self.private_key_path)
if os.path.isdir(self.private_key_path):
self.module.fail_json(
msg=f"{self.private_key_path} is a directory. Please specify a path to a file."
)
def _execute(self) -> None:
self.original_private_key = self._load_private_key()
self.original_public_key = self._load_public_key()
if self.state == "present":
self._validate_key_load()
if self._should_generate():
self._generate()
elif not self._public_key_valid():
self._restore_public_key()
self.private_key = self._load_private_key()
self.public_key = self._load_public_key()
for path in (self.private_key_path, self.public_key_path):
self._update_permissions(path)
else:
if self._should_remove():
self._remove()
def _load_private_key(self) -> PrivateKey | None:
result = None
if self._private_key_exists():
try:
result = self._get_private_key()
except Exception:
pass
return result
def _private_key_exists(self) -> bool:
return os.path.exists(self.private_key_path)
@abc.abstractmethod
def _get_private_key(self) -> PrivateKey:
pass
def _load_public_key(self) -> PublicKey | None:
result = None
if self._public_key_exists():
try:
result = PublicKey.load(self.public_key_path)
except (IOError, OSError):
pass
return result
def _public_key_exists(self) -> bool:
return os.path.exists(self.public_key_path)
def _validate_key_load(self) -> None:
if (
self._private_key_exists()
and self.regenerate in ("never", "fail", "partial_idempotence")
and (self.original_private_key is None or not self._private_key_readable())
):
self.module.fail_json(
msg="Unable to read the key. The key is protected with a passphrase or broken. "
+ "Will not proceed. To force regeneration, call the module with `generate` "
+ "set to `full_idempotence` or `always`, or with `force=true`."
)
@abc.abstractmethod
def _private_key_readable(self) -> bool:
pass
def _should_generate(self) -> bool:
if self.original_private_key is None:
return True
elif self.regenerate == "never":
return False
elif self.regenerate == "fail":
if not self._private_key_valid():
self.module.fail_json(
msg="Key has wrong type and/or size. Will not proceed. "
+ "To force regeneration, call the module with `generate` set to "
+ "`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
)
return False
elif self.regenerate in ("partial_idempotence", "full_idempotence"):
return not self._private_key_valid()
else:
return True
def _private_key_valid(self) -> bool:
if self.original_private_key is None:
return False
return all(
[
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(self.original_private_key),
]
)
@abc.abstractmethod
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
pass
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _generate(self) -> None:
temp_private_key, temp_public_key = self._generate_temp_keypair()
try:
self._safe_secure_move(
[
(temp_private_key, self.private_key_path),
(temp_public_key, self.public_key_path),
]
)
except OSError as e:
self.module.fail_json(msg=str(e))
def _generate_temp_keypair(self) -> tuple[str, str]:
temp_private_key = os.path.join(
self.module.tmpdir, os.path.basename(self.private_key_path)
)
temp_public_key = temp_private_key + ".pub"
try:
self._generate_keypair(temp_private_key)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
for f in (temp_private_key, temp_public_key):
self.module.add_cleanup_file(f)
return temp_private_key, temp_public_key
@abc.abstractmethod
def _generate_keypair(self, private_key_path: str) -> None:
pass
def _public_key_valid(self) -> bool:
if self.original_public_key is None:
return False
valid_public_key = self._get_public_key()
if valid_public_key:
valid_public_key.comment = self.comment
return self.original_public_key == valid_public_key
@abc.abstractmethod
def _get_public_key(self) -> PublicKey | t.Literal[""]:
pass
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _restore_public_key(self) -> None:
try:
temp_public_key = self._create_temp_public_key(
str(self._get_public_key()) + "\n"
)
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError):
self.module.fail_json(
msg="The public key is missing or does not match the private key. "
+ "Unable to regenerate the public key."
)
if self.comment:
self._update_comment()
def _create_temp_public_key(self, content: str | bytes) -> str:
temp_public_key = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path)
)
default_permissions = 0o644
existing_permissions = file_mode(self.public_key_path)
try:
secure_write(
temp_public_key,
existing_permissions or default_permissions,
to_bytes(content),
)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
self.module.add_cleanup_file(temp_public_key)
return temp_public_key
@abc.abstractmethod
def _update_comment(self) -> None:
pass
def _should_remove(self) -> bool:
return self._private_key_exists() or self._public_key_exists()
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _remove(self) -> None:
try:
if self._private_key_exists():
os.remove(self.private_key_path)
if self._public_key_exists():
os.remove(self.public_key_path)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
@property
def _result(self) -> dict[str, t.Any]:
private_key = self.private_key or self.original_private_key
public_key = self.public_key or self.original_public_key
return {
"size": self.size,
"type": self.type,
"filename": self.private_key_path,
"fingerprint": private_key.fingerprint if private_key else "",
"public_key": str(public_key) if public_key else "",
"comment": public_key.comment if public_key else "",
}
@property
def diff(self) -> dict[str, t.Any]:
before = (
self.original_private_key.to_dict() if self.original_private_key else {}
)
before.update(
self.original_public_key.to_dict() if self.original_public_key else {}
)
after = self.private_key.to_dict() if self.private_key else {}
after.update(self.public_key.to_dict() if self.public_key else {})
return {
"before": before,
"after": after,
}
class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module)
if self.module.params["private_key_format"] != "auto":
self.module.fail_json(
msg="'auto' is the only valid option for 'private_key_format' when 'backend' is not 'cryptography'"
)
self.ssh_keygen = KeygenCommand(self.module)
def _generate_keypair(self, private_key_path: str) -> None:
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
)
def _get_private_key(self) -> PrivateKey:
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
)
if rc != 0:
raise ValueError(err)
return PrivateKey.from_string(private_key_content)
def _get_public_key(self) -> PublicKey | t.Literal[""]:
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self) -> bool:
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
)
return not (
rc == 255
or any_in(
stderr,
"is not a public key file",
"incorrect passphrase",
"load failed",
)
)
def _update_comment(self) -> None:
try:
ssh_version = self._get_ssh_version() or "7.8"
force_new_format = (
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment or "",
force_new_format=force_new_format,
check_rc=True,
)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
return True
class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module)
if self.type == "rsa1":
self.module.fail_json(
msg="RSA1 keys are not supported by the cryptography backend"
)
self.passphrase = (
to_bytes(module.params["passphrase"])
if module.params["passphrase"]
else None
)
key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[
"private_key_format"
]
self.private_key_format = self._get_key_format(key_format)
def _get_key_format(
self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"]
) -> t.Literal["SSH", "PKCS1", "PKCS8"]:
result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH"
if key_format == "auto":
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
ssh_version = self._get_ssh_version() or "7.8"
if (
LooseVersion(ssh_version) < LooseVersion("7.8")
and self.type != "ed25519"
):
# OpenSSH made SSH formatted private keys available in version 6.5,
# but still defaulted to PKCS1 format with the exception of ed25519 keys
result = "PKCS1"
else:
result = key_format.upper() # type: ignore
return result
def _generate_keypair(self, private_key_path: str) -> None:
assert self.type != "rsa1"
keypair = OpensshKeypair.generate(
keytype=self.type,
size=self.size,
passphrase=self.passphrase,
comment=self.comment or "",
)
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
keypair.asymmetric_keypair, self.private_key_format
)
secure_write(private_key_path, 0o600, encoded_private_key)
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
def _get_private_key(self) -> PrivateKey:
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
return PrivateKey(
size=keypair.size,
key_type=keypair.key_type,
fingerprint=keypair.fingerprint,
format=parse_private_key_format(self.private_key_path),
)
def _get_public_key(self) -> PublicKey | t.Literal[""]:
try:
keypair = OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except OpenSSHError:
# Simulates the null output of ssh-keygen
return ""
return PublicKey.from_string(to_text(keypair.public_key))
def _private_key_readable(self) -> bool:
try:
OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return False
# Cryptography >= 3.0 uses a SSH key loader which does not raise an exception when a passphrase is provided
# when loading an unencrypted key
if self.passphrase:
try:
OpensshKeypair.load(
path=self.private_key_path, passphrase=None, no_public_key=True
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return True
else:
return False
return True
def _update_comment(self) -> None:
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
try:
keypair.comment = self.comment
except InvalidCommentError as e:
self.module.fail_json(msg=str(e))
try:
temp_public_key = self._create_temp_public_key(keypair.public_key + b"\n")
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
# avoids breaking behavior and prevents
# automatic conversions with OpenSSH upgrades
if self.module.params["private_key_format"] == "auto":
return True
return self.private_key_format == original_private_key.format
def select_backend(
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
) -> KeypairBackend:
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
CRYPTOGRAPHY_VERSION
) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION)
can_use_opensshbin = bool(module.get_bin_path("ssh-keygen"))
if backend == "auto":
if can_use_opensshbin and not module.params["passphrase"]:
backend = "opensshbin"
elif can_use_cryptography:
backend = "cryptography"
else:
module.fail_json(
msg=(
"Cannot find either the OpenSSH binary in the PATH "
f"or cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION} installed on this system"
)
)
if backend == "opensshbin":
if not can_use_opensshbin:
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
return KeypairBackendOpensshBin(module)
if backend == "cryptography":
if not can_use_cryptography:
module.fail_json(
msg=missing_required_lib(
f"cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION}"
)
)
return KeypairBackendCryptography(module)
raise ValueError(f"Unsupported value for backend: {backend}")

View File

@@ -0,0 +1,805 @@
# Copyright (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import binascii
import datetime as _datetime
import os
import typing as t
from base64 import b64encode
from datetime import datetime
from hashlib import sha256
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils._openssh.utils import (
OpensshParser,
_OpensshWriter,
)
from ansible_collections.community.crypto.plugins.module_utils._time import UTC as _UTC
from ansible_collections.community.crypto.plugins.module_utils._time import (
add_or_remove_timezone as _add_or_remove_timezone,
)
from ansible_collections.community.crypto.plugins.module_utils._time import (
convert_relative_to_datetime,
)
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import (
KeyType,
)
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
DateFormatStr = t.Literal["human_readable", "openssh"]
DateFormatInt = t.Literal["timestamp"]
else:
KeyType = None
# Protocol References
# -------------------
# https://datatracker.ietf.org/doc/html/rfc4251
# https://datatracker.ietf.org/doc/html/rfc4253
# https://datatracker.ietf.org/doc/html/rfc5656
# https://datatracker.ietf.org/doc/html/rfc8032
# https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
#
# Inspired by:
# ------------
# https://github.com/pyca/cryptography/blob/main/src/cryptography/hazmat/primitives/serialization/ssh.py
# https://github.com/paramiko/paramiko/blob/master/paramiko/message.py
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
_USER_TYPE = 1
_HOST_TYPE = 2
_SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = {
"rsa": b"ssh-rsa",
"dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
"ecdsa-nistp384": b"ecdsa-sha2-nistp384",
"ecdsa-nistp521": b"ecdsa-sha2-nistp521",
"ed25519": b"ssh-ed25519",
}
_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
_ECDSA_CURVE_IDENTIFIERS = {
"ecdsa-nistp256": b"nistp256",
"ecdsa-nistp384": b"nistp384",
"ecdsa-nistp521": b"nistp521",
}
_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
b"nistp256": "ecdsa-nistp256",
b"nistp384": "ecdsa-nistp384",
b"nistp521": "ecdsa-nistp521",
}
_ALWAYS = _add_or_remove_timezone(datetime(1970, 1, 1), with_timezone=True)
_FOREVER = datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC)
_CRITICAL_OPTIONS = (
"force-command",
"source-address",
"verify-required",
)
_DIRECTIVES = (
"clear",
"no-x11-forwarding",
"no-agent-forwarding",
"no-port-forwarding",
"no-pty",
"no-user-rc",
)
_EXTENSIONS = (
"permit-x11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
)
class OpensshCertificateTimeParameters:
def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to:
raise ValueError(
f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
else:
return (
self._valid_from == other._valid_from
and self._valid_to == other._valid_to
)
def __ne__(self, other: object) -> bool:
return not self == other
@property
def validity_string(self) -> str:
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}"
return ""
@t.overload
def valid_from(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_from(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format)
@t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_to(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format)
def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None:
valid_at_datetime = self.to_datetime(valid_at)
return self._valid_from <= valid_at_datetime <= self._valid_to
return True
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
return "always"
if dt == _FOREVER:
return "forever"
else:
return (
dt.isoformat().replace("+00:00", "")
if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S")
)
if date_format == "timestamp":
td = dt - _ALWAYS
return int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
)
raise ValueError(f"{date_format} is not a valid format")
@staticmethod
def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime:
try:
if isinstance(time_string_or_timestamp, (str, bytes)):
result = OpensshCertificateTimeParameters._time_string_to_datetime(
to_text(time_string_or_timestamp.strip())
)
elif isinstance(time_string_or_timestamp, int):
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
time_string_or_timestamp
)
else:
raise ValueError(
f"Value must be of type (str, unicode, int) not {type(time_string_or_timestamp)}"
)
except ValueError:
raise
return result
@staticmethod
def _timestamp_to_datetime(timestamp: int) -> datetime:
if timestamp == 0x0:
return _ALWAYS
if timestamp == 0xFFFFFFFFFFFFFFFF:
return _FOREVER
try:
return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
except OverflowError:
raise ValueError
@staticmethod
def _time_string_to_datetime(time_string: str) -> datetime:
if time_string == "always":
return _ALWAYS
if time_string == "forever":
return _FOREVER
if is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=True)
if result is None:
raise ValueError
return result
result = None
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=True,
)
except ValueError:
pass
if result is None:
raise ValueError
return result
_OpensshCertificateOption = t.TypeVar(
"_OpensshCertificateOption", bound="OpensshCertificateOption"
)
class OpensshCertificateOption:
def __init__(
self,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
):
if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'")
if not isinstance(name, (str, bytes)):
raise TypeError(f"name must be a string not {type(name)}")
if not isinstance(data, (str, bytes)):
raise TypeError(f"data must be a string not {type(data)}")
self._option_type = option_type
self._name = name.lower()
self._data = data
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return all(
[
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
]
)
def __hash__(self) -> int:
return hash((self._option_type, self._name, self._data))
def __ne__(self, other: object) -> bool:
return not self == other
def __str__(self) -> str:
if self._data:
return f"{self._name!r}={self._data!r}"
return f"{self._name!r}"
@property
def data(self) -> str | bytes:
return self._data
@property
def name(self) -> str | bytes:
return self._name
@property
def type(self) -> t.Literal["critical", "extension"]:
return self._option_type
@classmethod
def from_string(
cls: t.Type[_OpensshCertificateOption], option_string: str
) -> _OpensshCertificateOption:
if not isinstance(option_string, str):
raise ValueError(
f"option_string must be a string not {type(option_string)}"
)
option_type = None
if ":" in option_string:
option_type, value = option_string.strip().split(":", 1)
if "=" in value:
name, data = value.split("=", 1)
else:
name, data = value, ""
elif "=" in option_string:
name, data = option_string.strip().split("=", 1)
else:
name, data = option_string.strip(), ""
return cls(
# We have str, but we're expecting a specific literal:
option_type=option_type or get_option_type(name.lower()), # type: ignore
name=name,
data=data,
)
class OpensshCertificateInfo(metaclass=abc.ABCMeta):
"""Encapsulates all certificate information which is signed by a CA key"""
def __init__(
self,
nonce: bytes | None = None,
serial: int | None = None,
cert_type: int | None = None,
key_id: bytes | None = None,
principals: list[bytes] | None = None,
valid_after: int | None = None,
valid_before: int | None = None,
critical_options: list[tuple[bytes, bytes]] | None = None,
extensions: list[tuple[bytes, bytes]] | None = None,
reserved: bytes | None = None,
signing_key: bytes | None = None,
):
self.nonce = nonce
self.serial = serial
self._cert_type: int | None = cert_type
self.key_id = key_id
self.principals = principals
self.valid_after = valid_after
self.valid_before = valid_before
self.critical_options = critical_options
self.extensions = extensions
self.reserved = reserved
self.signing_key = signing_key
self.type_string: bytes | None = None
@property
def cert_type(self) -> t.Literal["user", "host", ""]:
if self._cert_type == _USER_TYPE:
return "user"
elif self._cert_type == _HOST_TYPE:
return "host"
else:
return ""
@cert_type.setter
def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None:
if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE
elif cert_type == "host" or cert_type == _HOST_TYPE:
self._cert_type = _HOST_TYPE
else:
raise ValueError(f"{cert_type} is not a valid certificate type")
def signing_key_fingerprint(self) -> bytes:
if self.signing_key is None:
raise ValueError("signing_key not present")
return fingerprint(self.signing_key)
@abc.abstractmethod
def public_key_fingerprint(self) -> bytes:
pass
@abc.abstractmethod
def parse_public_numbers(self, parser: OpensshParser) -> None:
pass
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self) -> bytes:
if self.e is None or self.n is None:
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS["rsa"])
writer.mpint(self.e)
writer.mpint(self.n)
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.e = parser.mpint()
self.n = parser.mpint()
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self,
p: int | None = None,
q: int | None = None,
g: int | None = None,
y: int | None = None,
**kwargs,
) -> None:
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p
self.q = q
self.g = g
self.y = y
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self) -> bytes:
if self.p is None or self.q is None or self.g is None or self.y is None:
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS["dsa"])
writer.mpint(self.p)
writer.mpint(self.q)
writer.mpint(self.g)
writer.mpint(self.y)
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.p = parser.mpint()
self.q = parser.mpint()
self.g = parser.mpint()
self.y = parser.mpint()
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None
if curve is not None:
self.curve = curve
self.public_key = public_key
@property
def curve(self) -> bytes | None:
return self._curve
@curve.setter
def curve(self, curve: bytes) -> None:
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve
self.type_string = (
_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]]
+ _CERT_SUFFIX_V01
)
else:
raise ValueError(
"Curve must be one of {(b','.join(_ECDSA_CURVE_IDENTIFIERS.values())).decode('UTF-8')}"
)
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self) -> bytes:
if self.curve is None or self.public_key is None:
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
writer.string(self.curve)
writer.string(self.public_key)
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.curve = parser.string()
self.public_key = parser.string()
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
def public_key_fingerprint(self) -> bytes:
if self.pk is None:
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS["ed25519"])
writer.string(self.pk)
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.pk = parser.string()
_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate")
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info
self.signature = signature
@classmethod
def load(
cls: t.Type[_OpensshCertificate], path: str | os.PathLike
) -> _OpensshCertificate:
if not os.path.exists(path):
raise ValueError(f"{path} is not a valid path.")
try:
with open(path, "rb") as cert_file:
data = cert_file.read()
except (IOError, OSError) as e:
raise ValueError(f"{path} cannot be opened for reading: {e}")
try:
format_identifier, b64_cert = data.split(b" ")[:2]
cert = binascii.a2b_base64(b64_cert)
except (binascii.Error, ValueError):
raise ValueError("Certificate not in OpenSSH format")
for key_type, string in _SSH_TYPE_STRINGS.items():
if format_identifier == string + _CERT_SUFFIX_V01:
pub_key_type = t.cast(KeyType, key_type)
break
else:
raise ValueError(
f"Invalid certificate format identifier: {format_identifier!r}"
)
parser = OpensshParser(cert)
if format_identifier != parser.string():
raise ValueError("Certificate formats do not match")
try:
cert_info = cls._parse_cert_info(pub_key_type, parser)
signature = parser.string()
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid certificate data: {e}")
if parser.remaining_bytes():
raise ValueError(
f"{parser.remaining_bytes()} bytes of additional data was not parsed while loading {path}"
)
return cls(
cert_info=cert_info,
signature=signature,
)
@property
def type_string(self) -> str:
return to_text(self._cert_info.type_string)
@property
def nonce(self) -> bytes:
if self._cert_info.nonce is None:
raise ValueError
return self._cert_info.nonce
@property
def public_key(self) -> str:
return to_text(self._cert_info.public_key_fingerprint())
@property
def serial(self) -> int:
if self._cert_info.serial is None:
raise ValueError
return self._cert_info.serial
@property
def type(self) -> t.Literal["user", "host"]:
result = self._cert_info.cert_type
if result == "":
raise ValueError
return result
@property
def key_id(self) -> str:
return to_text(self._cert_info.key_id)
@property
def principals(self) -> list[str]:
if self._cert_info.principals is None:
raise ValueError
return [to_text(p) for p in self._cert_info.principals]
@property
def valid_after(self) -> int:
if self._cert_info.valid_after is None:
raise ValueError
return self._cert_info.valid_after
@property
def valid_before(self) -> int:
if self._cert_info.valid_before is None:
raise ValueError
return self._cert_info.valid_before
@property
def critical_options(self) -> list[OpensshCertificateOption]:
if self._cert_info.critical_options is None:
raise ValueError
return [
OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options
]
@property
def extensions(self) -> list[OpensshCertificateOption]:
if self._cert_info.extensions is None:
raise ValueError
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions
]
@property
def reserved(self) -> bytes:
if self._cert_info.reserved is None:
raise ValueError
return self._cert_info.reserved
@property
def signing_key(self) -> str:
return to_text(self._cert_info.signing_key_fingerprint())
@property
def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data["signature_type"])
@staticmethod
def _parse_cert_info(
pub_key_type: KeyType, parser: OpensshParser
) -> OpensshCertificateInfo:
cert_info = get_cert_info_object(pub_key_type)
cert_info.nonce = parser.string()
cert_info.parse_public_numbers(parser)
cert_info.serial = parser.uint64()
# mypy doesn't understand that the setter accepts other types than the getter:
cert_info.cert_type = parser.uint32() # type: ignore
cert_info.key_id = parser.string()
cert_info.principals = parser.string_list()
cert_info.valid_after = parser.uint64()
cert_info.valid_before = parser.uint64()
cert_info.critical_options = parser.option_list()
cert_info.extensions = parser.option_list()
cert_info.reserved = parser.string()
cert_info.signing_key = parser.string()
return cert_info
def to_dict(self) -> dict[str, t.Any]:
time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after, valid_to=self.valid_before
)
return {
"type_string": self.type_string,
"nonce": self.nonce,
"serial": self.serial,
"cert_type": self.type,
"identifier": self.key_id,
"principals": self.principals,
"valid_after": time_parameters.valid_from(date_format="human_readable"),
"valid_before": time_parameters.valid_to(date_format="human_readable"),
"critical_options": [
str(critical_option) for critical_option in self.critical_options
],
"extensions": [str(extension) for extension in self.extensions],
"reserved": self.reserved,
"public_key": self.public_key,
"signing_key": self.signing_key,
}
def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]:
if any(d not in _DIRECTIVES for d in directives):
raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}")
directive_to_option = {
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if "clear" in directives:
return []
else:
return list(
set(default_options()) - set(directive_to_option[d] for d in directives)
)
def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key: bytes) -> bytes:
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256()
h.update(public_key)
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo:
if key_type == "rsa":
return OpensshRSACertificateInfo()
if key_type == "dsa":
return OpensshDSACertificateInfo()
if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
return OpensshECDSACertificateInfo()
if key_type == "ed25519":
return OpensshED25519CertificateInfo()
raise ValueError(f"{key_type} is not a valid key type")
def get_option_type(name: str) -> t.Literal["critical", "extension"]:
if name in _CRITICAL_OPTIONS:
return "critical"
if name in _EXTENSIONS:
return "extension"
raise ValueError(
f"{name} is not a valid option. "
"Custom options must start with 'critical:' or 'extension:' to indicate type"
)
def is_relative_time_string(time_string: str) -> bool:
return time_string.startswith("+") or time_string.startswith("-")
def parse_option_list(
option_list: t.Iterable[str],
) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]:
critical_options = []
directives = []
extensions = []
for option in option_list:
if option.lower() in _DIRECTIVES:
directives.append(option.lower())
else:
option_object = OpensshCertificateOption.from_string(option)
if option_object.type == "critical":
critical_options.append(option_object)
else:
extensions.append(option_object)
return critical_options, list(set(extensions + apply_directives(directives)))

View File

@@ -0,0 +1,769 @@
# Copyright (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
import typing as t
from base64 import b64decode, b64encode
from getpass import getuser
from socket import gethostname
try:
from cryptography import __version__ as CRYPTOGRAPHY_VERSION
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dsa, ec, padding, rsa
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
HAS_OPENSSH_SUPPORT = True
_ALGORITHM_PARAMETERS = {
"rsa": {
"default_size": 2048,
"valid_sizes": range(1024, 16384),
"signer_params": {
"padding": padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
"algorithm": hashes.SHA256(),
},
},
"dsa": {
"default_size": 1024,
"valid_sizes": [1024],
"signer_params": {
"algorithm": hashes.SHA256(),
},
},
"ed25519": {
"default_size": 256,
"valid_sizes": [256],
"signer_params": {},
},
"ecdsa": {
"default_size": 256,
"valid_sizes": [256, 384, 521],
"signer_params": {
"signature_algorithm": ec.ECDSA(hashes.SHA256()),
},
"curves": {
256: ec.SECP256R1(),
384: ec.SECP384R1(),
521: ec.SECP521R1(),
},
},
}
except ImportError:
HAS_OPENSSH_SUPPORT = False
CRYPTOGRAPHY_VERSION = "0.0"
_ALGORITHM_PARAMETERS = {}
if t.TYPE_CHECKING:
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
PrivateKeyTypes = t.Union[
rsa.RSAPrivateKey,
dsa.DSAPrivateKey,
ec.EllipticCurvePrivateKey,
Ed25519PrivateKey,
]
PublicKeyTypes = t.Union[
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
]
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes as AllPublicKeyTypes,
)
_TEXT_ENCODING = "UTF-8"
class OpenSSHError(Exception):
pass
class InvalidAlgorithmError(OpenSSHError):
pass
class InvalidCommentError(OpenSSHError):
pass
class InvalidDataError(OpenSSHError):
pass
class InvalidPrivateKeyFileError(OpenSSHError):
pass
class InvalidPublicKeyFileError(OpenSSHError):
pass
class InvalidKeyFormatError(OpenSSHError):
pass
class InvalidKeySizeError(OpenSSHError):
pass
class InvalidKeyTypeError(OpenSSHError):
pass
class InvalidPassphraseError(OpenSSHError):
pass
class InvalidSignatureError(OpenSSHError):
pass
_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair")
class AsymmetricKeypair:
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
@classmethod
def generate(
cls: t.Type[_AsymmetricKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
or defaults to an unencrypted RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the private key being generated
"""
if keytype not in _ALGORITHM_PARAMETERS:
raise InvalidKeyTypeError(
f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}"
)
if not size:
size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore
else:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore
raise InvalidKeySizeError(
f"{size} is not a valid key size for {keytype} keys"
)
size = t.cast(int, size)
privatekey: PrivateKeyTypes
if passphrase:
encryption_algorithm = get_encryption_algorithm(passphrase)
else:
encryption_algorithm = serialization.NoEncryption()
if keytype == "rsa":
privatekey = rsa.generate_private_key(
# Public exponent should always be 65537 to prevent issues
# if improper padding is used during signing
public_exponent=65537,
key_size=size,
)
elif keytype == "dsa":
privatekey = dsa.generate_private_key(
key_size=size,
)
elif keytype == "ed25519":
privatekey = Ed25519PrivateKey.generate()
elif keytype == "ecdsa":
privatekey = ec.generate_private_key(
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore
)
publickey = privatekey.public_key()
return cls(
keytype=keytype,
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm,
)
@classmethod
def load(
cls: t.Type[_AsymmetricKeypair],
path: str | os.PathLike,
passphrase: bytes | None = None,
private_key_format: KeySerializationFormat = "PEM",
public_key_format: KeySerializationFormat = "PEM",
no_public_key: bool = False,
) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret of type bytes used to decrypt the private key being loaded
:private_key_format: Format of private key to be loaded
:public_key_format: Format of public key to be loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if passphrase:
encryption_algorithm = get_encryption_algorithm(passphrase)
else:
encryption_algorithm = serialization.NoEncryption()
privatekey = load_privatekey(path, passphrase, private_key_format)
if no_public_key:
publickey = privatekey.public_key()
else:
# TODO: BUG: load_publickey() can return unsupported key types
# (Also we should check whether the public key fits the private key...)
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore
else:
size = privatekey.key_size
keytype: KeyType
if isinstance(privatekey, rsa.RSAPrivateKey):
keytype = "rsa"
elif isinstance(privatekey, dsa.DSAPrivateKey):
keytype = "dsa"
elif isinstance(privatekey, ec.EllipticCurvePrivateKey):
keytype = "ecdsa"
elif isinstance(privatekey, Ed25519PrivateKey):
keytype = "ed25519"
else:
raise InvalidKeyTypeError(f"Key type '{type(privatekey)}' is not supported")
return cls(
keytype=keytype,
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm,
)
def __init__(
self,
keytype: KeyType,
size: int,
privatekey: PrivateKeyTypes,
publickey: PublicKeyTypes,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> None:
"""
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
:privatekey: Private key object of this key pair
:publickey: Public key object of this key pair
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
"""
self.__size = size
self.__keytype = keytype
self.__privatekey = privatekey
self.__publickey = publickey
self.__encryption_algorithm = encryption_algorithm
try:
self.verify(self.sign(b"message"), b"message")
except InvalidSignatureError:
raise InvalidPublicKeyFileError(
"The private key and public key of this keypair do not match"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, AsymmetricKeypair):
return NotImplemented
return compare_publickeys(
self.public_key, other.public_key
) and compare_encryption_algorithms(
self.encryption_algorithm, other.encryption_algorithm
)
def __ne__(self, other: object) -> bool:
return not self == other
@property
def private_key(self) -> PrivateKeyTypes:
"""Returns the private key of this key pair"""
return self.__privatekey
@property
def public_key(self) -> PublicKeyTypes:
"""Returns the public key of this key pair"""
return self.__publickey
@property
def size(self) -> int:
"""Returns the size of the private key of this key pair"""
return self.__size
@property
def key_type(self) -> KeyType:
"""Returns the key type of this key pair"""
return self.__keytype
@property
def encryption_algorithm(self) -> serialization.KeySerializationEncryption:
"""Returns the key encryption algorithm of this key pair"""
return self.__encryption_algorithm
def sign(self, data: bytes) -> bytes:
"""Returns signature of data signed with the private key of this key pair
:data: byteslike data to sign
"""
try:
return self.__privatekey.sign(
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore
)
except TypeError as e:
raise InvalidDataError(e)
def verify(self, signature: bytes, data: bytes) -> None:
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
:signature: signature to verify
:data: byteslike data signed by the provided signature
"""
try:
self.__publickey.verify(
signature,
data,
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore
)
except InvalidSignature:
raise InvalidSignatureError
def update_passphrase(self, passphrase: bytes | None = None) -> None:
"""Updates the encryption algorithm of this key pair
:passphrase: Byte secret used to encrypt this key pair
"""
if passphrase:
self.__encryption_algorithm = get_encryption_algorithm(passphrase)
else:
self.__encryption_algorithm = serialization.NoEncryption()
_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair")
class OpensshKeypair:
"""Container for OpenSSH encoded asymmetric key pairs"""
@classmethod
def generate(
cls: t.Type[_OpensshKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
comment: str | None = None,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
:comment: Comment for a newly generated OpenSSH public key
"""
if comment is None:
comment = f"{getuser()}@{gethostname()}"
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
asym_keypair=asym_keypair,
openssh_privatekey=openssh_privatekey,
openssh_publickey=openssh_publickey,
fingerprint=fingerprint,
comment=comment,
)
@classmethod
def load(
cls: t.Type[_OpensshKeypair],
path: str | os.PathLike,
passphrase: bytes | None = None,
no_public_key: bool = False,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret used to decrypt the private key being loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if no_public_key:
comment = ""
else:
comment = extract_comment(str(path) + ".pub")
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
asym_keypair=asym_keypair,
openssh_privatekey=openssh_privatekey,
openssh_publickey=openssh_publickey,
fingerprint=fingerprint,
comment=comment,
)
@staticmethod
def encode_openssh_privatekey(
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
) -> bytes:
"""Returns an OpenSSH encoded private key for a given keypair
:asym_keypair: Asymmetric_Keypair from the private key is extracted
:key_format: Format of the encoded private key.
"""
if key_format == "SSH":
privatekey_format = serialization.PrivateFormat.OpenSSH
elif key_format == "PKCS8":
privatekey_format = serialization.PrivateFormat.PKCS8
elif key_format == "PKCS1":
if asym_keypair.key_type == "ed25519":
raise InvalidKeyFormatError(
"ed25519 keys cannot be represented in PKCS1 format"
)
privatekey_format = serialization.PrivateFormat.TraditionalOpenSSL
else:
raise InvalidKeyFormatError(
"The accepted private key formats are SSH, PKCS8, and PKCS1"
)
encoded_privatekey = asym_keypair.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=privatekey_format,
encryption_algorithm=asym_keypair.encryption_algorithm,
)
return encoded_privatekey
@staticmethod
def encode_openssh_publickey(
asym_keypair: AsymmetricKeypair, comment: str
) -> bytes:
"""Returns an OpenSSH encoded public key for a given keypair
:asym_keypair: Asymmetric_Keypair from the public key is extracted
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
"""
encoded_publickey = asym_keypair.public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
validate_comment(comment)
encoded_publickey += (
(b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b""
)
return encoded_publickey
def __init__(
self,
asym_keypair: AsymmetricKeypair,
openssh_privatekey: bytes,
openssh_publickey: bytes,
fingerprint: str,
comment: str | None,
) -> None:
"""
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
:openssh_privatekey: An OpenSSH encoded public key
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
:comment: Comment applied to the OpenSSH public key of this keypair
"""
self.__asym_keypair = asym_keypair
self.__openssh_privatekey = openssh_privatekey
self.__openssh_publickey = openssh_publickey
self.__fingerprint = fingerprint
self.__comment = comment
def __eq__(self, other: object) -> bool:
if not isinstance(other, OpensshKeypair):
return NotImplemented
return (
self.asymmetric_keypair == other.asymmetric_keypair
and self.comment == other.comment
)
@property
def asymmetric_keypair(self) -> AsymmetricKeypair:
"""Returns the underlying asymmetric key pair of this OpenSSH encoded key pair"""
return self.__asym_keypair
@property
def private_key(self) -> bytes:
"""Returns the OpenSSH formatted private key of this key pair"""
return self.__openssh_privatekey
@property
def public_key(self) -> bytes:
"""Returns the OpenSSH formatted public key of this key pair"""
return self.__openssh_publickey
@property
def size(self) -> int:
"""Returns the size of the private key of this key pair"""
return self.__asym_keypair.size
@property
def key_type(self) -> KeyType:
"""Returns the key type of this key pair"""
return self.__asym_keypair.key_type
@property
def fingerprint(self) -> str:
"""Returns the fingerprint (SHA256 Hash) of the public key of this key pair"""
return self.__fingerprint
@property
def comment(self) -> str | None:
"""Returns the comment applied to the OpenSSH formatted public key of this key pair"""
return self.__comment
@comment.setter
def comment(self, comment: str) -> bytes:
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
:comment: Text to update the OpenSSH public key comment
"""
validate_comment(comment)
self.__comment = comment
encoded_comment = (
f" {self.__comment}".encode(encoding=_TEXT_ENCODING)
if self.__comment
else b""
)
self.__openssh_publickey = (
b" ".join(self.__openssh_publickey.split(b" ", 2)[:2]) + encoded_comment
)
return self.__openssh_publickey
def update_passphrase(self, passphrase: bytes | None) -> None:
"""Updates the passphrase used to encrypt the private key of this keypair
:passphrase: Text secret used for encryption
"""
self.__asym_keypair.update_passphrase(passphrase)
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
self.__asym_keypair, "SSH"
)
def load_privatekey(
path: str | os.PathLike,
passphrase: bytes | None,
key_format: KeySerializationFormat,
) -> PrivateKeyTypes:
privatekey_loaders = {
"PEM": serialization.load_pem_private_key,
"DER": serialization.load_der_private_key,
"SSH": serialization.load_ssh_private_key,
}
try:
privatekey_loader = privatekey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
f"{key_format} is not a valid key format ({','.join(privatekey_loaders)})"
)
if not os.path.exists(path):
raise InvalidPrivateKeyFileError(f"No file was found at {path}")
try:
with open(path, "rb") as f:
content = f.read()
privatekey = privatekey_loader( # type: ignore
data=content,
password=passphrase,
)
except ValueError as exc:
# Revert to PEM if key could not be loaded in SSH format
if key_format == "SSH":
try:
privatekey = privatekey_loaders["PEM"]( # type: ignore
data=content,
password=passphrase,
)
except ValueError as e:
raise InvalidPrivateKeyFileError(e)
except TypeError as e:
raise InvalidPassphraseError(e)
except UnsupportedAlgorithm as e:
raise InvalidAlgorithmError(e)
else:
raise InvalidPrivateKeyFileError(exc)
except TypeError as e:
raise InvalidPassphraseError(e)
except UnsupportedAlgorithm as e:
raise InvalidAlgorithmError(e)
return privatekey
def load_publickey(
path: str | os.PathLike, key_format: KeySerializationFormat
) -> AllPublicKeyTypes:
publickey_loaders = {
"PEM": serialization.load_pem_public_key,
"DER": serialization.load_der_public_key,
"SSH": serialization.load_ssh_public_key,
}
try:
publickey_loader = publickey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
f"{key_format} is not a valid key format ({','.join(publickey_loaders)})"
)
if not os.path.exists(path):
raise InvalidPublicKeyFileError(f"No file was found at {path}")
try:
with open(path, "rb") as f:
content = f.read()
publickey = publickey_loader(
data=content,
)
except ValueError as e:
raise InvalidPublicKeyFileError(e)
except UnsupportedAlgorithm as e:
raise InvalidAlgorithmError(e)
return publickey
def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool:
a = isinstance(pk1, Ed25519PublicKey)
b = isinstance(pk2, Ed25519PublicKey)
if a or b:
if not a or not b:
return False
a_bytes = pk1.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
b_bytes = pk2.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
return a_bytes == b_bytes
else:
return pk1.public_numbers() == pk2.public_numbers() # type: ignore
def compare_encryption_algorithms(
ea1: serialization.KeySerializationEncryption,
ea2: serialization.KeySerializationEncryption,
) -> bool:
if isinstance(ea1, serialization.NoEncryption) and isinstance(
ea2, serialization.NoEncryption
):
return True
elif isinstance(ea1, serialization.BestAvailableEncryption) and isinstance(
ea2, serialization.BestAvailableEncryption
):
return ea1.password == ea2.password
else:
return False
def get_encryption_algorithm(
passphrase: bytes,
) -> serialization.KeySerializationEncryption:
try:
return serialization.BestAvailableEncryption(passphrase)
except ValueError as e:
raise InvalidPassphraseError(e)
def validate_comment(comment: str) -> None:
if not hasattr(comment, "encode"):
raise InvalidCommentError(f"{comment} cannot be encoded to text")
def extract_comment(path: str | os.PathLike) -> str:
if not os.path.exists(path):
raise InvalidPublicKeyFileError(f"No file was found at {path}")
try:
with open(path, "rb") as f:
fields = f.read().split(b" ", 2)
if len(fields) == 3:
comment = fields[2].decode(_TEXT_ENCODING)
else:
comment = ""
except (IOError, OSError) as e:
raise InvalidPublicKeyFileError(e)
return comment
def calculate_fingerprint(openssh_publickey: bytes) -> str:
digest = hashes.Hash(hashes.SHA256())
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
digest.update(decoded_pubkey)
value = b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip("=")
return f"SHA256:{value}"

View File

@@ -0,0 +1,349 @@
# Copyright (c) 2020, Doug Stanley <doug+ansible@technologixllc.com>
# Copyright (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
import re
import typing as t
from contextlib import contextmanager
from struct import Struct
# Protocol References
# -------------------
# https://datatracker.ietf.org/doc/html/rfc4251
# https://datatracker.ietf.org/doc/html/rfc4253
# https://datatracker.ietf.org/doc/html/rfc5656
# https://datatracker.ietf.org/doc/html/rfc8032
#
# Inspired by:
# ------------
# https://github.com/pyca/cryptography/blob/main/src/cryptography/hazmat/primitives/serialization/ssh.py
# https://github.com/paramiko/paramiko/blob/master/paramiko/message.py
# 0 (False) or 1 (True) encoded as a single byte
_BOOLEAN = Struct(b"?")
# Unsigned 8-bit integer in network-byte-order
_UBYTE = Struct(b"!B")
_UBYTE_MAX = 0xFF
# Unsigned 32-bit integer in network-byte-order
_UINT32 = Struct(b"!I")
# Unsigned 32-bit little endian integer
_UINT32_LE = Struct(b"<I")
_UINT32_MAX = 0xFFFFFFFF
# Unsigned 64-bit integer in network-byte-order
_UINT64 = Struct(b"!Q")
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
_T = t.TypeVar("_T")
def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool:
return any(e in sequence for e in elements)
def file_mode(path: str | os.PathLike) -> int:
if not os.path.exists(path):
return 0o000
return os.stat(path).st_mode & 0o777
def parse_openssh_version(version_string: str) -> str | None:
"""Parse the version output of ssh -V and return version numbers that can be compared"""
parsed_result = re.match(
r"^.*openssh_(?P<version>[0-9.]+)(p?[0-9]+)[^0-9]*.*$", version_string.lower()
)
if parsed_result is not None:
version = parsed_result.group("version").strip()
else:
version = None
return version
@contextmanager
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
try:
yield fd
finally:
os.close(fd)
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path, mode) as fd:
os.write(fd, content)
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for SSH data types
class OpensshParser:
"""Parser for OpenSSH encoded objects"""
BOOLEAN_OFFSET = 1
UINT32_OFFSET = 4
UINT64_OFFSET = 8
def __init__(self, data: bytes | bytearray) -> None:
if not isinstance(data, (bytes, bytearray)):
raise TypeError(f"Data must be bytes-like not {type(data)}")
self._data = memoryview(data)
self._pos = 0
def boolean(self) -> bool:
next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint32(self) -> int:
next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint64(self) -> int:
next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def string(self) -> bytes:
length = self.uint32()
next_pos = self._check_position(length)
value = self._data[self._pos : next_pos]
self._pos = next_pos
# Cast to bytes is required as a memoryview slice is itself a memoryview
return bytes(value)
def mpint(self) -> int:
return self._big_int(self.string(), "big", signed=True)
def name_list(self) -> list[str]:
raw_string = self.string()
return raw_string.decode("ASCII").split(",")
# Convenience function, but not an official data type from SSH
def string_list(self) -> list[bytes]:
result = []
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
while parser.remaining_bytes():
result.append(parser.string())
return result
# Convenience function, but not an official data type from SSH
def option_list(self) -> list[tuple[bytes, bytes]]:
result = []
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
while parser.remaining_bytes():
name = parser.string()
data = parser.string()
if data:
# data is doubly-encoded
data = OpensshParser(data).string()
result.append((name, data))
return result
def seek(self, offset: int) -> int:
self._pos = self._check_position(offset)
return self._pos
def remaining_bytes(self) -> int:
return len(self._data) - self._pos
def _check_position(self, offset: int) -> int:
if self._pos + offset > len(self._data):
raise ValueError(f"Insufficient data remaining at position: {self._pos}")
elif self._pos + offset < 0:
raise ValueError("Position cannot be less than zero.")
else:
return self._pos + offset
@classmethod
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
signature_data: dict[str, bytes | int] = {}
parser = cls(signature_string)
signature_type = parser.string()
signature_blob = parser.string()
blob_parser = cls(signature_blob)
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
signature_data["s"] = cls._big_int(signature_blob, "big")
elif signature_type == b"ssh-dss":
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
signature_data["r"] = cls._big_int(signature_blob[:20], "big")
signature_data["s"] = cls._big_int(signature_blob[20:], "big")
elif signature_type in (
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
):
# https://datatracker.ietf.org/doc/html/rfc5656#section-3.1.2
signature_data["r"] = blob_parser.mpint()
signature_data["s"] = blob_parser.mpint()
elif signature_type == b"ssh-ed25519":
# https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
else:
raise ValueError(f"{signature_type!r} is not a valid signature type")
signature_data["signature_type"] = signature_type
return signature_data
@classmethod
def _big_int(
cls,
raw_string: bytes,
byte_order: t.Literal["big", "little"],
signed: bool = False,
) -> int:
if byte_order not in ("big", "little"):
raise ValueError(
f"Byte_order must be one of (big, little) not {byte_order}"
)
return int.from_bytes(raw_string, byte_order, signed=signed)
class _OpensshWriter:
"""Writes SSH encoded values to a bytes-like buffer
.. warning::
This class is a private API and must not be exported outside of the openssh module_utils.
It is not to be used to construct Openssh objects, but rather as a utility to assist
in validating parsed material.
"""
def __init__(self, buffer: bytearray | None = None):
if buffer is not None:
if not isinstance(buffer, bytearray):
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
else:
buffer = bytearray()
self._buff: bytearray = buffer
def boolean(self, value: bool) -> t.Self:
if not isinstance(value, bool):
raise TypeError(f"Value must be of type bool not {type(value)}")
self._buff.extend(_BOOLEAN.pack(value))
return self
def uint32(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT32_MAX:
raise ValueError(
f"Value must be a positive integer less than {_UINT32_MAX}"
)
self._buff.extend(_UINT32.pack(value))
return self
def uint64(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT64_MAX:
raise ValueError(
f"Value must be a positive integer less than {_UINT64_MAX}"
)
self._buff.extend(_UINT64.pack(value))
return self
def string(self, value: bytes | bytearray) -> t.Self:
if not isinstance(value, (bytes, bytearray)):
raise TypeError(f"Value must be bytes-like not {type(value)}")
self.uint32(len(value))
self._buff.extend(value)
return self
def mpint(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
self.string(self._int_to_mpint(value))
return self
def name_list(self, value: list[str]) -> t.Self:
if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte strings not {type(value)}")
try:
self.string(",".join(value).encode("ASCII"))
except UnicodeEncodeError as e:
raise ValueError(f"Name-list's must consist of US-ASCII characters: {e}")
return self
def string_list(self, value: list[bytes]) -> t.Self:
if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte string not {type(value)}")
writer = _OpensshWriter()
for s in value:
writer.string(s)
self.string(writer.bytes())
return self
def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self:
if not isinstance(value, list) or (value and not isinstance(value[0], tuple)):
raise TypeError("Value must be a list of tuples")
writer = _OpensshWriter()
for name, data in value:
writer.string(name)
# SSH option data is encoded twice though this behavior is not documented
writer.string(_OpensshWriter().string(data).bytes() if data else bytes())
self.string(writer.bytes())
return self
@staticmethod
def _int_to_mpint(num: int) -> bytes:
byte_length = (num.bit_length() + 7) // 8
try:
return num.to_bytes(byte_length, "big", signed=True)
# Handles values which require \x00 or \xFF to pad sign-bit
except OverflowError:
return num.to_bytes(byte_length + 1, "big", signed=True)
def bytes(self) -> bytes:
return bytes(self._buff)