# -*- coding: utf-8 -*- # Copyright (c) Ansible Project # GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) # SPDX-License-Identifier: GPL-3.0-or-later from __future__ import absolute_import, division, print_function __metaclass__ = type import os.path from getpass import getuser from os import remove, rmdir from socket import gethostname from tempfile import mkdtemp import pytest from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptography import ( HAS_OPENSSH_SUPPORT, InvalidCommentError, InvalidKeySizeError, InvalidKeyTypeError, InvalidPassphraseError, InvalidPrivateKeyFileError, InvalidPublicKeyFileError, OpensshKeypair, ) DEFAULT_KEY_PARAMS = [ ( "rsa", None, None, None, ), ( "dsa", None, None, None, ), ( "ecdsa", None, None, None, ), ( "ed25519", None, None, None, ), ] VALID_USER_KEY_PARAMS = [ ( "rsa", 8192, "change_me".encode("UTF-8"), "comment", ), ( "dsa", 1024, "change_me".encode("UTF-8"), "comment", ), ( "ecdsa", 521, "change_me".encode("UTF-8"), "comment", ), ( "ed25519", 256, "change_me".encode("UTF-8"), "comment", ), ] INVALID_USER_KEY_PARAMS = [ ( "dne", None, None, None, ), ( "rsa", None, [1, 2, 3], "comment", ), ( "ecdsa", None, None, [1, 2, 3], ), ] INVALID_KEY_SIZES = [ ( "rsa", 1023, None, None, ), ( "rsa", 16385, None, None, ), ( "dsa", 256, None, None, ), ( "ecdsa", 1024, None, None, ), ( "ed25519", 1024, None, None, ), ] @pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_default_key_params(keytype, size, passphrase, comment): result = True default_sizes = { "rsa": 2048, "dsa": 1024, "ecdsa": 256, "ed25519": 256, } default_comment = "%s@%s" % (getuser(), gethostname()) pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) try: pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) if pair.size != default_sizes[pair.key_type] or pair.comment != default_comment: result = False except Exception as e: print(e) result = False assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_user_key_params(keytype, size, passphrase, comment): result = True try: pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) if pair.key_type != keytype or pair.size != size or pair.comment != comment: result = False except Exception as e: print(e) result = False assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_user_key_params(keytype, size, passphrase, comment): result = False try: OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) except (InvalidCommentError, InvalidKeyTypeError, InvalidPassphraseError): result = True except Exception as e: print(e) pass assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_key_sizes(keytype, size, passphrase, comment): result = False try: OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) except InvalidKeySizeError: result = True except Exception as e: print(e) pass assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_comment_update(): pair = OpensshKeypair.generate() new_comment = "comment" try: pair.comment = new_comment except Exception as e: print(e) pass assert ( pair.comment == new_comment and pair.public_key.split(b" ", 2)[2].decode() == new_comment ) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_comment_update(): result = False pair = OpensshKeypair.generate() new_comment = [1, 2, 3] try: pair.comment = new_comment except InvalidCommentError: result = True assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_passphrase_update(): result = False passphrase = "change_me".encode("UTF-8") try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair1 = OpensshKeypair.generate() pair1.update_passphrase(passphrase) with open(keyfilename, "w+b") as keyfile: keyfile.write(pair1.private_key) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair1.public_key) pair2 = OpensshKeypair.load(path=keyfilename, passphrase=passphrase) if pair1 == pair2: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_passphrase_update(): result = False passphrase = [1, 2, 3] pair = OpensshKeypair.generate() try: pair.update_passphrase(passphrase) except InvalidPassphraseError: result = True assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_privatekey(): result = False try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair = OpensshKeypair.generate() with open(keyfilename, "w+b") as keyfile: keyfile.write(pair.private_key[1:]) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair.public_key) OpensshKeypair.load(path=keyfilename) except InvalidPrivateKeyFileError: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_mismatched_keypair(): result = False try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair1 = OpensshKeypair.generate() pair2 = OpensshKeypair.generate() with open(keyfilename, "w+b") as keyfile: keyfile.write(pair1.private_key) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair2.public_key) OpensshKeypair.load(path=keyfilename) except InvalidPublicKeyFileError: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_keypair_comparison(): assert OpensshKeypair.generate() != OpensshKeypair.generate() assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="dsa") assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="ed25519") assert OpensshKeypair.generate(keytype="ed25519") != OpensshKeypair.generate( keytype="ed25519" ) try: tmpdir = mkdtemp() keys = { "rsa": { "pair": OpensshKeypair.generate(), "filename": os.path.join(tmpdir, "id_rsa"), }, "dsa": { "pair": OpensshKeypair.generate( keytype="dsa", passphrase="change_me".encode("UTF-8") ), "filename": os.path.join(tmpdir, "id_dsa"), }, "ed25519": { "pair": OpensshKeypair.generate(keytype="ed25519"), "filename": os.path.join(tmpdir, "id_ed25519"), }, } for v in keys.values(): with open(v["filename"], "w+b") as keyfile: keyfile.write(v["pair"].private_key) with open(v["filename"] + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(v["pair"].public_key) assert keys["rsa"]["pair"] == OpensshKeypair.load(path=keys["rsa"]["filename"]) loaded_dsa_key = OpensshKeypair.load( path=keys["dsa"]["filename"], passphrase="change_me".encode("UTF-8") ) assert keys["dsa"]["pair"] == loaded_dsa_key loaded_dsa_key.update_passphrase("change_me_again".encode("UTF-8")) assert keys["dsa"]["pair"] != loaded_dsa_key loaded_dsa_key.update_passphrase("change_me".encode("UTF-8")) assert keys["dsa"]["pair"] == loaded_dsa_key loaded_dsa_key.comment = "comment" assert keys["dsa"]["pair"] != loaded_dsa_key assert keys["ed25519"]["pair"] == OpensshKeypair.load( path=keys["ed25519"]["filename"] ) finally: for v in keys.values(): if os.path.exists(v["filename"]): remove(v["filename"]) if os.path.exists(v["filename"] + ".pub"): remove(v["filename"] + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert OpensshKeypair.generate() != []