Reformat everything with black.

I had to undo the u string prefix removals to not drop Python 2 compatibility.
That's why black isn't enabled in antsibull-nox.toml yet.
This commit is contained in:
Felix Fontein
2025-04-28 09:51:33 +02:00
parent 04a0d38e3b
commit aec1826c34
118 changed files with 11780 additions and 7565 deletions

View File

@@ -51,54 +51,56 @@ _USER_TYPE = 1
_HOST_TYPE = 2
_SSH_TYPE_STRINGS = {
'rsa': b"ssh-rsa",
'dsa': b"ssh-dss",
'ecdsa-nistp256': b"ecdsa-sha2-nistp256",
'ecdsa-nistp384': b"ecdsa-sha2-nistp384",
'ecdsa-nistp521': b"ecdsa-sha2-nistp521",
'ed25519': b"ssh-ed25519",
"rsa": b"ssh-rsa",
"dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
"ecdsa-nistp384": b"ecdsa-sha2-nistp384",
"ecdsa-nistp521": b"ecdsa-sha2-nistp521",
"ed25519": b"ssh-ed25519",
}
_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
_ECDSA_CURVE_IDENTIFIERS = {
'ecdsa-nistp256': b'nistp256',
'ecdsa-nistp384': b'nistp384',
'ecdsa-nistp521': b'nistp521',
"ecdsa-nistp256": b"nistp256",
"ecdsa-nistp384": b"nistp384",
"ecdsa-nistp521": b"nistp521",
}
_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
b'nistp256': 'ecdsa-nistp256',
b'nistp384': 'ecdsa-nistp384',
b'nistp521': 'ecdsa-nistp521',
b"nistp256": "ecdsa-nistp256",
b"nistp384": "ecdsa-nistp384",
b"nistp521": "ecdsa-nistp521",
}
_USE_TIMEZONE = sys.version_info >= (3, 6)
_ALWAYS = _add_or_remove_timezone(datetime(1970, 1, 1), with_timezone=_USE_TIMEZONE)
_FOREVER = datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
_FOREVER = (
datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
)
_CRITICAL_OPTIONS = (
'force-command',
'source-address',
'verify-required',
"force-command",
"source-address",
"verify-required",
)
_DIRECTIVES = (
'clear',
'no-x11-forwarding',
'no-agent-forwarding',
'no-port-forwarding',
'no-pty',
'no-user-rc',
"clear",
"no-x11-forwarding",
"no-agent-forwarding",
"no-port-forwarding",
"no-pty",
"no-user-rc",
)
_EXTENSIONS = (
'permit-x11-forwarding',
'permit-agent-forwarding',
'permit-port-forwarding',
'permit-pty',
'permit-user-rc'
"permit-x11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
)
if six.PY3:
@@ -111,13 +113,19 @@ class OpensshCertificateTimeParameters(object):
self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to:
raise ValueError("Valid from: %s must not be greater than Valid to: %s" % (valid_from, valid_to))
raise ValueError(
"Valid from: %s must not be greater than Valid to: %s"
% (valid_from, valid_to)
)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
else:
return self._valid_from == other._valid_from and self._valid_to == other._valid_to
return (
self._valid_from == other._valid_from
and self._valid_to == other._valid_to
)
def __ne__(self, other):
return not self == other
@@ -126,7 +134,8 @@ class OpensshCertificateTimeParameters(object):
def validity_string(self):
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return "%s:%s" % (
self.valid_from(date_format='openssh'), self.valid_to(date_format='openssh')
self.valid_from(date_format="openssh"),
self.valid_to(date_format="openssh"),
)
return ""
@@ -144,16 +153,22 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def format_datetime(dt, date_format):
if date_format in ('human_readable', 'openssh'):
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
result = 'always'
result = "always"
elif dt == _FOREVER:
result = 'forever'
result = "forever"
else:
result = dt.isoformat().replace('+00:00', '') if date_format == 'human_readable' else dt.strftime("%Y%m%d%H%M%S")
elif date_format == 'timestamp':
result = (
dt.isoformat().replace("+00:00", "")
if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S")
)
elif date_format == "timestamp":
td = dt - _ALWAYS
result = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6)
result = int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
)
else:
raise ValueError("%s is not a valid format" % date_format)
return result
@@ -162,12 +177,17 @@ class OpensshCertificateTimeParameters(object):
def to_datetime(time_string_or_timestamp):
try:
if isinstance(time_string_or_timestamp, six.string_types):
result = OpensshCertificateTimeParameters._time_string_to_datetime(time_string_or_timestamp.strip())
result = OpensshCertificateTimeParameters._time_string_to_datetime(
time_string_or_timestamp.strip()
)
elif isinstance(time_string_or_timestamp, (long, int)):
result = OpensshCertificateTimeParameters._timestamp_to_datetime(time_string_or_timestamp)
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
time_string_or_timestamp
)
else:
raise ValueError(
"Value must be of type (str, unicode, int, long) not %s" % type(time_string_or_timestamp)
"Value must be of type (str, unicode, int, long) not %s"
% type(time_string_or_timestamp)
)
except ValueError:
raise
@@ -182,7 +202,9 @@ class OpensshCertificateTimeParameters(object):
else:
try:
if _USE_TIMEZONE:
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
result = datetime.fromtimestamp(
timestamp, tz=_datetime.timezone.utc
)
else:
result = datetime.utcfromtimestamp(timestamp)
except OverflowError:
@@ -192,16 +214,21 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def _time_string_to_datetime(time_string):
result = None
if time_string == 'always':
if time_string == "always":
result = _ALWAYS
elif time_string == 'forever':
elif time_string == "forever":
result = _FOREVER
elif is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=_USE_TIMEZONE)
result = convert_relative_to_datetime(
time_string, with_timezone=_USE_TIMEZONE
)
else:
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(datetime.strptime(time_string, time_format), with_timezone=_USE_TIMEZONE)
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=_USE_TIMEZONE,
)
except ValueError:
pass
if result is None:
@@ -211,7 +238,7 @@ class OpensshCertificateTimeParameters(object):
class OpensshCertificateOption(object):
def __init__(self, option_type, name, data):
if option_type not in ('critical', 'extension'):
if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'")
if not isinstance(name, six.string_types):
@@ -228,11 +255,13 @@ class OpensshCertificateOption(object):
if not isinstance(other, type(self)):
return NotImplemented
return all([
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
])
return all(
[
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
]
)
def __hash__(self):
return hash((self._option_type, self._name, self._data))
@@ -260,42 +289,47 @@ class OpensshCertificateOption(object):
@classmethod
def from_string(cls, option_string):
if not isinstance(option_string, six.string_types):
raise ValueError("option_string must be a string not %s" % type(option_string))
raise ValueError(
"option_string must be a string not %s" % type(option_string)
)
option_type = None
if ':' in option_string:
option_type, value = option_string.strip().split(':', 1)
if '=' in value:
name, data = value.split('=', 1)
if ":" in option_string:
option_type, value = option_string.strip().split(":", 1)
if "=" in value:
name, data = value.split("=", 1)
else:
name, data = value, ''
elif '=' in option_string:
name, data = option_string.strip().split('=', 1)
name, data = value, ""
elif "=" in option_string:
name, data = option_string.strip().split("=", 1)
else:
name, data = option_string.strip(), ''
name, data = option_string.strip(), ""
return cls(
option_type=option_type or get_option_type(name.lower()),
name=name,
data=data
data=data,
)
@six.add_metaclass(abc.ABCMeta)
class OpensshCertificateInfo:
"""Encapsulates all certificate information which is signed by a CA key"""
def __init__(self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None):
def __init__(
self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None,
):
self.nonce = nonce
self.serial = serial
self._cert_type = cert_type
@@ -313,17 +347,17 @@ class OpensshCertificateInfo:
@property
def cert_type(self):
if self._cert_type == _USER_TYPE:
return 'user'
return "user"
elif self._cert_type == _HOST_TYPE:
return 'host'
return "host"
else:
return ''
return ""
@cert_type.setter
def cert_type(self, cert_type):
if cert_type == 'user' or cert_type == _USER_TYPE:
if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE
elif cert_type == 'host' or cert_type == _HOST_TYPE:
elif cert_type == "host" or cert_type == _HOST_TYPE:
self._cert_type = _HOST_TYPE
else:
raise ValueError("%s is not a valid certificate type" % cert_type)
@@ -343,17 +377,17 @@ class OpensshCertificateInfo:
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e=None, n=None, **kwargs):
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['rsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.e is None, self.n is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['rsa'])
writer.string(_SSH_TYPE_STRINGS["rsa"])
writer.mpint(self.e)
writer.mpint(self.n)
@@ -367,7 +401,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['dsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p
self.q = q
self.g = g
@@ -376,10 +410,10 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['dsa'])
writer.string(_SSH_TYPE_STRINGS["dsa"])
writer.mpint(self.p)
writer.mpint(self.q)
writer.mpint(self.g)
@@ -411,16 +445,20 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def curve(self, curve):
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve
self.type_string = _SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]] + _CERT_SUFFIX_V01
self.type_string = (
_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]]
+ _CERT_SUFFIX_V01
)
else:
raise ValueError(
"Curve must be one of %s" % (b','.join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode('UTF-8')
"Curve must be one of %s"
% (b",".join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode("UTF-8")
)
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.curve is None, self.public_key is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
@@ -437,15 +475,15 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk=None, **kwargs):
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['ed25519'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
def public_key_fingerprint(self):
if self.pk is None:
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['ed25519'])
writer.string(_SSH_TYPE_STRINGS["ed25519"])
writer.string(self.pk)
return fingerprint(writer.bytes())
@@ -457,6 +495,7 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate(object):
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info, signature):
self._cert_info = cert_info
@@ -468,13 +507,13 @@ class OpensshCertificate(object):
raise ValueError("%s is not a valid path." % path)
try:
with open(path, 'rb') as cert_file:
with open(path, "rb") as cert_file:
data = cert_file.read()
except (IOError, OSError) as e:
raise ValueError("%s cannot be opened for reading: %s" % (path, e))
try:
format_identifier, b64_cert = data.split(b' ')[:2]
format_identifier, b64_cert = data.split(b" ")[:2]
cert = binascii.a2b_base64(b64_cert)
except (binascii.Error, ValueError):
raise ValueError("Certificate not in OpenSSH format")
@@ -484,7 +523,9 @@ class OpensshCertificate(object):
pub_key_type = key_type
break
else:
raise ValueError("Invalid certificate format identifier: %s" % format_identifier)
raise ValueError(
"Invalid certificate format identifier: %s" % format_identifier
)
parser = OpensshParser(cert)
@@ -499,7 +540,8 @@ class OpensshCertificate(object):
if parser.remaining_bytes():
raise ValueError(
"%s bytes of additional data was not parsed while loading %s" % (parser.remaining_bytes(), path)
"%s bytes of additional data was not parsed while loading %s"
% (parser.remaining_bytes(), path)
)
return cls(
@@ -546,12 +588,16 @@ class OpensshCertificate(object):
@property
def critical_options(self):
return [
OpensshCertificateOption('critical', to_text(n), to_text(d)) for n, d in self._cert_info.critical_options
OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options
]
@property
def extensions(self):
return [OpensshCertificateOption('extension', to_text(n), to_text(d)) for n, d in self._cert_info.extensions]
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions
]
@property
def reserved(self):
@@ -564,7 +610,7 @@ class OpensshCertificate(object):
@property
def signature_type(self):
signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data['signature_type'])
return to_text(signature_data["signature_type"])
@staticmethod
def _parse_cert_info(pub_key_type, parser):
@@ -586,23 +632,24 @@ class OpensshCertificate(object):
def to_dict(self):
time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after,
valid_to=self.valid_before
valid_from=self.valid_after, valid_to=self.valid_before
)
return {
'type_string': self.type_string,
'nonce': self.nonce,
'serial': self.serial,
'cert_type': self.type,
'identifier': self.key_id,
'principals': self.principals,
'valid_after': time_parameters.valid_from(date_format='human_readable'),
'valid_before': time_parameters.valid_to(date_format='human_readable'),
'critical_options': [str(critical_option) for critical_option in self.critical_options],
'extensions': [str(extension) for extension in self.extensions],
'reserved': self.reserved,
'public_key': self.public_key,
'signing_key': self.signing_key,
"type_string": self.type_string,
"nonce": self.nonce,
"serial": self.serial,
"cert_type": self.type,
"identifier": self.key_id,
"principals": self.principals,
"valid_after": time_parameters.valid_from(date_format="human_readable"),
"valid_before": time_parameters.valid_to(date_format="human_readable"),
"critical_options": [
str(critical_option) for critical_option in self.critical_options
],
"extensions": [str(extension) for extension in self.extensions],
"reserved": self.reserved,
"public_key": self.public_key,
"signing_key": self.signing_key,
}
@@ -611,38 +658,46 @@ def apply_directives(directives):
raise ValueError("directives must be one of %s" % ", ".join(_DIRECTIVES))
directive_to_option = {
'no-x11-forwarding': OpensshCertificateOption('extension', 'permit-x11-forwarding', ''),
'no-agent-forwarding': OpensshCertificateOption('extension', 'permit-agent-forwarding', ''),
'no-port-forwarding': OpensshCertificateOption('extension', 'permit-port-forwarding', ''),
'no-pty': OpensshCertificateOption('extension', 'permit-pty', ''),
'no-user-rc': OpensshCertificateOption('extension', 'permit-user-rc', ''),
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if 'clear' in directives:
if "clear" in directives:
return []
else:
return list(set(default_options()) - set(directive_to_option[d] for d in directives))
return list(
set(default_options()) - set(directive_to_option[d] for d in directives)
)
def default_options():
return [OpensshCertificateOption('extension', name, '') for name in _EXTENSIONS]
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key):
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256()
h.update(public_key)
return b'SHA256:' + b64encode(h.digest()).rstrip(b'=')
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type):
if key_type == 'rsa':
if key_type == "rsa":
cert_info = OpensshRSACertificateInfo()
elif key_type == 'dsa':
elif key_type == "dsa":
cert_info = OpensshDSACertificateInfo()
elif key_type in ('ecdsa-nistp256', 'ecdsa-nistp384', 'ecdsa-nistp521'):
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
cert_info = OpensshECDSACertificateInfo()
elif key_type == 'ed25519':
elif key_type == "ed25519":
cert_info = OpensshED25519CertificateInfo()
else:
raise ValueError("%s is not a valid key type" % key_type)
@@ -652,12 +707,14 @@ def get_cert_info_object(key_type):
def get_option_type(name):
if name in _CRITICAL_OPTIONS:
result = 'critical'
result = "critical"
elif name in _EXTENSIONS:
result = 'extension'
result = "extension"
else:
raise ValueError("%s is not a valid option. " % name +
"Custom options must start with 'critical:' or 'extension:' to indicate type")
raise ValueError(
"%s is not a valid option. " % name
+ "Custom options must start with 'critical:' or 'extension:' to indicate type"
)
return result
@@ -675,7 +732,7 @@ def parse_option_list(option_list):
directives.append(option.lower())
else:
option_object = OpensshCertificateOption.from_string(option)
if option_object.type == 'critical':
if option_object.type == "critical":
critical_options.append(option_object)
else:
extensions.append(option_object)