mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-06 13:22:58 +00:00
Code refactoring (#889)
* Add __all__ to all module and plugin utils. * Convert quite a few positional args to keyword args. * Avoid Python 3.8+ syntax.
This commit is contained in:
@@ -99,7 +99,7 @@ def _restore_all_on_failure(
|
||||
|
||||
|
||||
class OpensshModule(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.changed: bool = False
|
||||
@@ -210,6 +210,7 @@ class KeygenCommand:
|
||||
|
||||
def generate_certificate(
|
||||
self,
|
||||
*,
|
||||
certificate_path: str,
|
||||
identifier: str,
|
||||
options: list[str] | None,
|
||||
@@ -247,7 +248,13 @@ class KeygenCommand:
|
||||
return self._run_command(args, **kwargs)
|
||||
|
||||
def generate_keypair(
|
||||
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
|
||||
self,
|
||||
*,
|
||||
private_key_path: str,
|
||||
size: int,
|
||||
type: str,
|
||||
comment: str | None,
|
||||
**kwargs,
|
||||
) -> tuple[int, str, str]:
|
||||
args = [
|
||||
self._bin_path,
|
||||
@@ -270,26 +277,29 @@ class KeygenCommand:
|
||||
return self._run_command(args, data=data, **kwargs)
|
||||
|
||||
def get_certificate_info(
|
||||
self, certificate_path: str, **kwargs
|
||||
self, *, certificate_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-L", "-f", certificate_path], **kwargs
|
||||
)
|
||||
|
||||
def get_matching_public_key(
|
||||
self, private_key_path: str, **kwargs
|
||||
self, *, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
|
||||
def get_private_key(
|
||||
self, *, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-l", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def update_comment(
|
||||
self,
|
||||
*,
|
||||
private_key_path: str,
|
||||
comment: str,
|
||||
force_new_format: bool = True,
|
||||
@@ -317,7 +327,7 @@ _PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
|
||||
|
||||
class PrivateKey:
|
||||
def __init__(
|
||||
self, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
self, *, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
) -> None:
|
||||
self._size = size
|
||||
self._type = key_type
|
||||
@@ -363,7 +373,7 @@ _PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
|
||||
|
||||
|
||||
class PublicKey:
|
||||
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
|
||||
def __init__(self, *, type_string: str, data: str, comment: str | None) -> None:
|
||||
self._type_string = type_string
|
||||
self._data = data
|
||||
self._comment = comment
|
||||
@@ -441,6 +451,7 @@ class PublicKey:
|
||||
|
||||
|
||||
def parse_private_key_format(
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
|
||||
with open(path, "r") as file:
|
||||
@@ -454,3 +465,14 @@ def parse_private_key_format(
|
||||
return "PKCS1"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
__all__ = (
|
||||
"restore_on_failure",
|
||||
"safe_atomic_move",
|
||||
"OpensshModule",
|
||||
"KeygenCommand",
|
||||
"PrivateKey",
|
||||
"PublicKey",
|
||||
"parse_private_key_format",
|
||||
)
|
||||
|
||||
@@ -53,8 +53,8 @@ if t.TYPE_CHECKING:
|
||||
|
||||
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module=module)
|
||||
|
||||
self.comment: str | None = self.module.params["comment"]
|
||||
self.private_key_path: str = self.module.params["path"]
|
||||
@@ -296,9 +296,9 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
try:
|
||||
secure_write(
|
||||
temp_public_key,
|
||||
existing_permissions or default_permissions,
|
||||
to_bytes(content),
|
||||
path=temp_public_key,
|
||||
mode=existing_permissions or default_permissions,
|
||||
content=to_bytes(content),
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
@@ -357,8 +357,8 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeypairBackendOpensshBin(KeypairBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module=module)
|
||||
|
||||
if self.module.params["private_key_format"] != "auto":
|
||||
self.module.fail_json(
|
||||
@@ -369,12 +369,16 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
self.ssh_keygen.generate_keypair(
|
||||
private_key_path, self.size, self.type, self.comment, check_rc=True
|
||||
private_key_path=private_key_path,
|
||||
size=self.size,
|
||||
type=self.type,
|
||||
comment=self.comment,
|
||||
check_rc=True,
|
||||
)
|
||||
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(
|
||||
self.private_key_path, check_rc=False
|
||||
private_key_path=self.private_key_path, check_rc=False
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(err)
|
||||
@@ -382,13 +386,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=True
|
||||
private_key_path=self.private_key_path, check_rc=True
|
||||
)[1]
|
||||
return PublicKey.from_string(public_key_content)
|
||||
|
||||
def _private_key_readable(self) -> bool:
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=False
|
||||
private_key_path=self.private_key_path, check_rc=False
|
||||
)
|
||||
return not (
|
||||
rc == 255
|
||||
@@ -407,8 +411,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
|
||||
)
|
||||
self.ssh_keygen.update_comment(
|
||||
self.private_key_path,
|
||||
self.comment or "",
|
||||
private_key_path=self.private_key_path,
|
||||
comment=self.comment or "",
|
||||
force_new_format=force_new_format,
|
||||
check_rc=True,
|
||||
)
|
||||
@@ -420,8 +424,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
|
||||
class KeypairBackendCryptography(KeypairBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module=module)
|
||||
|
||||
if self.type == "rsa1":
|
||||
self.module.fail_json(
|
||||
@@ -469,12 +473,12 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
)
|
||||
|
||||
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
|
||||
keypair.asymmetric_keypair, self.private_key_format
|
||||
asym_keypair=keypair.asymmetric_keypair, key_format=self.private_key_format
|
||||
)
|
||||
secure_write(private_key_path, 0o600, encoded_private_key)
|
||||
secure_write(path=private_key_path, mode=0o600, content=encoded_private_key)
|
||||
|
||||
public_key_path = private_key_path + ".pub"
|
||||
secure_write(public_key_path, 0o644, keypair.public_key)
|
||||
secure_write(path=public_key_path, mode=0o644, content=keypair.public_key)
|
||||
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
keypair = OpensshKeypair.load(
|
||||
@@ -485,7 +489,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
size=keypair.size,
|
||||
key_type=keypair.key_type,
|
||||
fingerprint=keypair.fingerprint,
|
||||
format=parse_private_key_format(self.private_key_path),
|
||||
format=parse_private_key_format(path=self.private_key_path),
|
||||
)
|
||||
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
@@ -550,7 +554,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
|
||||
def select_backend(
|
||||
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
*, module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
) -> KeypairBackend:
|
||||
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
|
||||
CRYPTOGRAPHY_VERSION
|
||||
@@ -573,7 +577,7 @@ def select_backend(
|
||||
if backend == "opensshbin":
|
||||
if not can_use_opensshbin:
|
||||
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
|
||||
return KeypairBackendOpensshBin(module)
|
||||
return KeypairBackendOpensshBin(module=module)
|
||||
if backend == "cryptography":
|
||||
if not can_use_cryptography:
|
||||
module.fail_json(
|
||||
@@ -581,5 +585,8 @@ def select_backend(
|
||||
f"cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION}"
|
||||
)
|
||||
)
|
||||
return KeypairBackendCryptography(module)
|
||||
return KeypairBackendCryptography(module=module)
|
||||
raise ValueError(f"Unsupported value for backend: {backend}")
|
||||
|
||||
|
||||
__all__ = ("KeypairBackend", "select_backend")
|
||||
|
||||
@@ -111,7 +111,7 @@ _EXTENSIONS = (
|
||||
|
||||
class OpensshCertificateTimeParameters:
|
||||
def __init__(
|
||||
self, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
self, *, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
) -> None:
|
||||
self._valid_from = self.to_datetime(valid_from)
|
||||
self._valid_to = self.to_datetime(valid_to)
|
||||
@@ -149,7 +149,7 @@ class OpensshCertificateTimeParameters:
|
||||
def valid_from(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_from(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_from, date_format)
|
||||
return self.format_datetime(self._valid_from, date_format=date_format)
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatStr) -> str: ...
|
||||
@@ -161,7 +161,7 @@ class OpensshCertificateTimeParameters:
|
||||
def valid_to(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_to(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_to, date_format)
|
||||
return self.format_datetime(self._valid_to, date_format=date_format)
|
||||
|
||||
def within_range(self, valid_at: str | bytes | int | None) -> bool:
|
||||
if valid_at is not None:
|
||||
@@ -171,18 +171,18 @@ class OpensshCertificateTimeParameters:
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int:
|
||||
if date_format in ("human_readable", "openssh"):
|
||||
if dt == _ALWAYS:
|
||||
return "always"
|
||||
@@ -264,6 +264,7 @@ _OpensshCertificateOption = t.TypeVar(
|
||||
class OpensshCertificateOption:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
option_type: t.Literal["critical", "extension"],
|
||||
name: str | bytes,
|
||||
data: str | bytes,
|
||||
@@ -350,6 +351,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nonce: bytes | None = None,
|
||||
serial: int | None = None,
|
||||
cert_type: int | None = None,
|
||||
@@ -409,7 +411,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
|
||||
self.e = e
|
||||
@@ -435,6 +437,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
p: int | None = None,
|
||||
q: int | None = None,
|
||||
g: int | None = None,
|
||||
@@ -471,7 +474,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(
|
||||
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
):
|
||||
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
|
||||
self._curve = None
|
||||
@@ -515,7 +518,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
|
||||
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
|
||||
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
|
||||
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
|
||||
self.pk = pk
|
||||
@@ -541,8 +544,7 @@ _OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate
|
||||
class OpensshCertificate:
|
||||
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
|
||||
|
||||
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
|
||||
def __init__(self, *, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
self._cert_info = cert_info
|
||||
self.signature = signature
|
||||
|
||||
@@ -574,7 +576,7 @@ class OpensshCertificate:
|
||||
f"Invalid certificate format identifier: {format_identifier!r}"
|
||||
)
|
||||
|
||||
parser = OpensshParser(cert)
|
||||
parser = OpensshParser(data=cert)
|
||||
|
||||
if format_identifier != parser.string():
|
||||
raise ValueError("Certificate formats do not match")
|
||||
@@ -649,7 +651,9 @@ class OpensshCertificate:
|
||||
if self._cert_info.critical_options is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("critical", to_text(n), to_text(d))
|
||||
OpensshCertificateOption(
|
||||
option_type="critical", name=to_text(n), data=to_text(d)
|
||||
)
|
||||
for n, d in self._cert_info.critical_options
|
||||
]
|
||||
|
||||
@@ -658,7 +662,9 @@ class OpensshCertificate:
|
||||
if self._cert_info.extensions is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("extension", to_text(n), to_text(d))
|
||||
OpensshCertificateOption(
|
||||
option_type="extension", name=to_text(n), data=to_text(d)
|
||||
)
|
||||
for n, d in self._cert_info.extensions
|
||||
]
|
||||
|
||||
@@ -674,7 +680,7 @@ class OpensshCertificate:
|
||||
|
||||
@property
|
||||
def signature_type(self) -> str:
|
||||
signature_data = OpensshParser.signature_data(self.signature)
|
||||
signature_data = OpensshParser.signature_data(signature_string=self.signature)
|
||||
return to_text(signature_data["signature_type"])
|
||||
|
||||
@staticmethod
|
||||
@@ -727,16 +733,20 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
|
||||
|
||||
directive_to_option = {
|
||||
"no-x11-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-x11-forwarding", ""
|
||||
option_type="extension", name="permit-x11-forwarding", data=""
|
||||
),
|
||||
"no-agent-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-agent-forwarding", ""
|
||||
option_type="extension", name="permit-agent-forwarding", data=""
|
||||
),
|
||||
"no-port-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-port-forwarding", ""
|
||||
option_type="extension", name="permit-port-forwarding", data=""
|
||||
),
|
||||
"no-pty": OpensshCertificateOption(
|
||||
option_type="extension", name="permit-pty", data=""
|
||||
),
|
||||
"no-user-rc": OpensshCertificateOption(
|
||||
option_type="extension", name="permit-user-rc", data=""
|
||||
),
|
||||
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
|
||||
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
|
||||
}
|
||||
|
||||
if "clear" in directives:
|
||||
@@ -748,7 +758,10 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
|
||||
|
||||
|
||||
def default_options() -> list[OpensshCertificateOption]:
|
||||
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
|
||||
return [
|
||||
OpensshCertificateOption(option_type="extension", name=name, data="")
|
||||
for name in _EXTENSIONS
|
||||
]
|
||||
|
||||
|
||||
def fingerprint(public_key: bytes) -> bytes:
|
||||
@@ -803,3 +816,22 @@ def parse_option_list(
|
||||
extensions.append(option_object)
|
||||
|
||||
return critical_options, list(set(extensions + apply_directives(directives)))
|
||||
|
||||
|
||||
__all__ = (
|
||||
"OpensshCertificateTimeParameters",
|
||||
"OpensshCertificateOption",
|
||||
"OpensshCertificateInfo",
|
||||
"OpensshRSACertificateInfo",
|
||||
"OpensshDSACertificateInfo",
|
||||
"OpensshECDSACertificateInfo",
|
||||
"OpensshED25519CertificateInfo",
|
||||
"OpensshCertificate",
|
||||
"apply_directives",
|
||||
"default_options",
|
||||
"fingerprint",
|
||||
"get_cert_info_object",
|
||||
"get_option_type",
|
||||
"is_relative_time_string",
|
||||
"parse_option_list",
|
||||
)
|
||||
|
||||
@@ -145,6 +145,7 @@ class AsymmetricKeypair:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
*,
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
@@ -208,6 +209,7 @@ class AsymmetricKeypair:
|
||||
@classmethod
|
||||
def load(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
private_key_format: KeySerializationFormat = "PEM",
|
||||
@@ -228,13 +230,15 @@ class AsymmetricKeypair:
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
|
||||
privatekey = load_privatekey(path, passphrase, private_key_format)
|
||||
privatekey = load_privatekey(
|
||||
path=path, passphrase=passphrase, key_format=private_key_format
|
||||
)
|
||||
if no_public_key:
|
||||
publickey = privatekey.public_key()
|
||||
else:
|
||||
# TODO: BUG: load_publickey() can return unsupported key types
|
||||
# (Also we should check whether the public key fits the private key...)
|
||||
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
|
||||
publickey = load_publickey(path=path + ".pub", key_format=public_key_format) # type: ignore
|
||||
|
||||
# Ed25519 keys are always of size 256 and do not have a key_size attribute
|
||||
if isinstance(privatekey, Ed25519PrivateKey):
|
||||
@@ -264,6 +268,7 @@ class AsymmetricKeypair:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
keytype: KeyType,
|
||||
size: int,
|
||||
privatekey: PrivateKeyTypes,
|
||||
@@ -285,7 +290,7 @@ class AsymmetricKeypair:
|
||||
self.__encryption_algorithm = encryption_algorithm
|
||||
|
||||
try:
|
||||
self.verify(self.sign(b"message"), b"message")
|
||||
self.verify(signature=self.sign(b"message"), data=b"message")
|
||||
except InvalidSignatureError:
|
||||
raise InvalidPublicKeyFileError(
|
||||
"The private key and public key of this keypair do not match"
|
||||
@@ -347,7 +352,7 @@ class AsymmetricKeypair:
|
||||
except TypeError as e:
|
||||
raise InvalidDataError(e)
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
def verify(self, *, signature: bytes, data: bytes) -> None:
|
||||
"""Verifies that the signature associated with the provided data was signed
|
||||
by the private key of this key pair.
|
||||
|
||||
@@ -384,6 +389,7 @@ class OpensshKeypair:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
*,
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
@@ -400,9 +406,15 @@ class OpensshKeypair:
|
||||
if comment is None:
|
||||
comment = f"{getuser()}@{gethostname()}"
|
||||
|
||||
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
asym_keypair = AsymmetricKeypair.generate(
|
||||
keytype=keytype, size=size, passphrase=passphrase
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(
|
||||
asym_keypair=asym_keypair, key_format="SSH"
|
||||
)
|
||||
openssh_publickey = cls.encode_openssh_publickey(
|
||||
asym_keypair=asym_keypair, comment=comment
|
||||
)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
return cls(
|
||||
@@ -416,6 +428,7 @@ class OpensshKeypair:
|
||||
@classmethod
|
||||
def load(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
no_public_key: bool = False,
|
||||
@@ -433,10 +446,18 @@ class OpensshKeypair:
|
||||
comment = extract_comment(str(path) + ".pub")
|
||||
|
||||
asym_keypair = AsymmetricKeypair.load(
|
||||
path, passphrase, "SSH", "SSH", no_public_key
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
private_key_format="SSH",
|
||||
public_key_format="SSH",
|
||||
no_public_key=no_public_key,
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(
|
||||
asym_keypair=asym_keypair, key_format="SSH"
|
||||
)
|
||||
openssh_publickey = cls.encode_openssh_publickey(
|
||||
asym_keypair=asym_keypair, comment=comment
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
return cls(
|
||||
@@ -449,7 +470,7 @@ class OpensshKeypair:
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_privatekey(
|
||||
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
*, asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded private key for a given keypair
|
||||
|
||||
@@ -482,7 +503,7 @@ class OpensshKeypair:
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_publickey(
|
||||
asym_keypair: AsymmetricKeypair, comment: str
|
||||
*, asym_keypair: AsymmetricKeypair, comment: str
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded public key for a given keypair
|
||||
|
||||
@@ -504,6 +525,7 @@ class OpensshKeypair:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
asym_keypair: AsymmetricKeypair,
|
||||
openssh_privatekey: bytes,
|
||||
openssh_publickey: bytes,
|
||||
@@ -603,11 +625,12 @@ class OpensshKeypair:
|
||||
|
||||
self.__asym_keypair.update_passphrase(passphrase)
|
||||
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
|
||||
self.__asym_keypair, "SSH"
|
||||
asym_keypair=self.__asym_keypair, key_format="SSH"
|
||||
)
|
||||
|
||||
|
||||
def load_privatekey(
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None,
|
||||
key_format: KeySerializationFormat,
|
||||
@@ -662,7 +685,7 @@ def load_privatekey(
|
||||
|
||||
|
||||
def load_publickey(
|
||||
path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
*, path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
) -> AllPublicKeyTypes:
|
||||
publickey_loaders = {
|
||||
"PEM": serialization.load_pem_public_key,
|
||||
@@ -767,3 +790,30 @@ def calculate_fingerprint(openssh_publickey: bytes) -> str:
|
||||
|
||||
value = b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip("=")
|
||||
return f"SHA256:{value}"
|
||||
|
||||
|
||||
__all__ = (
|
||||
"HAS_OPENSSH_SUPPORT",
|
||||
"CRYPTOGRAPHY_VERSION",
|
||||
"OpenSSHError",
|
||||
"InvalidAlgorithmError",
|
||||
"InvalidCommentError",
|
||||
"InvalidDataError",
|
||||
"InvalidPrivateKeyFileError",
|
||||
"InvalidPublicKeyFileError",
|
||||
"InvalidKeyFormatError",
|
||||
"InvalidKeySizeError",
|
||||
"InvalidKeyTypeError",
|
||||
"InvalidPassphraseError",
|
||||
"InvalidSignatureError",
|
||||
"AsymmetricKeypair",
|
||||
"OpensshKeypair",
|
||||
"load_privatekey",
|
||||
"load_publickey",
|
||||
"compare_publickeys",
|
||||
"compare_encryption_algorithms",
|
||||
"get_encryption_algorithm",
|
||||
"validate_comment",
|
||||
"extract_comment",
|
||||
"calculate_fingerprint",
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ def parse_openssh_version(version_string: str) -> str | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
def secure_open(*, path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
|
||||
try:
|
||||
yield fd
|
||||
@@ -78,8 +78,8 @@ def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path, mode) as fd:
|
||||
def secure_write(*, path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path=path, mode=mode) as fd:
|
||||
os.write(fd, content)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class OpensshParser:
|
||||
UINT32_OFFSET = 4
|
||||
UINT64_OFFSET = 8
|
||||
|
||||
def __init__(self, data: bytes | bytearray) -> None:
|
||||
def __init__(self, *, data: bytes | bytearray) -> None:
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
raise TypeError(f"Data must be bytes-like not {type(data)}")
|
||||
|
||||
@@ -142,7 +142,7 @@ class OpensshParser:
|
||||
raw_string = self.string()
|
||||
|
||||
if raw_string:
|
||||
parser = OpensshParser(raw_string)
|
||||
parser = OpensshParser(data=raw_string)
|
||||
while parser.remaining_bytes():
|
||||
result.append(parser.string())
|
||||
|
||||
@@ -154,14 +154,14 @@ class OpensshParser:
|
||||
raw_string = self.string()
|
||||
|
||||
if raw_string:
|
||||
parser = OpensshParser(raw_string)
|
||||
parser = OpensshParser(data=raw_string)
|
||||
|
||||
while parser.remaining_bytes():
|
||||
name = parser.string()
|
||||
data = parser.string()
|
||||
if data:
|
||||
# data is doubly-encoded
|
||||
data = OpensshParser(data).string()
|
||||
data = OpensshParser(data=data).string()
|
||||
result.append((name, data))
|
||||
|
||||
return result
|
||||
@@ -183,14 +183,14 @@ class OpensshParser:
|
||||
return self._pos + offset
|
||||
|
||||
@classmethod
|
||||
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
def signature_data(cls, *, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
signature_data: dict[str, bytes | int] = {}
|
||||
|
||||
parser = cls(signature_string)
|
||||
parser = cls(data=signature_string)
|
||||
signature_type = parser.string()
|
||||
signature_blob = parser.string()
|
||||
|
||||
blob_parser = cls(signature_blob)
|
||||
blob_parser = cls(data=signature_blob)
|
||||
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
|
||||
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
|
||||
@@ -242,7 +242,7 @@ class _OpensshWriter:
|
||||
in validating parsed material.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer: bytearray | None = None):
|
||||
def __init__(self, *, buffer: bytearray | None = None):
|
||||
if buffer is not None:
|
||||
if not isinstance(buffer, bytearray):
|
||||
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
|
||||
@@ -347,3 +347,13 @@ class _OpensshWriter:
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self._buff)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"any_in",
|
||||
"file_mode",
|
||||
"parse_openssh_version",
|
||||
"secure_open",
|
||||
"secure_write",
|
||||
"OpensshParser",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user