Reformat everything with black.

I had to undo the u string prefix removals to not drop Python 2 compatibility.
That's why black isn't enabled in antsibull-nox.toml yet.
This commit is contained in:
Felix Fontein
2025-04-28 09:51:33 +02:00
parent 04a0d38e3b
commit aec1826c34
118 changed files with 11780 additions and 7565 deletions

View File

@@ -44,17 +44,24 @@ def safe_atomic_move(module, path, destination):
def _restore_all_on_failure(f):
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
backups = [(d, self.module.backup_local(d)) for s, d in sources_and_destinations if os.path.exists(d)]
backups = [
(d, self.module.backup_local(d))
for s, d in sources_and_destinations
if os.path.exists(d)
]
try:
f(self, sources_and_destinations, *args, **kwargs)
except Exception:
for destination, backup in backups:
self.module.atomic_move(os.path.abspath(backup), os.path.abspath(destination))
self.module.atomic_move(
os.path.abspath(backup), os.path.abspath(destination)
)
raise
else:
for destination, backup in backups:
self.module.add_cleanup_file(backup)
return backup_and_restore
@@ -85,10 +92,10 @@ class OpensshModule(object):
def result(self):
result = self._result
result['changed'] = self.changed
result["changed"] = self.changed
if self.module._diff:
result['diff'] = self.diff
result["diff"] = self.diff
return result
@@ -107,6 +114,7 @@ class OpensshModule(object):
def wrapper(self, *args, **kwargs):
if not self.check_mode:
f(self, *args, **kwargs)
return wrapper
@staticmethod
@@ -114,72 +122,92 @@ class OpensshModule(object):
def wrapper(self, *args, **kwargs):
f(self, *args, **kwargs)
self.changed = True
return wrapper
def _check_if_base_dir(self, path):
base_dir = os.path.dirname(path) or '.'
base_dir = os.path.dirname(path) or "."
if not os.path.isdir(base_dir):
self.module.fail_json(
name=base_dir,
msg='The directory %s does not exist or the file is not a directory' % base_dir
msg="The directory %s does not exist or the file is not a directory"
% base_dir,
)
def _get_ssh_version(self):
ssh_bin = self.module.get_bin_path('ssh')
ssh_bin = self.module.get_bin_path("ssh")
if not ssh_bin:
return ""
return parse_openssh_version(self.module.run_command([ssh_bin, '-V', '-q'], check_rc=True)[2].strip())
return parse_openssh_version(
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
)
@_restore_all_on_failure
def _safe_secure_move(self, sources_and_destinations):
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
permissions if 'destination' does not already exists).
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
permissions if 'destination' does not already exists).
"""
for source, destination in sources_and_destinations:
if os.path.exists(destination):
self.module.atomic_move(os.path.abspath(source), os.path.abspath(destination))
self.module.atomic_move(
os.path.abspath(source), os.path.abspath(destination)
)
else:
self.module.preserved_copy(source, destination)
def _update_permissions(self, path):
file_args = self.module.load_file_common_arguments(self.module.params)
file_args['path'] = path
file_args["path"] = path
if not self.module.check_file_absent_if_check_mode(path):
self.changed = self.module.set_fs_attributes_if_different(file_args, self.changed)
self.changed = self.module.set_fs_attributes_if_different(
file_args, self.changed
)
else:
self.changed = True
class KeygenCommand(object):
def __init__(self, module):
self._bin_path = module.get_bin_path('ssh-keygen', True)
self._bin_path = module.get_bin_path("ssh-keygen", True)
self._run_command = module.run_command
def generate_certificate(self, certificate_path, identifier, options, pkcs11_provider, principals,
serial_number, signature_algorithm, signing_key_path, type,
time_parameters, use_agent, **kwargs):
args = [self._bin_path, '-s', signing_key_path, '-P', '', '-I', identifier]
def generate_certificate(
self,
certificate_path,
identifier,
options,
pkcs11_provider,
principals,
serial_number,
signature_algorithm,
signing_key_path,
type,
time_parameters,
use_agent,
**kwargs
):
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
if options:
for option in options:
args.extend(['-O', option])
args.extend(["-O", option])
if pkcs11_provider:
args.extend(['-D', pkcs11_provider])
args.extend(["-D", pkcs11_provider])
if principals:
args.extend(['-n', ','.join(principals)])
args.extend(["-n", ",".join(principals)])
if serial_number is not None:
args.extend(['-z', str(serial_number)])
if type == 'host':
args.extend(['-h'])
args.extend(["-z", str(serial_number)])
if type == "host":
args.extend(["-h"])
if use_agent:
args.extend(['-U'])
args.extend(["-U"])
if time_parameters.validity_string:
args.extend(['-V', time_parameters.validity_string])
args.extend(["-V", time_parameters.validity_string])
if signature_algorithm:
args.extend(['-t', signature_algorithm])
args.extend(["-t", signature_algorithm])
args.append(certificate_path)
return self._run_command(args, **kwargs)
@@ -187,44 +215,62 @@ class KeygenCommand(object):
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
args = [
self._bin_path,
'-q',
'-N', '',
'-b', str(size),
'-t', type,
'-f', private_key_path,
'-C', comment or ''
"-q",
"-N",
"",
"-b",
str(size),
"-t",
type,
"-f",
private_key_path,
"-C",
comment or "",
]
# "y" must be entered in response to the "overwrite" prompt
data = 'y' if os.path.exists(private_key_path) else None
data = "y" if os.path.exists(private_key_path) else None
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(self, certificate_path, **kwargs):
return self._run_command([self._bin_path, '-L', '-f', certificate_path], **kwargs)
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(self, private_key_path, **kwargs):
return self._run_command([self._bin_path, '-P', '', '-y', '-f', private_key_path], **kwargs)
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path, **kwargs):
return self._run_command([self._bin_path, '-l', '-f', private_key_path], **kwargs)
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(self, private_key_path, comment, force_new_format=True, **kwargs):
if os.path.exists(private_key_path) and not os.access(private_key_path, os.W_OK):
def update_comment(
self, private_key_path, comment, force_new_format=True, **kwargs
):
if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK
):
try:
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
except (IOError, OSError) as e:
raise e("The private key at %s is not writeable preventing a comment update" % private_key_path)
raise e(
"The private key at %s is not writeable preventing a comment update"
% private_key_path
)
command = [self._bin_path, '-q']
command = [self._bin_path, "-q"]
if force_new_format:
command.append('-o')
command.extend(['-c', '-C', comment, '-f', private_key_path])
command.append("-o")
command.extend(["-c", "-C", comment, "-f", private_key_path])
return self._run_command(command, **kwargs)
class PrivateKey(object):
def __init__(self, size, key_type, fingerprint, format=''):
def __init__(self, size, key_type, fingerprint, format=""):
self._size = size
self._type = key_type
self._fingerprint = fingerprint
@@ -258,10 +304,10 @@ class PrivateKey(object):
def to_dict(self):
return {
'size': self._size,
'type': self._type,
'fingerprint': self._fingerprint,
'format': self._format,
"size": self._size,
"type": self._type,
"fingerprint": self._fingerprint,
"format": self._format,
}
@@ -275,11 +321,17 @@ class PublicKey(object):
if not isinstance(other, type(self)):
return NotImplemented
return all([
self._type_string == other._type_string,
self._data == other._data,
(self._comment == other._comment) if self._comment is not None and other._comment is not None else True
])
return all(
[
self._type_string == other._type_string,
self._data == other._data,
(
(self._comment == other._comment)
if self._comment is not None and other._comment is not None
else True
),
]
)
def __ne__(self, other):
return not self == other
@@ -305,19 +357,19 @@ class PublicKey(object):
@classmethod
def from_string(cls, string):
properties = string.strip('\n').split(' ', 2)
properties = string.strip("\n").split(" ", 2)
return cls(
type_string=properties[0],
data=properties[1],
comment=properties[2] if len(properties) > 2 else ""
comment=properties[2] if len(properties) > 2 else "",
)
@classmethod
def load(cls, path):
try:
with open(path, 'r') as f:
properties = f.read().strip(' \n').split(' ', 2)
with open(path, "r") as f:
properties = f.read().strip(" \n").split(" ", 2)
except (IOError, OSError):
raise
@@ -327,25 +379,25 @@ class PublicKey(object):
return cls(
type_string=properties[0],
data=properties[1],
comment='' if len(properties) <= 2 else properties[2],
comment="" if len(properties) <= 2 else properties[2],
)
def to_dict(self):
return {
'comment': self._comment,
'public_key': self._data,
"comment": self._comment,
"public_key": self._data,
}
def parse_private_key_format(path):
with open(path, 'r') as file:
with open(path, "r") as file:
header = file.readline().strip()
if header == '-----BEGIN OPENSSH PRIVATE KEY-----':
return 'SSH'
elif header == '-----BEGIN PRIVATE KEY-----':
return 'PKCS8'
elif header == '-----BEGIN RSA PRIVATE KEY-----':
return 'PKCS1'
if header == "-----BEGIN OPENSSH PRIVATE KEY-----":
return "SSH"
elif header == "-----BEGIN PRIVATE KEY-----":
return "PKCS8"
elif header == "-----BEGIN RSA PRIVATE KEY-----":
return "PKCS1"
return ''
return ""

View File

@@ -48,14 +48,18 @@ class KeypairBackend(OpensshModule):
def __init__(self, module):
super(KeypairBackend, self).__init__(module)
self.comment = self.module.params['comment']
self.private_key_path = self.module.params['path']
self.public_key_path = self.private_key_path + '.pub'
self.regenerate = self.module.params['regenerate'] if not self.module.params['force'] else 'always'
self.state = self.module.params['state']
self.type = self.module.params['type']
self.comment = self.module.params["comment"]
self.private_key_path = self.module.params["path"]
self.public_key_path = self.private_key_path + ".pub"
self.regenerate = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.size = self._get_size(self.module.params['size'])
self.size = self._get_size(self.module.params["size"])
self._validate_path()
self.original_private_key = None
@@ -64,31 +68,35 @@ class KeypairBackend(OpensshModule):
self.public_key = None
def _get_size(self, size):
if self.type in ('rsa', 'rsa1'):
if self.type in ("rsa", "rsa1"):
result = 4096 if size is None else size
if result < 1024:
return self.module.fail_json(
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. " +
"Attempting to use bit lengths under 1024 will cause the module to fail."
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. "
+ "Attempting to use bit lengths under 1024 will cause the module to fail."
)
elif self.type == 'dsa':
elif self.type == "dsa":
result = 1024 if size is None else size
if result != 1024:
return self.module.fail_json(msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2.")
elif self.type == 'ecdsa':
return self.module.fail_json(
msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2."
)
elif self.type == "ecdsa":
result = 256 if size is None else size
if result not in (256, 384, 521):
return self.module.fail_json(
msg="For ECDSA keys, size determines the key length by selecting from one of " +
"three elliptic curve sizes: 256, 384 or 521 bits. " +
"Attempting to use bit lengths other than these three values for ECDSA keys will " +
"cause this module to fail."
msg="For ECDSA keys, size determines the key length by selecting from one of "
+ "three elliptic curve sizes: 256, 384 or 521 bits. "
+ "Attempting to use bit lengths other than these three values for ECDSA keys will "
+ "cause this module to fail."
)
elif self.type == 'ed25519':
elif self.type == "ed25519":
# User input is ignored for `key size` when `key type` is ed25519
result = 256
else:
return self.module.fail_json(msg="%s is not a valid value for key type" % self.type)
return self.module.fail_json(
msg="%s is not a valid value for key type" % self.type
)
return result
@@ -96,13 +104,16 @@ class KeypairBackend(OpensshModule):
self._check_if_base_dir(self.private_key_path)
if os.path.isdir(self.private_key_path):
self.module.fail_json(msg='%s is a directory. Please specify a path to a file.' % self.private_key_path)
self.module.fail_json(
msg="%s is a directory. Please specify a path to a file."
% self.private_key_path
)
def _execute(self):
self.original_private_key = self._load_private_key()
self.original_public_key = self._load_public_key()
if self.state == 'present':
if self.state == "present":
self._validate_key_load()
if self._should_generate():
@@ -149,13 +160,15 @@ class KeypairBackend(OpensshModule):
return os.path.exists(self.public_key_path)
def _validate_key_load(self):
if (self._private_key_exists()
and self.regenerate in ('never', 'fail', 'partial_idempotence')
and (self.original_private_key is None or not self._private_key_readable())):
if (
self._private_key_exists()
and self.regenerate in ("never", "fail", "partial_idempotence")
and (self.original_private_key is None or not self._private_key_readable())
):
self.module.fail_json(
msg="Unable to read the key. The key is protected with a passphrase or broken. " +
"Will not proceed. To force regeneration, call the module with `generate` " +
"set to `full_idempotence` or `always`, or with `force=true`."
msg="Unable to read the key. The key is protected with a passphrase or broken. "
+ "Will not proceed. To force regeneration, call the module with `generate` "
+ "set to `full_idempotence` or `always`, or with `force=true`."
)
@abc.abstractmethod
@@ -165,17 +178,17 @@ class KeypairBackend(OpensshModule):
def _should_generate(self):
if self.original_private_key is None:
return True
elif self.regenerate == 'never':
elif self.regenerate == "never":
return False
elif self.regenerate == 'fail':
elif self.regenerate == "fail":
if not self._private_key_valid():
self.module.fail_json(
msg="Key has wrong type and/or size. Will not proceed. " +
"To force regeneration, call the module with `generate` set to " +
"`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
msg="Key has wrong type and/or size. Will not proceed. "
+ "To force regeneration, call the module with `generate` set to "
+ "`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
)
return False
elif self.regenerate in ('partial_idempotence', 'full_idempotence'):
elif self.regenerate in ("partial_idempotence", "full_idempotence"):
return not self._private_key_valid()
else:
return True
@@ -184,11 +197,13 @@ class KeypairBackend(OpensshModule):
if self.original_private_key is None:
return False
return all([
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(),
])
return all(
[
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(),
]
)
@abc.abstractmethod
def _private_key_valid_backend(self):
@@ -200,13 +215,20 @@ class KeypairBackend(OpensshModule):
temp_private_key, temp_public_key = self._generate_temp_keypair()
try:
self._safe_secure_move([(temp_private_key, self.private_key_path), (temp_public_key, self.public_key_path)])
self._safe_secure_move(
[
(temp_private_key, self.private_key_path),
(temp_public_key, self.public_key_path),
]
)
except OSError as e:
self.module.fail_json(msg=to_native(e))
def _generate_temp_keypair(self):
temp_private_key = os.path.join(self.module.tmpdir, os.path.basename(self.private_key_path))
temp_public_key = temp_private_key + '.pub'
temp_private_key = os.path.join(
self.module.tmpdir, os.path.basename(self.private_key_path)
)
temp_public_key = temp_private_key + ".pub"
try:
self._generate_keypair(temp_private_key)
@@ -239,27 +261,33 @@ class KeypairBackend(OpensshModule):
@OpensshModule.skip_if_check_mode
def _restore_public_key(self):
try:
temp_public_key = self._create_temp_public_key(str(self._get_public_key()) + '\n')
self._safe_secure_move([
(temp_public_key, self.public_key_path)
])
temp_public_key = self._create_temp_public_key(
str(self._get_public_key()) + "\n"
)
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError):
self.module.fail_json(
msg="The public key is missing or does not match the private key. " +
"Unable to regenerate the public key."
msg="The public key is missing or does not match the private key. "
+ "Unable to regenerate the public key."
)
if self.comment:
self._update_comment()
def _create_temp_public_key(self, content):
temp_public_key = os.path.join(self.module.tmpdir, os.path.basename(self.public_key_path))
temp_public_key = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path)
)
default_permissions = 0o644
existing_permissions = file_mode(self.public_key_path)
try:
secure_write(temp_public_key, existing_permissions or default_permissions, to_bytes(content))
secure_write(
temp_public_key,
existing_permissions or default_permissions,
to_bytes(content),
)
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
self.module.add_cleanup_file(temp_public_key)
@@ -290,25 +318,29 @@ class KeypairBackend(OpensshModule):
public_key = self.public_key or self.original_public_key
return {
'size': self.size,
'type': self.type,
'filename': self.private_key_path,
'fingerprint': private_key.fingerprint if private_key else '',
'public_key': str(public_key) if public_key else '',
'comment': public_key.comment if public_key else '',
"size": self.size,
"type": self.type,
"filename": self.private_key_path,
"fingerprint": private_key.fingerprint if private_key else "",
"public_key": str(public_key) if public_key else "",
"comment": public_key.comment if public_key else "",
}
@property
def diff(self):
before = self.original_private_key.to_dict() if self.original_private_key else {}
before.update(self.original_public_key.to_dict() if self.original_public_key else {})
before = (
self.original_private_key.to_dict() if self.original_private_key else {}
)
before.update(
self.original_public_key.to_dict() if self.original_public_key else {}
)
after = self.private_key.to_dict() if self.private_key else {}
after.update(self.public_key.to_dict() if self.public_key else {})
return {
'before': before,
'after': after,
"before": before,
"after": after,
}
@@ -316,36 +348,59 @@ class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module):
super(KeypairBackendOpensshBin, self).__init__(module)
if self.module.params['private_key_format'] != 'auto':
if self.module.params["private_key_format"] != "auto":
self.module.fail_json(
msg="'auto' is the only valid option for " +
"'private_key_format' when 'backend' is not 'cryptography'"
msg="'auto' is the only valid option for "
+ "'private_key_format' when 'backend' is not 'cryptography'"
)
self.ssh_keygen = KeygenCommand(self.module)
def _generate_keypair(self, private_key_path):
self.ssh_keygen.generate_keypair(private_key_path, self.size, self.type, self.comment, check_rc=True)
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
)
def _get_private_key(self):
rc, private_key_content, err = self.ssh_keygen.get_private_key(self.private_key_path, check_rc=False)
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
)
if rc != 0:
raise ValueError(err)
return PrivateKey.from_string(private_key_content)
def _get_public_key(self):
public_key_content = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=True)[1]
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self):
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=False)
return not (rc == 255 or any_in(stderr, 'is not a public key file', 'incorrect passphrase', 'load failed'))
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
)
return not (
rc == 255
or any_in(
stderr,
"is not a public key file",
"incorrect passphrase",
"load failed",
)
)
def _update_comment(self):
try:
ssh_version = self._get_ssh_version() or "7.8"
force_new_format = LooseVersion('6.5') <= LooseVersion(ssh_version) < LooseVersion('7.8')
self.ssh_keygen.update_comment(self.private_key_path, self.comment, force_new_format=force_new_format, check_rc=True)
force_new_format = (
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment,
force_new_format=force_new_format,
check_rc=True,
)
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
@@ -357,30 +412,41 @@ class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module):
super(KeypairBackendCryptography, self).__init__(module)
if self.type == 'rsa1':
self.module.fail_json(msg="RSA1 keys are not supported by the cryptography backend")
if self.type == "rsa1":
self.module.fail_json(
msg="RSA1 keys are not supported by the cryptography backend"
)
self.passphrase = to_bytes(module.params['passphrase']) if module.params['passphrase'] else None
self.private_key_format = self._get_key_format(module.params['private_key_format'])
self.passphrase = (
to_bytes(module.params["passphrase"])
if module.params["passphrase"]
else None
)
self.private_key_format = self._get_key_format(
module.params["private_key_format"]
)
def _get_key_format(self, key_format):
result = 'SSH'
result = "SSH"
if key_format == 'auto':
if key_format == "auto":
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
ssh_version = self._get_ssh_version() or "7.8"
if LooseVersion(ssh_version) < LooseVersion("7.8") and self.type != 'ed25519':
if (
LooseVersion(ssh_version) < LooseVersion("7.8")
and self.type != "ed25519"
):
# OpenSSH made SSH formatted private keys available in version 6.5,
# but still defaulted to PKCS1 format with the exception of ed25519 keys
result = 'PKCS1'
result = "PKCS1"
if result == 'SSH' and not HAS_OPENSSH_PRIVATE_FORMAT:
if result == "SSH" and not HAS_OPENSSH_PRIVATE_FORMAT:
self.module.fail_json(
msg=missing_required_lib(
'cryptography >= 3.0',
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 " +
"or for ed25519 keys"
"cryptography >= 3.0",
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 "
+ "or for ed25519 keys",
)
)
else:
@@ -393,7 +459,7 @@ class KeypairBackendCryptography(KeypairBackend):
keytype=self.type,
size=self.size,
passphrase=self.passphrase,
comment=self.comment or '',
comment=self.comment or "",
)
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
@@ -401,22 +467,28 @@ class KeypairBackendCryptography(KeypairBackend):
)
secure_write(private_key_path, 0o600, encoded_private_key)
public_key_path = private_key_path + '.pub'
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
def _get_private_key(self):
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
return PrivateKey(
size=keypair.size,
key_type=keypair.key_type,
fingerprint=keypair.fingerprint,
format=parse_private_key_format(self.private_key_path)
format=parse_private_key_format(self.private_key_path),
)
def _get_public_key(self):
try:
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except OpenSSHError:
# Simulates the null output of ssh-keygen
return ""
@@ -425,7 +497,11 @@ class KeypairBackendCryptography(KeypairBackend):
def _private_key_readable(self):
try:
OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return False
@@ -433,7 +509,9 @@ class KeypairBackendCryptography(KeypairBackend):
# when loading an unencrypted key
if self.passphrase:
try:
OpensshKeypair.load(path=self.private_key_path, passphrase=None, no_public_key=True)
OpensshKeypair.load(
path=self.private_key_path, passphrase=None, no_public_key=True
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return True
else:
@@ -442,14 +520,16 @@ class KeypairBackendCryptography(KeypairBackend):
return True
def _update_comment(self):
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
try:
keypair.comment = self.comment
except InvalidCommentError as e:
self.module.fail_json(msg=to_native(e))
try:
temp_public_key = self._create_temp_public_key(keypair.public_key + b'\n')
temp_public_key = self._create_temp_public_key(keypair.public_key + b"\n")
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
@@ -457,7 +537,7 @@ class KeypairBackendCryptography(KeypairBackend):
def _private_key_valid_backend(self):
# avoids breaking behavior and prevents
# automatic conversions with OpenSSH upgrades
if self.module.params['private_key_format'] == 'auto':
if self.module.params["private_key_format"] == "auto":
return True
return self.private_key_format == self.original_private_key.format
@@ -465,24 +545,26 @@ class KeypairBackendCryptography(KeypairBackend):
def select_backend(module, backend):
can_use_cryptography = HAS_OPENSSH_SUPPORT
can_use_opensshbin = bool(module.get_bin_path('ssh-keygen'))
can_use_opensshbin = bool(module.get_bin_path("ssh-keygen"))
if backend == 'auto':
if can_use_opensshbin and not module.params['passphrase']:
backend = 'opensshbin'
if backend == "auto":
if can_use_opensshbin and not module.params["passphrase"]:
backend = "opensshbin"
elif can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
else:
module.fail_json(msg="Cannot find either the OpenSSH binary in the PATH " +
"or cryptography >= 2.6 installed on this system")
module.fail_json(
msg="Cannot find either the OpenSSH binary in the PATH "
+ "or cryptography >= 2.6 installed on this system"
)
if backend == 'opensshbin':
if backend == "opensshbin":
if not can_use_opensshbin:
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
return backend, KeypairBackendOpensshBin(module)
elif backend == 'cryptography':
elif backend == "cryptography":
if not can_use_cryptography:
module.fail_json(msg=missing_required_lib("cryptography >= 2.6"))
return backend, KeypairBackendCryptography(module)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -51,54 +51,56 @@ _USER_TYPE = 1
_HOST_TYPE = 2
_SSH_TYPE_STRINGS = {
'rsa': b"ssh-rsa",
'dsa': b"ssh-dss",
'ecdsa-nistp256': b"ecdsa-sha2-nistp256",
'ecdsa-nistp384': b"ecdsa-sha2-nistp384",
'ecdsa-nistp521': b"ecdsa-sha2-nistp521",
'ed25519': b"ssh-ed25519",
"rsa": b"ssh-rsa",
"dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
"ecdsa-nistp384": b"ecdsa-sha2-nistp384",
"ecdsa-nistp521": b"ecdsa-sha2-nistp521",
"ed25519": b"ssh-ed25519",
}
_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
_ECDSA_CURVE_IDENTIFIERS = {
'ecdsa-nistp256': b'nistp256',
'ecdsa-nistp384': b'nistp384',
'ecdsa-nistp521': b'nistp521',
"ecdsa-nistp256": b"nistp256",
"ecdsa-nistp384": b"nistp384",
"ecdsa-nistp521": b"nistp521",
}
_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
b'nistp256': 'ecdsa-nistp256',
b'nistp384': 'ecdsa-nistp384',
b'nistp521': 'ecdsa-nistp521',
b"nistp256": "ecdsa-nistp256",
b"nistp384": "ecdsa-nistp384",
b"nistp521": "ecdsa-nistp521",
}
_USE_TIMEZONE = sys.version_info >= (3, 6)
_ALWAYS = _add_or_remove_timezone(datetime(1970, 1, 1), with_timezone=_USE_TIMEZONE)
_FOREVER = datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
_FOREVER = (
datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
)
_CRITICAL_OPTIONS = (
'force-command',
'source-address',
'verify-required',
"force-command",
"source-address",
"verify-required",
)
_DIRECTIVES = (
'clear',
'no-x11-forwarding',
'no-agent-forwarding',
'no-port-forwarding',
'no-pty',
'no-user-rc',
"clear",
"no-x11-forwarding",
"no-agent-forwarding",
"no-port-forwarding",
"no-pty",
"no-user-rc",
)
_EXTENSIONS = (
'permit-x11-forwarding',
'permit-agent-forwarding',
'permit-port-forwarding',
'permit-pty',
'permit-user-rc'
"permit-x11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
)
if six.PY3:
@@ -111,13 +113,19 @@ class OpensshCertificateTimeParameters(object):
self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to:
raise ValueError("Valid from: %s must not be greater than Valid to: %s" % (valid_from, valid_to))
raise ValueError(
"Valid from: %s must not be greater than Valid to: %s"
% (valid_from, valid_to)
)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
else:
return self._valid_from == other._valid_from and self._valid_to == other._valid_to
return (
self._valid_from == other._valid_from
and self._valid_to == other._valid_to
)
def __ne__(self, other):
return not self == other
@@ -126,7 +134,8 @@ class OpensshCertificateTimeParameters(object):
def validity_string(self):
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return "%s:%s" % (
self.valid_from(date_format='openssh'), self.valid_to(date_format='openssh')
self.valid_from(date_format="openssh"),
self.valid_to(date_format="openssh"),
)
return ""
@@ -144,16 +153,22 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def format_datetime(dt, date_format):
if date_format in ('human_readable', 'openssh'):
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
result = 'always'
result = "always"
elif dt == _FOREVER:
result = 'forever'
result = "forever"
else:
result = dt.isoformat().replace('+00:00', '') if date_format == 'human_readable' else dt.strftime("%Y%m%d%H%M%S")
elif date_format == 'timestamp':
result = (
dt.isoformat().replace("+00:00", "")
if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S")
)
elif date_format == "timestamp":
td = dt - _ALWAYS
result = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6)
result = int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
)
else:
raise ValueError("%s is not a valid format" % date_format)
return result
@@ -162,12 +177,17 @@ class OpensshCertificateTimeParameters(object):
def to_datetime(time_string_or_timestamp):
try:
if isinstance(time_string_or_timestamp, six.string_types):
result = OpensshCertificateTimeParameters._time_string_to_datetime(time_string_or_timestamp.strip())
result = OpensshCertificateTimeParameters._time_string_to_datetime(
time_string_or_timestamp.strip()
)
elif isinstance(time_string_or_timestamp, (long, int)):
result = OpensshCertificateTimeParameters._timestamp_to_datetime(time_string_or_timestamp)
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
time_string_or_timestamp
)
else:
raise ValueError(
"Value must be of type (str, unicode, int, long) not %s" % type(time_string_or_timestamp)
"Value must be of type (str, unicode, int, long) not %s"
% type(time_string_or_timestamp)
)
except ValueError:
raise
@@ -182,7 +202,9 @@ class OpensshCertificateTimeParameters(object):
else:
try:
if _USE_TIMEZONE:
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
result = datetime.fromtimestamp(
timestamp, tz=_datetime.timezone.utc
)
else:
result = datetime.utcfromtimestamp(timestamp)
except OverflowError:
@@ -192,16 +214,21 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def _time_string_to_datetime(time_string):
result = None
if time_string == 'always':
if time_string == "always":
result = _ALWAYS
elif time_string == 'forever':
elif time_string == "forever":
result = _FOREVER
elif is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=_USE_TIMEZONE)
result = convert_relative_to_datetime(
time_string, with_timezone=_USE_TIMEZONE
)
else:
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(datetime.strptime(time_string, time_format), with_timezone=_USE_TIMEZONE)
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=_USE_TIMEZONE,
)
except ValueError:
pass
if result is None:
@@ -211,7 +238,7 @@ class OpensshCertificateTimeParameters(object):
class OpensshCertificateOption(object):
def __init__(self, option_type, name, data):
if option_type not in ('critical', 'extension'):
if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'")
if not isinstance(name, six.string_types):
@@ -228,11 +255,13 @@ class OpensshCertificateOption(object):
if not isinstance(other, type(self)):
return NotImplemented
return all([
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
])
return all(
[
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
]
)
def __hash__(self):
return hash((self._option_type, self._name, self._data))
@@ -260,42 +289,47 @@ class OpensshCertificateOption(object):
@classmethod
def from_string(cls, option_string):
if not isinstance(option_string, six.string_types):
raise ValueError("option_string must be a string not %s" % type(option_string))
raise ValueError(
"option_string must be a string not %s" % type(option_string)
)
option_type = None
if ':' in option_string:
option_type, value = option_string.strip().split(':', 1)
if '=' in value:
name, data = value.split('=', 1)
if ":" in option_string:
option_type, value = option_string.strip().split(":", 1)
if "=" in value:
name, data = value.split("=", 1)
else:
name, data = value, ''
elif '=' in option_string:
name, data = option_string.strip().split('=', 1)
name, data = value, ""
elif "=" in option_string:
name, data = option_string.strip().split("=", 1)
else:
name, data = option_string.strip(), ''
name, data = option_string.strip(), ""
return cls(
option_type=option_type or get_option_type(name.lower()),
name=name,
data=data
data=data,
)
@six.add_metaclass(abc.ABCMeta)
class OpensshCertificateInfo:
"""Encapsulates all certificate information which is signed by a CA key"""
def __init__(self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None):
def __init__(
self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None,
):
self.nonce = nonce
self.serial = serial
self._cert_type = cert_type
@@ -313,17 +347,17 @@ class OpensshCertificateInfo:
@property
def cert_type(self):
if self._cert_type == _USER_TYPE:
return 'user'
return "user"
elif self._cert_type == _HOST_TYPE:
return 'host'
return "host"
else:
return ''
return ""
@cert_type.setter
def cert_type(self, cert_type):
if cert_type == 'user' or cert_type == _USER_TYPE:
if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE
elif cert_type == 'host' or cert_type == _HOST_TYPE:
elif cert_type == "host" or cert_type == _HOST_TYPE:
self._cert_type = _HOST_TYPE
else:
raise ValueError("%s is not a valid certificate type" % cert_type)
@@ -343,17 +377,17 @@ class OpensshCertificateInfo:
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e=None, n=None, **kwargs):
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['rsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.e is None, self.n is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['rsa'])
writer.string(_SSH_TYPE_STRINGS["rsa"])
writer.mpint(self.e)
writer.mpint(self.n)
@@ -367,7 +401,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['dsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p
self.q = q
self.g = g
@@ -376,10 +410,10 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['dsa'])
writer.string(_SSH_TYPE_STRINGS["dsa"])
writer.mpint(self.p)
writer.mpint(self.q)
writer.mpint(self.g)
@@ -411,16 +445,20 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def curve(self, curve):
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve
self.type_string = _SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]] + _CERT_SUFFIX_V01
self.type_string = (
_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]]
+ _CERT_SUFFIX_V01
)
else:
raise ValueError(
"Curve must be one of %s" % (b','.join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode('UTF-8')
"Curve must be one of %s"
% (b",".join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode("UTF-8")
)
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.curve is None, self.public_key is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
@@ -437,15 +475,15 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk=None, **kwargs):
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['ed25519'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
def public_key_fingerprint(self):
if self.pk is None:
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['ed25519'])
writer.string(_SSH_TYPE_STRINGS["ed25519"])
writer.string(self.pk)
return fingerprint(writer.bytes())
@@ -457,6 +495,7 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate(object):
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info, signature):
self._cert_info = cert_info
@@ -468,13 +507,13 @@ class OpensshCertificate(object):
raise ValueError("%s is not a valid path." % path)
try:
with open(path, 'rb') as cert_file:
with open(path, "rb") as cert_file:
data = cert_file.read()
except (IOError, OSError) as e:
raise ValueError("%s cannot be opened for reading: %s" % (path, e))
try:
format_identifier, b64_cert = data.split(b' ')[:2]
format_identifier, b64_cert = data.split(b" ")[:2]
cert = binascii.a2b_base64(b64_cert)
except (binascii.Error, ValueError):
raise ValueError("Certificate not in OpenSSH format")
@@ -484,7 +523,9 @@ class OpensshCertificate(object):
pub_key_type = key_type
break
else:
raise ValueError("Invalid certificate format identifier: %s" % format_identifier)
raise ValueError(
"Invalid certificate format identifier: %s" % format_identifier
)
parser = OpensshParser(cert)
@@ -499,7 +540,8 @@ class OpensshCertificate(object):
if parser.remaining_bytes():
raise ValueError(
"%s bytes of additional data was not parsed while loading %s" % (parser.remaining_bytes(), path)
"%s bytes of additional data was not parsed while loading %s"
% (parser.remaining_bytes(), path)
)
return cls(
@@ -546,12 +588,16 @@ class OpensshCertificate(object):
@property
def critical_options(self):
return [
OpensshCertificateOption('critical', to_text(n), to_text(d)) for n, d in self._cert_info.critical_options
OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options
]
@property
def extensions(self):
return [OpensshCertificateOption('extension', to_text(n), to_text(d)) for n, d in self._cert_info.extensions]
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions
]
@property
def reserved(self):
@@ -564,7 +610,7 @@ class OpensshCertificate(object):
@property
def signature_type(self):
signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data['signature_type'])
return to_text(signature_data["signature_type"])
@staticmethod
def _parse_cert_info(pub_key_type, parser):
@@ -586,23 +632,24 @@ class OpensshCertificate(object):
def to_dict(self):
time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after,
valid_to=self.valid_before
valid_from=self.valid_after, valid_to=self.valid_before
)
return {
'type_string': self.type_string,
'nonce': self.nonce,
'serial': self.serial,
'cert_type': self.type,
'identifier': self.key_id,
'principals': self.principals,
'valid_after': time_parameters.valid_from(date_format='human_readable'),
'valid_before': time_parameters.valid_to(date_format='human_readable'),
'critical_options': [str(critical_option) for critical_option in self.critical_options],
'extensions': [str(extension) for extension in self.extensions],
'reserved': self.reserved,
'public_key': self.public_key,
'signing_key': self.signing_key,
"type_string": self.type_string,
"nonce": self.nonce,
"serial": self.serial,
"cert_type": self.type,
"identifier": self.key_id,
"principals": self.principals,
"valid_after": time_parameters.valid_from(date_format="human_readable"),
"valid_before": time_parameters.valid_to(date_format="human_readable"),
"critical_options": [
str(critical_option) for critical_option in self.critical_options
],
"extensions": [str(extension) for extension in self.extensions],
"reserved": self.reserved,
"public_key": self.public_key,
"signing_key": self.signing_key,
}
@@ -611,38 +658,46 @@ def apply_directives(directives):
raise ValueError("directives must be one of %s" % ", ".join(_DIRECTIVES))
directive_to_option = {
'no-x11-forwarding': OpensshCertificateOption('extension', 'permit-x11-forwarding', ''),
'no-agent-forwarding': OpensshCertificateOption('extension', 'permit-agent-forwarding', ''),
'no-port-forwarding': OpensshCertificateOption('extension', 'permit-port-forwarding', ''),
'no-pty': OpensshCertificateOption('extension', 'permit-pty', ''),
'no-user-rc': OpensshCertificateOption('extension', 'permit-user-rc', ''),
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if 'clear' in directives:
if "clear" in directives:
return []
else:
return list(set(default_options()) - set(directive_to_option[d] for d in directives))
return list(
set(default_options()) - set(directive_to_option[d] for d in directives)
)
def default_options():
return [OpensshCertificateOption('extension', name, '') for name in _EXTENSIONS]
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key):
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256()
h.update(public_key)
return b'SHA256:' + b64encode(h.digest()).rstrip(b'=')
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type):
if key_type == 'rsa':
if key_type == "rsa":
cert_info = OpensshRSACertificateInfo()
elif key_type == 'dsa':
elif key_type == "dsa":
cert_info = OpensshDSACertificateInfo()
elif key_type in ('ecdsa-nistp256', 'ecdsa-nistp384', 'ecdsa-nistp521'):
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
cert_info = OpensshECDSACertificateInfo()
elif key_type == 'ed25519':
elif key_type == "ed25519":
cert_info = OpensshED25519CertificateInfo()
else:
raise ValueError("%s is not a valid key type" % key_type)
@@ -652,12 +707,14 @@ def get_cert_info_object(key_type):
def get_option_type(name):
if name in _CRITICAL_OPTIONS:
result = 'critical'
result = "critical"
elif name in _EXTENSIONS:
result = 'extension'
result = "extension"
else:
raise ValueError("%s is not a valid option. " % name +
"Custom options must start with 'critical:' or 'extension:' to indicate type")
raise ValueError(
"%s is not a valid option. " % name
+ "Custom options must start with 'critical:' or 'extension:' to indicate type"
)
return result
@@ -675,7 +732,7 @@ def parse_option_list(option_list):
directives.append(option.lower())
else:
option_object = OpensshCertificateOption.from_string(option)
if option_object.type == 'critical':
if option_object.type == "critical":
critical_options.append(option_object)
else:
extensions.append(option_object)

View File

@@ -38,41 +38,41 @@ try:
HAS_OPENSSH_SUPPORT = True
_ALGORITHM_PARAMETERS = {
'rsa': {
'default_size': 2048,
'valid_sizes': range(1024, 16384),
'signer_params': {
'padding': padding.PSS(
"rsa": {
"default_size": 2048,
"valid_sizes": range(1024, 16384),
"signer_params": {
"padding": padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
'algorithm': hashes.SHA256(),
"algorithm": hashes.SHA256(),
},
},
'dsa': {
'default_size': 1024,
'valid_sizes': [1024],
'signer_params': {
'algorithm': hashes.SHA256(),
"dsa": {
"default_size": 1024,
"valid_sizes": [1024],
"signer_params": {
"algorithm": hashes.SHA256(),
},
},
'ed25519': {
'default_size': 256,
'valid_sizes': [256],
'signer_params': {},
"ed25519": {
"default_size": 256,
"valid_sizes": [256],
"signer_params": {},
},
'ecdsa': {
'default_size': 256,
'valid_sizes': [256, 384, 521],
'signer_params': {
'signature_algorithm': ec.ECDSA(hashes.SHA256()),
"ecdsa": {
"default_size": 256,
"valid_sizes": [256, 384, 521],
"signer_params": {
"signature_algorithm": ec.ECDSA(hashes.SHA256()),
},
'curves': {
"curves": {
256: ec.SECP256R1(),
384: ec.SECP384R1(),
521: ec.SECP521R1(),
}
}
},
},
}
except ImportError:
HAS_OPENSSH_PRIVATE_FORMAT = False
@@ -80,7 +80,7 @@ except ImportError:
CRYPTOGRAPHY_VERSION = "0.0"
_ALGORITHM_PARAMETERS = {}
_TEXT_ENCODING = 'UTF-8'
_TEXT_ENCODING = "UTF-8"
class OpenSSHError(Exception):
@@ -131,26 +131,25 @@ class AsymmetricKeypair(object):
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
@classmethod
def generate(cls, keytype='rsa', size=None, passphrase=None):
def generate(cls, keytype="rsa", size=None, passphrase=None):
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
or defaults to an unencrypted RSA-2048 key
or defaults to an unencrypted RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the private key being generated
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the private key being generated
"""
if keytype not in _ALGORITHM_PARAMETERS.keys():
raise InvalidKeyTypeError(
"%s is not a valid keytype. Valid keytypes are %s" % (
keytype, ", ".join(_ALGORITHM_PARAMETERS.keys())
)
"%s is not a valid keytype. Valid keytypes are %s"
% (keytype, ", ".join(_ALGORITHM_PARAMETERS.keys()))
)
if not size:
size = _ALGORITHM_PARAMETERS[keytype]['default_size']
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
else:
if size not in _ALGORITHM_PARAMETERS[keytype]['valid_sizes']:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
raise InvalidKeySizeError(
"%s is not a valid key size for %s keys" % (size, keytype)
)
@@ -160,7 +159,7 @@ class AsymmetricKeypair(object):
else:
encryption_algorithm = serialization.NoEncryption()
if keytype == 'rsa':
if keytype == "rsa":
privatekey = rsa.generate_private_key(
# Public exponent should always be 65537 to prevent issues
# if improper padding is used during signing
@@ -168,16 +167,16 @@ class AsymmetricKeypair(object):
key_size=size,
backend=backend,
)
elif keytype == 'dsa':
elif keytype == "dsa":
privatekey = dsa.generate_private_key(
key_size=size,
backend=backend,
)
elif keytype == 'ed25519':
elif keytype == "ed25519":
privatekey = Ed25519PrivateKey.generate()
elif keytype == 'ecdsa':
elif keytype == "ecdsa":
privatekey = ec.generate_private_key(
_ALGORITHM_PARAMETERS['ecdsa']['curves'][size],
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
backend=backend,
)
@@ -188,18 +187,25 @@ class AsymmetricKeypair(object):
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
@classmethod
def load(cls, path, passphrase=None, private_key_format='PEM', public_key_format='PEM', no_public_key=False):
def load(
cls,
path,
passphrase=None,
private_key_format="PEM",
public_key_format="PEM",
no_public_key=False,
):
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret of type bytes used to decrypt the private key being loaded
:private_key_format: Format of private key to be loaded
:public_key_format: Format of public key to be loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
:path: A path to an existing private key to be loaded
:passphrase: Secret of type bytes used to decrypt the private key being loaded
:private_key_format: Format of private key to be loaded
:public_key_format: Format of public key to be loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if passphrase:
@@ -211,40 +217,42 @@ class AsymmetricKeypair(object):
if no_public_key:
publickey = privatekey.public_key()
else:
publickey = load_publickey(path + '.pub', public_key_format)
publickey = load_publickey(path + ".pub", public_key_format)
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
size = _ALGORITHM_PARAMETERS['ed25519']['default_size']
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
else:
size = privatekey.key_size
if isinstance(privatekey, rsa.RSAPrivateKey):
keytype = 'rsa'
keytype = "rsa"
elif isinstance(privatekey, dsa.DSAPrivateKey):
keytype = 'dsa'
keytype = "dsa"
elif isinstance(privatekey, ec.EllipticCurvePrivateKey):
keytype = 'ecdsa'
keytype = "ecdsa"
elif isinstance(privatekey, Ed25519PrivateKey):
keytype = 'ed25519'
keytype = "ed25519"
else:
raise InvalidKeyTypeError("Key type '%s' is not supported" % type(privatekey))
raise InvalidKeyTypeError(
"Key type '%s' is not supported" % type(privatekey)
)
return cls(
keytype=keytype,
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
"""
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
:privatekey: Private key object of this key pair
:publickey: Public key object of this key pair
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
:privatekey: Private key object of this key pair
:publickey: Public key object of this key pair
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
"""
self.__size = size
@@ -254,7 +262,7 @@ class AsymmetricKeypair(object):
self.__encryption_algorithm = encryption_algorithm
try:
self.verify(self.sign(b'message'), b'message')
self.verify(self.sign(b"message"), b"message")
except InvalidSignatureError:
raise InvalidPublicKeyFileError(
"The private key and public key of this keypair do not match"
@@ -264,8 +272,11 @@ class AsymmetricKeypair(object):
if not isinstance(other, AsymmetricKeypair):
return NotImplemented
return (compare_publickeys(self.public_key, other.public_key) and
compare_encryption_algorithms(self.encryption_algorithm, other.encryption_algorithm))
return compare_publickeys(
self.public_key, other.public_key
) and compare_encryption_algorithms(
self.encryption_algorithm, other.encryption_algorithm
)
def __ne__(self, other):
return not self == other
@@ -303,13 +314,12 @@ class AsymmetricKeypair(object):
def sign(self, data):
"""Returns signature of data signed with the private key of this key pair
:data: byteslike data to sign
:data: byteslike data to sign
"""
try:
signature = self.__privatekey.sign(
data,
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
)
except TypeError as e:
raise InvalidDataError(e)
@@ -318,16 +328,16 @@ class AsymmetricKeypair(object):
def verify(self, signature, data):
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
by the private key of this key pair.
:signature: signature to verify
:data: byteslike data signed by the provided signature
:signature: signature to verify
:data: byteslike data signed by the provided signature
"""
try:
return self.__publickey.verify(
signature,
data,
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
)
except InvalidSignature:
raise InvalidSignatureError
@@ -335,7 +345,7 @@ class AsymmetricKeypair(object):
def update_passphrase(self, passphrase=None):
"""Updates the encryption algorithm of this key pair
:passphrase: Byte secret used to encrypt this key pair
:passphrase: Byte secret used to encrypt this key pair
"""
if passphrase:
@@ -348,20 +358,20 @@ class OpensshKeypair(object):
"""Container for OpenSSH encoded asymmetric key pairs"""
@classmethod
def generate(cls, keytype='rsa', size=None, passphrase=None, comment=None):
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
:comment: Comment for a newly generated OpenSSH public key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
:comment: Comment for a newly generated OpenSSH public key
"""
if comment is None:
comment = "%s@%s" % (getuser(), gethostname())
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
@@ -377,18 +387,20 @@ class OpensshKeypair(object):
def load(cls, path, passphrase=None, no_public_key=False):
"""Returns an Openssh_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret used to decrypt the private key being loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
:path: A path to an existing private key to be loaded
:passphrase: Secret used to decrypt the private key being loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if no_public_key:
comment = ""
else:
comment = extract_comment(path + '.pub')
comment = extract_comment(path + ".pub")
asym_keypair = AsymmetricKeypair.load(path, passphrase, 'SSH', 'SSH', no_public_key)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
@@ -404,29 +416,33 @@ class OpensshKeypair(object):
def encode_openssh_privatekey(asym_keypair, key_format):
"""Returns an OpenSSH encoded private key for a given keypair
:asym_keypair: Asymmetric_Keypair from the private key is extracted
:key_format: Format of the encoded private key.
:asym_keypair: Asymmetric_Keypair from the private key is extracted
:key_format: Format of the encoded private key.
"""
if key_format == 'SSH':
if key_format == "SSH":
# Default to PEM format if SSH not available
if not HAS_OPENSSH_PRIVATE_FORMAT:
privatekey_format = serialization.PrivateFormat.PKCS8
else:
privatekey_format = serialization.PrivateFormat.OpenSSH
elif key_format == 'PKCS8':
elif key_format == "PKCS8":
privatekey_format = serialization.PrivateFormat.PKCS8
elif key_format == 'PKCS1':
if asym_keypair.key_type == 'ed25519':
raise InvalidKeyFormatError("ed25519 keys cannot be represented in PKCS1 format")
elif key_format == "PKCS1":
if asym_keypair.key_type == "ed25519":
raise InvalidKeyFormatError(
"ed25519 keys cannot be represented in PKCS1 format"
)
privatekey_format = serialization.PrivateFormat.TraditionalOpenSSL
else:
raise InvalidKeyFormatError("The accepted private key formats are SSH, PKCS8, and PKCS1")
raise InvalidKeyFormatError(
"The accepted private key formats are SSH, PKCS8, and PKCS1"
)
encoded_privatekey = asym_keypair.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=privatekey_format,
encryption_algorithm=asym_keypair.encryption_algorithm
encryption_algorithm=asym_keypair.encryption_algorithm,
)
return encoded_privatekey
@@ -435,8 +451,8 @@ class OpensshKeypair(object):
def encode_openssh_publickey(asym_keypair, comment):
"""Returns an OpenSSH encoded public key for a given keypair
:asym_keypair: Asymmetric_Keypair from the public key is extracted
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
:asym_keypair: Asymmetric_Keypair from the public key is extracted
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
"""
encoded_publickey = asym_keypair.public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
@@ -445,17 +461,21 @@ class OpensshKeypair(object):
validate_comment(comment)
encoded_publickey += (" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b''
encoded_publickey += (
(" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b""
)
return encoded_publickey
def __init__(self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment):
def __init__(
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
):
"""
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
:openssh_privatekey: An OpenSSH encoded public key
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
:comment: Comment applied to the OpenSSH public key of this keypair
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
:openssh_privatekey: An OpenSSH encoded public key
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
:comment: Comment applied to the OpenSSH public key of this keypair
"""
self.__asym_keypair = asym_keypair
@@ -468,7 +488,10 @@ class OpensshKeypair(object):
if not isinstance(other, OpensshKeypair):
return NotImplemented
return self.asymmetric_keypair == other.asymmetric_keypair and self.comment == other.comment
return (
self.asymmetric_keypair == other.asymmetric_keypair
and self.comment == other.comment
)
@property
def asymmetric_keypair(self):
@@ -516,53 +539,59 @@ class OpensshKeypair(object):
def comment(self, comment):
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
:comment: Text to update the OpenSSH public key comment
:comment: Text to update the OpenSSH public key comment
"""
validate_comment(comment)
self.__comment = comment
encoded_comment = (" %s" % self.__comment).encode(encoding=_TEXT_ENCODING) if self.__comment else b''
self.__openssh_publickey = b' '.join(self.__openssh_publickey.split(b' ', 2)[:2]) + encoded_comment
encoded_comment = (
(" %s" % self.__comment).encode(encoding=_TEXT_ENCODING)
if self.__comment
else b""
)
self.__openssh_publickey = (
b" ".join(self.__openssh_publickey.split(b" ", 2)[:2]) + encoded_comment
)
return self.__openssh_publickey
def update_passphrase(self, passphrase):
"""Updates the passphrase used to encrypt the private key of this keypair
:passphrase: Text secret used for encryption
:passphrase: Text secret used for encryption
"""
self.__asym_keypair.update_passphrase(passphrase)
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(self.__asym_keypair, 'SSH')
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
self.__asym_keypair, "SSH"
)
def load_privatekey(path, passphrase, key_format):
privatekey_loaders = {
'PEM': serialization.load_pem_private_key,
'DER': serialization.load_der_private_key,
"PEM": serialization.load_pem_private_key,
"DER": serialization.load_der_private_key,
}
# OpenSSH formatted private keys are not available in Cryptography <3.0
if hasattr(serialization, 'load_ssh_private_key'):
privatekey_loaders['SSH'] = serialization.load_ssh_private_key
if hasattr(serialization, "load_ssh_private_key"):
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
else:
privatekey_loaders['SSH'] = serialization.load_pem_private_key
privatekey_loaders["SSH"] = serialization.load_pem_private_key
try:
privatekey_loader = privatekey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
"%s is not a valid key format (%s)" % (
key_format,
','.join(privatekey_loaders.keys())
)
"%s is not a valid key format (%s)"
% (key_format, ",".join(privatekey_loaders.keys()))
)
if not os.path.exists(path):
raise InvalidPrivateKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
content = f.read()
privatekey = privatekey_loader(
@@ -573,9 +602,9 @@ def load_privatekey(path, passphrase, key_format):
except ValueError as e:
# Revert to PEM if key could not be loaded in SSH format
if key_format == 'SSH':
if key_format == "SSH":
try:
privatekey = privatekey_loaders['PEM'](
privatekey = privatekey_loaders["PEM"](
data=content,
password=passphrase,
backend=backend,
@@ -598,26 +627,24 @@ def load_privatekey(path, passphrase, key_format):
def load_publickey(path, key_format):
publickey_loaders = {
'PEM': serialization.load_pem_public_key,
'DER': serialization.load_der_public_key,
'SSH': serialization.load_ssh_public_key,
"PEM": serialization.load_pem_public_key,
"DER": serialization.load_der_public_key,
"SSH": serialization.load_ssh_public_key,
}
try:
publickey_loader = publickey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
"%s is not a valid key format (%s)" % (
key_format,
','.join(publickey_loaders.keys())
)
"%s is not a valid key format (%s)"
% (key_format, ",".join(publickey_loaders.keys()))
)
if not os.path.exists(path):
raise InvalidPublicKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
content = f.read()
publickey = publickey_loader(
@@ -646,10 +673,13 @@ def compare_publickeys(pk1, pk2):
def compare_encryption_algorithms(ea1, ea2):
if isinstance(ea1, serialization.NoEncryption) and isinstance(ea2, serialization.NoEncryption):
if isinstance(ea1, serialization.NoEncryption) and isinstance(
ea2, serialization.NoEncryption
):
return True
elif (isinstance(ea1, serialization.BestAvailableEncryption) and
isinstance(ea2, serialization.BestAvailableEncryption)):
elif isinstance(ea1, serialization.BestAvailableEncryption) and isinstance(
ea2, serialization.BestAvailableEncryption
):
return ea1.password == ea2.password
else:
return False
@@ -663,7 +693,7 @@ def get_encryption_algorithm(passphrase):
def validate_comment(comment):
if not hasattr(comment, 'encode'):
if not hasattr(comment, "encode"):
raise InvalidCommentError("%s cannot be encoded to text" % comment)
@@ -673,8 +703,8 @@ def extract_comment(path):
raise InvalidPublicKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
fields = f.read().split(b' ', 2)
with open(path, "rb") as f:
fields = f.read().split(b" ", 2)
if len(fields) == 3:
comment = fields[2].decode(_TEXT_ENCODING)
else:
@@ -687,7 +717,9 @@ def extract_comment(path):
def calculate_fingerprint(openssh_publickey):
digest = hashes.Hash(hashes.SHA256(), backend=backend)
decoded_pubkey = b64decode(openssh_publickey.split(b' ')[1])
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
digest.update(decoded_pubkey)
return 'SHA256:%s' % b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip('=')
return "SHA256:%s" % b64encode(digest.finalize()).decode(
encoding=_TEXT_ENCODING
).rstrip("=")

View File

@@ -34,17 +34,17 @@ if PY3:
long = int
# 0 (False) or 1 (True) encoded as a single byte
_BOOLEAN = Struct(b'?')
_BOOLEAN = Struct(b"?")
# Unsigned 8-bit integer in network-byte-order
_UBYTE = Struct(b'!B')
_UBYTE = Struct(b"!B")
_UBYTE_MAX = 0xFF
# Unsigned 32-bit integer in network-byte-order
_UINT32 = Struct(b'!I')
_UINT32 = Struct(b"!I")
# Unsigned 32-bit little endian integer
_UINT32_LE = Struct(b'<I')
_UINT32_LE = Struct(b"<I")
_UINT32_MAX = 0xFFFFFFFF
# Unsigned 64-bit integer in network-byte-order
_UINT64 = Struct(b'!Q')
_UINT64 = Struct(b"!Q")
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
@@ -89,6 +89,7 @@ def secure_write(path, mode, content):
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for SSH data types
class OpensshParser(object):
"""Parser for OpenSSH encoded objects"""
BOOLEAN_OFFSET = 1
UINT32_OFFSET = 4
UINT64_OFFSET = 8
@@ -103,21 +104,21 @@ class OpensshParser(object):
def boolean(self):
next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos:next_pos])[0]
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint32(self):
next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos:next_pos])[0]
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint64(self):
next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos:next_pos])[0]
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
@@ -126,7 +127,7 @@ class OpensshParser(object):
next_pos = self._check_position(length)
value = self._data[self._pos:next_pos]
value = self._data[self._pos : next_pos]
self._pos = next_pos
# Cast to bytes is required as a memoryview slice is itself a memoryview
return value if not PY3 else bytes(value)
@@ -136,7 +137,7 @@ class OpensshParser(object):
def name_list(self):
raw_string = self.string()
return raw_string.decode('ASCII').split(',')
return raw_string.decode("ASCII").split(",")
# Convenience function, but not an official data type from SSH
def string_list(self):
@@ -193,33 +194,39 @@ class OpensshParser(object):
signature_blob = parser.string()
blob_parser = cls(signature_blob)
if signature_type in (b'ssh-rsa', b'rsa-sha2-256', b'rsa-sha2-512'):
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
signature_data['s'] = cls._big_int(signature_blob, "big")
elif signature_type == b'ssh-dss':
signature_data["s"] = cls._big_int(signature_blob, "big")
elif signature_type == b"ssh-dss":
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
signature_data['r'] = cls._big_int(signature_blob[:20], "big")
signature_data['s'] = cls._big_int(signature_blob[20:], "big")
elif signature_type in (b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521'):
signature_data["r"] = cls._big_int(signature_blob[:20], "big")
signature_data["s"] = cls._big_int(signature_blob[20:], "big")
elif signature_type in (
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
):
# https://datatracker.ietf.org/doc/html/rfc5656#section-3.1.2
signature_data['r'] = blob_parser.mpint()
signature_data['s'] = blob_parser.mpint()
elif signature_type == b'ssh-ed25519':
signature_data["r"] = blob_parser.mpint()
signature_data["s"] = blob_parser.mpint()
elif signature_type == b"ssh-ed25519":
# https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2
signature_data['R'] = cls._big_int(signature_blob[:32], "little")
signature_data['S'] = cls._big_int(signature_blob[32:], "little")
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
else:
raise ValueError("%s is not a valid signature type" % signature_type)
signature_data['signature_type'] = signature_type
signature_data["signature_type"] = signature_type
return signature_data
@classmethod
def _big_int(cls, raw_string, byte_order, signed=False):
if byte_order not in ("big", "little"):
raise ValueError("Byte_order must be one of (big, little) not %s" % byte_order)
raise ValueError(
"Byte_order must be one of (big, little) not %s" % byte_order
)
if PY3:
return int.from_bytes(raw_string, byte_order, signed=signed)
@@ -232,21 +239,31 @@ class OpensshParser(object):
msb = raw_string[0] if byte_order == "big" else raw_string[-1]
negative = bool(ord(msb) & 0x80)
# Match pad value for two's complement
pad = b'\xFF' if signed and negative else b'\x00'
pad = b"\xff" if signed and negative else b"\x00"
# The definition of ``mpint`` enforces that unnecessary bytes are not encoded so they are added back
pad_length = (4 - byte_length % 4)
pad_length = 4 - byte_length % 4
if pad_length < 4:
raw_string = pad * pad_length + raw_string if byte_order == "big" else raw_string + pad * pad_length
raw_string = (
pad * pad_length + raw_string
if byte_order == "big"
else raw_string + pad * pad_length
)
byte_length += pad_length
# Accumulate arbitrary precision integer bytes in the appropriate order
if byte_order == "big":
for i in range(0, byte_length, cls.UINT32_OFFSET):
left_shift = result << cls.UINT32_OFFSET * 8
result = left_shift + _UINT32.unpack(raw_string[i:i + cls.UINT32_OFFSET])[0]
result = (
left_shift
+ _UINT32.unpack(raw_string[i : i + cls.UINT32_OFFSET])[0]
)
else:
for i in range(byte_length, 0, -cls.UINT32_OFFSET):
left_shift = result << cls.UINT32_OFFSET * 8
result = left_shift + _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET:i])[0]
result = (
left_shift
+ _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET : i])[0]
)
# Adjust for two's complement
if signed and negative:
result -= 1 << (8 * byte_length)
@@ -262,10 +279,13 @@ class _OpensshWriter(object):
It is not to be used to construct Openssh objects, but rather as a utility to assist
in validating parsed material.
"""
def __init__(self, buffer=None):
if buffer is not None:
if not isinstance(buffer, (bytes, bytearray)):
raise TypeError("Buffer must be a bytes-like object not %s" % type(buffer))
raise TypeError(
"Buffer must be a bytes-like object not %s" % type(buffer)
)
else:
buffer = bytearray()
@@ -283,7 +303,9 @@ class _OpensshWriter(object):
if not isinstance(value, int):
raise TypeError("Value must be of type int not %s" % type(value))
if value < 0 or value > _UINT32_MAX:
raise ValueError("Value must be a positive integer less than %s" % _UINT32_MAX)
raise ValueError(
"Value must be a positive integer less than %s" % _UINT32_MAX
)
self._buff.extend(_UINT32.pack(value))
@@ -293,7 +315,9 @@ class _OpensshWriter(object):
if not isinstance(value, (long, int)):
raise TypeError("Value must be of type (long, int) not %s" % type(value))
if value < 0 or value > _UINT64_MAX:
raise ValueError("Value must be a positive integer less than %s" % _UINT64_MAX)
raise ValueError(
"Value must be a positive integer less than %s" % _UINT64_MAX
)
self._buff.extend(_UINT64.pack(value))
@@ -320,7 +344,7 @@ class _OpensshWriter(object):
raise TypeError("Value must be a list of byte strings not %s" % type(value))
try:
self.string(','.join(value).encode('ASCII'))
self.string(",".join(value).encode("ASCII"))
except UnicodeEncodeError as e:
raise ValueError("Name-list's must consist of US-ASCII characters: %s" % e)
@@ -365,9 +389,9 @@ class _OpensshWriter(object):
result = bytes()
# 0 and -1 are treated as special cases since they are used as sentinels for all other values
if num == 0:
result += b'\x00'
result += b"\x00"
elif num == -1:
result += b'\xFF'
result += b"\xff"
elif num > 0:
while num >> 32:
result = _UINT32.pack(num & _UINT32_MAX) + result
@@ -378,7 +402,7 @@ class _OpensshWriter(object):
num = num >> 8
# Zero pad final byte if most-significant bit is 1 as per mpint definition
if ord(result[0]) & 0x80:
result = b'\x00' + result
result = b"\x00" + result
else:
while (num >> 32) < -1:
result = _UINT32.pack(num & _UINT32_MAX) + result
@@ -387,7 +411,7 @@ class _OpensshWriter(object):
result = _UBYTE.pack(num & _UBYTE_MAX) + result
num = num >> 8
if not ord(result[0]) & 0x80:
result = b'\xFF' + result
result = b"\xff" + result
return result