# Copyright (c) Ansible Project # GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) # SPDX-License-Identifier: GPL-3.0-or-later from __future__ import annotations import os.path import typing as t from getpass import getuser from os import remove, rmdir from socket import gethostname from tempfile import mkdtemp import pytest from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import ( HAS_OPENSSH_SUPPORT, InvalidCommentError, InvalidKeySizeError, InvalidKeyTypeError, InvalidPassphraseError, InvalidPrivateKeyFileError, InvalidPublicKeyFileError, OpensshKeypair, ) if t.TYPE_CHECKING: from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import ( KeyType, ) DEFAULT_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", None, None, None, ), ( "dsa", None, None, None, ), ( "ecdsa", None, None, None, ), ( "ed25519", None, None, None, ), ] VALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", 8192, "change_me".encode("UTF-8"), "comment", ), ( "dsa", 1024, "change_me".encode("UTF-8"), "comment", ), ( "ecdsa", 521, "change_me".encode("UTF-8"), "comment", ), ( "ed25519", 256, "change_me".encode("UTF-8"), "comment", ), ] INVALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "dne", # type: ignore None, None, None, ), ( "rsa", None, [1, 2, 3], # type: ignore "comment", ), ( "ecdsa", None, None, [1, 2, 3], # type: ignore ), ] INVALID_KEY_SIZES: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", 1023, None, None, ), ( "rsa", 16385, None, None, ), ( "dsa", 256, None, None, ), ( "ecdsa", 1024, None, None, ), ( "ed25519", 1024, None, None, ), ] @pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_default_key_params( keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None ) -> None: result = True default_sizes = { "rsa": 2048, "dsa": 1024, "ecdsa": 256, "ed25519": 256, } default_comment = f"{getuser()}@{gethostname()}" pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) try: pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) if pair.size != default_sizes[pair.key_type] or pair.comment != default_comment: result = False except Exception as e: print(e) result = False assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_user_key_params( keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None ) -> None: result = True try: pair = OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) if pair.key_type != keytype or pair.size != size or pair.comment != comment: result = False except Exception as e: print(e) result = False assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_user_key_params( keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None ) -> None: result = False try: OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) except (InvalidCommentError, InvalidKeyTypeError, InvalidPassphraseError): result = True except Exception as e: print(e) assert result @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_key_sizes( keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None ) -> None: result = False try: OpensshKeypair.generate( keytype=keytype, size=size, passphrase=passphrase, comment=comment ) except InvalidKeySizeError: result = True except Exception as e: print(e) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_comment_update() -> None: pair = OpensshKeypair.generate() new_comment = "comment" try: pair.comment = new_comment except Exception as e: print(e) assert ( pair.comment == new_comment and pair.public_key.split(b" ", 2)[2].decode() == new_comment ) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_comment_update() -> None: result = False pair = OpensshKeypair.generate() new_comment = [1, 2, 3] try: pair.comment = new_comment # type: ignore except InvalidCommentError: result = True assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_valid_passphrase_update() -> None: result = False passphrase = "change_me".encode("UTF-8") try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair1 = OpensshKeypair.generate() pair1.update_passphrase(passphrase) with open(keyfilename, "w+b") as keyfile: keyfile.write(pair1.private_key) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair1.public_key) pair2 = OpensshKeypair.load(path=keyfilename, passphrase=passphrase) if pair1 == pair2: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_passphrase_update() -> None: result = False passphrase = [1, 2, 3] pair = OpensshKeypair.generate() try: pair.update_passphrase(passphrase) # type: ignore except InvalidPassphraseError: result = True assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_invalid_privatekey() -> None: result = False try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair = OpensshKeypair.generate() with open(keyfilename, "w+b") as keyfile: keyfile.write(pair.private_key[1:]) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair.public_key) OpensshKeypair.load(path=keyfilename) except InvalidPrivateKeyFileError: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_mismatched_keypair() -> None: result = False try: tmpdir = mkdtemp() keyfilename = os.path.join(tmpdir, "id_rsa") pair1 = OpensshKeypair.generate() pair2 = OpensshKeypair.generate() with open(keyfilename, "w+b") as keyfile: keyfile.write(pair1.private_key) with open(keyfilename + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(pair2.public_key) OpensshKeypair.load(path=keyfilename) except InvalidPublicKeyFileError: result = True finally: if os.path.exists(keyfilename): remove(keyfilename) if os.path.exists(keyfilename + ".pub"): remove(keyfilename + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert result @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") def test_keypair_comparison() -> None: assert OpensshKeypair.generate() != OpensshKeypair.generate() assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="dsa") assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="ed25519") assert OpensshKeypair.generate(keytype="ed25519") != OpensshKeypair.generate( keytype="ed25519" ) try: tmpdir = mkdtemp() keys: dict[str, dict[str, t.Any]] = { "rsa": { "pair": OpensshKeypair.generate(), "filename": os.path.join(tmpdir, "id_rsa"), }, "dsa": { "pair": OpensshKeypair.generate( keytype="dsa", passphrase="change_me".encode("UTF-8") ), "filename": os.path.join(tmpdir, "id_dsa"), }, "ed25519": { "pair": OpensshKeypair.generate(keytype="ed25519"), "filename": os.path.join(tmpdir, "id_ed25519"), }, } for v in keys.values(): with open(v["filename"], "w+b") as keyfile: keyfile.write(v["pair"].private_key) with open(v["filename"] + ".pub", "w+b") as pubkeyfile: pubkeyfile.write(v["pair"].public_key) assert keys["rsa"]["pair"] == OpensshKeypair.load(path=keys["rsa"]["filename"]) loaded_dsa_key = OpensshKeypair.load( path=keys["dsa"]["filename"], passphrase="change_me".encode("UTF-8") ) assert keys["dsa"]["pair"] == loaded_dsa_key loaded_dsa_key.update_passphrase("change_me_again".encode("UTF-8")) assert keys["dsa"]["pair"] != loaded_dsa_key loaded_dsa_key.update_passphrase("change_me".encode("UTF-8")) assert keys["dsa"]["pair"] == loaded_dsa_key loaded_dsa_key.comment = "comment" assert keys["dsa"]["pair"] != loaded_dsa_key assert keys["ed25519"]["pair"] == OpensshKeypair.load( path=keys["ed25519"]["filename"] ) finally: for v in keys.values(): if os.path.exists(v["filename"]): remove(v["filename"]) if os.path.exists(v["filename"] + ".pub"): remove(v["filename"] + ".pub") if os.path.exists(tmpdir): rmdir(tmpdir) assert OpensshKeypair.generate() != []