Add type hints and type checking (#885)

* Enable basic type checking.

* Fix first errors.

* Add changelog fragment.

* Add types to module_utils and plugin_utils (without module backends).

* Add typing hints for acme_* modules.

* Add typing to X.509 certificate modules, and add more helpers.

* Add typing to remaining module backends.

* Add typing for action, filter, and lookup plugins.

* Bump ansible-core 2.19 beta requirement for typing.

* Add more typing definitions.

* Add typing to some unit tests.
This commit is contained in:
Felix Fontein
2025-05-11 18:00:11 +02:00
committed by GitHub
parent 82f0176773
commit f758d94fba
124 changed files with 4986 additions and 2662 deletions

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import base64
import datetime
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
@@ -19,12 +20,20 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
from ..test_time import TIMEZONES, cartesian_product
def load_fixture(name):
with open(os.path.join(os.path.dirname(__file__), "fixtures", name)) as f:
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
Criterium,
)
def load_fixture(name: str) -> str:
with open(
os.path.join(os.path.dirname(__file__), "fixtures", name), encoding="utf-8"
) as f:
return f.read()
TEST_PEM_DERS = [
TEST_PEM_DERS: list[tuple[str, bytes]] = [
(
load_fixture("privatekey_1.pem"),
base64.b64decode(
@@ -36,7 +45,7 @@ TEST_PEM_DERS = [
]
TEST_KEYS = [
TEST_KEYS: list[tuple[str, dict[str, t.Any], str]] = [
(
load_fixture("privatekey_1.pem"),
{
@@ -56,7 +65,7 @@ TEST_KEYS = [
]
TEST_CSRS = [
TEST_CSRS: list[tuple[str, set[tuple[str, str]], str]] = [
(
load_fixture("csr_1.pem"),
set([("dns", "ansible.com"), ("dns", "example.com"), ("dns", "example.org")]),
@@ -87,17 +96,19 @@ TEST_CERT_OPENSSL_OUTPUT_2 = load_fixture("cert_2.txt") # OpenSSL 3.3.0 output
TEST_CERT_OPENSSL_OUTPUT_2B = load_fixture("cert_2-b.txt") # OpenSSL 1.1.1f output
TEST_CERT_DAYS = cartesian_product(
TIMEZONES,
[
(datetime.datetime(2018, 11, 15, 1, 2, 3), 11),
(datetime.datetime(2018, 11, 25, 15, 20, 0), 1),
(datetime.datetime(2018, 11, 25, 15, 30, 0), 0),
],
TEST_CERT_DAYS: list[tuple[datetime.timedelta, datetime.datetime, int]] = (
cartesian_product(
TIMEZONES,
[
(datetime.datetime(2018, 11, 15, 1, 2, 3), 11),
(datetime.datetime(2018, 11, 25, 15, 20, 0), 1),
(datetime.datetime(2018, 11, 25, 15, 30, 0), 0),
],
)
)
TEST_CERT_INFO = CertificateInformation(
TEST_CERT_INFO_1 = CertificateInformation(
not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24),
not_valid_before=datetime.datetime(2018, 11, 25, 15, 28, 23),
serial_number=1,
@@ -115,65 +126,69 @@ TEST_CERT_INFO_2 = CertificateInformation(
)
TEST_CERT_INFO = [
(TEST_CERT, TEST_CERT_INFO, TEST_CERT_OPENSSL_OUTPUT),
TEST_CERT_INFO: list[tuple[str, CertificateInformation, str]] = [
(TEST_CERT, TEST_CERT_INFO_1, TEST_CERT_OPENSSL_OUTPUT),
(TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2),
(TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2B),
]
TEST_PARSE_ACME_TIMESTAMP = cartesian_product(
TIMEZONES,
[
(
"2024-01-01T00:11:22Z",
dict(year=2024, month=1, day=1, hour=0, minute=11, second=22),
),
(
"2024-01-01T00:11:22.123Z",
dict(
year=2024,
month=1,
day=1,
hour=0,
minute=11,
second=22,
microsecond=123000,
TEST_PARSE_ACME_TIMESTAMP: list[tuple[datetime.timedelta, str, dict[str, int]]] = (
cartesian_product(
TIMEZONES,
[
(
"2024-01-01T00:11:22Z",
dict(year=2024, month=1, day=1, hour=0, minute=11, second=22),
),
),
(
"2024-04-17T06:54:13.333333334Z",
dict(
year=2024,
month=4,
day=17,
hour=6,
minute=54,
second=13,
microsecond=333333,
(
"2024-01-01T00:11:22.123Z",
dict(
year=2024,
month=1,
day=1,
hour=0,
minute=11,
second=22,
microsecond=123000,
),
),
),
(
"2024-01-01T00:11:22+0100",
dict(year=2023, month=12, day=31, hour=23, minute=11, second=22),
),
(
"2024-01-01T00:11:22.123+0100",
dict(
year=2023,
month=12,
day=31,
hour=23,
minute=11,
second=22,
microsecond=123000,
(
"2024-04-17T06:54:13.333333334Z",
dict(
year=2024,
month=4,
day=17,
hour=6,
minute=54,
second=13,
microsecond=333333,
),
),
),
],
(
"2024-01-01T00:11:22+0100",
dict(year=2023, month=12, day=31, hour=23, minute=11, second=22),
),
(
"2024-01-01T00:11:22.123+0100",
dict(
year=2023,
month=12,
day=31,
hour=23,
minute=11,
second=22,
microsecond=123000,
),
),
],
)
)
TEST_INTERPOLATE_TIMESTAMP = cartesian_product(
TEST_INTERPOLATE_TIMESTAMP: list[
tuple[datetime.timedelta, dict[str, int], dict[str, int], float, dict[str, int]]
] = cartesian_product(
TIMEZONES,
[
(
@@ -199,26 +214,50 @@ TEST_INTERPOLATE_TIMESTAMP = cartesian_product(
class FakeBackend(CryptoBackend):
def parse_key(self, key_file=None, key_content=None, passphrase=None):
def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase=None,
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def sign(self, payload64, protected64, key_data):
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any] | None
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def create_mac_key(self, alg, key):
def create_mac_key(self, alg: str, key: str) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def create_chain_matcher(self, criterium):
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")
def get_cert_information(self, cert_filename=None, cert_content=None):
def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> t.NoReturn:
raise BackendException("Not implemented in fake backend")

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import datetime
import typing as t
from unittest.mock import (
MagicMock,
)
@@ -35,12 +36,20 @@ from .backend_data import (
)
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
)
if not HAS_CURRENT_CRYPTOGRAPHY:
pytest.skip("cryptography not found")
@pytest.mark.parametrize("pem, result, dummy", TEST_KEYS)
def test_eckeyparse_cryptography(pem, result, dummy, tmpdir):
def test_eckeyparse_cryptography(
pem: str, result: dict[str, t.Any], dummy: str, tmpdir
) -> None:
fn = tmpdir / "test.pem"
fn.write(pem)
module = MagicMock()
@@ -54,7 +63,9 @@ def test_eckeyparse_cryptography(pem, result, dummy, tmpdir):
@pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS)
def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir):
def test_csridentifiers_cryptography(
csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir
) -> None:
fn = tmpdir / "test.csr"
fn.write(csr)
module = MagicMock()
@@ -66,7 +77,9 @@ def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir):
@pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS)
def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
def test_certdays_cryptography(
timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
fn = tmpdir / "test-cert.pem"
fn.write(TEST_CERT)
@@ -81,7 +94,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
@pytest.mark.parametrize(
"cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO
)
def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir):
def test_get_cert_information(
cert_content: str,
expected_cert_info: CertificateInformation,
openssl_output: str,
tmpdir,
) -> None:
fn = tmpdir / "test-cert.pem"
fn.write(cert_content)
module = MagicMock()
@@ -105,7 +123,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output,
@pytest.mark.parametrize(
"timezone", [datetime.timedelta(hours=0)] if CRYPTOGRAPHY_TIMEZONE else TIMEZONES
)
def test_now(timezone):
def test_now(timezone: datetime.timedelta) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
module = MagicMock()
backend = CryptographyBackend(module)
@@ -119,7 +137,9 @@ def test_now(timezone):
@pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP)
def test_parse_acme_timestamp(timezone, input, expected):
def test_parse_acme_timestamp(
timezone: datetime.timedelta, input: str, expected: dict[str, int]
) -> None:
with freeze_time("2024-02-03 04:05:06 +00:00", tz_offset=timezone):
module = MagicMock()
backend = CryptographyBackend(module)
@@ -131,7 +151,13 @@ def test_parse_acme_timestamp(timezone, input, expected):
@pytest.mark.parametrize(
"timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP
)
def test_interpolate_timestamp(timezone, start, end, percentage, expected):
def test_interpolate_timestamp(
timezone: datetime.timedelta,
start: dict[str, int],
end: dict[str, int],
percentage: float,
expected: dict[str, int],
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
module = MagicMock()
backend = CryptographyBackend(module)

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import datetime
import typing as t
from unittest.mock import (
MagicMock,
)
@@ -31,6 +32,12 @@ from .backend_data import (
)
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
)
# from ..test_time import TIMEZONES
@@ -47,7 +54,9 @@ TEST_IPS = [
@pytest.mark.parametrize("pem, result, openssl_output", TEST_KEYS)
def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir):
def test_eckeyparse_openssl(
pem: str, result: dict[str, t.Any], openssl_output: str, tmpdir
) -> None:
fn = tmpdir / "test.key"
fn.write(pem)
module = MagicMock()
@@ -59,7 +68,9 @@ def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir):
@pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS)
def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir):
def test_csridentifiers_openssl(
csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir
) -> None:
fn = tmpdir / "test.csr"
fn.write(csr)
module = MagicMock()
@@ -70,14 +81,16 @@ def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir):
@pytest.mark.parametrize("ip, result", TEST_IPS)
def test_normalize_ip(ip, result):
def test_normalize_ip(ip: str, result: str) -> None:
module = MagicMock()
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
assert backend._normalize_ip(ip) == result
@pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS)
def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
def test_certdays_cryptography(
timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
fn = tmpdir / "test-cert.pem"
fn.write(TEST_CERT)
@@ -93,7 +106,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
@pytest.mark.parametrize(
"cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO
)
def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir):
def test_get_cert_information(
cert_content: str,
expected_cert_info: CertificateInformation,
openssl_output: str,
tmpdir,
) -> None:
fn = tmpdir / "test-cert.pem"
fn.write(cert_content)
module = MagicMock()
@@ -115,7 +133,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output,
# Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553)
# this only works with timezone = UTC if CRYPTOGRAPHY_TIMEZONE is truish
@pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)])
def test_now(timezone):
def test_now(timezone: datetime.timedelta) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
module = MagicMock()
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
@@ -125,7 +143,9 @@ def test_now(timezone):
@pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP)
def test_parse_acme_timestamp(timezone, input, expected):
def test_parse_acme_timestamp(
timezone: datetime.timedelta, input: str, expected: dict[str, int]
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
module = MagicMock()
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
@@ -137,7 +157,13 @@ def test_parse_acme_timestamp(timezone, input, expected):
@pytest.mark.parametrize(
"timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP
)
def test_interpolate_timestamp(timezone, start, end, percentage, expected):
def test_interpolate_timestamp(
timezone: datetime.timedelta,
start: dict[str, int],
end: dict[str, int],
percentage: float,
expected: dict[str, int],
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
module = MagicMock()
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import typing as t
from unittest.mock import (
MagicMock,
)
@@ -21,21 +22,21 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def test_combine_identifier():
def test_combine_identifier() -> None:
assert combine_identifier("", "") == ":"
assert combine_identifier("a", "b") == "a:b"
def test_split_identifier():
assert split_identifier(":") == ["", ""]
assert split_identifier("a:b") == ["a", "b"]
assert split_identifier("a:b:c") == ["a", "b:c"]
def test_split_identifier() -> None:
assert split_identifier(":") == ("", "")
assert split_identifier("a:b") == ("a", "b")
assert split_identifier("a:b:c") == ("a", "b:c")
with pytest.raises(ModuleFailException) as exc:
split_identifier("a")
assert exc.value.msg == 'Identifier "a" is not of the form <type>:<identifier>'
def test_challenge_from_to_json():
def test_challenge_from_to_json() -> None:
client = MagicMock()
data = {
@@ -57,7 +58,7 @@ def test_challenge_from_to_json():
"status": "valid",
"token": "foo",
}
challenge = Challenge.from_json(None, data, url="xxx")
challenge = Challenge.from_json(None, data, url="xxx") # type: ignore
assert challenge.data == data
assert challenge.type == "type"
assert challenge.url == "xxx"
@@ -66,10 +67,12 @@ def test_challenge_from_to_json():
assert challenge.to_json() == data
def test_authorization_from_to_json():
def test_authorization_from_to_json() -> None:
client = MagicMock()
client.version = 2
data: dict[str, t.Any]
data = {
"challenges": [],
"status": "valid",
@@ -138,7 +141,7 @@ def test_authorization_from_to_json():
}
def test_authorization_create_error():
def test_authorization_create_error() -> None:
client = MagicMock()
client.version = 2
client.directory.directory = {}
@@ -148,7 +151,7 @@ def test_authorization_create_error():
assert exc.value.msg == "ACME endpoint does not support pre-authorization."
def test_wait_for_validation_error():
def test_wait_for_validation_error() -> None:
client = MagicMock()
client.version = 2
data = {

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import typing as t
from unittest.mock import (
MagicMock,
)
@@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
TEST_FORMAT_ERROR_PROBLEM = [
TEST_FORMAT_ERROR_PROBLEM: list[tuple[dict[str, t.Any], str, str]] = [
(
{
"type": "foo",
@@ -90,33 +91,37 @@ TEST_FORMAT_ERROR_PROBLEM = [
@pytest.mark.parametrize(
"problem, subproblem_prefix, result", TEST_FORMAT_ERROR_PROBLEM
)
def test_format_error_problem(problem, subproblem_prefix, result):
def test_format_error_problem(
problem: dict[str, t.Any], subproblem_prefix: str, result: str
) -> None:
res = format_error_problem(problem, subproblem_prefix)
assert res == result
def create_regular_response(response_text):
def create_regular_response(response_text: str) -> MagicMock:
response = MagicMock()
response.read = MagicMock(return_value=response_text.encode("utf-8"))
response.closed = False
return response
def create_error_response():
def create_error_response() -> MagicMock:
response = MagicMock()
response.read = MagicMock(side_effect=AttributeError("read"))
response.closed = True
return response
def create_decode_error(msg):
def f(content):
def create_decode_error(msg: str) -> t.Callable[[t.Any], t.Any]:
def f(content: t.Any) -> t.NoReturn:
raise Exception(msg)
return f
TEST_ACME_PROTOCOL_EXCEPTION = [
TEST_ACME_PROTOCOL_EXCEPTION: list[
tuple[dict[str, t.Any], t.Callable[[t.Any], t.Any] | None, str, dict[str, t.Any]]
] = [
(
{},
None,
@@ -341,14 +346,19 @@ TEST_ACME_PROTOCOL_EXCEPTION = [
@pytest.mark.parametrize("input, from_json, msg, args", TEST_ACME_PROTOCOL_EXCEPTION)
def test_acme_protocol_exception(input, from_json, msg, args):
def test_acme_protocol_exception(
input: dict[str, t.Any],
from_json: t.Callable[[t.Any], t.NoReturn] | None,
msg: str,
args: dict[str, t.Any],
) -> None:
if from_json is None:
module = None
else:
module = MagicMock()
module.from_json = from_json
with pytest.raises(ACMEProtocolException) as exc:
raise ACMEProtocolException(module, **input)
raise ACMEProtocolException(module, **input) # type: ignore
print(exc.value.msg)
print(exc.value.module_fail_args)

View File

@@ -18,14 +18,13 @@ TEST_TEXT = r"""1234
5678"""
def test_read_file(tmpdir):
def test_read_file(tmpdir) -> None:
fn = tmpdir / "test.txt"
fn.write(TEST_TEXT)
assert read_file(str(fn), "t") == TEST_TEXT
assert read_file(str(fn), "b") == TEST_TEXT.encode("utf-8")
assert read_file(str(fn)) == TEST_TEXT.encode("utf-8")
def test_write_file(tmpdir):
def test_write_file(tmpdir) -> None:
fn = tmpdir / "test.txt"
module = MagicMock()
write_file(module, str(fn), TEST_TEXT.encode("utf-8"))

View File

@@ -15,7 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
def test_order_from_json():
def test_order_from_json() -> None:
client = MagicMock()
data = {
@@ -35,7 +35,7 @@ def test_order_from_json():
assert order.authorizations == {}
def test_wait_for_finalization_error():
def test_wait_for_finalization_error() -> None:
client = MagicMock()
client.version = 2

View File

@@ -5,10 +5,12 @@
from __future__ import annotations
import datetime
import typing as t
import pytest
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
CryptoBackend,
)
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
compute_cert_id,
@@ -21,7 +23,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
from .backend_data import TEST_PEM_DERS
NOPAD_B64 = [
NOPAD_B64: list[tuple[str, str]] = [
("", ""),
("\n", "Cg"),
("123", "MTIz"),
@@ -29,7 +31,7 @@ NOPAD_B64 = [
]
TEST_LINKS_HEADER = [
TEST_LINKS_HEADER: list[tuple[dict[str, t.Any], list[tuple[str, str]]]] = [
(
{},
[],
@@ -60,13 +62,13 @@ TEST_LINKS_HEADER = [
]
TEST_RETRY_AFTER_HEADER = [
TEST_RETRY_AFTER_HEADER: list[tuple[str, datetime.datetime]] = [
("120", datetime.datetime(2024, 4, 29, 0, 2, 0)),
("Wed, 21 Oct 2015 07:28:00 GMT", datetime.datetime(2015, 10, 21, 7, 28, 0)),
]
TEST_COMPUTE_CERT_ID = [
TEST_COMPUTE_CERT_ID: list[tuple[CertificateInformation, str]] = [
(
CertificateInformation(
not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24),
@@ -93,19 +95,21 @@ TEST_COMPUTE_CERT_ID = [
@pytest.mark.parametrize("value, result", NOPAD_B64)
def test_nopad_b64(value, result):
def test_nopad_b64(value: str, result: str) -> None:
assert nopad_b64(value.encode("utf-8")) == result
@pytest.mark.parametrize("pem, der", TEST_PEM_DERS)
def test_pem_to_der(pem, der, tmpdir):
def test_pem_to_der(pem: str, der: bytes, tmpdir):
fn = tmpdir / "test.pem"
fn.write(pem)
assert pem_to_der(str(fn)) == der
@pytest.mark.parametrize("value, expected_result", TEST_LINKS_HEADER)
def test_process_links(value, expected_result):
def test_process_links(
value: dict[str, t.Any], expected_result: list[tuple[str, str]]
) -> None:
data = []
def callback(url, rel):
@@ -117,12 +121,15 @@ def test_process_links(value, expected_result):
@pytest.mark.parametrize("value, expected_result", TEST_RETRY_AFTER_HEADER)
def test_parse_retry_after(value, expected_result):
def test_parse_retry_after(value: str, expected_result: datetime.datetime) -> None:
assert expected_result == parse_retry_after(
value, now=datetime.datetime(2024, 4, 29, 0, 0, 0)
)
@pytest.mark.parametrize("cert_info, expected_result", TEST_COMPUTE_CERT_ID)
def test_compute_cert_id(cert_info, expected_result):
assert expected_result == compute_cert_id(backend=None, cert_info=cert_info)
def test_compute_cert_id(
cert_info: CertificateInformation, expected_result: str
) -> None:
backend: CryptoBackend = None # type: ignore
assert expected_result == compute_cert_id(backend=backend, cert_info=cert_info)

View File

@@ -10,12 +10,11 @@ import subprocess
import pytest
from ansible_collections.community.crypto.plugins.module_utils.crypto._asn1 import (
pack_asn1,
serialize_asn1_string_as_der,
)
TEST_CASES = [
TEST_CASES: list[tuple[str, bytes]] = [
("UTF8:Hello World", b"\x0c\x0b\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64"),
(
"EXPLICIT:10,UTF8:Hello World",
@@ -76,7 +75,7 @@ TEST_CASES = [
@pytest.mark.parametrize("value, expected", TEST_CASES)
def test_serialize_asn1_string_as_der(value, expected):
def test_serialize_asn1_string_as_der(value: str, expected: bytes) -> None:
actual = serialize_asn1_string_as_der(value)
print(f"{value} | {base64.b16encode(actual).decode()}")
assert actual == expected
@@ -89,7 +88,7 @@ def test_serialize_asn1_string_as_der(value, expected):
"EXPLICIT,UTF:value",
],
)
def test_serialize_asn1_string_as_der_invalid_format(value):
def test_serialize_asn1_string_as_der_invalid_format(value: str) -> None:
expected = (
"The ASN.1 serialized string must be in the format [modifier,]type[:value]"
)
@@ -97,20 +96,15 @@ def test_serialize_asn1_string_as_der_invalid_format(value):
serialize_asn1_string_as_der(value)
def test_serialize_asn1_string_as_der_invalid_type():
def test_serialize_asn1_string_as_der_invalid_type() -> None:
expected = 'The ASN.1 serialized string is not a known type "OID", only UTF8 types are supported'
with pytest.raises(ValueError, match=re.escape(expected)):
serialize_asn1_string_as_der("OID:1.2.3.4")
def test_pack_asn_invalid_class():
with pytest.raises(ValueError, match="tag_class must be between 0 and 3 not 4"):
pack_asn1(4, True, 0, b"")
@pytest.mark.skip() # This is to just to build the test case assertions and shouldn't run normally.
@pytest.mark.parametrize("value, expected", TEST_CASES)
def test_test_cases(value, expected, tmp_path):
def test_test_cases(value: str, expected: bytes, tmp_path) -> None:
test_file = tmp_path / "test.der"
subprocess.run(
["openssl", "asn1parse", "-genstr", value, "-noout", "-out", test_file],

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import re
import typing as t
import cryptography
import pytest
@@ -20,7 +21,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
from ansible_collections.community.crypto.plugins.module_utils.version import (
LooseVersion,
)
from cryptography.x509 import NameAttribute, oid
from cryptography.x509 import NameAttribute, OtherName, oid
@pytest.mark.parametrize(
@@ -35,7 +36,7 @@ from cryptography.x509 import NameAttribute, oid
("*.☺.", "*.xn--74h.", None),
],
)
def test_adjust_idn(unicode, idna, cycled_unicode):
def test_adjust_idn(unicode: str, idna: str, cycled_unicode: str | None) -> None:
if cycled_unicode is None:
cycled_unicode = unicode
@@ -70,9 +71,10 @@ def test_adjust_idn(unicode, idna, cycled_unicode):
("bar", "foo", re.escape('Invalid value for idn_rewrite: "foo"')),
],
)
def test_adjust_idn_fail_valueerror(value, idn_rewrite, message):
def test_adjust_idn_fail_valueerror(value: str, idn_rewrite: str, message: str) -> None:
with pytest.raises(ValueError, match=message):
_adjust_idn(value, idn_rewrite)
idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore
_adjust_idn(value, idn_rewrite_)
@pytest.mark.parametrize(
@@ -88,27 +90,29 @@ def test_adjust_idn_fail_valueerror(value, idn_rewrite, message):
),
],
)
def test_adjust_idn_fail_user_error(value, idn_rewrite, message):
def test_adjust_idn_fail_user_error(value: str, idn_rewrite: str, message: str) -> None:
with pytest.raises(OpenSSLObjectError, match=message):
_adjust_idn(value, idn_rewrite)
idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore
_adjust_idn(value, idn_rewrite_)
def test_cryptography_get_name_invalid_prefix():
def test_cryptography_get_name_invalid_prefix() -> None:
with pytest.raises(
OpenSSLObjectError, match="^Cannot parse Subject Alternative Name"
):
cryptography_get_name("fake:value")
def test_cryptography_get_name_other_name_no_oid():
def test_cryptography_get_name_other_name_no_oid() -> None:
with pytest.raises(
OpenSSLObjectError, match="Cannot parse Subject Alternative Name otherName"
):
cryptography_get_name("otherName:value")
def test_cryptography_get_name_other_name_utfstring():
def test_cryptography_get_name_other_name_utfstring() -> None:
actual = cryptography_get_name("otherName:1.3.6.1.4.1.311.20.2.3;UTF8:Hello World")
assert isinstance(actual, OtherName)
assert actual.type_id.dotted_string == "1.3.6.1.4.1.311.20.2.3"
assert actual.value == b"\x0c\x0bHello World"
@@ -164,7 +168,9 @@ def test_cryptography_get_name_other_name_utfstring():
),
],
)
def test_parse_dn_component(name, options, expected):
def test_parse_dn_component(
name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes]
) -> None:
result = _parse_dn_component(name, **options)
print(result, expected)
assert result == expected
@@ -186,7 +192,9 @@ if (
(b"CN= ", {}, (NameAttribute(oid.NameOID.COMMON_NAME, ""), b"")),
],
)
def test_parse_dn_component_not_py26(name, options, expected):
def test_parse_dn_component_not_py26(
name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes]
) -> None:
result = _parse_dn_component(name, **options)
print(result, expected)
assert result == expected
@@ -200,7 +208,9 @@ if (
(b"CN=#0,", {}, 'Invalid hex sequence entry "0,"'),
],
)
def test_parse_dn_component_failure(name, options, message):
def test_parse_dn_component_failure(
name: bytes, options: dict[str, t.Any], message: str
) -> None:
with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"):
_parse_dn_component(name, **options)
@@ -225,7 +235,7 @@ def test_parse_dn_component_failure(name, options, message):
),
],
)
def test_parse_dn(name, expected):
def test_parse_dn(name: bytes, expected: list[NameAttribute]) -> None:
result = _parse_dn(name)
print(result, expected)
assert result == expected
@@ -236,14 +246,14 @@ def test_parse_dn(name, expected):
[
(
b"CN=\\0",
'Error while parsing distinguished name "CN=\\0": Hex escape sequence "\\0" incomplete at end of string',
"Error while parsing distinguished name 'CN=\\\\0': Hex escape sequence \"\\0\" incomplete at end of string",
),
(
b"CN=x,",
'Error while parsing distinguished name "CN=x,": unexpected end of string',
"Error while parsing distinguished name 'CN=x,': unexpected end of string",
),
],
)
def test_parse_dn_failure(name, message):
def test_parse_dn_failure(name: bytes, message: str):
with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"):
_parse_dn(name)

View File

@@ -26,7 +26,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
(2, 10, 5, 4),
],
)
def test_binary_exp_mod(f, e, m, result):
def test_binary_exp_mod(f: int, e: int, m: int, result: int) -> None:
value = binary_exp_mod(f, e, m)
print(value)
assert value == result
@@ -46,7 +46,7 @@ def test_binary_exp_mod(f, e, m, result):
(1024, 10, 2),
],
)
def test_simple_gcd(a, b, result):
def test_simple_gcd(a: int, b: int, result: int) -> None:
value = simple_gcd(a, b)
print(value)
assert value == result
@@ -70,7 +70,7 @@ def test_simple_gcd(a, b, result):
(211, False), # the smallest prime number >= 200
],
)
def test_quick_is_not_prime(n, result):
def test_quick_is_not_prime(n: int, result: bool) -> None:
value = quick_is_not_prime(n)
print(value)
assert value == result
@@ -88,7 +88,7 @@ def test_quick_is_not_prime(n, result):
(256, None, b"\x01\x00"),
],
)
def test_convert_int_to_bytes(no, count, result):
def test_convert_int_to_bytes(no: int, count: int | None, result: bytes) -> None:
value = convert_int_to_bytes(no, count=count)
print(value)
assert value == result
@@ -108,7 +108,7 @@ def test_convert_int_to_bytes(no, count, result):
(256, 4, "0100"),
],
)
def test_convert_int_to_hex(no, digits, result):
def test_convert_int_to_hex(no: int, digits: int | None, result: str) -> None:
value = convert_int_to_hex(no, digits=digits)
print(value)
assert value == result
@@ -125,7 +125,7 @@ def test_convert_int_to_hex(no, digits, result):
(b"\x01\x00", 256),
],
)
def test_convert_bytes_to_int(data, result):
def test_convert_bytes_to_int(data: bytes, result: int) -> None:
value = convert_bytes_to_int(data)
print(value)
assert value == result

View File

@@ -4,6 +4,8 @@
from __future__ import annotations
import typing as t
import pytest
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
extract_first_pem,
@@ -13,7 +15,9 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
)
PEM_TEST_CASES = [
PEM_TEST_CASES: list[
tuple[bytes, list[str], bool, t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]]
] = [
(b"", [], False, "raw"),
(b"random stuff\nblabla", [], False, "raw"),
(b"-----BEGIN PRIVATE KEY-----", [], False, "raw"),
@@ -51,7 +55,12 @@ PEM_TEST_CASES = [
@pytest.mark.parametrize("data, pems, is_pem, private_key_type", PEM_TEST_CASES)
def test_pem_handling(data, pems, is_pem, private_key_type):
def test_pem_handling(
data: bytes,
pems: list[str],
is_pem: bool,
private_key_type: t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"],
):
assert identify_pem_format(data) == is_pem
assert identify_private_key_format(data) == private_key_type
try:

View File

@@ -132,7 +132,9 @@ VALID_EXTENSIONS = [
]
INVALID_EXTENSIONS = [OpensshCertificateOption("extension", "test", "")]
VALID_TIME_PARAMETERS = [
VALID_TIME_PARAMETERS: list[
tuple[int | str, int | str, str, int, int | str, str, str, int, str]
] = [
(
0,
"always",
@@ -223,28 +225,28 @@ VALID_TIME_PARAMETERS = [
),
]
INVALID_TIME_PARAMETERS = [
INVALID_TIME_PARAMETERS: list[tuple[int | str, int | str]] = [
(-1, 0xFFFFFFFFFFFFFFFFFF),
("never", "ever"),
("01-01-1980", "01-01-1990"),
(1, 0),
]
VALID_VALIDITY_TEST = [
VALID_VALIDITY_TEST: list[tuple[str, str, str]] = [
("always", "forever", "2000-01-01"),
("1999-12-31", "2000-01-02", "2000-01-01"),
("1999-12-31 23:59:00", "2000-01-01 00:01:00", "2000-01-01 00:00:00"),
("1999-12-31 23:59:59", "2000-01-01 00:00:01", "2000-01-01 00:00:00"),
]
INVALID_VALIDITY_TEST = [
INVALID_VALIDITY_TEST: list[tuple[str, str, str]] = [
("always", "forever", "1969-12-31"),
("always", "2000-01-01", "2000-01-02"),
("2000-01-01", "forever", "1999-12-31"),
("2000-01-01 00:00:00", "2000-01-01 00:00:01", "2000-01-01 00:00:02"),
]
VALID_OPTIONS = [
VALID_OPTIONS: list[tuple[str, OpensshCertificateOption]] = [
(
"force-command=/usr/bin/csh",
OpensshCertificateOption("critical", "force-command", "/usr/bin/csh"),
@@ -265,7 +267,7 @@ VALID_OPTIONS = [
("extension:foo", OpensshCertificateOption("extension", "foo", "")),
]
INVALID_OPTIONS = [
INVALID_OPTIONS: list[str | list] = [
"foobar",
"foo=bar",
"foo:bar=baz",
@@ -273,7 +275,7 @@ INVALID_OPTIONS = [
]
def test_rsa_certificate(tmpdir):
def test_rsa_certificate(tmpdir) -> None:
cert_file = tmpdir / "id_rsa-cert.pub"
cert_file.write(RSA_CERT_SIGNED_BY_DSA, mode="wb")
@@ -285,7 +287,7 @@ def test_rsa_certificate(tmpdir):
assert cert.signing_key == DSA_FINGERPRINT
def test_dsa_certificate(tmpdir):
def test_dsa_certificate(tmpdir) -> None:
cert_file = tmpdir / "id_dsa-cert.pub"
cert_file.write(DSA_CERT_SIGNED_BY_ECDSA_NO_OPTS)
@@ -298,7 +300,7 @@ def test_dsa_certificate(tmpdir):
assert cert.extensions == []
def test_ecdsa_certificate(tmpdir):
def test_ecdsa_certificate(tmpdir) -> None:
cert_file = tmpdir / "id_ecdsa-cert.pub"
cert_file.write(ECDSA_CERT_SIGNED_BY_ED25519_VALID_OPTS)
@@ -310,7 +312,7 @@ def test_ecdsa_certificate(tmpdir):
assert cert.extensions == VALID_EXTENSIONS
def test_ed25519_certificate(tmpdir):
def test_ed25519_certificate(tmpdir) -> None:
cert_file = tmpdir / "id_ed25519-cert.pub"
cert_file.write(ED25519_CERT_SIGNED_BY_RSA_INVALID_OPTS)
@@ -322,7 +324,7 @@ def test_ed25519_certificate(tmpdir):
assert cert.extensions == INVALID_EXTENSIONS
def test_invalid_data(tmpdir):
def test_invalid_data(tmpdir) -> None:
result = False
cert_file = tmpdir / "invalid-cert.pub"
cert_file.write(INVALID_DATA)
@@ -341,16 +343,16 @@ def test_invalid_data(tmpdir):
VALID_TIME_PARAMETERS,
)
def test_valid_time_parameters(
valid_from,
valid_from_hr,
valid_from_openssh,
valid_from_timestamp,
valid_to,
valid_to_hr,
valid_to_openssh,
valid_to_timestamp,
validity_string,
):
valid_from: int | str,
valid_from_hr: int | str,
valid_from_openssh: str,
valid_from_timestamp: int,
valid_to: int | str,
valid_to_hr: str,
valid_to_openssh: str,
valid_to_timestamp: int,
validity_string: str,
) -> None:
time_parameters = OpensshCertificateTimeParameters(
valid_from=valid_from, valid_to=valid_to
)
@@ -364,35 +366,37 @@ def test_valid_time_parameters(
@pytest.mark.parametrize("valid_from,valid_to", INVALID_TIME_PARAMETERS)
def test_invalid_time_parameters(valid_from, valid_to):
def test_invalid_time_parameters(valid_from: int | str, valid_to: int | str) -> None:
with pytest.raises(ValueError):
OpensshCertificateTimeParameters(valid_from, valid_to)
@pytest.mark.parametrize("valid_from,valid_to,valid_at", VALID_VALIDITY_TEST)
def test_valid_validity_test(valid_from, valid_to, valid_at):
def test_valid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None:
assert OpensshCertificateTimeParameters(valid_from, valid_to).within_range(valid_at)
@pytest.mark.parametrize("valid_from,valid_to,valid_at", INVALID_VALIDITY_TEST)
def test_invalid_validity_test(valid_from, valid_to, valid_at):
def test_invalid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None:
assert not OpensshCertificateTimeParameters(valid_from, valid_to).within_range(
valid_at
)
@pytest.mark.parametrize("option_string,option_object", VALID_OPTIONS)
def test_valid_options(option_string, option_object):
def test_valid_options(
option_string: str, option_object: OpensshCertificateOption
) -> None:
assert OpensshCertificateOption.from_string(option_string) == option_object
@pytest.mark.parametrize("option_string", INVALID_OPTIONS)
def test_invalid_options(option_string):
def test_invalid_options(option_string: str) -> None:
with pytest.raises(ValueError):
OpensshCertificateOption.from_string(option_string)
def test_parse_option_list():
def test_parse_option_list() -> None:
critical_options, extensions = parse_option_list(["force-command=/usr/bin/csh"])
critical_option_objects = [
@@ -411,7 +415,7 @@ def test_parse_option_list():
assert set(extensions) == set(extension_objects)
def test_parse_option_list_with_directives():
def test_parse_option_list_with_directives() -> None:
critical_options, extensions = parse_option_list(
["clear", "no-pty", "permit-pty", "permit-user-rc"]
)
@@ -425,7 +429,7 @@ def test_parse_option_list_with_directives():
assert set(extensions) == set(extension_objects)
def test_parse_option_list_case_sensitivity():
def test_parse_option_list_case_sensitivity() -> None:
critical_options, extensions = parse_option_list(
["CLEAR", "no-X11-forwarding", "permit-X11-forwarding"]
)

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import os.path
import typing as t
from getpass import getuser
from os import remove, rmdir
from socket import gethostname
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptogra
)
DEFAULT_KEY_PARAMS = [
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptography import (
KeyType,
)
DEFAULT_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
(
"rsa",
None,
@@ -50,7 +57,7 @@ DEFAULT_KEY_PARAMS = [
),
]
VALID_USER_KEY_PARAMS = [
VALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
(
"rsa",
8192,
@@ -77,9 +84,9 @@ VALID_USER_KEY_PARAMS = [
),
]
INVALID_USER_KEY_PARAMS = [
INVALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
(
"dne",
"dne", # type: ignore
None,
None,
None,
@@ -87,18 +94,18 @@ INVALID_USER_KEY_PARAMS = [
(
"rsa",
None,
[1, 2, 3],
[1, 2, 3], # type: ignore
"comment",
),
(
"ecdsa",
None,
None,
[1, 2, 3],
[1, 2, 3], # type: ignore
),
]
INVALID_KEY_SIZES = [
INVALID_KEY_SIZES: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
(
"rsa",
1023,
@@ -134,7 +141,9 @@ INVALID_KEY_SIZES = [
@pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS)
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_default_key_params(keytype, size, passphrase, comment):
def test_default_key_params(
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
) -> None:
result = True
default_sizes = {
@@ -163,7 +172,9 @@ def test_default_key_params(keytype, size, passphrase, comment):
@pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS)
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_valid_user_key_params(keytype, size, passphrase, comment):
def test_valid_user_key_params(
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
) -> None:
result = True
try:
@@ -181,7 +192,9 @@ def test_valid_user_key_params(keytype, size, passphrase, comment):
@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS)
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_invalid_user_key_params(keytype, size, passphrase, comment):
def test_invalid_user_key_params(
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
) -> None:
result = False
try:
@@ -199,7 +212,9 @@ def test_invalid_user_key_params(keytype, size, passphrase, comment):
@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES)
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_invalid_key_sizes(keytype, size, passphrase, comment):
def test_invalid_key_sizes(
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
) -> None:
result = False
try:
@@ -216,7 +231,7 @@ def test_invalid_key_sizes(keytype, size, passphrase, comment):
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_valid_comment_update():
def test_valid_comment_update() -> None:
pair = OpensshKeypair.generate()
new_comment = "comment"
@@ -233,13 +248,13 @@ def test_valid_comment_update():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_invalid_comment_update():
def test_invalid_comment_update() -> None:
result = False
pair = OpensshKeypair.generate()
new_comment = [1, 2, 3]
try:
pair.comment = new_comment
pair.comment = new_comment # type: ignore
except InvalidCommentError:
result = True
@@ -247,7 +262,7 @@ def test_invalid_comment_update():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_valid_passphrase_update():
def test_valid_passphrase_update() -> None:
result = False
passphrase = "change_me".encode("UTF-8")
@@ -281,13 +296,13 @@ def test_valid_passphrase_update():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_invalid_passphrase_update():
def test_invalid_passphrase_update() -> None:
result = False
passphrase = [1, 2, 3]
pair = OpensshKeypair.generate()
try:
pair.update_passphrase(passphrase)
pair.update_passphrase(passphrase) # type: ignore
except InvalidPassphraseError:
result = True
@@ -295,7 +310,7 @@ def test_invalid_passphrase_update():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_invalid_privatekey():
def test_invalid_privatekey() -> None:
result = False
try:
@@ -325,7 +340,7 @@ def test_invalid_privatekey():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_mismatched_keypair():
def test_mismatched_keypair() -> None:
result = False
try:
@@ -356,7 +371,7 @@ def test_mismatched_keypair():
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
def test_keypair_comparison():
def test_keypair_comparison() -> None:
assert OpensshKeypair.generate() != OpensshKeypair.generate()
assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="dsa")
assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="ed25519")
@@ -366,7 +381,7 @@ def test_keypair_comparison():
try:
tmpdir = mkdtemp()
keys = {
keys: dict[str, dict[str, t.Any]] = {
"rsa": {
"pair": OpensshKeypair.generate(),
"filename": os.path.join(tmpdir, "id_rsa"),

View File

@@ -4,6 +4,8 @@
from __future__ import annotations
import typing as t
import pytest
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
OpensshParser,
@@ -15,36 +17,36 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.utils imp
SSH_VERSION_STRING = "OpenSSH_7.9p1, OpenSSL 1.1.0i-fips 14 Aug 2018"
SSH_VERSION_NUMBER = "7.9"
VALID_BOOLEAN = [True, False]
INVALID_BOOLEAN = [0x02]
VALID_UINT32 = [
VALID_BOOLEAN: list[bool] = [True, False]
INVALID_BOOLEAN: list[t.Any] = [0x02]
VALID_UINT32: list[int] = [
0x00,
0x01,
0x01234567,
0xFFFFFFFF,
]
INVALID_UINT32 = [
INVALID_UINT32: list[int] = [
0xFFFFFFFFF,
-1,
]
VALID_UINT64 = [
VALID_UINT64: list[int] = [
0x00,
0x01,
0x0123456789ABCDEF,
0xFFFFFFFFFFFFFFFF,
]
INVALID_UINT64 = [
INVALID_UINT64: list[int] = [
0xFFFFFFFFFFFFFFFFF,
-1,
]
VALID_STRING = [
VALID_STRING: list[bytes] = [
b"test string",
]
INVALID_STRING = [
INVALID_STRING: list[t.Any] = [
[],
]
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for examples source
VALID_MPINT = [
VALID_MPINT: list[int] = [
0x00,
0x9A378F9B2E332A7,
0x80,
@@ -53,50 +55,50 @@ VALID_MPINT = [
# Additional large int test
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
]
INVALID_MPINT = [
INVALID_MPINT: list[t.Any] = [
[],
]
def test_parse_openssh_version():
def test_parse_openssh_version() -> None:
assert parse_openssh_version(SSH_VERSION_STRING) == SSH_VERSION_NUMBER
@pytest.mark.parametrize("boolean", VALID_BOOLEAN)
def test_valid_boolean(boolean):
def test_valid_boolean(boolean: bool) -> None:
assert OpensshParser(_OpensshWriter().boolean(boolean).bytes()).boolean() == boolean
@pytest.mark.parametrize("boolean", INVALID_BOOLEAN)
def test_invalid_boolean(boolean):
def test_invalid_boolean(boolean: t.Any) -> None:
with pytest.raises(TypeError):
_OpensshWriter().boolean(boolean)
@pytest.mark.parametrize("uint32", VALID_UINT32)
def test_valid_uint32(uint32):
def test_valid_uint32(uint32: int) -> None:
assert OpensshParser(_OpensshWriter().uint32(uint32).bytes()).uint32() == uint32
@pytest.mark.parametrize("uint32", INVALID_UINT32)
def test_invalid_uint32(uint32):
def test_invalid_uint32(uint32: int) -> None:
with pytest.raises(ValueError):
_OpensshWriter().uint32(uint32)
@pytest.mark.parametrize("uint64", VALID_UINT64)
def test_valid_uint64(uint64):
def test_valid_uint64(uint64: int) -> None:
assert OpensshParser(_OpensshWriter().uint64(uint64).bytes()).uint64() == uint64
@pytest.mark.parametrize("uint64", INVALID_UINT64)
def test_invalid_uint64(uint64):
def test_invalid_uint64(uint64: int) -> None:
with pytest.raises(ValueError):
_OpensshWriter().uint64(uint64)
@pytest.mark.parametrize("ssh_string", VALID_STRING)
def test_valid_string(ssh_string):
def test_valid_string(ssh_string: bytes) -> None:
assert (
OpensshParser(_OpensshWriter().string(ssh_string).bytes()).string()
== ssh_string
@@ -104,23 +106,23 @@ def test_valid_string(ssh_string):
@pytest.mark.parametrize("ssh_string", INVALID_STRING)
def test_invalid_string(ssh_string):
def test_invalid_string(ssh_string: t.Any) -> None:
with pytest.raises(TypeError):
_OpensshWriter().string(ssh_string)
@pytest.mark.parametrize("mpint", VALID_MPINT)
def test_valid_mpint(mpint):
def test_valid_mpint(mpint: int) -> None:
assert OpensshParser(_OpensshWriter().mpint(mpint).bytes()).mpint() == mpint
@pytest.mark.parametrize("mpint", INVALID_MPINT)
def test_invalid_mpint(mpint):
def test_invalid_mpint(mpint: t.Any) -> None:
with pytest.raises(TypeError):
_OpensshWriter().mpint(mpint)
def test_valid_seek():
def test_valid_seek() -> None:
buffer = bytearray(b"buffer")
parser = OpensshParser(buffer)
parser.seek(len(buffer))
@@ -129,7 +131,7 @@ def test_valid_seek():
assert parser.remaining_bytes() == len(buffer)
def test_invalid_seek():
def test_invalid_seek() -> None:
buffer = b"buffer"
parser = OpensshParser(buffer)
@@ -140,6 +142,6 @@ def test_invalid_seek():
parser.seek(-1)
def test_writer_bytes():
def test_writer_bytes() -> None:
buffer = bytearray(b"buffer")
assert _OpensshWriter(buffer).bytes() == buffer

View File

@@ -5,9 +5,9 @@
from __future__ import annotations
import datetime
import typing as t
import pytest
from ansible.module_utils.common.collections import is_sequence
from ansible_collections.community.crypto.plugins.module_utils.time import (
UTC,
add_or_remove_timezone,
@@ -30,25 +30,27 @@ TIMEZONES = [
]
def cartesian_product(list1, list2):
result = []
if t.TYPE_CHECKING:
_S = t.TypeVar("_S")
_Ts = t.TypeVarTuple("_Ts")
def cartesian_product(
list1: list[_S], list2: "list[tuple[*_Ts]]"
) -> "list[tuple[_S, *_Ts]]":
result: "list[tuple[_S, *_Ts]]" = []
for item1 in list1:
if not is_sequence(item1):
item1 = (item1,)
elif not isinstance(item1, tuple):
item1 = tuple(item1)
item1_tuple = (item1,)
for item2 in list2:
if not is_sequence(item2):
item2 = (item2,)
elif not isinstance(item2, tuple):
item2 = tuple(item2)
result.append(item1 + item2)
result.append(item1_tuple + item2)
return result
ONE_HOUR_PLUS = datetime.timezone(datetime.timedelta(hours=1))
TEST_REMOVE_TIMEZONE = cartesian_product(
TEST_REMOVE_TIMEZONE: list[
tuple[datetime.timedelta, datetime.datetime, datetime.datetime]
] = cartesian_product(
TIMEZONES,
[
(
@@ -66,7 +68,9 @@ TEST_REMOVE_TIMEZONE = cartesian_product(
],
)
TEST_UTC_TIMEZONE = cartesian_product(
TEST_UTC_TIMEZONE: list[
tuple[datetime.timedelta, datetime.datetime, datetime.datetime]
] = cartesian_product(
TIMEZONES,
[
(
@@ -84,48 +88,67 @@ TEST_UTC_TIMEZONE = cartesian_product(
],
)
TEST_EPOCH_SECONDS = cartesian_product(
TIMEZONES,
[
(0, dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0)),
(
1e-6,
dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1),
),
(
1e-3,
dict(
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1000
TEST_EPOCH_SECONDS: list[tuple[datetime.timedelta, float, dict[str, int]]] = (
cartesian_product(
TIMEZONES,
[
(
0,
dict(
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0
),
),
),
(
3691.2,
dict(
year=1970,
day=1,
month=1,
hour=1,
minute=1,
second=31,
microsecond=200000,
(
1e-6,
dict(
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1
),
),
),
],
(
1e-3,
dict(
year=1970,
day=1,
month=1,
hour=0,
minute=0,
second=0,
microsecond=1000,
),
),
(
3691.2,
dict(
year=1970,
day=1,
month=1,
hour=1,
minute=1,
second=31,
microsecond=200000,
),
),
],
)
)
TEST_EPOCH_TO_SECONDS = cartesian_product(
TIMEZONES,
[
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62),
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62),
(
datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS),
62 - 3600,
),
],
TEST_EPOCH_TO_SECONDS: list[tuple[datetime.timedelta, datetime.datetime, int]] = (
cartesian_product(
TIMEZONES,
[
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62),
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62),
(
datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS),
62 - 3600,
),
],
)
)
TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product(
TEST_CONVERT_RELATIVE_TO_DATETIME: list[
tuple[datetime.timedelta, str, bool, datetime.datetime, datetime.datetime]
] = cartesian_product(
TIMEZONES,
[
(
@@ -167,7 +190,9 @@ TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product(
],
)
TEST_GET_RELATIVE_TIME_OPTION = cartesian_product(
TEST_GET_RELATIVE_TIME_OPTION: list[
tuple[datetime.timedelta, str, str, bool, datetime.datetime, datetime.datetime]
] = cartesian_product(
TIMEZONES,
[
(
@@ -259,7 +284,9 @@ TEST_GET_RELATIVE_TIME_OPTION = cartesian_product(
@pytest.mark.parametrize("timezone, input, expected", TEST_REMOVE_TIMEZONE)
def test_remove_timezone(timezone, input, expected):
def test_remove_timezone(
timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output_1 = remove_timezone(input)
assert expected == output_1
@@ -268,7 +295,9 @@ def test_remove_timezone(timezone, input, expected):
@pytest.mark.parametrize("timezone, input, expected", TEST_UTC_TIMEZONE)
def test_utc_timezone(timezone, input, expected):
def test_utc_timezone(
timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output_1 = ensure_utc_timezone(input)
assert expected == output_1
@@ -280,7 +309,7 @@ def test_utc_timezone(timezone, input, expected):
# Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553)
# this only works with timezone = UTC
@pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)])
def test_get_now_datetime_w_timezone(timezone):
def test_get_now_datetime_w_timezone(timezone: datetime.timedelta) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output_2 = get_now_datetime(with_timezone=True)
assert output_2.tzinfo is not None
@@ -289,7 +318,7 @@ def test_get_now_datetime_w_timezone(timezone):
@pytest.mark.parametrize("timezone", TIMEZONES)
def test_get_now_datetime_wo_timezone(timezone):
def test_get_now_datetime_wo_timezone(timezone: datetime.timedelta) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output_1 = get_now_datetime(with_timezone=False)
assert output_1.tzinfo is None
@@ -297,13 +326,15 @@ def test_get_now_datetime_wo_timezone(timezone):
@pytest.mark.parametrize("timezone, seconds, timestamp", TEST_EPOCH_SECONDS)
def test_epoch_seconds(timezone, seconds, timestamp):
def test_epoch_seconds(
timezone: datetime.timedelta, seconds: float, timestamp: dict[str, int]
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
ts_wo_tz = datetime.datetime(**timestamp)
ts_wo_tz: datetime.datetime = datetime.datetime(**timestamp) # type: ignore
assert seconds == get_epoch_seconds(ts_wo_tz)
timestamp_w_tz = dict(timestamp)
timestamp_w_tz: dict[str, t.Any] = dict(timestamp)
timestamp_w_tz["tzinfo"] = UTC
ts_w_tz = datetime.datetime(**timestamp_w_tz)
ts_w_tz: datetime.datetime = datetime.datetime(**timestamp_w_tz) # type: ignore
assert seconds == get_epoch_seconds(ts_w_tz)
output_1 = from_epoch_seconds(seconds, with_timezone=False)
assert ts_wo_tz == output_1
@@ -312,7 +343,9 @@ def test_epoch_seconds(timezone, seconds, timestamp):
@pytest.mark.parametrize("timezone, timestamp, expected_seconds", TEST_EPOCH_TO_SECONDS)
def test_epoch_to_seconds(timezone, timestamp, expected_seconds):
def test_epoch_to_seconds(
timezone: datetime.timedelta, timestamp: datetime.datetime, expected_seconds: int
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
assert expected_seconds == get_epoch_seconds(timestamp)
@@ -322,8 +355,12 @@ def test_epoch_to_seconds(timezone, timestamp, expected_seconds):
TEST_CONVERT_RELATIVE_TO_DATETIME,
)
def test_convert_relative_to_datetime(
timezone, relative_time_string, with_timezone, now, expected
):
timezone: datetime.timedelta,
relative_time_string: str,
with_timezone: bool,
now: datetime.datetime,
expected: datetime.datetime,
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output = convert_relative_to_datetime(
relative_time_string, with_timezone=with_timezone, now=now
@@ -336,8 +373,13 @@ def test_convert_relative_to_datetime(
TEST_GET_RELATIVE_TIME_OPTION,
)
def test_get_relative_time_option(
timezone, input_string, input_name, with_timezone, now, expected
):
timezone: datetime.timedelta,
input_string: str,
input_name: str,
with_timezone: bool,
now: datetime.datetime,
expected: datetime.datetime,
) -> None:
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
output = get_relative_time_option(
input_string,