Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -111,7 +111,7 @@ _EXTENSIONS = (
class OpensshCertificateTimeParameters:
def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
self, *, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to)
@@ -149,7 +149,7 @@ class OpensshCertificateTimeParameters:
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format)
return self.format_datetime(self._valid_from, date_format=date_format)
@t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@@ -161,7 +161,7 @@ class OpensshCertificateTimeParameters:
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format)
return self.format_datetime(self._valid_to, date_format=date_format)
def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None:
@@ -171,18 +171,18 @@ class OpensshCertificateTimeParameters:
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
def format_datetime(dt: datetime, *, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
def format_datetime(dt: datetime, *, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
return "always"
@@ -264,6 +264,7 @@ _OpensshCertificateOption = t.TypeVar(
class OpensshCertificateOption:
def __init__(
self,
*,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
@@ -350,6 +351,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
def __init__(
self,
*,
nonce: bytes | None = None,
serial: int | None = None,
cert_type: int | None = None,
@@ -409,7 +411,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
@@ -435,6 +437,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self,
*,
p: int | None = None,
q: int | None = None,
g: int | None = None,
@@ -471,7 +474,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None
@@ -515,7 +518,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
@@ -541,8 +544,7 @@ _OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate
class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
def __init__(self, *, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info
self.signature = signature
@@ -574,7 +576,7 @@ class OpensshCertificate:
f"Invalid certificate format identifier: {format_identifier!r}"
)
parser = OpensshParser(cert)
parser = OpensshParser(data=cert)
if format_identifier != parser.string():
raise ValueError("Certificate formats do not match")
@@ -649,7 +651,9 @@ class OpensshCertificate:
if self._cert_info.critical_options is None:
raise ValueError
return [
OpensshCertificateOption("critical", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="critical", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.critical_options
]
@@ -658,7 +662,9 @@ class OpensshCertificate:
if self._cert_info.extensions is None:
raise ValueError
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="extension", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.extensions
]
@@ -674,7 +680,7 @@ class OpensshCertificate:
@property
def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature)
signature_data = OpensshParser.signature_data(signature_string=self.signature)
return to_text(signature_data["signature_type"])
@staticmethod
@@ -727,16 +733,20 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
directive_to_option = {
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
option_type="extension", name="permit-x11-forwarding", data=""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
option_type="extension", name="permit-agent-forwarding", data=""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
option_type="extension", name="permit-port-forwarding", data=""
),
"no-pty": OpensshCertificateOption(
option_type="extension", name="permit-pty", data=""
),
"no-user-rc": OpensshCertificateOption(
option_type="extension", name="permit-user-rc", data=""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if "clear" in directives:
@@ -748,7 +758,10 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
return [
OpensshCertificateOption(option_type="extension", name=name, data="")
for name in _EXTENSIONS
]
def fingerprint(public_key: bytes) -> bytes:
@@ -803,3 +816,22 @@ def parse_option_list(
extensions.append(option_object)
return critical_options, list(set(extensions + apply_directives(directives)))
__all__ = (
"OpensshCertificateTimeParameters",
"OpensshCertificateOption",
"OpensshCertificateInfo",
"OpensshRSACertificateInfo",
"OpensshDSACertificateInfo",
"OpensshECDSACertificateInfo",
"OpensshED25519CertificateInfo",
"OpensshCertificate",
"apply_directives",
"default_options",
"fingerprint",
"get_cert_info_object",
"get_option_type",
"is_relative_time_string",
"parse_option_list",
)