Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -99,7 +99,7 @@ def _restore_all_on_failure(
class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.changed: bool = False
@@ -210,6 +210,7 @@ class KeygenCommand:
def generate_certificate(
self,
*,
certificate_path: str,
identifier: str,
options: list[str] | None,
@@ -247,7 +248,13 @@ class KeygenCommand:
return self._run_command(args, **kwargs)
def generate_keypair(
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
self,
*,
private_key_path: str,
size: int,
type: str,
comment: str | None,
**kwargs,
) -> tuple[int, str, str]:
args = [
self._bin_path,
@@ -270,26 +277,29 @@ class KeygenCommand:
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(
self, certificate_path: str, **kwargs
self, *, certificate_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(
self, private_key_path: str, **kwargs
self, *, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
def get_private_key(
self, *, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(
self,
*,
private_key_path: str,
comment: str,
force_new_format: bool = True,
@@ -317,7 +327,7 @@ _PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
class PrivateKey:
def __init__(
self, size: int, key_type: str, fingerprint: str, format: str = ""
self, *, size: int, key_type: str, fingerprint: str, format: str = ""
) -> None:
self._size = size
self._type = key_type
@@ -363,7 +373,7 @@ _PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
class PublicKey:
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
def __init__(self, *, type_string: str, data: str, comment: str | None) -> None:
self._type_string = type_string
self._data = data
self._comment = comment
@@ -441,6 +451,7 @@ class PublicKey:
def parse_private_key_format(
*,
path: str | os.PathLike,
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
with open(path, "r") as file:
@@ -454,3 +465,14 @@ def parse_private_key_format(
return "PKCS1"
return ""
__all__ = (
"restore_on_failure",
"safe_atomic_move",
"OpensshModule",
"KeygenCommand",
"PrivateKey",
"PublicKey",
"parse_private_key_format",
)

View File

@@ -53,8 +53,8 @@ if t.TYPE_CHECKING:
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module=module)
self.comment: str | None = self.module.params["comment"]
self.private_key_path: str = self.module.params["path"]
@@ -296,9 +296,9 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
try:
secure_write(
temp_public_key,
existing_permissions or default_permissions,
to_bytes(content),
path=temp_public_key,
mode=existing_permissions or default_permissions,
content=to_bytes(content),
)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
@@ -357,8 +357,8 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module=module)
if self.module.params["private_key_format"] != "auto":
self.module.fail_json(
@@ -369,12 +369,16 @@ class KeypairBackendOpensshBin(KeypairBackend):
def _generate_keypair(self, private_key_path: str) -> None:
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
private_key_path=private_key_path,
size=self.size,
type=self.type,
comment=self.comment,
check_rc=True,
)
def _get_private_key(self) -> PrivateKey:
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
private_key_path=self.private_key_path, check_rc=False
)
if rc != 0:
raise ValueError(err)
@@ -382,13 +386,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
def _get_public_key(self) -> PublicKey | t.Literal[""]:
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
private_key_path=self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self) -> bool:
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
private_key_path=self.private_key_path, check_rc=False
)
return not (
rc == 255
@@ -407,8 +411,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment or "",
private_key_path=self.private_key_path,
comment=self.comment or "",
force_new_format=force_new_format,
check_rc=True,
)
@@ -420,8 +424,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module=module)
if self.type == "rsa1":
self.module.fail_json(
@@ -469,12 +473,12 @@ class KeypairBackendCryptography(KeypairBackend):
)
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
keypair.asymmetric_keypair, self.private_key_format
asym_keypair=keypair.asymmetric_keypair, key_format=self.private_key_format
)
secure_write(private_key_path, 0o600, encoded_private_key)
secure_write(path=private_key_path, mode=0o600, content=encoded_private_key)
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
secure_write(path=public_key_path, mode=0o644, content=keypair.public_key)
def _get_private_key(self) -> PrivateKey:
keypair = OpensshKeypair.load(
@@ -485,7 +489,7 @@ class KeypairBackendCryptography(KeypairBackend):
size=keypair.size,
key_type=keypair.key_type,
fingerprint=keypair.fingerprint,
format=parse_private_key_format(self.private_key_path),
format=parse_private_key_format(path=self.private_key_path),
)
def _get_public_key(self) -> PublicKey | t.Literal[""]:
@@ -550,7 +554,7 @@ class KeypairBackendCryptography(KeypairBackend):
def select_backend(
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
*, module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
) -> KeypairBackend:
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
CRYPTOGRAPHY_VERSION
@@ -573,7 +577,7 @@ def select_backend(
if backend == "opensshbin":
if not can_use_opensshbin:
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
return KeypairBackendOpensshBin(module)
return KeypairBackendOpensshBin(module=module)
if backend == "cryptography":
if not can_use_cryptography:
module.fail_json(
@@ -581,5 +585,8 @@ def select_backend(
f"cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION}"
)
)
return KeypairBackendCryptography(module)
return KeypairBackendCryptography(module=module)
raise ValueError(f"Unsupported value for backend: {backend}")
__all__ = ("KeypairBackend", "select_backend")

View File

@@ -111,7 +111,7 @@ _EXTENSIONS = (
class OpensshCertificateTimeParameters:
def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
self, *, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to)
@@ -149,7 +149,7 @@ class OpensshCertificateTimeParameters:
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format)
return self.format_datetime(self._valid_from, date_format=date_format)
@t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@@ -161,7 +161,7 @@ class OpensshCertificateTimeParameters:
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format)
return self.format_datetime(self._valid_to, date_format=date_format)
def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None:
@@ -171,18 +171,18 @@ class OpensshCertificateTimeParameters:
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
def format_datetime(dt: datetime, *, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
def format_datetime(dt: datetime, *, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
return "always"
@@ -264,6 +264,7 @@ _OpensshCertificateOption = t.TypeVar(
class OpensshCertificateOption:
def __init__(
self,
*,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
@@ -350,6 +351,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
def __init__(
self,
*,
nonce: bytes | None = None,
serial: int | None = None,
cert_type: int | None = None,
@@ -409,7 +411,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
@@ -435,6 +437,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self,
*,
p: int | None = None,
q: int | None = None,
g: int | None = None,
@@ -471,7 +474,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None
@@ -515,7 +518,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
@@ -541,8 +544,7 @@ _OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate
class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
def __init__(self, *, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info
self.signature = signature
@@ -574,7 +576,7 @@ class OpensshCertificate:
f"Invalid certificate format identifier: {format_identifier!r}"
)
parser = OpensshParser(cert)
parser = OpensshParser(data=cert)
if format_identifier != parser.string():
raise ValueError("Certificate formats do not match")
@@ -649,7 +651,9 @@ class OpensshCertificate:
if self._cert_info.critical_options is None:
raise ValueError
return [
OpensshCertificateOption("critical", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="critical", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.critical_options
]
@@ -658,7 +662,9 @@ class OpensshCertificate:
if self._cert_info.extensions is None:
raise ValueError
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="extension", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.extensions
]
@@ -674,7 +680,7 @@ class OpensshCertificate:
@property
def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature)
signature_data = OpensshParser.signature_data(signature_string=self.signature)
return to_text(signature_data["signature_type"])
@staticmethod
@@ -727,16 +733,20 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
directive_to_option = {
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
option_type="extension", name="permit-x11-forwarding", data=""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
option_type="extension", name="permit-agent-forwarding", data=""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
option_type="extension", name="permit-port-forwarding", data=""
),
"no-pty": OpensshCertificateOption(
option_type="extension", name="permit-pty", data=""
),
"no-user-rc": OpensshCertificateOption(
option_type="extension", name="permit-user-rc", data=""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if "clear" in directives:
@@ -748,7 +758,10 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
return [
OpensshCertificateOption(option_type="extension", name=name, data="")
for name in _EXTENSIONS
]
def fingerprint(public_key: bytes) -> bytes:
@@ -803,3 +816,22 @@ def parse_option_list(
extensions.append(option_object)
return critical_options, list(set(extensions + apply_directives(directives)))
__all__ = (
"OpensshCertificateTimeParameters",
"OpensshCertificateOption",
"OpensshCertificateInfo",
"OpensshRSACertificateInfo",
"OpensshDSACertificateInfo",
"OpensshECDSACertificateInfo",
"OpensshED25519CertificateInfo",
"OpensshCertificate",
"apply_directives",
"default_options",
"fingerprint",
"get_cert_info_object",
"get_option_type",
"is_relative_time_string",
"parse_option_list",
)

View File

@@ -145,6 +145,7 @@ class AsymmetricKeypair:
@classmethod
def generate(
cls: t.Type[_AsymmetricKeypair],
*,
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
@@ -208,6 +209,7 @@ class AsymmetricKeypair:
@classmethod
def load(
cls: t.Type[_AsymmetricKeypair],
*,
path: str | os.PathLike,
passphrase: bytes | None = None,
private_key_format: KeySerializationFormat = "PEM",
@@ -228,13 +230,15 @@ class AsymmetricKeypair:
else:
encryption_algorithm = serialization.NoEncryption()
privatekey = load_privatekey(path, passphrase, private_key_format)
privatekey = load_privatekey(
path=path, passphrase=passphrase, key_format=private_key_format
)
if no_public_key:
publickey = privatekey.public_key()
else:
# TODO: BUG: load_publickey() can return unsupported key types
# (Also we should check whether the public key fits the private key...)
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
publickey = load_publickey(path=path + ".pub", key_format=public_key_format) # type: ignore
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
@@ -264,6 +268,7 @@ class AsymmetricKeypair:
def __init__(
self,
*,
keytype: KeyType,
size: int,
privatekey: PrivateKeyTypes,
@@ -285,7 +290,7 @@ class AsymmetricKeypair:
self.__encryption_algorithm = encryption_algorithm
try:
self.verify(self.sign(b"message"), b"message")
self.verify(signature=self.sign(b"message"), data=b"message")
except InvalidSignatureError:
raise InvalidPublicKeyFileError(
"The private key and public key of this keypair do not match"
@@ -347,7 +352,7 @@ class AsymmetricKeypair:
except TypeError as e:
raise InvalidDataError(e)
def verify(self, signature: bytes, data: bytes) -> None:
def verify(self, *, signature: bytes, data: bytes) -> None:
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
@@ -384,6 +389,7 @@ class OpensshKeypair:
@classmethod
def generate(
cls: t.Type[_OpensshKeypair],
*,
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
@@ -400,9 +406,15 @@ class OpensshKeypair:
if comment is None:
comment = f"{getuser()}@{gethostname()}"
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
asym_keypair = AsymmetricKeypair.generate(
keytype=keytype, size=size, passphrase=passphrase
)
openssh_privatekey = cls.encode_openssh_privatekey(
asym_keypair=asym_keypair, key_format="SSH"
)
openssh_publickey = cls.encode_openssh_publickey(
asym_keypair=asym_keypair, comment=comment
)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
@@ -416,6 +428,7 @@ class OpensshKeypair:
@classmethod
def load(
cls: t.Type[_OpensshKeypair],
*,
path: str | os.PathLike,
passphrase: bytes | None = None,
no_public_key: bool = False,
@@ -433,10 +446,18 @@ class OpensshKeypair:
comment = extract_comment(str(path) + ".pub")
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
path=path,
passphrase=passphrase,
private_key_format="SSH",
public_key_format="SSH",
no_public_key=no_public_key,
)
openssh_privatekey = cls.encode_openssh_privatekey(
asym_keypair=asym_keypair, key_format="SSH"
)
openssh_publickey = cls.encode_openssh_publickey(
asym_keypair=asym_keypair, comment=comment
)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
@@ -449,7 +470,7 @@ class OpensshKeypair:
@staticmethod
def encode_openssh_privatekey(
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
*, asym_keypair: AsymmetricKeypair, key_format: KeyFormat
) -> bytes:
"""Returns an OpenSSH encoded private key for a given keypair
@@ -482,7 +503,7 @@ class OpensshKeypair:
@staticmethod
def encode_openssh_publickey(
asym_keypair: AsymmetricKeypair, comment: str
*, asym_keypair: AsymmetricKeypair, comment: str
) -> bytes:
"""Returns an OpenSSH encoded public key for a given keypair
@@ -504,6 +525,7 @@ class OpensshKeypair:
def __init__(
self,
*,
asym_keypair: AsymmetricKeypair,
openssh_privatekey: bytes,
openssh_publickey: bytes,
@@ -603,11 +625,12 @@ class OpensshKeypair:
self.__asym_keypair.update_passphrase(passphrase)
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
self.__asym_keypair, "SSH"
asym_keypair=self.__asym_keypair, key_format="SSH"
)
def load_privatekey(
*,
path: str | os.PathLike,
passphrase: bytes | None,
key_format: KeySerializationFormat,
@@ -662,7 +685,7 @@ def load_privatekey(
def load_publickey(
path: str | os.PathLike, key_format: KeySerializationFormat
*, path: str | os.PathLike, key_format: KeySerializationFormat
) -> AllPublicKeyTypes:
publickey_loaders = {
"PEM": serialization.load_pem_public_key,
@@ -767,3 +790,30 @@ def calculate_fingerprint(openssh_publickey: bytes) -> str:
value = b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip("=")
return f"SHA256:{value}"
__all__ = (
"HAS_OPENSSH_SUPPORT",
"CRYPTOGRAPHY_VERSION",
"OpenSSHError",
"InvalidAlgorithmError",
"InvalidCommentError",
"InvalidDataError",
"InvalidPrivateKeyFileError",
"InvalidPublicKeyFileError",
"InvalidKeyFormatError",
"InvalidKeySizeError",
"InvalidKeyTypeError",
"InvalidPassphraseError",
"InvalidSignatureError",
"AsymmetricKeypair",
"OpensshKeypair",
"load_privatekey",
"load_publickey",
"compare_publickeys",
"compare_encryption_algorithms",
"get_encryption_algorithm",
"validate_comment",
"extract_comment",
"calculate_fingerprint",
)

View File

@@ -70,7 +70,7 @@ def parse_openssh_version(version_string: str) -> str | None:
@contextmanager
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
def secure_open(*, path: str | os.PathLike, mode: int) -> t.Iterator[int]:
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
try:
yield fd
@@ -78,8 +78,8 @@ def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
os.close(fd)
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path, mode) as fd:
def secure_write(*, path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path=path, mode=mode) as fd:
os.write(fd, content)
@@ -91,7 +91,7 @@ class OpensshParser:
UINT32_OFFSET = 4
UINT64_OFFSET = 8
def __init__(self, data: bytes | bytearray) -> None:
def __init__(self, *, data: bytes | bytearray) -> None:
if not isinstance(data, (bytes, bytearray)):
raise TypeError(f"Data must be bytes-like not {type(data)}")
@@ -142,7 +142,7 @@ class OpensshParser:
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
parser = OpensshParser(data=raw_string)
while parser.remaining_bytes():
result.append(parser.string())
@@ -154,14 +154,14 @@ class OpensshParser:
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
parser = OpensshParser(data=raw_string)
while parser.remaining_bytes():
name = parser.string()
data = parser.string()
if data:
# data is doubly-encoded
data = OpensshParser(data).string()
data = OpensshParser(data=data).string()
result.append((name, data))
return result
@@ -183,14 +183,14 @@ class OpensshParser:
return self._pos + offset
@classmethod
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
def signature_data(cls, *, signature_string: bytes) -> dict[str, bytes | int]:
signature_data: dict[str, bytes | int] = {}
parser = cls(signature_string)
parser = cls(data=signature_string)
signature_type = parser.string()
signature_blob = parser.string()
blob_parser = cls(signature_blob)
blob_parser = cls(data=signature_blob)
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
@@ -242,7 +242,7 @@ class _OpensshWriter:
in validating parsed material.
"""
def __init__(self, buffer: bytearray | None = None):
def __init__(self, *, buffer: bytearray | None = None):
if buffer is not None:
if not isinstance(buffer, bytearray):
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
@@ -347,3 +347,13 @@ class _OpensshWriter:
def bytes(self) -> bytes:
return bytes(self._buff)
__all__ = (
"any_in",
"file_mode",
"parse_openssh_version",
"secure_open",
"secure_write",
"OpensshParser",
)