mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-07 13:53:06 +00:00
Reformat everything with black.
I had to undo the u string prefix removals to not drop Python 2 compatibility. That's why black isn't enabled in antsibull-nox.toml yet.
This commit is contained in:
@@ -44,17 +44,24 @@ def safe_atomic_move(module, path, destination):
|
||||
|
||||
def _restore_all_on_failure(f):
|
||||
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
|
||||
backups = [(d, self.module.backup_local(d)) for s, d in sources_and_destinations if os.path.exists(d)]
|
||||
backups = [
|
||||
(d, self.module.backup_local(d))
|
||||
for s, d in sources_and_destinations
|
||||
if os.path.exists(d)
|
||||
]
|
||||
|
||||
try:
|
||||
f(self, sources_and_destinations, *args, **kwargs)
|
||||
except Exception:
|
||||
for destination, backup in backups:
|
||||
self.module.atomic_move(os.path.abspath(backup), os.path.abspath(destination))
|
||||
self.module.atomic_move(
|
||||
os.path.abspath(backup), os.path.abspath(destination)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
for destination, backup in backups:
|
||||
self.module.add_cleanup_file(backup)
|
||||
|
||||
return backup_and_restore
|
||||
|
||||
|
||||
@@ -85,10 +92,10 @@ class OpensshModule(object):
|
||||
def result(self):
|
||||
result = self._result
|
||||
|
||||
result['changed'] = self.changed
|
||||
result["changed"] = self.changed
|
||||
|
||||
if self.module._diff:
|
||||
result['diff'] = self.diff
|
||||
result["diff"] = self.diff
|
||||
|
||||
return result
|
||||
|
||||
@@ -107,6 +114,7 @@ class OpensshModule(object):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.check_mode:
|
||||
f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
@@ -114,72 +122,92 @@ class OpensshModule(object):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
f(self, *args, **kwargs)
|
||||
self.changed = True
|
||||
|
||||
return wrapper
|
||||
|
||||
def _check_if_base_dir(self, path):
|
||||
base_dir = os.path.dirname(path) or '.'
|
||||
base_dir = os.path.dirname(path) or "."
|
||||
if not os.path.isdir(base_dir):
|
||||
self.module.fail_json(
|
||||
name=base_dir,
|
||||
msg='The directory %s does not exist or the file is not a directory' % base_dir
|
||||
msg="The directory %s does not exist or the file is not a directory"
|
||||
% base_dir,
|
||||
)
|
||||
|
||||
def _get_ssh_version(self):
|
||||
ssh_bin = self.module.get_bin_path('ssh')
|
||||
ssh_bin = self.module.get_bin_path("ssh")
|
||||
if not ssh_bin:
|
||||
return ""
|
||||
return parse_openssh_version(self.module.run_command([ssh_bin, '-V', '-q'], check_rc=True)[2].strip())
|
||||
return parse_openssh_version(
|
||||
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
|
||||
)
|
||||
|
||||
@_restore_all_on_failure
|
||||
def _safe_secure_move(self, sources_and_destinations):
|
||||
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
|
||||
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
|
||||
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
|
||||
permissions if 'destination' does not already exists).
|
||||
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
|
||||
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
|
||||
permissions if 'destination' does not already exists).
|
||||
"""
|
||||
for source, destination in sources_and_destinations:
|
||||
if os.path.exists(destination):
|
||||
self.module.atomic_move(os.path.abspath(source), os.path.abspath(destination))
|
||||
self.module.atomic_move(
|
||||
os.path.abspath(source), os.path.abspath(destination)
|
||||
)
|
||||
else:
|
||||
self.module.preserved_copy(source, destination)
|
||||
|
||||
def _update_permissions(self, path):
|
||||
file_args = self.module.load_file_common_arguments(self.module.params)
|
||||
file_args['path'] = path
|
||||
file_args["path"] = path
|
||||
|
||||
if not self.module.check_file_absent_if_check_mode(path):
|
||||
self.changed = self.module.set_fs_attributes_if_different(file_args, self.changed)
|
||||
self.changed = self.module.set_fs_attributes_if_different(
|
||||
file_args, self.changed
|
||||
)
|
||||
else:
|
||||
self.changed = True
|
||||
|
||||
|
||||
class KeygenCommand(object):
|
||||
def __init__(self, module):
|
||||
self._bin_path = module.get_bin_path('ssh-keygen', True)
|
||||
self._bin_path = module.get_bin_path("ssh-keygen", True)
|
||||
self._run_command = module.run_command
|
||||
|
||||
def generate_certificate(self, certificate_path, identifier, options, pkcs11_provider, principals,
|
||||
serial_number, signature_algorithm, signing_key_path, type,
|
||||
time_parameters, use_agent, **kwargs):
|
||||
args = [self._bin_path, '-s', signing_key_path, '-P', '', '-I', identifier]
|
||||
def generate_certificate(
|
||||
self,
|
||||
certificate_path,
|
||||
identifier,
|
||||
options,
|
||||
pkcs11_provider,
|
||||
principals,
|
||||
serial_number,
|
||||
signature_algorithm,
|
||||
signing_key_path,
|
||||
type,
|
||||
time_parameters,
|
||||
use_agent,
|
||||
**kwargs
|
||||
):
|
||||
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
|
||||
|
||||
if options:
|
||||
for option in options:
|
||||
args.extend(['-O', option])
|
||||
args.extend(["-O", option])
|
||||
if pkcs11_provider:
|
||||
args.extend(['-D', pkcs11_provider])
|
||||
args.extend(["-D", pkcs11_provider])
|
||||
if principals:
|
||||
args.extend(['-n', ','.join(principals)])
|
||||
args.extend(["-n", ",".join(principals)])
|
||||
if serial_number is not None:
|
||||
args.extend(['-z', str(serial_number)])
|
||||
if type == 'host':
|
||||
args.extend(['-h'])
|
||||
args.extend(["-z", str(serial_number)])
|
||||
if type == "host":
|
||||
args.extend(["-h"])
|
||||
if use_agent:
|
||||
args.extend(['-U'])
|
||||
args.extend(["-U"])
|
||||
if time_parameters.validity_string:
|
||||
args.extend(['-V', time_parameters.validity_string])
|
||||
args.extend(["-V", time_parameters.validity_string])
|
||||
if signature_algorithm:
|
||||
args.extend(['-t', signature_algorithm])
|
||||
args.extend(["-t", signature_algorithm])
|
||||
args.append(certificate_path)
|
||||
|
||||
return self._run_command(args, **kwargs)
|
||||
@@ -187,44 +215,62 @@ class KeygenCommand(object):
|
||||
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
|
||||
args = [
|
||||
self._bin_path,
|
||||
'-q',
|
||||
'-N', '',
|
||||
'-b', str(size),
|
||||
'-t', type,
|
||||
'-f', private_key_path,
|
||||
'-C', comment or ''
|
||||
"-q",
|
||||
"-N",
|
||||
"",
|
||||
"-b",
|
||||
str(size),
|
||||
"-t",
|
||||
type,
|
||||
"-f",
|
||||
private_key_path,
|
||||
"-C",
|
||||
comment or "",
|
||||
]
|
||||
|
||||
# "y" must be entered in response to the "overwrite" prompt
|
||||
data = 'y' if os.path.exists(private_key_path) else None
|
||||
data = "y" if os.path.exists(private_key_path) else None
|
||||
|
||||
return self._run_command(args, data=data, **kwargs)
|
||||
|
||||
def get_certificate_info(self, certificate_path, **kwargs):
|
||||
return self._run_command([self._bin_path, '-L', '-f', certificate_path], **kwargs)
|
||||
return self._run_command(
|
||||
[self._bin_path, "-L", "-f", certificate_path], **kwargs
|
||||
)
|
||||
|
||||
def get_matching_public_key(self, private_key_path, **kwargs):
|
||||
return self._run_command([self._bin_path, '-P', '', '-y', '-f', private_key_path], **kwargs)
|
||||
return self._run_command(
|
||||
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def get_private_key(self, private_key_path, **kwargs):
|
||||
return self._run_command([self._bin_path, '-l', '-f', private_key_path], **kwargs)
|
||||
return self._run_command(
|
||||
[self._bin_path, "-l", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def update_comment(self, private_key_path, comment, force_new_format=True, **kwargs):
|
||||
if os.path.exists(private_key_path) and not os.access(private_key_path, os.W_OK):
|
||||
def update_comment(
|
||||
self, private_key_path, comment, force_new_format=True, **kwargs
|
||||
):
|
||||
if os.path.exists(private_key_path) and not os.access(
|
||||
private_key_path, os.W_OK
|
||||
):
|
||||
try:
|
||||
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
|
||||
except (IOError, OSError) as e:
|
||||
raise e("The private key at %s is not writeable preventing a comment update" % private_key_path)
|
||||
raise e(
|
||||
"The private key at %s is not writeable preventing a comment update"
|
||||
% private_key_path
|
||||
)
|
||||
|
||||
command = [self._bin_path, '-q']
|
||||
command = [self._bin_path, "-q"]
|
||||
if force_new_format:
|
||||
command.append('-o')
|
||||
command.extend(['-c', '-C', comment, '-f', private_key_path])
|
||||
command.append("-o")
|
||||
command.extend(["-c", "-C", comment, "-f", private_key_path])
|
||||
return self._run_command(command, **kwargs)
|
||||
|
||||
|
||||
class PrivateKey(object):
|
||||
def __init__(self, size, key_type, fingerprint, format=''):
|
||||
def __init__(self, size, key_type, fingerprint, format=""):
|
||||
self._size = size
|
||||
self._type = key_type
|
||||
self._fingerprint = fingerprint
|
||||
@@ -258,10 +304,10 @@ class PrivateKey(object):
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'size': self._size,
|
||||
'type': self._type,
|
||||
'fingerprint': self._fingerprint,
|
||||
'format': self._format,
|
||||
"size": self._size,
|
||||
"type": self._type,
|
||||
"fingerprint": self._fingerprint,
|
||||
"format": self._format,
|
||||
}
|
||||
|
||||
|
||||
@@ -275,11 +321,17 @@ class PublicKey(object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return all([
|
||||
self._type_string == other._type_string,
|
||||
self._data == other._data,
|
||||
(self._comment == other._comment) if self._comment is not None and other._comment is not None else True
|
||||
])
|
||||
return all(
|
||||
[
|
||||
self._type_string == other._type_string,
|
||||
self._data == other._data,
|
||||
(
|
||||
(self._comment == other._comment)
|
||||
if self._comment is not None and other._comment is not None
|
||||
else True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
@@ -305,19 +357,19 @@ class PublicKey(object):
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
properties = string.strip('\n').split(' ', 2)
|
||||
properties = string.strip("\n").split(" ", 2)
|
||||
|
||||
return cls(
|
||||
type_string=properties[0],
|
||||
data=properties[1],
|
||||
comment=properties[2] if len(properties) > 2 else ""
|
||||
comment=properties[2] if len(properties) > 2 else "",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
properties = f.read().strip(' \n').split(' ', 2)
|
||||
with open(path, "r") as f:
|
||||
properties = f.read().strip(" \n").split(" ", 2)
|
||||
except (IOError, OSError):
|
||||
raise
|
||||
|
||||
@@ -327,25 +379,25 @@ class PublicKey(object):
|
||||
return cls(
|
||||
type_string=properties[0],
|
||||
data=properties[1],
|
||||
comment='' if len(properties) <= 2 else properties[2],
|
||||
comment="" if len(properties) <= 2 else properties[2],
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'comment': self._comment,
|
||||
'public_key': self._data,
|
||||
"comment": self._comment,
|
||||
"public_key": self._data,
|
||||
}
|
||||
|
||||
|
||||
def parse_private_key_format(path):
|
||||
with open(path, 'r') as file:
|
||||
with open(path, "r") as file:
|
||||
header = file.readline().strip()
|
||||
|
||||
if header == '-----BEGIN OPENSSH PRIVATE KEY-----':
|
||||
return 'SSH'
|
||||
elif header == '-----BEGIN PRIVATE KEY-----':
|
||||
return 'PKCS8'
|
||||
elif header == '-----BEGIN RSA PRIVATE KEY-----':
|
||||
return 'PKCS1'
|
||||
if header == "-----BEGIN OPENSSH PRIVATE KEY-----":
|
||||
return "SSH"
|
||||
elif header == "-----BEGIN PRIVATE KEY-----":
|
||||
return "PKCS8"
|
||||
elif header == "-----BEGIN RSA PRIVATE KEY-----":
|
||||
return "PKCS1"
|
||||
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@@ -48,14 +48,18 @@ class KeypairBackend(OpensshModule):
|
||||
def __init__(self, module):
|
||||
super(KeypairBackend, self).__init__(module)
|
||||
|
||||
self.comment = self.module.params['comment']
|
||||
self.private_key_path = self.module.params['path']
|
||||
self.public_key_path = self.private_key_path + '.pub'
|
||||
self.regenerate = self.module.params['regenerate'] if not self.module.params['force'] else 'always'
|
||||
self.state = self.module.params['state']
|
||||
self.type = self.module.params['type']
|
||||
self.comment = self.module.params["comment"]
|
||||
self.private_key_path = self.module.params["path"]
|
||||
self.public_key_path = self.private_key_path + ".pub"
|
||||
self.regenerate = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
|
||||
self.size = self._get_size(self.module.params['size'])
|
||||
self.size = self._get_size(self.module.params["size"])
|
||||
self._validate_path()
|
||||
|
||||
self.original_private_key = None
|
||||
@@ -64,31 +68,35 @@ class KeypairBackend(OpensshModule):
|
||||
self.public_key = None
|
||||
|
||||
def _get_size(self, size):
|
||||
if self.type in ('rsa', 'rsa1'):
|
||||
if self.type in ("rsa", "rsa1"):
|
||||
result = 4096 if size is None else size
|
||||
if result < 1024:
|
||||
return self.module.fail_json(
|
||||
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. " +
|
||||
"Attempting to use bit lengths under 1024 will cause the module to fail."
|
||||
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. "
|
||||
+ "Attempting to use bit lengths under 1024 will cause the module to fail."
|
||||
)
|
||||
elif self.type == 'dsa':
|
||||
elif self.type == "dsa":
|
||||
result = 1024 if size is None else size
|
||||
if result != 1024:
|
||||
return self.module.fail_json(msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2.")
|
||||
elif self.type == 'ecdsa':
|
||||
return self.module.fail_json(
|
||||
msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2."
|
||||
)
|
||||
elif self.type == "ecdsa":
|
||||
result = 256 if size is None else size
|
||||
if result not in (256, 384, 521):
|
||||
return self.module.fail_json(
|
||||
msg="For ECDSA keys, size determines the key length by selecting from one of " +
|
||||
"three elliptic curve sizes: 256, 384 or 521 bits. " +
|
||||
"Attempting to use bit lengths other than these three values for ECDSA keys will " +
|
||||
"cause this module to fail."
|
||||
msg="For ECDSA keys, size determines the key length by selecting from one of "
|
||||
+ "three elliptic curve sizes: 256, 384 or 521 bits. "
|
||||
+ "Attempting to use bit lengths other than these three values for ECDSA keys will "
|
||||
+ "cause this module to fail."
|
||||
)
|
||||
elif self.type == 'ed25519':
|
||||
elif self.type == "ed25519":
|
||||
# User input is ignored for `key size` when `key type` is ed25519
|
||||
result = 256
|
||||
else:
|
||||
return self.module.fail_json(msg="%s is not a valid value for key type" % self.type)
|
||||
return self.module.fail_json(
|
||||
msg="%s is not a valid value for key type" % self.type
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -96,13 +104,16 @@ class KeypairBackend(OpensshModule):
|
||||
self._check_if_base_dir(self.private_key_path)
|
||||
|
||||
if os.path.isdir(self.private_key_path):
|
||||
self.module.fail_json(msg='%s is a directory. Please specify a path to a file.' % self.private_key_path)
|
||||
self.module.fail_json(
|
||||
msg="%s is a directory. Please specify a path to a file."
|
||||
% self.private_key_path
|
||||
)
|
||||
|
||||
def _execute(self):
|
||||
self.original_private_key = self._load_private_key()
|
||||
self.original_public_key = self._load_public_key()
|
||||
|
||||
if self.state == 'present':
|
||||
if self.state == "present":
|
||||
self._validate_key_load()
|
||||
|
||||
if self._should_generate():
|
||||
@@ -149,13 +160,15 @@ class KeypairBackend(OpensshModule):
|
||||
return os.path.exists(self.public_key_path)
|
||||
|
||||
def _validate_key_load(self):
|
||||
if (self._private_key_exists()
|
||||
and self.regenerate in ('never', 'fail', 'partial_idempotence')
|
||||
and (self.original_private_key is None or not self._private_key_readable())):
|
||||
if (
|
||||
self._private_key_exists()
|
||||
and self.regenerate in ("never", "fail", "partial_idempotence")
|
||||
and (self.original_private_key is None or not self._private_key_readable())
|
||||
):
|
||||
self.module.fail_json(
|
||||
msg="Unable to read the key. The key is protected with a passphrase or broken. " +
|
||||
"Will not proceed. To force regeneration, call the module with `generate` " +
|
||||
"set to `full_idempotence` or `always`, or with `force=true`."
|
||||
msg="Unable to read the key. The key is protected with a passphrase or broken. "
|
||||
+ "Will not proceed. To force regeneration, call the module with `generate` "
|
||||
+ "set to `full_idempotence` or `always`, or with `force=true`."
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -165,17 +178,17 @@ class KeypairBackend(OpensshModule):
|
||||
def _should_generate(self):
|
||||
if self.original_private_key is None:
|
||||
return True
|
||||
elif self.regenerate == 'never':
|
||||
elif self.regenerate == "never":
|
||||
return False
|
||||
elif self.regenerate == 'fail':
|
||||
elif self.regenerate == "fail":
|
||||
if not self._private_key_valid():
|
||||
self.module.fail_json(
|
||||
msg="Key has wrong type and/or size. Will not proceed. " +
|
||||
"To force regeneration, call the module with `generate` set to " +
|
||||
"`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
|
||||
msg="Key has wrong type and/or size. Will not proceed. "
|
||||
+ "To force regeneration, call the module with `generate` set to "
|
||||
+ "`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
|
||||
)
|
||||
return False
|
||||
elif self.regenerate in ('partial_idempotence', 'full_idempotence'):
|
||||
elif self.regenerate in ("partial_idempotence", "full_idempotence"):
|
||||
return not self._private_key_valid()
|
||||
else:
|
||||
return True
|
||||
@@ -184,11 +197,13 @@ class KeypairBackend(OpensshModule):
|
||||
if self.original_private_key is None:
|
||||
return False
|
||||
|
||||
return all([
|
||||
self.size == self.original_private_key.size,
|
||||
self.type == self.original_private_key.type,
|
||||
self._private_key_valid_backend(),
|
||||
])
|
||||
return all(
|
||||
[
|
||||
self.size == self.original_private_key.size,
|
||||
self.type == self.original_private_key.type,
|
||||
self._private_key_valid_backend(),
|
||||
]
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _private_key_valid_backend(self):
|
||||
@@ -200,13 +215,20 @@ class KeypairBackend(OpensshModule):
|
||||
temp_private_key, temp_public_key = self._generate_temp_keypair()
|
||||
|
||||
try:
|
||||
self._safe_secure_move([(temp_private_key, self.private_key_path), (temp_public_key, self.public_key_path)])
|
||||
self._safe_secure_move(
|
||||
[
|
||||
(temp_private_key, self.private_key_path),
|
||||
(temp_public_key, self.public_key_path),
|
||||
]
|
||||
)
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=to_native(e))
|
||||
|
||||
def _generate_temp_keypair(self):
|
||||
temp_private_key = os.path.join(self.module.tmpdir, os.path.basename(self.private_key_path))
|
||||
temp_public_key = temp_private_key + '.pub'
|
||||
temp_private_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.private_key_path)
|
||||
)
|
||||
temp_public_key = temp_private_key + ".pub"
|
||||
|
||||
try:
|
||||
self._generate_keypair(temp_private_key)
|
||||
@@ -239,27 +261,33 @@ class KeypairBackend(OpensshModule):
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _restore_public_key(self):
|
||||
try:
|
||||
temp_public_key = self._create_temp_public_key(str(self._get_public_key()) + '\n')
|
||||
self._safe_secure_move([
|
||||
(temp_public_key, self.public_key_path)
|
||||
])
|
||||
temp_public_key = self._create_temp_public_key(
|
||||
str(self._get_public_key()) + "\n"
|
||||
)
|
||||
self._safe_secure_move([(temp_public_key, self.public_key_path)])
|
||||
except (IOError, OSError):
|
||||
self.module.fail_json(
|
||||
msg="The public key is missing or does not match the private key. " +
|
||||
"Unable to regenerate the public key."
|
||||
msg="The public key is missing or does not match the private key. "
|
||||
+ "Unable to regenerate the public key."
|
||||
)
|
||||
|
||||
if self.comment:
|
||||
self._update_comment()
|
||||
|
||||
def _create_temp_public_key(self, content):
|
||||
temp_public_key = os.path.join(self.module.tmpdir, os.path.basename(self.public_key_path))
|
||||
temp_public_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.public_key_path)
|
||||
)
|
||||
|
||||
default_permissions = 0o644
|
||||
existing_permissions = file_mode(self.public_key_path)
|
||||
|
||||
try:
|
||||
secure_write(temp_public_key, existing_permissions or default_permissions, to_bytes(content))
|
||||
secure_write(
|
||||
temp_public_key,
|
||||
existing_permissions or default_permissions,
|
||||
to_bytes(content),
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=to_native(e))
|
||||
self.module.add_cleanup_file(temp_public_key)
|
||||
@@ -290,25 +318,29 @@ class KeypairBackend(OpensshModule):
|
||||
public_key = self.public_key or self.original_public_key
|
||||
|
||||
return {
|
||||
'size': self.size,
|
||||
'type': self.type,
|
||||
'filename': self.private_key_path,
|
||||
'fingerprint': private_key.fingerprint if private_key else '',
|
||||
'public_key': str(public_key) if public_key else '',
|
||||
'comment': public_key.comment if public_key else '',
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"filename": self.private_key_path,
|
||||
"fingerprint": private_key.fingerprint if private_key else "",
|
||||
"public_key": str(public_key) if public_key else "",
|
||||
"comment": public_key.comment if public_key else "",
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
before = self.original_private_key.to_dict() if self.original_private_key else {}
|
||||
before.update(self.original_public_key.to_dict() if self.original_public_key else {})
|
||||
before = (
|
||||
self.original_private_key.to_dict() if self.original_private_key else {}
|
||||
)
|
||||
before.update(
|
||||
self.original_public_key.to_dict() if self.original_public_key else {}
|
||||
)
|
||||
|
||||
after = self.private_key.to_dict() if self.private_key else {}
|
||||
after.update(self.public_key.to_dict() if self.public_key else {})
|
||||
|
||||
return {
|
||||
'before': before,
|
||||
'after': after,
|
||||
"before": before,
|
||||
"after": after,
|
||||
}
|
||||
|
||||
|
||||
@@ -316,36 +348,59 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
super(KeypairBackendOpensshBin, self).__init__(module)
|
||||
|
||||
if self.module.params['private_key_format'] != 'auto':
|
||||
if self.module.params["private_key_format"] != "auto":
|
||||
self.module.fail_json(
|
||||
msg="'auto' is the only valid option for " +
|
||||
"'private_key_format' when 'backend' is not 'cryptography'"
|
||||
msg="'auto' is the only valid option for "
|
||||
+ "'private_key_format' when 'backend' is not 'cryptography'"
|
||||
)
|
||||
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
def _generate_keypair(self, private_key_path):
|
||||
self.ssh_keygen.generate_keypair(private_key_path, self.size, self.type, self.comment, check_rc=True)
|
||||
self.ssh_keygen.generate_keypair(
|
||||
private_key_path, self.size, self.type, self.comment, check_rc=True
|
||||
)
|
||||
|
||||
def _get_private_key(self):
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(self.private_key_path, check_rc=False)
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(err)
|
||||
return PrivateKey.from_string(private_key_content)
|
||||
|
||||
def _get_public_key(self):
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=True)[1]
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=True
|
||||
)[1]
|
||||
return PublicKey.from_string(public_key_content)
|
||||
|
||||
def _private_key_readable(self):
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=False)
|
||||
return not (rc == 255 or any_in(stderr, 'is not a public key file', 'incorrect passphrase', 'load failed'))
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
return not (
|
||||
rc == 255
|
||||
or any_in(
|
||||
stderr,
|
||||
"is not a public key file",
|
||||
"incorrect passphrase",
|
||||
"load failed",
|
||||
)
|
||||
)
|
||||
|
||||
def _update_comment(self):
|
||||
try:
|
||||
ssh_version = self._get_ssh_version() or "7.8"
|
||||
force_new_format = LooseVersion('6.5') <= LooseVersion(ssh_version) < LooseVersion('7.8')
|
||||
self.ssh_keygen.update_comment(self.private_key_path, self.comment, force_new_format=force_new_format, check_rc=True)
|
||||
force_new_format = (
|
||||
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
|
||||
)
|
||||
self.ssh_keygen.update_comment(
|
||||
self.private_key_path,
|
||||
self.comment,
|
||||
force_new_format=force_new_format,
|
||||
check_rc=True,
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=to_native(e))
|
||||
|
||||
@@ -357,30 +412,41 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
super(KeypairBackendCryptography, self).__init__(module)
|
||||
|
||||
if self.type == 'rsa1':
|
||||
self.module.fail_json(msg="RSA1 keys are not supported by the cryptography backend")
|
||||
if self.type == "rsa1":
|
||||
self.module.fail_json(
|
||||
msg="RSA1 keys are not supported by the cryptography backend"
|
||||
)
|
||||
|
||||
self.passphrase = to_bytes(module.params['passphrase']) if module.params['passphrase'] else None
|
||||
self.private_key_format = self._get_key_format(module.params['private_key_format'])
|
||||
self.passphrase = (
|
||||
to_bytes(module.params["passphrase"])
|
||||
if module.params["passphrase"]
|
||||
else None
|
||||
)
|
||||
self.private_key_format = self._get_key_format(
|
||||
module.params["private_key_format"]
|
||||
)
|
||||
|
||||
def _get_key_format(self, key_format):
|
||||
result = 'SSH'
|
||||
result = "SSH"
|
||||
|
||||
if key_format == 'auto':
|
||||
if key_format == "auto":
|
||||
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
|
||||
ssh_version = self._get_ssh_version() or "7.8"
|
||||
|
||||
if LooseVersion(ssh_version) < LooseVersion("7.8") and self.type != 'ed25519':
|
||||
if (
|
||||
LooseVersion(ssh_version) < LooseVersion("7.8")
|
||||
and self.type != "ed25519"
|
||||
):
|
||||
# OpenSSH made SSH formatted private keys available in version 6.5,
|
||||
# but still defaulted to PKCS1 format with the exception of ed25519 keys
|
||||
result = 'PKCS1'
|
||||
result = "PKCS1"
|
||||
|
||||
if result == 'SSH' and not HAS_OPENSSH_PRIVATE_FORMAT:
|
||||
if result == "SSH" and not HAS_OPENSSH_PRIVATE_FORMAT:
|
||||
self.module.fail_json(
|
||||
msg=missing_required_lib(
|
||||
'cryptography >= 3.0',
|
||||
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 " +
|
||||
"or for ed25519 keys"
|
||||
"cryptography >= 3.0",
|
||||
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 "
|
||||
+ "or for ed25519 keys",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -393,7 +459,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
keytype=self.type,
|
||||
size=self.size,
|
||||
passphrase=self.passphrase,
|
||||
comment=self.comment or '',
|
||||
comment=self.comment or "",
|
||||
)
|
||||
|
||||
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
|
||||
@@ -401,22 +467,28 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
)
|
||||
secure_write(private_key_path, 0o600, encoded_private_key)
|
||||
|
||||
public_key_path = private_key_path + '.pub'
|
||||
public_key_path = private_key_path + ".pub"
|
||||
secure_write(public_key_path, 0o644, keypair.public_key)
|
||||
|
||||
def _get_private_key(self):
|
||||
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
|
||||
return PrivateKey(
|
||||
size=keypair.size,
|
||||
key_type=keypair.key_type,
|
||||
fingerprint=keypair.fingerprint,
|
||||
format=parse_private_key_format(self.private_key_path)
|
||||
format=parse_private_key_format(self.private_key_path),
|
||||
)
|
||||
|
||||
def _get_public_key(self):
|
||||
try:
|
||||
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
passphrase=self.passphrase,
|
||||
no_public_key=True,
|
||||
)
|
||||
except OpenSSHError:
|
||||
# Simulates the null output of ssh-keygen
|
||||
return ""
|
||||
@@ -425,7 +497,11 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
def _private_key_readable(self):
|
||||
try:
|
||||
OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
|
||||
OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
passphrase=self.passphrase,
|
||||
no_public_key=True,
|
||||
)
|
||||
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
|
||||
return False
|
||||
|
||||
@@ -433,7 +509,9 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
# when loading an unencrypted key
|
||||
if self.passphrase:
|
||||
try:
|
||||
OpensshKeypair.load(path=self.private_key_path, passphrase=None, no_public_key=True)
|
||||
OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=None, no_public_key=True
|
||||
)
|
||||
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
|
||||
return True
|
||||
else:
|
||||
@@ -442,14 +520,16 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
return True
|
||||
|
||||
def _update_comment(self):
|
||||
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
try:
|
||||
keypair.comment = self.comment
|
||||
except InvalidCommentError as e:
|
||||
self.module.fail_json(msg=to_native(e))
|
||||
|
||||
try:
|
||||
temp_public_key = self._create_temp_public_key(keypair.public_key + b'\n')
|
||||
temp_public_key = self._create_temp_public_key(keypair.public_key + b"\n")
|
||||
self._safe_secure_move([(temp_public_key, self.public_key_path)])
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=to_native(e))
|
||||
@@ -457,7 +537,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
def _private_key_valid_backend(self):
|
||||
# avoids breaking behavior and prevents
|
||||
# automatic conversions with OpenSSH upgrades
|
||||
if self.module.params['private_key_format'] == 'auto':
|
||||
if self.module.params["private_key_format"] == "auto":
|
||||
return True
|
||||
|
||||
return self.private_key_format == self.original_private_key.format
|
||||
@@ -465,24 +545,26 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
def select_backend(module, backend):
|
||||
can_use_cryptography = HAS_OPENSSH_SUPPORT
|
||||
can_use_opensshbin = bool(module.get_bin_path('ssh-keygen'))
|
||||
can_use_opensshbin = bool(module.get_bin_path("ssh-keygen"))
|
||||
|
||||
if backend == 'auto':
|
||||
if can_use_opensshbin and not module.params['passphrase']:
|
||||
backend = 'opensshbin'
|
||||
if backend == "auto":
|
||||
if can_use_opensshbin and not module.params["passphrase"]:
|
||||
backend = "opensshbin"
|
||||
elif can_use_cryptography:
|
||||
backend = 'cryptography'
|
||||
backend = "cryptography"
|
||||
else:
|
||||
module.fail_json(msg="Cannot find either the OpenSSH binary in the PATH " +
|
||||
"or cryptography >= 2.6 installed on this system")
|
||||
module.fail_json(
|
||||
msg="Cannot find either the OpenSSH binary in the PATH "
|
||||
+ "or cryptography >= 2.6 installed on this system"
|
||||
)
|
||||
|
||||
if backend == 'opensshbin':
|
||||
if backend == "opensshbin":
|
||||
if not can_use_opensshbin:
|
||||
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
|
||||
return backend, KeypairBackendOpensshBin(module)
|
||||
elif backend == 'cryptography':
|
||||
elif backend == "cryptography":
|
||||
if not can_use_cryptography:
|
||||
module.fail_json(msg=missing_required_lib("cryptography >= 2.6"))
|
||||
return backend, KeypairBackendCryptography(module)
|
||||
else:
|
||||
raise ValueError('Unsupported value for backend: {0}'.format(backend))
|
||||
raise ValueError("Unsupported value for backend: {0}".format(backend))
|
||||
|
||||
@@ -51,54 +51,56 @@ _USER_TYPE = 1
|
||||
_HOST_TYPE = 2
|
||||
|
||||
_SSH_TYPE_STRINGS = {
|
||||
'rsa': b"ssh-rsa",
|
||||
'dsa': b"ssh-dss",
|
||||
'ecdsa-nistp256': b"ecdsa-sha2-nistp256",
|
||||
'ecdsa-nistp384': b"ecdsa-sha2-nistp384",
|
||||
'ecdsa-nistp521': b"ecdsa-sha2-nistp521",
|
||||
'ed25519': b"ssh-ed25519",
|
||||
"rsa": b"ssh-rsa",
|
||||
"dsa": b"ssh-dss",
|
||||
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
|
||||
"ecdsa-nistp384": b"ecdsa-sha2-nistp384",
|
||||
"ecdsa-nistp521": b"ecdsa-sha2-nistp521",
|
||||
"ed25519": b"ssh-ed25519",
|
||||
}
|
||||
_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
|
||||
_ECDSA_CURVE_IDENTIFIERS = {
|
||||
'ecdsa-nistp256': b'nistp256',
|
||||
'ecdsa-nistp384': b'nistp384',
|
||||
'ecdsa-nistp521': b'nistp521',
|
||||
"ecdsa-nistp256": b"nistp256",
|
||||
"ecdsa-nistp384": b"nistp384",
|
||||
"ecdsa-nistp521": b"nistp521",
|
||||
}
|
||||
_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
|
||||
b'nistp256': 'ecdsa-nistp256',
|
||||
b'nistp384': 'ecdsa-nistp384',
|
||||
b'nistp521': 'ecdsa-nistp521',
|
||||
b"nistp256": "ecdsa-nistp256",
|
||||
b"nistp384": "ecdsa-nistp384",
|
||||
b"nistp521": "ecdsa-nistp521",
|
||||
}
|
||||
|
||||
_USE_TIMEZONE = sys.version_info >= (3, 6)
|
||||
|
||||
|
||||
_ALWAYS = _add_or_remove_timezone(datetime(1970, 1, 1), with_timezone=_USE_TIMEZONE)
|
||||
_FOREVER = datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
|
||||
_FOREVER = (
|
||||
datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
|
||||
)
|
||||
|
||||
_CRITICAL_OPTIONS = (
|
||||
'force-command',
|
||||
'source-address',
|
||||
'verify-required',
|
||||
"force-command",
|
||||
"source-address",
|
||||
"verify-required",
|
||||
)
|
||||
|
||||
_DIRECTIVES = (
|
||||
'clear',
|
||||
'no-x11-forwarding',
|
||||
'no-agent-forwarding',
|
||||
'no-port-forwarding',
|
||||
'no-pty',
|
||||
'no-user-rc',
|
||||
"clear",
|
||||
"no-x11-forwarding",
|
||||
"no-agent-forwarding",
|
||||
"no-port-forwarding",
|
||||
"no-pty",
|
||||
"no-user-rc",
|
||||
)
|
||||
|
||||
_EXTENSIONS = (
|
||||
'permit-x11-forwarding',
|
||||
'permit-agent-forwarding',
|
||||
'permit-port-forwarding',
|
||||
'permit-pty',
|
||||
'permit-user-rc'
|
||||
"permit-x11-forwarding",
|
||||
"permit-agent-forwarding",
|
||||
"permit-port-forwarding",
|
||||
"permit-pty",
|
||||
"permit-user-rc",
|
||||
)
|
||||
|
||||
if six.PY3:
|
||||
@@ -111,13 +113,19 @@ class OpensshCertificateTimeParameters(object):
|
||||
self._valid_to = self.to_datetime(valid_to)
|
||||
|
||||
if self._valid_from > self._valid_to:
|
||||
raise ValueError("Valid from: %s must not be greater than Valid to: %s" % (valid_from, valid_to))
|
||||
raise ValueError(
|
||||
"Valid from: %s must not be greater than Valid to: %s"
|
||||
% (valid_from, valid_to)
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
else:
|
||||
return self._valid_from == other._valid_from and self._valid_to == other._valid_to
|
||||
return (
|
||||
self._valid_from == other._valid_from
|
||||
and self._valid_to == other._valid_to
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
@@ -126,7 +134,8 @@ class OpensshCertificateTimeParameters(object):
|
||||
def validity_string(self):
|
||||
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
|
||||
return "%s:%s" % (
|
||||
self.valid_from(date_format='openssh'), self.valid_to(date_format='openssh')
|
||||
self.valid_from(date_format="openssh"),
|
||||
self.valid_to(date_format="openssh"),
|
||||
)
|
||||
return ""
|
||||
|
||||
@@ -144,16 +153,22 @@ class OpensshCertificateTimeParameters(object):
|
||||
|
||||
@staticmethod
|
||||
def format_datetime(dt, date_format):
|
||||
if date_format in ('human_readable', 'openssh'):
|
||||
if date_format in ("human_readable", "openssh"):
|
||||
if dt == _ALWAYS:
|
||||
result = 'always'
|
||||
result = "always"
|
||||
elif dt == _FOREVER:
|
||||
result = 'forever'
|
||||
result = "forever"
|
||||
else:
|
||||
result = dt.isoformat().replace('+00:00', '') if date_format == 'human_readable' else dt.strftime("%Y%m%d%H%M%S")
|
||||
elif date_format == 'timestamp':
|
||||
result = (
|
||||
dt.isoformat().replace("+00:00", "")
|
||||
if date_format == "human_readable"
|
||||
else dt.strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
elif date_format == "timestamp":
|
||||
td = dt - _ALWAYS
|
||||
result = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6)
|
||||
result = int(
|
||||
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
|
||||
)
|
||||
else:
|
||||
raise ValueError("%s is not a valid format" % date_format)
|
||||
return result
|
||||
@@ -162,12 +177,17 @@ class OpensshCertificateTimeParameters(object):
|
||||
def to_datetime(time_string_or_timestamp):
|
||||
try:
|
||||
if isinstance(time_string_or_timestamp, six.string_types):
|
||||
result = OpensshCertificateTimeParameters._time_string_to_datetime(time_string_or_timestamp.strip())
|
||||
result = OpensshCertificateTimeParameters._time_string_to_datetime(
|
||||
time_string_or_timestamp.strip()
|
||||
)
|
||||
elif isinstance(time_string_or_timestamp, (long, int)):
|
||||
result = OpensshCertificateTimeParameters._timestamp_to_datetime(time_string_or_timestamp)
|
||||
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
|
||||
time_string_or_timestamp
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Value must be of type (str, unicode, int, long) not %s" % type(time_string_or_timestamp)
|
||||
"Value must be of type (str, unicode, int, long) not %s"
|
||||
% type(time_string_or_timestamp)
|
||||
)
|
||||
except ValueError:
|
||||
raise
|
||||
@@ -182,7 +202,9 @@ class OpensshCertificateTimeParameters(object):
|
||||
else:
|
||||
try:
|
||||
if _USE_TIMEZONE:
|
||||
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
|
||||
result = datetime.fromtimestamp(
|
||||
timestamp, tz=_datetime.timezone.utc
|
||||
)
|
||||
else:
|
||||
result = datetime.utcfromtimestamp(timestamp)
|
||||
except OverflowError:
|
||||
@@ -192,16 +214,21 @@ class OpensshCertificateTimeParameters(object):
|
||||
@staticmethod
|
||||
def _time_string_to_datetime(time_string):
|
||||
result = None
|
||||
if time_string == 'always':
|
||||
if time_string == "always":
|
||||
result = _ALWAYS
|
||||
elif time_string == 'forever':
|
||||
elif time_string == "forever":
|
||||
result = _FOREVER
|
||||
elif is_relative_time_string(time_string):
|
||||
result = convert_relative_to_datetime(time_string, with_timezone=_USE_TIMEZONE)
|
||||
result = convert_relative_to_datetime(
|
||||
time_string, with_timezone=_USE_TIMEZONE
|
||||
)
|
||||
else:
|
||||
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
||||
try:
|
||||
result = _add_or_remove_timezone(datetime.strptime(time_string, time_format), with_timezone=_USE_TIMEZONE)
|
||||
result = _add_or_remove_timezone(
|
||||
datetime.strptime(time_string, time_format),
|
||||
with_timezone=_USE_TIMEZONE,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if result is None:
|
||||
@@ -211,7 +238,7 @@ class OpensshCertificateTimeParameters(object):
|
||||
|
||||
class OpensshCertificateOption(object):
|
||||
def __init__(self, option_type, name, data):
|
||||
if option_type not in ('critical', 'extension'):
|
||||
if option_type not in ("critical", "extension"):
|
||||
raise ValueError("type must be either 'critical' or 'extension'")
|
||||
|
||||
if not isinstance(name, six.string_types):
|
||||
@@ -228,11 +255,13 @@ class OpensshCertificateOption(object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return all([
|
||||
self._option_type == other._option_type,
|
||||
self._name == other._name,
|
||||
self._data == other._data,
|
||||
])
|
||||
return all(
|
||||
[
|
||||
self._option_type == other._option_type,
|
||||
self._name == other._name,
|
||||
self._data == other._data,
|
||||
]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._option_type, self._name, self._data))
|
||||
@@ -260,42 +289,47 @@ class OpensshCertificateOption(object):
|
||||
@classmethod
|
||||
def from_string(cls, option_string):
|
||||
if not isinstance(option_string, six.string_types):
|
||||
raise ValueError("option_string must be a string not %s" % type(option_string))
|
||||
raise ValueError(
|
||||
"option_string must be a string not %s" % type(option_string)
|
||||
)
|
||||
option_type = None
|
||||
|
||||
if ':' in option_string:
|
||||
option_type, value = option_string.strip().split(':', 1)
|
||||
if '=' in value:
|
||||
name, data = value.split('=', 1)
|
||||
if ":" in option_string:
|
||||
option_type, value = option_string.strip().split(":", 1)
|
||||
if "=" in value:
|
||||
name, data = value.split("=", 1)
|
||||
else:
|
||||
name, data = value, ''
|
||||
elif '=' in option_string:
|
||||
name, data = option_string.strip().split('=', 1)
|
||||
name, data = value, ""
|
||||
elif "=" in option_string:
|
||||
name, data = option_string.strip().split("=", 1)
|
||||
else:
|
||||
name, data = option_string.strip(), ''
|
||||
name, data = option_string.strip(), ""
|
||||
|
||||
return cls(
|
||||
option_type=option_type or get_option_type(name.lower()),
|
||||
name=name,
|
||||
data=data
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class OpensshCertificateInfo:
|
||||
"""Encapsulates all certificate information which is signed by a CA key"""
|
||||
def __init__(self,
|
||||
nonce=None,
|
||||
serial=None,
|
||||
cert_type=None,
|
||||
key_id=None,
|
||||
principals=None,
|
||||
valid_after=None,
|
||||
valid_before=None,
|
||||
critical_options=None,
|
||||
extensions=None,
|
||||
reserved=None,
|
||||
signing_key=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nonce=None,
|
||||
serial=None,
|
||||
cert_type=None,
|
||||
key_id=None,
|
||||
principals=None,
|
||||
valid_after=None,
|
||||
valid_before=None,
|
||||
critical_options=None,
|
||||
extensions=None,
|
||||
reserved=None,
|
||||
signing_key=None,
|
||||
):
|
||||
self.nonce = nonce
|
||||
self.serial = serial
|
||||
self._cert_type = cert_type
|
||||
@@ -313,17 +347,17 @@ class OpensshCertificateInfo:
|
||||
@property
|
||||
def cert_type(self):
|
||||
if self._cert_type == _USER_TYPE:
|
||||
return 'user'
|
||||
return "user"
|
||||
elif self._cert_type == _HOST_TYPE:
|
||||
return 'host'
|
||||
return "host"
|
||||
else:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@cert_type.setter
|
||||
def cert_type(self, cert_type):
|
||||
if cert_type == 'user' or cert_type == _USER_TYPE:
|
||||
if cert_type == "user" or cert_type == _USER_TYPE:
|
||||
self._cert_type = _USER_TYPE
|
||||
elif cert_type == 'host' or cert_type == _HOST_TYPE:
|
||||
elif cert_type == "host" or cert_type == _HOST_TYPE:
|
||||
self._cert_type = _HOST_TYPE
|
||||
else:
|
||||
raise ValueError("%s is not a valid certificate type" % cert_type)
|
||||
@@ -343,17 +377,17 @@ class OpensshCertificateInfo:
|
||||
class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, e=None, n=None, **kwargs):
|
||||
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS['rsa'] + _CERT_SUFFIX_V01
|
||||
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
|
||||
self.e = e
|
||||
self.n = n
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.e is None, self.n is None]):
|
||||
return b''
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
writer.string(_SSH_TYPE_STRINGS['rsa'])
|
||||
writer.string(_SSH_TYPE_STRINGS["rsa"])
|
||||
writer.mpint(self.e)
|
||||
writer.mpint(self.n)
|
||||
|
||||
@@ -367,7 +401,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
|
||||
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS['dsa'] + _CERT_SUFFIX_V01
|
||||
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
|
||||
self.p = p
|
||||
self.q = q
|
||||
self.g = g
|
||||
@@ -376,10 +410,10 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
|
||||
return b''
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
writer.string(_SSH_TYPE_STRINGS['dsa'])
|
||||
writer.string(_SSH_TYPE_STRINGS["dsa"])
|
||||
writer.mpint(self.p)
|
||||
writer.mpint(self.q)
|
||||
writer.mpint(self.g)
|
||||
@@ -411,16 +445,20 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
def curve(self, curve):
|
||||
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
|
||||
self._curve = curve
|
||||
self.type_string = _SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]] + _CERT_SUFFIX_V01
|
||||
self.type_string = (
|
||||
_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]]
|
||||
+ _CERT_SUFFIX_V01
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Curve must be one of %s" % (b','.join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode('UTF-8')
|
||||
"Curve must be one of %s"
|
||||
% (b",".join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode("UTF-8")
|
||||
)
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.curve is None, self.public_key is None]):
|
||||
return b''
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
|
||||
@@ -437,15 +475,15 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, pk=None, **kwargs):
|
||||
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS['ed25519'] + _CERT_SUFFIX_V01
|
||||
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
|
||||
self.pk = pk
|
||||
|
||||
def public_key_fingerprint(self):
|
||||
if self.pk is None:
|
||||
return b''
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
writer.string(_SSH_TYPE_STRINGS['ed25519'])
|
||||
writer.string(_SSH_TYPE_STRINGS["ed25519"])
|
||||
writer.string(self.pk)
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
@@ -457,6 +495,7 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
|
||||
class OpensshCertificate(object):
|
||||
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
|
||||
|
||||
def __init__(self, cert_info, signature):
|
||||
|
||||
self._cert_info = cert_info
|
||||
@@ -468,13 +507,13 @@ class OpensshCertificate(object):
|
||||
raise ValueError("%s is not a valid path." % path)
|
||||
|
||||
try:
|
||||
with open(path, 'rb') as cert_file:
|
||||
with open(path, "rb") as cert_file:
|
||||
data = cert_file.read()
|
||||
except (IOError, OSError) as e:
|
||||
raise ValueError("%s cannot be opened for reading: %s" % (path, e))
|
||||
|
||||
try:
|
||||
format_identifier, b64_cert = data.split(b' ')[:2]
|
||||
format_identifier, b64_cert = data.split(b" ")[:2]
|
||||
cert = binascii.a2b_base64(b64_cert)
|
||||
except (binascii.Error, ValueError):
|
||||
raise ValueError("Certificate not in OpenSSH format")
|
||||
@@ -484,7 +523,9 @@ class OpensshCertificate(object):
|
||||
pub_key_type = key_type
|
||||
break
|
||||
else:
|
||||
raise ValueError("Invalid certificate format identifier: %s" % format_identifier)
|
||||
raise ValueError(
|
||||
"Invalid certificate format identifier: %s" % format_identifier
|
||||
)
|
||||
|
||||
parser = OpensshParser(cert)
|
||||
|
||||
@@ -499,7 +540,8 @@ class OpensshCertificate(object):
|
||||
|
||||
if parser.remaining_bytes():
|
||||
raise ValueError(
|
||||
"%s bytes of additional data was not parsed while loading %s" % (parser.remaining_bytes(), path)
|
||||
"%s bytes of additional data was not parsed while loading %s"
|
||||
% (parser.remaining_bytes(), path)
|
||||
)
|
||||
|
||||
return cls(
|
||||
@@ -546,12 +588,16 @@ class OpensshCertificate(object):
|
||||
@property
|
||||
def critical_options(self):
|
||||
return [
|
||||
OpensshCertificateOption('critical', to_text(n), to_text(d)) for n, d in self._cert_info.critical_options
|
||||
OpensshCertificateOption("critical", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.critical_options
|
||||
]
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
return [OpensshCertificateOption('extension', to_text(n), to_text(d)) for n, d in self._cert_info.extensions]
|
||||
return [
|
||||
OpensshCertificateOption("extension", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.extensions
|
||||
]
|
||||
|
||||
@property
|
||||
def reserved(self):
|
||||
@@ -564,7 +610,7 @@ class OpensshCertificate(object):
|
||||
@property
|
||||
def signature_type(self):
|
||||
signature_data = OpensshParser.signature_data(self.signature)
|
||||
return to_text(signature_data['signature_type'])
|
||||
return to_text(signature_data["signature_type"])
|
||||
|
||||
@staticmethod
|
||||
def _parse_cert_info(pub_key_type, parser):
|
||||
@@ -586,23 +632,24 @@ class OpensshCertificate(object):
|
||||
|
||||
def to_dict(self):
|
||||
time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.valid_after,
|
||||
valid_to=self.valid_before
|
||||
valid_from=self.valid_after, valid_to=self.valid_before
|
||||
)
|
||||
return {
|
||||
'type_string': self.type_string,
|
||||
'nonce': self.nonce,
|
||||
'serial': self.serial,
|
||||
'cert_type': self.type,
|
||||
'identifier': self.key_id,
|
||||
'principals': self.principals,
|
||||
'valid_after': time_parameters.valid_from(date_format='human_readable'),
|
||||
'valid_before': time_parameters.valid_to(date_format='human_readable'),
|
||||
'critical_options': [str(critical_option) for critical_option in self.critical_options],
|
||||
'extensions': [str(extension) for extension in self.extensions],
|
||||
'reserved': self.reserved,
|
||||
'public_key': self.public_key,
|
||||
'signing_key': self.signing_key,
|
||||
"type_string": self.type_string,
|
||||
"nonce": self.nonce,
|
||||
"serial": self.serial,
|
||||
"cert_type": self.type,
|
||||
"identifier": self.key_id,
|
||||
"principals": self.principals,
|
||||
"valid_after": time_parameters.valid_from(date_format="human_readable"),
|
||||
"valid_before": time_parameters.valid_to(date_format="human_readable"),
|
||||
"critical_options": [
|
||||
str(critical_option) for critical_option in self.critical_options
|
||||
],
|
||||
"extensions": [str(extension) for extension in self.extensions],
|
||||
"reserved": self.reserved,
|
||||
"public_key": self.public_key,
|
||||
"signing_key": self.signing_key,
|
||||
}
|
||||
|
||||
|
||||
@@ -611,38 +658,46 @@ def apply_directives(directives):
|
||||
raise ValueError("directives must be one of %s" % ", ".join(_DIRECTIVES))
|
||||
|
||||
directive_to_option = {
|
||||
'no-x11-forwarding': OpensshCertificateOption('extension', 'permit-x11-forwarding', ''),
|
||||
'no-agent-forwarding': OpensshCertificateOption('extension', 'permit-agent-forwarding', ''),
|
||||
'no-port-forwarding': OpensshCertificateOption('extension', 'permit-port-forwarding', ''),
|
||||
'no-pty': OpensshCertificateOption('extension', 'permit-pty', ''),
|
||||
'no-user-rc': OpensshCertificateOption('extension', 'permit-user-rc', ''),
|
||||
"no-x11-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-x11-forwarding", ""
|
||||
),
|
||||
"no-agent-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-agent-forwarding", ""
|
||||
),
|
||||
"no-port-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-port-forwarding", ""
|
||||
),
|
||||
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
|
||||
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
|
||||
}
|
||||
|
||||
if 'clear' in directives:
|
||||
if "clear" in directives:
|
||||
return []
|
||||
else:
|
||||
return list(set(default_options()) - set(directive_to_option[d] for d in directives))
|
||||
return list(
|
||||
set(default_options()) - set(directive_to_option[d] for d in directives)
|
||||
)
|
||||
|
||||
|
||||
def default_options():
|
||||
return [OpensshCertificateOption('extension', name, '') for name in _EXTENSIONS]
|
||||
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
|
||||
|
||||
|
||||
def fingerprint(public_key):
|
||||
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
|
||||
h = sha256()
|
||||
h.update(public_key)
|
||||
return b'SHA256:' + b64encode(h.digest()).rstrip(b'=')
|
||||
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
|
||||
|
||||
|
||||
def get_cert_info_object(key_type):
|
||||
if key_type == 'rsa':
|
||||
if key_type == "rsa":
|
||||
cert_info = OpensshRSACertificateInfo()
|
||||
elif key_type == 'dsa':
|
||||
elif key_type == "dsa":
|
||||
cert_info = OpensshDSACertificateInfo()
|
||||
elif key_type in ('ecdsa-nistp256', 'ecdsa-nistp384', 'ecdsa-nistp521'):
|
||||
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
|
||||
cert_info = OpensshECDSACertificateInfo()
|
||||
elif key_type == 'ed25519':
|
||||
elif key_type == "ed25519":
|
||||
cert_info = OpensshED25519CertificateInfo()
|
||||
else:
|
||||
raise ValueError("%s is not a valid key type" % key_type)
|
||||
@@ -652,12 +707,14 @@ def get_cert_info_object(key_type):
|
||||
|
||||
def get_option_type(name):
|
||||
if name in _CRITICAL_OPTIONS:
|
||||
result = 'critical'
|
||||
result = "critical"
|
||||
elif name in _EXTENSIONS:
|
||||
result = 'extension'
|
||||
result = "extension"
|
||||
else:
|
||||
raise ValueError("%s is not a valid option. " % name +
|
||||
"Custom options must start with 'critical:' or 'extension:' to indicate type")
|
||||
raise ValueError(
|
||||
"%s is not a valid option. " % name
|
||||
+ "Custom options must start with 'critical:' or 'extension:' to indicate type"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -675,7 +732,7 @@ def parse_option_list(option_list):
|
||||
directives.append(option.lower())
|
||||
else:
|
||||
option_object = OpensshCertificateOption.from_string(option)
|
||||
if option_object.type == 'critical':
|
||||
if option_object.type == "critical":
|
||||
critical_options.append(option_object)
|
||||
else:
|
||||
extensions.append(option_object)
|
||||
|
||||
@@ -38,41 +38,41 @@ try:
|
||||
HAS_OPENSSH_SUPPORT = True
|
||||
|
||||
_ALGORITHM_PARAMETERS = {
|
||||
'rsa': {
|
||||
'default_size': 2048,
|
||||
'valid_sizes': range(1024, 16384),
|
||||
'signer_params': {
|
||||
'padding': padding.PSS(
|
||||
"rsa": {
|
||||
"default_size": 2048,
|
||||
"valid_sizes": range(1024, 16384),
|
||||
"signer_params": {
|
||||
"padding": padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
'algorithm': hashes.SHA256(),
|
||||
"algorithm": hashes.SHA256(),
|
||||
},
|
||||
},
|
||||
'dsa': {
|
||||
'default_size': 1024,
|
||||
'valid_sizes': [1024],
|
||||
'signer_params': {
|
||||
'algorithm': hashes.SHA256(),
|
||||
"dsa": {
|
||||
"default_size": 1024,
|
||||
"valid_sizes": [1024],
|
||||
"signer_params": {
|
||||
"algorithm": hashes.SHA256(),
|
||||
},
|
||||
},
|
||||
'ed25519': {
|
||||
'default_size': 256,
|
||||
'valid_sizes': [256],
|
||||
'signer_params': {},
|
||||
"ed25519": {
|
||||
"default_size": 256,
|
||||
"valid_sizes": [256],
|
||||
"signer_params": {},
|
||||
},
|
||||
'ecdsa': {
|
||||
'default_size': 256,
|
||||
'valid_sizes': [256, 384, 521],
|
||||
'signer_params': {
|
||||
'signature_algorithm': ec.ECDSA(hashes.SHA256()),
|
||||
"ecdsa": {
|
||||
"default_size": 256,
|
||||
"valid_sizes": [256, 384, 521],
|
||||
"signer_params": {
|
||||
"signature_algorithm": ec.ECDSA(hashes.SHA256()),
|
||||
},
|
||||
'curves': {
|
||||
"curves": {
|
||||
256: ec.SECP256R1(),
|
||||
384: ec.SECP384R1(),
|
||||
521: ec.SECP521R1(),
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
except ImportError:
|
||||
HAS_OPENSSH_PRIVATE_FORMAT = False
|
||||
@@ -80,7 +80,7 @@ except ImportError:
|
||||
CRYPTOGRAPHY_VERSION = "0.0"
|
||||
_ALGORITHM_PARAMETERS = {}
|
||||
|
||||
_TEXT_ENCODING = 'UTF-8'
|
||||
_TEXT_ENCODING = "UTF-8"
|
||||
|
||||
|
||||
class OpenSSHError(Exception):
|
||||
@@ -131,26 +131,25 @@ class AsymmetricKeypair(object):
|
||||
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype='rsa', size=None, passphrase=None):
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None):
|
||||
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
|
||||
or defaults to an unencrypted RSA-2048 key
|
||||
or defaults to an unencrypted RSA-2048 key
|
||||
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for newly generated keys
|
||||
:passphrase: Secret of type Bytes used to encrypt the private key being generated
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for newly generated keys
|
||||
:passphrase: Secret of type Bytes used to encrypt the private key being generated
|
||||
"""
|
||||
|
||||
if keytype not in _ALGORITHM_PARAMETERS.keys():
|
||||
raise InvalidKeyTypeError(
|
||||
"%s is not a valid keytype. Valid keytypes are %s" % (
|
||||
keytype, ", ".join(_ALGORITHM_PARAMETERS.keys())
|
||||
)
|
||||
"%s is not a valid keytype. Valid keytypes are %s"
|
||||
% (keytype, ", ".join(_ALGORITHM_PARAMETERS.keys()))
|
||||
)
|
||||
|
||||
if not size:
|
||||
size = _ALGORITHM_PARAMETERS[keytype]['default_size']
|
||||
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
|
||||
else:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]['valid_sizes']:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
|
||||
raise InvalidKeySizeError(
|
||||
"%s is not a valid key size for %s keys" % (size, keytype)
|
||||
)
|
||||
@@ -160,7 +159,7 @@ class AsymmetricKeypair(object):
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
|
||||
if keytype == 'rsa':
|
||||
if keytype == "rsa":
|
||||
privatekey = rsa.generate_private_key(
|
||||
# Public exponent should always be 65537 to prevent issues
|
||||
# if improper padding is used during signing
|
||||
@@ -168,16 +167,16 @@ class AsymmetricKeypair(object):
|
||||
key_size=size,
|
||||
backend=backend,
|
||||
)
|
||||
elif keytype == 'dsa':
|
||||
elif keytype == "dsa":
|
||||
privatekey = dsa.generate_private_key(
|
||||
key_size=size,
|
||||
backend=backend,
|
||||
)
|
||||
elif keytype == 'ed25519':
|
||||
elif keytype == "ed25519":
|
||||
privatekey = Ed25519PrivateKey.generate()
|
||||
elif keytype == 'ecdsa':
|
||||
elif keytype == "ecdsa":
|
||||
privatekey = ec.generate_private_key(
|
||||
_ALGORITHM_PARAMETERS['ecdsa']['curves'][size],
|
||||
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
@@ -188,18 +187,25 @@ class AsymmetricKeypair(object):
|
||||
size=size,
|
||||
privatekey=privatekey,
|
||||
publickey=publickey,
|
||||
encryption_algorithm=encryption_algorithm
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, passphrase=None, private_key_format='PEM', public_key_format='PEM', no_public_key=False):
|
||||
def load(
|
||||
cls,
|
||||
path,
|
||||
passphrase=None,
|
||||
private_key_format="PEM",
|
||||
public_key_format="PEM",
|
||||
no_public_key=False,
|
||||
):
|
||||
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
:passphrase: Secret of type bytes used to decrypt the private key being loaded
|
||||
:private_key_format: Format of private key to be loaded
|
||||
:public_key_format: Format of public key to be loaded
|
||||
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
|
||||
:path: A path to an existing private key to be loaded
|
||||
:passphrase: Secret of type bytes used to decrypt the private key being loaded
|
||||
:private_key_format: Format of private key to be loaded
|
||||
:public_key_format: Format of public key to be loaded
|
||||
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
|
||||
"""
|
||||
|
||||
if passphrase:
|
||||
@@ -211,40 +217,42 @@ class AsymmetricKeypair(object):
|
||||
if no_public_key:
|
||||
publickey = privatekey.public_key()
|
||||
else:
|
||||
publickey = load_publickey(path + '.pub', public_key_format)
|
||||
publickey = load_publickey(path + ".pub", public_key_format)
|
||||
|
||||
# Ed25519 keys are always of size 256 and do not have a key_size attribute
|
||||
if isinstance(privatekey, Ed25519PrivateKey):
|
||||
size = _ALGORITHM_PARAMETERS['ed25519']['default_size']
|
||||
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
|
||||
else:
|
||||
size = privatekey.key_size
|
||||
|
||||
if isinstance(privatekey, rsa.RSAPrivateKey):
|
||||
keytype = 'rsa'
|
||||
keytype = "rsa"
|
||||
elif isinstance(privatekey, dsa.DSAPrivateKey):
|
||||
keytype = 'dsa'
|
||||
keytype = "dsa"
|
||||
elif isinstance(privatekey, ec.EllipticCurvePrivateKey):
|
||||
keytype = 'ecdsa'
|
||||
keytype = "ecdsa"
|
||||
elif isinstance(privatekey, Ed25519PrivateKey):
|
||||
keytype = 'ed25519'
|
||||
keytype = "ed25519"
|
||||
else:
|
||||
raise InvalidKeyTypeError("Key type '%s' is not supported" % type(privatekey))
|
||||
raise InvalidKeyTypeError(
|
||||
"Key type '%s' is not supported" % type(privatekey)
|
||||
)
|
||||
|
||||
return cls(
|
||||
keytype=keytype,
|
||||
size=size,
|
||||
privatekey=privatekey,
|
||||
publickey=publickey,
|
||||
encryption_algorithm=encryption_algorithm
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
|
||||
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
|
||||
"""
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for the private key of this key pair
|
||||
:privatekey: Private key object of this key pair
|
||||
:publickey: Public key object of this key pair
|
||||
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for the private key of this key pair
|
||||
:privatekey: Private key object of this key pair
|
||||
:publickey: Public key object of this key pair
|
||||
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
|
||||
"""
|
||||
|
||||
self.__size = size
|
||||
@@ -254,7 +262,7 @@ class AsymmetricKeypair(object):
|
||||
self.__encryption_algorithm = encryption_algorithm
|
||||
|
||||
try:
|
||||
self.verify(self.sign(b'message'), b'message')
|
||||
self.verify(self.sign(b"message"), b"message")
|
||||
except InvalidSignatureError:
|
||||
raise InvalidPublicKeyFileError(
|
||||
"The private key and public key of this keypair do not match"
|
||||
@@ -264,8 +272,11 @@ class AsymmetricKeypair(object):
|
||||
if not isinstance(other, AsymmetricKeypair):
|
||||
return NotImplemented
|
||||
|
||||
return (compare_publickeys(self.public_key, other.public_key) and
|
||||
compare_encryption_algorithms(self.encryption_algorithm, other.encryption_algorithm))
|
||||
return compare_publickeys(
|
||||
self.public_key, other.public_key
|
||||
) and compare_encryption_algorithms(
|
||||
self.encryption_algorithm, other.encryption_algorithm
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
@@ -303,13 +314,12 @@ class AsymmetricKeypair(object):
|
||||
def sign(self, data):
|
||||
"""Returns signature of data signed with the private key of this key pair
|
||||
|
||||
:data: byteslike data to sign
|
||||
:data: byteslike data to sign
|
||||
"""
|
||||
|
||||
try:
|
||||
signature = self.__privatekey.sign(
|
||||
data,
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
|
||||
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
|
||||
)
|
||||
except TypeError as e:
|
||||
raise InvalidDataError(e)
|
||||
@@ -318,16 +328,16 @@ class AsymmetricKeypair(object):
|
||||
|
||||
def verify(self, signature, data):
|
||||
"""Verifies that the signature associated with the provided data was signed
|
||||
by the private key of this key pair.
|
||||
by the private key of this key pair.
|
||||
|
||||
:signature: signature to verify
|
||||
:data: byteslike data signed by the provided signature
|
||||
:signature: signature to verify
|
||||
:data: byteslike data signed by the provided signature
|
||||
"""
|
||||
try:
|
||||
return self.__publickey.verify(
|
||||
signature,
|
||||
data,
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
|
||||
)
|
||||
except InvalidSignature:
|
||||
raise InvalidSignatureError
|
||||
@@ -335,7 +345,7 @@ class AsymmetricKeypair(object):
|
||||
def update_passphrase(self, passphrase=None):
|
||||
"""Updates the encryption algorithm of this key pair
|
||||
|
||||
:passphrase: Byte secret used to encrypt this key pair
|
||||
:passphrase: Byte secret used to encrypt this key pair
|
||||
"""
|
||||
|
||||
if passphrase:
|
||||
@@ -348,20 +358,20 @@ class OpensshKeypair(object):
|
||||
"""Container for OpenSSH encoded asymmetric key pairs"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype='rsa', size=None, passphrase=None, comment=None):
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
|
||||
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
|
||||
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for newly generated keys
|
||||
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
|
||||
:comment: Comment for a newly generated OpenSSH public key
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for newly generated keys
|
||||
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
|
||||
:comment: Comment for a newly generated OpenSSH public key
|
||||
"""
|
||||
|
||||
if comment is None:
|
||||
comment = "%s@%s" % (getuser(), gethostname())
|
||||
|
||||
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
@@ -377,18 +387,20 @@ class OpensshKeypair(object):
|
||||
def load(cls, path, passphrase=None, no_public_key=False):
|
||||
"""Returns an Openssh_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
:passphrase: Secret used to decrypt the private key being loaded
|
||||
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
|
||||
:path: A path to an existing private key to be loaded
|
||||
:passphrase: Secret used to decrypt the private key being loaded
|
||||
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
|
||||
"""
|
||||
|
||||
if no_public_key:
|
||||
comment = ""
|
||||
else:
|
||||
comment = extract_comment(path + '.pub')
|
||||
comment = extract_comment(path + ".pub")
|
||||
|
||||
asym_keypair = AsymmetricKeypair.load(path, passphrase, 'SSH', 'SSH', no_public_key)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
|
||||
asym_keypair = AsymmetricKeypair.load(
|
||||
path, passphrase, "SSH", "SSH", no_public_key
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
@@ -404,29 +416,33 @@ class OpensshKeypair(object):
|
||||
def encode_openssh_privatekey(asym_keypair, key_format):
|
||||
"""Returns an OpenSSH encoded private key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the private key is extracted
|
||||
:key_format: Format of the encoded private key.
|
||||
:asym_keypair: Asymmetric_Keypair from the private key is extracted
|
||||
:key_format: Format of the encoded private key.
|
||||
"""
|
||||
|
||||
if key_format == 'SSH':
|
||||
if key_format == "SSH":
|
||||
# Default to PEM format if SSH not available
|
||||
if not HAS_OPENSSH_PRIVATE_FORMAT:
|
||||
privatekey_format = serialization.PrivateFormat.PKCS8
|
||||
else:
|
||||
privatekey_format = serialization.PrivateFormat.OpenSSH
|
||||
elif key_format == 'PKCS8':
|
||||
elif key_format == "PKCS8":
|
||||
privatekey_format = serialization.PrivateFormat.PKCS8
|
||||
elif key_format == 'PKCS1':
|
||||
if asym_keypair.key_type == 'ed25519':
|
||||
raise InvalidKeyFormatError("ed25519 keys cannot be represented in PKCS1 format")
|
||||
elif key_format == "PKCS1":
|
||||
if asym_keypair.key_type == "ed25519":
|
||||
raise InvalidKeyFormatError(
|
||||
"ed25519 keys cannot be represented in PKCS1 format"
|
||||
)
|
||||
privatekey_format = serialization.PrivateFormat.TraditionalOpenSSL
|
||||
else:
|
||||
raise InvalidKeyFormatError("The accepted private key formats are SSH, PKCS8, and PKCS1")
|
||||
raise InvalidKeyFormatError(
|
||||
"The accepted private key formats are SSH, PKCS8, and PKCS1"
|
||||
)
|
||||
|
||||
encoded_privatekey = asym_keypair.private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=privatekey_format,
|
||||
encryption_algorithm=asym_keypair.encryption_algorithm
|
||||
encryption_algorithm=asym_keypair.encryption_algorithm,
|
||||
)
|
||||
|
||||
return encoded_privatekey
|
||||
@@ -435,8 +451,8 @@ class OpensshKeypair(object):
|
||||
def encode_openssh_publickey(asym_keypair, comment):
|
||||
"""Returns an OpenSSH encoded public key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the public key is extracted
|
||||
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
|
||||
:asym_keypair: Asymmetric_Keypair from the public key is extracted
|
||||
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
|
||||
"""
|
||||
encoded_publickey = asym_keypair.public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
@@ -445,17 +461,21 @@ class OpensshKeypair(object):
|
||||
|
||||
validate_comment(comment)
|
||||
|
||||
encoded_publickey += (" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b''
|
||||
encoded_publickey += (
|
||||
(" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b""
|
||||
)
|
||||
|
||||
return encoded_publickey
|
||||
|
||||
def __init__(self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment):
|
||||
def __init__(
|
||||
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
|
||||
):
|
||||
"""
|
||||
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
|
||||
:openssh_privatekey: An OpenSSH encoded private key
|
||||
:openssh_privatekey: An OpenSSH encoded public key
|
||||
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
|
||||
:comment: Comment applied to the OpenSSH public key of this keypair
|
||||
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
|
||||
:openssh_privatekey: An OpenSSH encoded private key
|
||||
:openssh_privatekey: An OpenSSH encoded public key
|
||||
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
|
||||
:comment: Comment applied to the OpenSSH public key of this keypair
|
||||
"""
|
||||
|
||||
self.__asym_keypair = asym_keypair
|
||||
@@ -468,7 +488,10 @@ class OpensshKeypair(object):
|
||||
if not isinstance(other, OpensshKeypair):
|
||||
return NotImplemented
|
||||
|
||||
return self.asymmetric_keypair == other.asymmetric_keypair and self.comment == other.comment
|
||||
return (
|
||||
self.asymmetric_keypair == other.asymmetric_keypair
|
||||
and self.comment == other.comment
|
||||
)
|
||||
|
||||
@property
|
||||
def asymmetric_keypair(self):
|
||||
@@ -516,53 +539,59 @@ class OpensshKeypair(object):
|
||||
def comment(self, comment):
|
||||
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
|
||||
|
||||
:comment: Text to update the OpenSSH public key comment
|
||||
:comment: Text to update the OpenSSH public key comment
|
||||
"""
|
||||
|
||||
validate_comment(comment)
|
||||
|
||||
self.__comment = comment
|
||||
encoded_comment = (" %s" % self.__comment).encode(encoding=_TEXT_ENCODING) if self.__comment else b''
|
||||
self.__openssh_publickey = b' '.join(self.__openssh_publickey.split(b' ', 2)[:2]) + encoded_comment
|
||||
encoded_comment = (
|
||||
(" %s" % self.__comment).encode(encoding=_TEXT_ENCODING)
|
||||
if self.__comment
|
||||
else b""
|
||||
)
|
||||
self.__openssh_publickey = (
|
||||
b" ".join(self.__openssh_publickey.split(b" ", 2)[:2]) + encoded_comment
|
||||
)
|
||||
return self.__openssh_publickey
|
||||
|
||||
def update_passphrase(self, passphrase):
|
||||
"""Updates the passphrase used to encrypt the private key of this keypair
|
||||
|
||||
:passphrase: Text secret used for encryption
|
||||
:passphrase: Text secret used for encryption
|
||||
"""
|
||||
|
||||
self.__asym_keypair.update_passphrase(passphrase)
|
||||
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(self.__asym_keypair, 'SSH')
|
||||
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
|
||||
self.__asym_keypair, "SSH"
|
||||
)
|
||||
|
||||
|
||||
def load_privatekey(path, passphrase, key_format):
|
||||
privatekey_loaders = {
|
||||
'PEM': serialization.load_pem_private_key,
|
||||
'DER': serialization.load_der_private_key,
|
||||
"PEM": serialization.load_pem_private_key,
|
||||
"DER": serialization.load_der_private_key,
|
||||
}
|
||||
|
||||
# OpenSSH formatted private keys are not available in Cryptography <3.0
|
||||
if hasattr(serialization, 'load_ssh_private_key'):
|
||||
privatekey_loaders['SSH'] = serialization.load_ssh_private_key
|
||||
if hasattr(serialization, "load_ssh_private_key"):
|
||||
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
|
||||
else:
|
||||
privatekey_loaders['SSH'] = serialization.load_pem_private_key
|
||||
privatekey_loaders["SSH"] = serialization.load_pem_private_key
|
||||
|
||||
try:
|
||||
privatekey_loader = privatekey_loaders[key_format]
|
||||
except KeyError:
|
||||
raise InvalidKeyFormatError(
|
||||
"%s is not a valid key format (%s)" % (
|
||||
key_format,
|
||||
','.join(privatekey_loaders.keys())
|
||||
)
|
||||
"%s is not a valid key format (%s)"
|
||||
% (key_format, ",".join(privatekey_loaders.keys()))
|
||||
)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise InvalidPrivateKeyFileError("No file was found at %s" % path)
|
||||
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
privatekey = privatekey_loader(
|
||||
@@ -573,9 +602,9 @@ def load_privatekey(path, passphrase, key_format):
|
||||
|
||||
except ValueError as e:
|
||||
# Revert to PEM if key could not be loaded in SSH format
|
||||
if key_format == 'SSH':
|
||||
if key_format == "SSH":
|
||||
try:
|
||||
privatekey = privatekey_loaders['PEM'](
|
||||
privatekey = privatekey_loaders["PEM"](
|
||||
data=content,
|
||||
password=passphrase,
|
||||
backend=backend,
|
||||
@@ -598,26 +627,24 @@ def load_privatekey(path, passphrase, key_format):
|
||||
|
||||
def load_publickey(path, key_format):
|
||||
publickey_loaders = {
|
||||
'PEM': serialization.load_pem_public_key,
|
||||
'DER': serialization.load_der_public_key,
|
||||
'SSH': serialization.load_ssh_public_key,
|
||||
"PEM": serialization.load_pem_public_key,
|
||||
"DER": serialization.load_der_public_key,
|
||||
"SSH": serialization.load_ssh_public_key,
|
||||
}
|
||||
|
||||
try:
|
||||
publickey_loader = publickey_loaders[key_format]
|
||||
except KeyError:
|
||||
raise InvalidKeyFormatError(
|
||||
"%s is not a valid key format (%s)" % (
|
||||
key_format,
|
||||
','.join(publickey_loaders.keys())
|
||||
)
|
||||
"%s is not a valid key format (%s)"
|
||||
% (key_format, ",".join(publickey_loaders.keys()))
|
||||
)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise InvalidPublicKeyFileError("No file was found at %s" % path)
|
||||
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
publickey = publickey_loader(
|
||||
@@ -646,10 +673,13 @@ def compare_publickeys(pk1, pk2):
|
||||
|
||||
|
||||
def compare_encryption_algorithms(ea1, ea2):
|
||||
if isinstance(ea1, serialization.NoEncryption) and isinstance(ea2, serialization.NoEncryption):
|
||||
if isinstance(ea1, serialization.NoEncryption) and isinstance(
|
||||
ea2, serialization.NoEncryption
|
||||
):
|
||||
return True
|
||||
elif (isinstance(ea1, serialization.BestAvailableEncryption) and
|
||||
isinstance(ea2, serialization.BestAvailableEncryption)):
|
||||
elif isinstance(ea1, serialization.BestAvailableEncryption) and isinstance(
|
||||
ea2, serialization.BestAvailableEncryption
|
||||
):
|
||||
return ea1.password == ea2.password
|
||||
else:
|
||||
return False
|
||||
@@ -663,7 +693,7 @@ def get_encryption_algorithm(passphrase):
|
||||
|
||||
|
||||
def validate_comment(comment):
|
||||
if not hasattr(comment, 'encode'):
|
||||
if not hasattr(comment, "encode"):
|
||||
raise InvalidCommentError("%s cannot be encoded to text" % comment)
|
||||
|
||||
|
||||
@@ -673,8 +703,8 @@ def extract_comment(path):
|
||||
raise InvalidPublicKeyFileError("No file was found at %s" % path)
|
||||
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
fields = f.read().split(b' ', 2)
|
||||
with open(path, "rb") as f:
|
||||
fields = f.read().split(b" ", 2)
|
||||
if len(fields) == 3:
|
||||
comment = fields[2].decode(_TEXT_ENCODING)
|
||||
else:
|
||||
@@ -687,7 +717,9 @@ def extract_comment(path):
|
||||
|
||||
def calculate_fingerprint(openssh_publickey):
|
||||
digest = hashes.Hash(hashes.SHA256(), backend=backend)
|
||||
decoded_pubkey = b64decode(openssh_publickey.split(b' ')[1])
|
||||
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
|
||||
digest.update(decoded_pubkey)
|
||||
|
||||
return 'SHA256:%s' % b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip('=')
|
||||
return "SHA256:%s" % b64encode(digest.finalize()).decode(
|
||||
encoding=_TEXT_ENCODING
|
||||
).rstrip("=")
|
||||
|
||||
@@ -34,17 +34,17 @@ if PY3:
|
||||
long = int
|
||||
|
||||
# 0 (False) or 1 (True) encoded as a single byte
|
||||
_BOOLEAN = Struct(b'?')
|
||||
_BOOLEAN = Struct(b"?")
|
||||
# Unsigned 8-bit integer in network-byte-order
|
||||
_UBYTE = Struct(b'!B')
|
||||
_UBYTE = Struct(b"!B")
|
||||
_UBYTE_MAX = 0xFF
|
||||
# Unsigned 32-bit integer in network-byte-order
|
||||
_UINT32 = Struct(b'!I')
|
||||
_UINT32 = Struct(b"!I")
|
||||
# Unsigned 32-bit little endian integer
|
||||
_UINT32_LE = Struct(b'<I')
|
||||
_UINT32_LE = Struct(b"<I")
|
||||
_UINT32_MAX = 0xFFFFFFFF
|
||||
# Unsigned 64-bit integer in network-byte-order
|
||||
_UINT64 = Struct(b'!Q')
|
||||
_UINT64 = Struct(b"!Q")
|
||||
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ def secure_write(path, mode, content):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for SSH data types
|
||||
class OpensshParser(object):
|
||||
"""Parser for OpenSSH encoded objects"""
|
||||
|
||||
BOOLEAN_OFFSET = 1
|
||||
UINT32_OFFSET = 4
|
||||
UINT64_OFFSET = 8
|
||||
@@ -103,21 +104,21 @@ class OpensshParser(object):
|
||||
def boolean(self):
|
||||
next_pos = self._check_position(self.BOOLEAN_OFFSET)
|
||||
|
||||
value = _BOOLEAN.unpack(self._data[self._pos:next_pos])[0]
|
||||
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint32(self):
|
||||
next_pos = self._check_position(self.UINT32_OFFSET)
|
||||
|
||||
value = _UINT32.unpack(self._data[self._pos:next_pos])[0]
|
||||
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint64(self):
|
||||
next_pos = self._check_position(self.UINT64_OFFSET)
|
||||
|
||||
value = _UINT64.unpack(self._data[self._pos:next_pos])[0]
|
||||
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
@@ -126,7 +127,7 @@ class OpensshParser(object):
|
||||
|
||||
next_pos = self._check_position(length)
|
||||
|
||||
value = self._data[self._pos:next_pos]
|
||||
value = self._data[self._pos : next_pos]
|
||||
self._pos = next_pos
|
||||
# Cast to bytes is required as a memoryview slice is itself a memoryview
|
||||
return value if not PY3 else bytes(value)
|
||||
@@ -136,7 +137,7 @@ class OpensshParser(object):
|
||||
|
||||
def name_list(self):
|
||||
raw_string = self.string()
|
||||
return raw_string.decode('ASCII').split(',')
|
||||
return raw_string.decode("ASCII").split(",")
|
||||
|
||||
# Convenience function, but not an official data type from SSH
|
||||
def string_list(self):
|
||||
@@ -193,33 +194,39 @@ class OpensshParser(object):
|
||||
signature_blob = parser.string()
|
||||
|
||||
blob_parser = cls(signature_blob)
|
||||
if signature_type in (b'ssh-rsa', b'rsa-sha2-256', b'rsa-sha2-512'):
|
||||
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
|
||||
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
|
||||
signature_data['s'] = cls._big_int(signature_blob, "big")
|
||||
elif signature_type == b'ssh-dss':
|
||||
signature_data["s"] = cls._big_int(signature_blob, "big")
|
||||
elif signature_type == b"ssh-dss":
|
||||
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
signature_data['r'] = cls._big_int(signature_blob[:20], "big")
|
||||
signature_data['s'] = cls._big_int(signature_blob[20:], "big")
|
||||
elif signature_type in (b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521'):
|
||||
signature_data["r"] = cls._big_int(signature_blob[:20], "big")
|
||||
signature_data["s"] = cls._big_int(signature_blob[20:], "big")
|
||||
elif signature_type in (
|
||||
b"ecdsa-sha2-nistp256",
|
||||
b"ecdsa-sha2-nistp384",
|
||||
b"ecdsa-sha2-nistp521",
|
||||
):
|
||||
# https://datatracker.ietf.org/doc/html/rfc5656#section-3.1.2
|
||||
signature_data['r'] = blob_parser.mpint()
|
||||
signature_data['s'] = blob_parser.mpint()
|
||||
elif signature_type == b'ssh-ed25519':
|
||||
signature_data["r"] = blob_parser.mpint()
|
||||
signature_data["s"] = blob_parser.mpint()
|
||||
elif signature_type == b"ssh-ed25519":
|
||||
# https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2
|
||||
signature_data['R'] = cls._big_int(signature_blob[:32], "little")
|
||||
signature_data['S'] = cls._big_int(signature_blob[32:], "little")
|
||||
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
|
||||
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
|
||||
else:
|
||||
raise ValueError("%s is not a valid signature type" % signature_type)
|
||||
|
||||
signature_data['signature_type'] = signature_type
|
||||
signature_data["signature_type"] = signature_type
|
||||
|
||||
return signature_data
|
||||
|
||||
@classmethod
|
||||
def _big_int(cls, raw_string, byte_order, signed=False):
|
||||
if byte_order not in ("big", "little"):
|
||||
raise ValueError("Byte_order must be one of (big, little) not %s" % byte_order)
|
||||
raise ValueError(
|
||||
"Byte_order must be one of (big, little) not %s" % byte_order
|
||||
)
|
||||
|
||||
if PY3:
|
||||
return int.from_bytes(raw_string, byte_order, signed=signed)
|
||||
@@ -232,21 +239,31 @@ class OpensshParser(object):
|
||||
msb = raw_string[0] if byte_order == "big" else raw_string[-1]
|
||||
negative = bool(ord(msb) & 0x80)
|
||||
# Match pad value for two's complement
|
||||
pad = b'\xFF' if signed and negative else b'\x00'
|
||||
pad = b"\xff" if signed and negative else b"\x00"
|
||||
# The definition of ``mpint`` enforces that unnecessary bytes are not encoded so they are added back
|
||||
pad_length = (4 - byte_length % 4)
|
||||
pad_length = 4 - byte_length % 4
|
||||
if pad_length < 4:
|
||||
raw_string = pad * pad_length + raw_string if byte_order == "big" else raw_string + pad * pad_length
|
||||
raw_string = (
|
||||
pad * pad_length + raw_string
|
||||
if byte_order == "big"
|
||||
else raw_string + pad * pad_length
|
||||
)
|
||||
byte_length += pad_length
|
||||
# Accumulate arbitrary precision integer bytes in the appropriate order
|
||||
if byte_order == "big":
|
||||
for i in range(0, byte_length, cls.UINT32_OFFSET):
|
||||
left_shift = result << cls.UINT32_OFFSET * 8
|
||||
result = left_shift + _UINT32.unpack(raw_string[i:i + cls.UINT32_OFFSET])[0]
|
||||
result = (
|
||||
left_shift
|
||||
+ _UINT32.unpack(raw_string[i : i + cls.UINT32_OFFSET])[0]
|
||||
)
|
||||
else:
|
||||
for i in range(byte_length, 0, -cls.UINT32_OFFSET):
|
||||
left_shift = result << cls.UINT32_OFFSET * 8
|
||||
result = left_shift + _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET:i])[0]
|
||||
result = (
|
||||
left_shift
|
||||
+ _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET : i])[0]
|
||||
)
|
||||
# Adjust for two's complement
|
||||
if signed and negative:
|
||||
result -= 1 << (8 * byte_length)
|
||||
@@ -262,10 +279,13 @@ class _OpensshWriter(object):
|
||||
It is not to be used to construct Openssh objects, but rather as a utility to assist
|
||||
in validating parsed material.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer=None):
|
||||
if buffer is not None:
|
||||
if not isinstance(buffer, (bytes, bytearray)):
|
||||
raise TypeError("Buffer must be a bytes-like object not %s" % type(buffer))
|
||||
raise TypeError(
|
||||
"Buffer must be a bytes-like object not %s" % type(buffer)
|
||||
)
|
||||
else:
|
||||
buffer = bytearray()
|
||||
|
||||
@@ -283,7 +303,9 @@ class _OpensshWriter(object):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("Value must be of type int not %s" % type(value))
|
||||
if value < 0 or value > _UINT32_MAX:
|
||||
raise ValueError("Value must be a positive integer less than %s" % _UINT32_MAX)
|
||||
raise ValueError(
|
||||
"Value must be a positive integer less than %s" % _UINT32_MAX
|
||||
)
|
||||
|
||||
self._buff.extend(_UINT32.pack(value))
|
||||
|
||||
@@ -293,7 +315,9 @@ class _OpensshWriter(object):
|
||||
if not isinstance(value, (long, int)):
|
||||
raise TypeError("Value must be of type (long, int) not %s" % type(value))
|
||||
if value < 0 or value > _UINT64_MAX:
|
||||
raise ValueError("Value must be a positive integer less than %s" % _UINT64_MAX)
|
||||
raise ValueError(
|
||||
"Value must be a positive integer less than %s" % _UINT64_MAX
|
||||
)
|
||||
|
||||
self._buff.extend(_UINT64.pack(value))
|
||||
|
||||
@@ -320,7 +344,7 @@ class _OpensshWriter(object):
|
||||
raise TypeError("Value must be a list of byte strings not %s" % type(value))
|
||||
|
||||
try:
|
||||
self.string(','.join(value).encode('ASCII'))
|
||||
self.string(",".join(value).encode("ASCII"))
|
||||
except UnicodeEncodeError as e:
|
||||
raise ValueError("Name-list's must consist of US-ASCII characters: %s" % e)
|
||||
|
||||
@@ -365,9 +389,9 @@ class _OpensshWriter(object):
|
||||
result = bytes()
|
||||
# 0 and -1 are treated as special cases since they are used as sentinels for all other values
|
||||
if num == 0:
|
||||
result += b'\x00'
|
||||
result += b"\x00"
|
||||
elif num == -1:
|
||||
result += b'\xFF'
|
||||
result += b"\xff"
|
||||
elif num > 0:
|
||||
while num >> 32:
|
||||
result = _UINT32.pack(num & _UINT32_MAX) + result
|
||||
@@ -378,7 +402,7 @@ class _OpensshWriter(object):
|
||||
num = num >> 8
|
||||
# Zero pad final byte if most-significant bit is 1 as per mpint definition
|
||||
if ord(result[0]) & 0x80:
|
||||
result = b'\x00' + result
|
||||
result = b"\x00" + result
|
||||
else:
|
||||
while (num >> 32) < -1:
|
||||
result = _UINT32.pack(num & _UINT32_MAX) + result
|
||||
@@ -387,7 +411,7 @@ class _OpensshWriter(object):
|
||||
result = _UBYTE.pack(num & _UBYTE_MAX) + result
|
||||
num = num >> 8
|
||||
if not ord(result[0]) & 0x80:
|
||||
result = b'\xFF' + result
|
||||
result = b"\xff" + result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user