mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-06 21:33:00 +00:00
Add type hints and type checking (#885)
* Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests.
This commit is contained in:
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
@@ -19,12 +20,20 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
from ..test_time import TIMEZONES, cartesian_product
|
||||
|
||||
|
||||
def load_fixture(name):
|
||||
with open(os.path.join(os.path.dirname(__file__), "fixtures", name)) as f:
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
Criterium,
|
||||
)
|
||||
|
||||
|
||||
def load_fixture(name: str) -> str:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "fixtures", name), encoding="utf-8"
|
||||
) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
TEST_PEM_DERS = [
|
||||
TEST_PEM_DERS: list[tuple[str, bytes]] = [
|
||||
(
|
||||
load_fixture("privatekey_1.pem"),
|
||||
base64.b64decode(
|
||||
@@ -36,7 +45,7 @@ TEST_PEM_DERS = [
|
||||
]
|
||||
|
||||
|
||||
TEST_KEYS = [
|
||||
TEST_KEYS: list[tuple[str, dict[str, t.Any], str]] = [
|
||||
(
|
||||
load_fixture("privatekey_1.pem"),
|
||||
{
|
||||
@@ -56,7 +65,7 @@ TEST_KEYS = [
|
||||
]
|
||||
|
||||
|
||||
TEST_CSRS = [
|
||||
TEST_CSRS: list[tuple[str, set[tuple[str, str]], str]] = [
|
||||
(
|
||||
load_fixture("csr_1.pem"),
|
||||
set([("dns", "ansible.com"), ("dns", "example.com"), ("dns", "example.org")]),
|
||||
@@ -87,17 +96,19 @@ TEST_CERT_OPENSSL_OUTPUT_2 = load_fixture("cert_2.txt") # OpenSSL 3.3.0 output
|
||||
TEST_CERT_OPENSSL_OUTPUT_2B = load_fixture("cert_2-b.txt") # OpenSSL 1.1.1f output
|
||||
|
||||
|
||||
TEST_CERT_DAYS = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(datetime.datetime(2018, 11, 15, 1, 2, 3), 11),
|
||||
(datetime.datetime(2018, 11, 25, 15, 20, 0), 1),
|
||||
(datetime.datetime(2018, 11, 25, 15, 30, 0), 0),
|
||||
],
|
||||
TEST_CERT_DAYS: list[tuple[datetime.timedelta, datetime.datetime, int]] = (
|
||||
cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(datetime.datetime(2018, 11, 15, 1, 2, 3), 11),
|
||||
(datetime.datetime(2018, 11, 25, 15, 20, 0), 1),
|
||||
(datetime.datetime(2018, 11, 25, 15, 30, 0), 0),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
TEST_CERT_INFO = CertificateInformation(
|
||||
TEST_CERT_INFO_1 = CertificateInformation(
|
||||
not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24),
|
||||
not_valid_before=datetime.datetime(2018, 11, 25, 15, 28, 23),
|
||||
serial_number=1,
|
||||
@@ -115,65 +126,69 @@ TEST_CERT_INFO_2 = CertificateInformation(
|
||||
)
|
||||
|
||||
|
||||
TEST_CERT_INFO = [
|
||||
(TEST_CERT, TEST_CERT_INFO, TEST_CERT_OPENSSL_OUTPUT),
|
||||
TEST_CERT_INFO: list[tuple[str, CertificateInformation, str]] = [
|
||||
(TEST_CERT, TEST_CERT_INFO_1, TEST_CERT_OPENSSL_OUTPUT),
|
||||
(TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2),
|
||||
(TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2B),
|
||||
]
|
||||
|
||||
|
||||
TEST_PARSE_ACME_TIMESTAMP = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
"2024-01-01T00:11:22Z",
|
||||
dict(year=2024, month=1, day=1, hour=0, minute=11, second=22),
|
||||
),
|
||||
(
|
||||
"2024-01-01T00:11:22.123Z",
|
||||
dict(
|
||||
year=2024,
|
||||
month=1,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=11,
|
||||
second=22,
|
||||
microsecond=123000,
|
||||
TEST_PARSE_ACME_TIMESTAMP: list[tuple[datetime.timedelta, str, dict[str, int]]] = (
|
||||
cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
"2024-01-01T00:11:22Z",
|
||||
dict(year=2024, month=1, day=1, hour=0, minute=11, second=22),
|
||||
),
|
||||
),
|
||||
(
|
||||
"2024-04-17T06:54:13.333333334Z",
|
||||
dict(
|
||||
year=2024,
|
||||
month=4,
|
||||
day=17,
|
||||
hour=6,
|
||||
minute=54,
|
||||
second=13,
|
||||
microsecond=333333,
|
||||
(
|
||||
"2024-01-01T00:11:22.123Z",
|
||||
dict(
|
||||
year=2024,
|
||||
month=1,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=11,
|
||||
second=22,
|
||||
microsecond=123000,
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
"2024-01-01T00:11:22+0100",
|
||||
dict(year=2023, month=12, day=31, hour=23, minute=11, second=22),
|
||||
),
|
||||
(
|
||||
"2024-01-01T00:11:22.123+0100",
|
||||
dict(
|
||||
year=2023,
|
||||
month=12,
|
||||
day=31,
|
||||
hour=23,
|
||||
minute=11,
|
||||
second=22,
|
||||
microsecond=123000,
|
||||
(
|
||||
"2024-04-17T06:54:13.333333334Z",
|
||||
dict(
|
||||
year=2024,
|
||||
month=4,
|
||||
day=17,
|
||||
hour=6,
|
||||
minute=54,
|
||||
second=13,
|
||||
microsecond=333333,
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
(
|
||||
"2024-01-01T00:11:22+0100",
|
||||
dict(year=2023, month=12, day=31, hour=23, minute=11, second=22),
|
||||
),
|
||||
(
|
||||
"2024-01-01T00:11:22.123+0100",
|
||||
dict(
|
||||
year=2023,
|
||||
month=12,
|
||||
day=31,
|
||||
hour=23,
|
||||
minute=11,
|
||||
second=22,
|
||||
microsecond=123000,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
TEST_INTERPOLATE_TIMESTAMP = cartesian_product(
|
||||
TEST_INTERPOLATE_TIMESTAMP: list[
|
||||
tuple[datetime.timedelta, dict[str, int], dict[str, int], float, dict[str, int]]
|
||||
] = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
@@ -199,26 +214,50 @@ TEST_INTERPOLATE_TIMESTAMP = cartesian_product(
|
||||
|
||||
|
||||
class FakeBackend(CryptoBackend):
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase=None,
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any] | None
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> t.NoReturn:
|
||||
raise BackendException("Not implemented in fake backend")
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing as t
|
||||
from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
@@ -35,12 +36,20 @@ from .backend_data import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
)
|
||||
|
||||
|
||||
if not HAS_CURRENT_CRYPTOGRAPHY:
|
||||
pytest.skip("cryptography not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pem, result, dummy", TEST_KEYS)
|
||||
def test_eckeyparse_cryptography(pem, result, dummy, tmpdir):
|
||||
def test_eckeyparse_cryptography(
|
||||
pem: str, result: dict[str, t.Any], dummy: str, tmpdir
|
||||
) -> None:
|
||||
fn = tmpdir / "test.pem"
|
||||
fn.write(pem)
|
||||
module = MagicMock()
|
||||
@@ -54,7 +63,9 @@ def test_eckeyparse_cryptography(pem, result, dummy, tmpdir):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS)
|
||||
def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir):
|
||||
def test_csridentifiers_cryptography(
|
||||
csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir
|
||||
) -> None:
|
||||
fn = tmpdir / "test.csr"
|
||||
fn.write(csr)
|
||||
module = MagicMock()
|
||||
@@ -66,7 +77,9 @@ def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS)
|
||||
def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
|
||||
def test_certdays_cryptography(
|
||||
timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
fn = tmpdir / "test-cert.pem"
|
||||
fn.write(TEST_CERT)
|
||||
@@ -81,7 +94,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
|
||||
@pytest.mark.parametrize(
|
||||
"cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO
|
||||
)
|
||||
def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir):
|
||||
def test_get_cert_information(
|
||||
cert_content: str,
|
||||
expected_cert_info: CertificateInformation,
|
||||
openssl_output: str,
|
||||
tmpdir,
|
||||
) -> None:
|
||||
fn = tmpdir / "test-cert.pem"
|
||||
fn.write(cert_content)
|
||||
module = MagicMock()
|
||||
@@ -105,7 +123,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output,
|
||||
@pytest.mark.parametrize(
|
||||
"timezone", [datetime.timedelta(hours=0)] if CRYPTOGRAPHY_TIMEZONE else TIMEZONES
|
||||
)
|
||||
def test_now(timezone):
|
||||
def test_now(timezone: datetime.timedelta) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = CryptographyBackend(module)
|
||||
@@ -119,7 +137,9 @@ def test_now(timezone):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP)
|
||||
def test_parse_acme_timestamp(timezone, input, expected):
|
||||
def test_parse_acme_timestamp(
|
||||
timezone: datetime.timedelta, input: str, expected: dict[str, int]
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06 +00:00", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = CryptographyBackend(module)
|
||||
@@ -131,7 +151,13 @@ def test_parse_acme_timestamp(timezone, input, expected):
|
||||
@pytest.mark.parametrize(
|
||||
"timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP
|
||||
)
|
||||
def test_interpolate_timestamp(timezone, start, end, percentage, expected):
|
||||
def test_interpolate_timestamp(
|
||||
timezone: datetime.timedelta,
|
||||
start: dict[str, int],
|
||||
end: dict[str, int],
|
||||
percentage: float,
|
||||
expected: dict[str, int],
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = CryptographyBackend(module)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing as t
|
||||
from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
@@ -31,6 +32,12 @@ from .backend_data import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
)
|
||||
|
||||
|
||||
# from ..test_time import TIMEZONES
|
||||
|
||||
|
||||
@@ -47,7 +54,9 @@ TEST_IPS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pem, result, openssl_output", TEST_KEYS)
|
||||
def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir):
|
||||
def test_eckeyparse_openssl(
|
||||
pem: str, result: dict[str, t.Any], openssl_output: str, tmpdir
|
||||
) -> None:
|
||||
fn = tmpdir / "test.key"
|
||||
fn.write(pem)
|
||||
module = MagicMock()
|
||||
@@ -59,7 +68,9 @@ def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS)
|
||||
def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir):
|
||||
def test_csridentifiers_openssl(
|
||||
csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir
|
||||
) -> None:
|
||||
fn = tmpdir / "test.csr"
|
||||
fn.write(csr)
|
||||
module = MagicMock()
|
||||
@@ -70,14 +81,16 @@ def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ip, result", TEST_IPS)
|
||||
def test_normalize_ip(ip, result):
|
||||
def test_normalize_ip(ip: str, result: str) -> None:
|
||||
module = MagicMock()
|
||||
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
|
||||
assert backend._normalize_ip(ip) == result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS)
|
||||
def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
|
||||
def test_certdays_cryptography(
|
||||
timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
fn = tmpdir / "test-cert.pem"
|
||||
fn.write(TEST_CERT)
|
||||
@@ -93,7 +106,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir):
|
||||
@pytest.mark.parametrize(
|
||||
"cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO
|
||||
)
|
||||
def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir):
|
||||
def test_get_cert_information(
|
||||
cert_content: str,
|
||||
expected_cert_info: CertificateInformation,
|
||||
openssl_output: str,
|
||||
tmpdir,
|
||||
) -> None:
|
||||
fn = tmpdir / "test-cert.pem"
|
||||
fn.write(cert_content)
|
||||
module = MagicMock()
|
||||
@@ -115,7 +133,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output,
|
||||
# Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553)
|
||||
# this only works with timezone = UTC if CRYPTOGRAPHY_TIMEZONE is truish
|
||||
@pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)])
|
||||
def test_now(timezone):
|
||||
def test_now(timezone: datetime.timedelta) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
|
||||
@@ -125,7 +143,9 @@ def test_now(timezone):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP)
|
||||
def test_parse_acme_timestamp(timezone, input, expected):
|
||||
def test_parse_acme_timestamp(
|
||||
timezone: datetime.timedelta, input: str, expected: dict[str, int]
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
|
||||
@@ -137,7 +157,13 @@ def test_parse_acme_timestamp(timezone, input, expected):
|
||||
@pytest.mark.parametrize(
|
||||
"timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP
|
||||
)
|
||||
def test_interpolate_timestamp(timezone, start, end, percentage, expected):
|
||||
def test_interpolate_timestamp(
|
||||
timezone: datetime.timedelta,
|
||||
start: dict[str, int],
|
||||
end: dict[str, int],
|
||||
percentage: float,
|
||||
expected: dict[str, int],
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
module = MagicMock()
|
||||
backend = OpenSSLCLIBackend(module, openssl_binary="openssl")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
@@ -21,21 +22,21 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def test_combine_identifier():
|
||||
def test_combine_identifier() -> None:
|
||||
assert combine_identifier("", "") == ":"
|
||||
assert combine_identifier("a", "b") == "a:b"
|
||||
|
||||
|
||||
def test_split_identifier():
|
||||
assert split_identifier(":") == ["", ""]
|
||||
assert split_identifier("a:b") == ["a", "b"]
|
||||
assert split_identifier("a:b:c") == ["a", "b:c"]
|
||||
def test_split_identifier() -> None:
|
||||
assert split_identifier(":") == ("", "")
|
||||
assert split_identifier("a:b") == ("a", "b")
|
||||
assert split_identifier("a:b:c") == ("a", "b:c")
|
||||
with pytest.raises(ModuleFailException) as exc:
|
||||
split_identifier("a")
|
||||
assert exc.value.msg == 'Identifier "a" is not of the form <type>:<identifier>'
|
||||
|
||||
|
||||
def test_challenge_from_to_json():
|
||||
def test_challenge_from_to_json() -> None:
|
||||
client = MagicMock()
|
||||
|
||||
data = {
|
||||
@@ -57,7 +58,7 @@ def test_challenge_from_to_json():
|
||||
"status": "valid",
|
||||
"token": "foo",
|
||||
}
|
||||
challenge = Challenge.from_json(None, data, url="xxx")
|
||||
challenge = Challenge.from_json(None, data, url="xxx") # type: ignore
|
||||
assert challenge.data == data
|
||||
assert challenge.type == "type"
|
||||
assert challenge.url == "xxx"
|
||||
@@ -66,10 +67,12 @@ def test_challenge_from_to_json():
|
||||
assert challenge.to_json() == data
|
||||
|
||||
|
||||
def test_authorization_from_to_json():
|
||||
def test_authorization_from_to_json() -> None:
|
||||
client = MagicMock()
|
||||
client.version = 2
|
||||
|
||||
data: dict[str, t.Any]
|
||||
|
||||
data = {
|
||||
"challenges": [],
|
||||
"status": "valid",
|
||||
@@ -138,7 +141,7 @@ def test_authorization_from_to_json():
|
||||
}
|
||||
|
||||
|
||||
def test_authorization_create_error():
|
||||
def test_authorization_create_error() -> None:
|
||||
client = MagicMock()
|
||||
client.version = 2
|
||||
client.directory.directory = {}
|
||||
@@ -148,7 +151,7 @@ def test_authorization_create_error():
|
||||
assert exc.value.msg == "ACME endpoint does not support pre-authorization."
|
||||
|
||||
|
||||
def test_wait_for_validation_error():
|
||||
def test_wait_for_validation_error() -> None:
|
||||
client = MagicMock()
|
||||
client.version = 2
|
||||
data = {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from unittest.mock import (
|
||||
MagicMock,
|
||||
)
|
||||
@@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
TEST_FORMAT_ERROR_PROBLEM = [
|
||||
TEST_FORMAT_ERROR_PROBLEM: list[tuple[dict[str, t.Any], str, str]] = [
|
||||
(
|
||||
{
|
||||
"type": "foo",
|
||||
@@ -90,33 +91,37 @@ TEST_FORMAT_ERROR_PROBLEM = [
|
||||
@pytest.mark.parametrize(
|
||||
"problem, subproblem_prefix, result", TEST_FORMAT_ERROR_PROBLEM
|
||||
)
|
||||
def test_format_error_problem(problem, subproblem_prefix, result):
|
||||
def test_format_error_problem(
|
||||
problem: dict[str, t.Any], subproblem_prefix: str, result: str
|
||||
) -> None:
|
||||
res = format_error_problem(problem, subproblem_prefix)
|
||||
assert res == result
|
||||
|
||||
|
||||
def create_regular_response(response_text):
|
||||
def create_regular_response(response_text: str) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.read = MagicMock(return_value=response_text.encode("utf-8"))
|
||||
response.closed = False
|
||||
return response
|
||||
|
||||
|
||||
def create_error_response():
|
||||
def create_error_response() -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.read = MagicMock(side_effect=AttributeError("read"))
|
||||
response.closed = True
|
||||
return response
|
||||
|
||||
|
||||
def create_decode_error(msg):
|
||||
def f(content):
|
||||
def create_decode_error(msg: str) -> t.Callable[[t.Any], t.Any]:
|
||||
def f(content: t.Any) -> t.NoReturn:
|
||||
raise Exception(msg)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
TEST_ACME_PROTOCOL_EXCEPTION = [
|
||||
TEST_ACME_PROTOCOL_EXCEPTION: list[
|
||||
tuple[dict[str, t.Any], t.Callable[[t.Any], t.Any] | None, str, dict[str, t.Any]]
|
||||
] = [
|
||||
(
|
||||
{},
|
||||
None,
|
||||
@@ -341,14 +346,19 @@ TEST_ACME_PROTOCOL_EXCEPTION = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input, from_json, msg, args", TEST_ACME_PROTOCOL_EXCEPTION)
|
||||
def test_acme_protocol_exception(input, from_json, msg, args):
|
||||
def test_acme_protocol_exception(
|
||||
input: dict[str, t.Any],
|
||||
from_json: t.Callable[[t.Any], t.NoReturn] | None,
|
||||
msg: str,
|
||||
args: dict[str, t.Any],
|
||||
) -> None:
|
||||
if from_json is None:
|
||||
module = None
|
||||
else:
|
||||
module = MagicMock()
|
||||
module.from_json = from_json
|
||||
with pytest.raises(ACMEProtocolException) as exc:
|
||||
raise ACMEProtocolException(module, **input)
|
||||
raise ACMEProtocolException(module, **input) # type: ignore
|
||||
|
||||
print(exc.value.msg)
|
||||
print(exc.value.module_fail_args)
|
||||
|
||||
@@ -18,14 +18,13 @@ TEST_TEXT = r"""1234
|
||||
5678"""
|
||||
|
||||
|
||||
def test_read_file(tmpdir):
|
||||
def test_read_file(tmpdir) -> None:
|
||||
fn = tmpdir / "test.txt"
|
||||
fn.write(TEST_TEXT)
|
||||
assert read_file(str(fn), "t") == TEST_TEXT
|
||||
assert read_file(str(fn), "b") == TEST_TEXT.encode("utf-8")
|
||||
assert read_file(str(fn)) == TEST_TEXT.encode("utf-8")
|
||||
|
||||
|
||||
def test_write_file(tmpdir):
|
||||
def test_write_file(tmpdir) -> None:
|
||||
fn = tmpdir / "test.txt"
|
||||
module = MagicMock()
|
||||
write_file(module, str(fn), TEST_TEXT.encode("utf-8"))
|
||||
|
||||
@@ -15,7 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
|
||||
|
||||
|
||||
def test_order_from_json():
|
||||
def test_order_from_json() -> None:
|
||||
client = MagicMock()
|
||||
|
||||
data = {
|
||||
@@ -35,7 +35,7 @@ def test_order_from_json():
|
||||
assert order.authorizations == {}
|
||||
|
||||
|
||||
def test_wait_for_finalization_error():
|
||||
def test_wait_for_finalization_error() -> None:
|
||||
client = MagicMock()
|
||||
client.version = 2
|
||||
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
CryptoBackend,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
|
||||
compute_cert_id,
|
||||
@@ -21,7 +23,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
from .backend_data import TEST_PEM_DERS
|
||||
|
||||
|
||||
NOPAD_B64 = [
|
||||
NOPAD_B64: list[tuple[str, str]] = [
|
||||
("", ""),
|
||||
("\n", "Cg"),
|
||||
("123", "MTIz"),
|
||||
@@ -29,7 +31,7 @@ NOPAD_B64 = [
|
||||
]
|
||||
|
||||
|
||||
TEST_LINKS_HEADER = [
|
||||
TEST_LINKS_HEADER: list[tuple[dict[str, t.Any], list[tuple[str, str]]]] = [
|
||||
(
|
||||
{},
|
||||
[],
|
||||
@@ -60,13 +62,13 @@ TEST_LINKS_HEADER = [
|
||||
]
|
||||
|
||||
|
||||
TEST_RETRY_AFTER_HEADER = [
|
||||
TEST_RETRY_AFTER_HEADER: list[tuple[str, datetime.datetime]] = [
|
||||
("120", datetime.datetime(2024, 4, 29, 0, 2, 0)),
|
||||
("Wed, 21 Oct 2015 07:28:00 GMT", datetime.datetime(2015, 10, 21, 7, 28, 0)),
|
||||
]
|
||||
|
||||
|
||||
TEST_COMPUTE_CERT_ID = [
|
||||
TEST_COMPUTE_CERT_ID: list[tuple[CertificateInformation, str]] = [
|
||||
(
|
||||
CertificateInformation(
|
||||
not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24),
|
||||
@@ -93,19 +95,21 @@ TEST_COMPUTE_CERT_ID = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, result", NOPAD_B64)
|
||||
def test_nopad_b64(value, result):
|
||||
def test_nopad_b64(value: str, result: str) -> None:
|
||||
assert nopad_b64(value.encode("utf-8")) == result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pem, der", TEST_PEM_DERS)
|
||||
def test_pem_to_der(pem, der, tmpdir):
|
||||
def test_pem_to_der(pem: str, der: bytes, tmpdir):
|
||||
fn = tmpdir / "test.pem"
|
||||
fn.write(pem)
|
||||
assert pem_to_der(str(fn)) == der
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, expected_result", TEST_LINKS_HEADER)
|
||||
def test_process_links(value, expected_result):
|
||||
def test_process_links(
|
||||
value: dict[str, t.Any], expected_result: list[tuple[str, str]]
|
||||
) -> None:
|
||||
data = []
|
||||
|
||||
def callback(url, rel):
|
||||
@@ -117,12 +121,15 @@ def test_process_links(value, expected_result):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, expected_result", TEST_RETRY_AFTER_HEADER)
|
||||
def test_parse_retry_after(value, expected_result):
|
||||
def test_parse_retry_after(value: str, expected_result: datetime.datetime) -> None:
|
||||
assert expected_result == parse_retry_after(
|
||||
value, now=datetime.datetime(2024, 4, 29, 0, 0, 0)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cert_info, expected_result", TEST_COMPUTE_CERT_ID)
|
||||
def test_compute_cert_id(cert_info, expected_result):
|
||||
assert expected_result == compute_cert_id(backend=None, cert_info=cert_info)
|
||||
def test_compute_cert_id(
|
||||
cert_info: CertificateInformation, expected_result: str
|
||||
) -> None:
|
||||
backend: CryptoBackend = None # type: ignore
|
||||
assert expected_result == compute_cert_id(backend=backend, cert_info=cert_info)
|
||||
|
||||
@@ -10,12 +10,11 @@ import subprocess
|
||||
|
||||
import pytest
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto._asn1 import (
|
||||
pack_asn1,
|
||||
serialize_asn1_string_as_der,
|
||||
)
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
TEST_CASES: list[tuple[str, bytes]] = [
|
||||
("UTF8:Hello World", b"\x0c\x0b\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64"),
|
||||
(
|
||||
"EXPLICIT:10,UTF8:Hello World",
|
||||
@@ -76,7 +75,7 @@ TEST_CASES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, expected", TEST_CASES)
|
||||
def test_serialize_asn1_string_as_der(value, expected):
|
||||
def test_serialize_asn1_string_as_der(value: str, expected: bytes) -> None:
|
||||
actual = serialize_asn1_string_as_der(value)
|
||||
print(f"{value} | {base64.b16encode(actual).decode()}")
|
||||
assert actual == expected
|
||||
@@ -89,7 +88,7 @@ def test_serialize_asn1_string_as_der(value, expected):
|
||||
"EXPLICIT,UTF:value",
|
||||
],
|
||||
)
|
||||
def test_serialize_asn1_string_as_der_invalid_format(value):
|
||||
def test_serialize_asn1_string_as_der_invalid_format(value: str) -> None:
|
||||
expected = (
|
||||
"The ASN.1 serialized string must be in the format [modifier,]type[:value]"
|
||||
)
|
||||
@@ -97,20 +96,15 @@ def test_serialize_asn1_string_as_der_invalid_format(value):
|
||||
serialize_asn1_string_as_der(value)
|
||||
|
||||
|
||||
def test_serialize_asn1_string_as_der_invalid_type():
|
||||
def test_serialize_asn1_string_as_der_invalid_type() -> None:
|
||||
expected = 'The ASN.1 serialized string is not a known type "OID", only UTF8 types are supported'
|
||||
with pytest.raises(ValueError, match=re.escape(expected)):
|
||||
serialize_asn1_string_as_der("OID:1.2.3.4")
|
||||
|
||||
|
||||
def test_pack_asn_invalid_class():
|
||||
with pytest.raises(ValueError, match="tag_class must be between 0 and 3 not 4"):
|
||||
pack_asn1(4, True, 0, b"")
|
||||
|
||||
|
||||
@pytest.mark.skip() # This is to just to build the test case assertions and shouldn't run normally.
|
||||
@pytest.mark.parametrize("value, expected", TEST_CASES)
|
||||
def test_test_cases(value, expected, tmp_path):
|
||||
def test_test_cases(value: str, expected: bytes, tmp_path) -> None:
|
||||
test_file = tmp_path / "test.der"
|
||||
subprocess.run(
|
||||
["openssl", "asn1parse", "-genstr", value, "-noout", "-out", test_file],
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import cryptography
|
||||
import pytest
|
||||
@@ -20,7 +21,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
LooseVersion,
|
||||
)
|
||||
from cryptography.x509 import NameAttribute, oid
|
||||
from cryptography.x509 import NameAttribute, OtherName, oid
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -35,7 +36,7 @@ from cryptography.x509 import NameAttribute, oid
|
||||
("*.☺.", "*.xn--74h.", None),
|
||||
],
|
||||
)
|
||||
def test_adjust_idn(unicode, idna, cycled_unicode):
|
||||
def test_adjust_idn(unicode: str, idna: str, cycled_unicode: str | None) -> None:
|
||||
if cycled_unicode is None:
|
||||
cycled_unicode = unicode
|
||||
|
||||
@@ -70,9 +71,10 @@ def test_adjust_idn(unicode, idna, cycled_unicode):
|
||||
("bar", "foo", re.escape('Invalid value for idn_rewrite: "foo"')),
|
||||
],
|
||||
)
|
||||
def test_adjust_idn_fail_valueerror(value, idn_rewrite, message):
|
||||
def test_adjust_idn_fail_valueerror(value: str, idn_rewrite: str, message: str) -> None:
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_adjust_idn(value, idn_rewrite)
|
||||
idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore
|
||||
_adjust_idn(value, idn_rewrite_)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -88,27 +90,29 @@ def test_adjust_idn_fail_valueerror(value, idn_rewrite, message):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_adjust_idn_fail_user_error(value, idn_rewrite, message):
|
||||
def test_adjust_idn_fail_user_error(value: str, idn_rewrite: str, message: str) -> None:
|
||||
with pytest.raises(OpenSSLObjectError, match=message):
|
||||
_adjust_idn(value, idn_rewrite)
|
||||
idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore
|
||||
_adjust_idn(value, idn_rewrite_)
|
||||
|
||||
|
||||
def test_cryptography_get_name_invalid_prefix():
|
||||
def test_cryptography_get_name_invalid_prefix() -> None:
|
||||
with pytest.raises(
|
||||
OpenSSLObjectError, match="^Cannot parse Subject Alternative Name"
|
||||
):
|
||||
cryptography_get_name("fake:value")
|
||||
|
||||
|
||||
def test_cryptography_get_name_other_name_no_oid():
|
||||
def test_cryptography_get_name_other_name_no_oid() -> None:
|
||||
with pytest.raises(
|
||||
OpenSSLObjectError, match="Cannot parse Subject Alternative Name otherName"
|
||||
):
|
||||
cryptography_get_name("otherName:value")
|
||||
|
||||
|
||||
def test_cryptography_get_name_other_name_utfstring():
|
||||
def test_cryptography_get_name_other_name_utfstring() -> None:
|
||||
actual = cryptography_get_name("otherName:1.3.6.1.4.1.311.20.2.3;UTF8:Hello World")
|
||||
assert isinstance(actual, OtherName)
|
||||
assert actual.type_id.dotted_string == "1.3.6.1.4.1.311.20.2.3"
|
||||
assert actual.value == b"\x0c\x0bHello World"
|
||||
|
||||
@@ -164,7 +168,9 @@ def test_cryptography_get_name_other_name_utfstring():
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_dn_component(name, options, expected):
|
||||
def test_parse_dn_component(
|
||||
name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes]
|
||||
) -> None:
|
||||
result = _parse_dn_component(name, **options)
|
||||
print(result, expected)
|
||||
assert result == expected
|
||||
@@ -186,7 +192,9 @@ if (
|
||||
(b"CN= ", {}, (NameAttribute(oid.NameOID.COMMON_NAME, ""), b"")),
|
||||
],
|
||||
)
|
||||
def test_parse_dn_component_not_py26(name, options, expected):
|
||||
def test_parse_dn_component_not_py26(
|
||||
name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes]
|
||||
) -> None:
|
||||
result = _parse_dn_component(name, **options)
|
||||
print(result, expected)
|
||||
assert result == expected
|
||||
@@ -200,7 +208,9 @@ if (
|
||||
(b"CN=#0,", {}, 'Invalid hex sequence entry "0,"'),
|
||||
],
|
||||
)
|
||||
def test_parse_dn_component_failure(name, options, message):
|
||||
def test_parse_dn_component_failure(
|
||||
name: bytes, options: dict[str, t.Any], message: str
|
||||
) -> None:
|
||||
with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"):
|
||||
_parse_dn_component(name, **options)
|
||||
|
||||
@@ -225,7 +235,7 @@ def test_parse_dn_component_failure(name, options, message):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_dn(name, expected):
|
||||
def test_parse_dn(name: bytes, expected: list[NameAttribute]) -> None:
|
||||
result = _parse_dn(name)
|
||||
print(result, expected)
|
||||
assert result == expected
|
||||
@@ -236,14 +246,14 @@ def test_parse_dn(name, expected):
|
||||
[
|
||||
(
|
||||
b"CN=\\0",
|
||||
'Error while parsing distinguished name "CN=\\0": Hex escape sequence "\\0" incomplete at end of string',
|
||||
"Error while parsing distinguished name 'CN=\\\\0': Hex escape sequence \"\\0\" incomplete at end of string",
|
||||
),
|
||||
(
|
||||
b"CN=x,",
|
||||
'Error while parsing distinguished name "CN=x,": unexpected end of string',
|
||||
"Error while parsing distinguished name 'CN=x,': unexpected end of string",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_dn_failure(name, message):
|
||||
def test_parse_dn_failure(name: bytes, message: str):
|
||||
with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"):
|
||||
_parse_dn(name)
|
||||
|
||||
@@ -26,7 +26,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
|
||||
(2, 10, 5, 4),
|
||||
],
|
||||
)
|
||||
def test_binary_exp_mod(f, e, m, result):
|
||||
def test_binary_exp_mod(f: int, e: int, m: int, result: int) -> None:
|
||||
value = binary_exp_mod(f, e, m)
|
||||
print(value)
|
||||
assert value == result
|
||||
@@ -46,7 +46,7 @@ def test_binary_exp_mod(f, e, m, result):
|
||||
(1024, 10, 2),
|
||||
],
|
||||
)
|
||||
def test_simple_gcd(a, b, result):
|
||||
def test_simple_gcd(a: int, b: int, result: int) -> None:
|
||||
value = simple_gcd(a, b)
|
||||
print(value)
|
||||
assert value == result
|
||||
@@ -70,7 +70,7 @@ def test_simple_gcd(a, b, result):
|
||||
(211, False), # the smallest prime number >= 200
|
||||
],
|
||||
)
|
||||
def test_quick_is_not_prime(n, result):
|
||||
def test_quick_is_not_prime(n: int, result: bool) -> None:
|
||||
value = quick_is_not_prime(n)
|
||||
print(value)
|
||||
assert value == result
|
||||
@@ -88,7 +88,7 @@ def test_quick_is_not_prime(n, result):
|
||||
(256, None, b"\x01\x00"),
|
||||
],
|
||||
)
|
||||
def test_convert_int_to_bytes(no, count, result):
|
||||
def test_convert_int_to_bytes(no: int, count: int | None, result: bytes) -> None:
|
||||
value = convert_int_to_bytes(no, count=count)
|
||||
print(value)
|
||||
assert value == result
|
||||
@@ -108,7 +108,7 @@ def test_convert_int_to_bytes(no, count, result):
|
||||
(256, 4, "0100"),
|
||||
],
|
||||
)
|
||||
def test_convert_int_to_hex(no, digits, result):
|
||||
def test_convert_int_to_hex(no: int, digits: int | None, result: str) -> None:
|
||||
value = convert_int_to_hex(no, digits=digits)
|
||||
print(value)
|
||||
assert value == result
|
||||
@@ -125,7 +125,7 @@ def test_convert_int_to_hex(no, digits, result):
|
||||
(b"\x01\x00", 256),
|
||||
],
|
||||
)
|
||||
def test_convert_bytes_to_int(data, result):
|
||||
def test_convert_bytes_to_int(data: bytes, result: int) -> None:
|
||||
value = convert_bytes_to_int(data)
|
||||
print(value)
|
||||
assert value == result
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
|
||||
extract_first_pem,
|
||||
@@ -13,7 +15,9 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
PEM_TEST_CASES = [
|
||||
PEM_TEST_CASES: list[
|
||||
tuple[bytes, list[str], bool, t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]]
|
||||
] = [
|
||||
(b"", [], False, "raw"),
|
||||
(b"random stuff\nblabla", [], False, "raw"),
|
||||
(b"-----BEGIN PRIVATE KEY-----", [], False, "raw"),
|
||||
@@ -51,7 +55,12 @@ PEM_TEST_CASES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data, pems, is_pem, private_key_type", PEM_TEST_CASES)
|
||||
def test_pem_handling(data, pems, is_pem, private_key_type):
|
||||
def test_pem_handling(
|
||||
data: bytes,
|
||||
pems: list[str],
|
||||
is_pem: bool,
|
||||
private_key_type: t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"],
|
||||
):
|
||||
assert identify_pem_format(data) == is_pem
|
||||
assert identify_private_key_format(data) == private_key_type
|
||||
try:
|
||||
|
||||
@@ -132,7 +132,9 @@ VALID_EXTENSIONS = [
|
||||
]
|
||||
INVALID_EXTENSIONS = [OpensshCertificateOption("extension", "test", "")]
|
||||
|
||||
VALID_TIME_PARAMETERS = [
|
||||
VALID_TIME_PARAMETERS: list[
|
||||
tuple[int | str, int | str, str, int, int | str, str, str, int, str]
|
||||
] = [
|
||||
(
|
||||
0,
|
||||
"always",
|
||||
@@ -223,28 +225,28 @@ VALID_TIME_PARAMETERS = [
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_TIME_PARAMETERS = [
|
||||
INVALID_TIME_PARAMETERS: list[tuple[int | str, int | str]] = [
|
||||
(-1, 0xFFFFFFFFFFFFFFFFFF),
|
||||
("never", "ever"),
|
||||
("01-01-1980", "01-01-1990"),
|
||||
(1, 0),
|
||||
]
|
||||
|
||||
VALID_VALIDITY_TEST = [
|
||||
VALID_VALIDITY_TEST: list[tuple[str, str, str]] = [
|
||||
("always", "forever", "2000-01-01"),
|
||||
("1999-12-31", "2000-01-02", "2000-01-01"),
|
||||
("1999-12-31 23:59:00", "2000-01-01 00:01:00", "2000-01-01 00:00:00"),
|
||||
("1999-12-31 23:59:59", "2000-01-01 00:00:01", "2000-01-01 00:00:00"),
|
||||
]
|
||||
|
||||
INVALID_VALIDITY_TEST = [
|
||||
INVALID_VALIDITY_TEST: list[tuple[str, str, str]] = [
|
||||
("always", "forever", "1969-12-31"),
|
||||
("always", "2000-01-01", "2000-01-02"),
|
||||
("2000-01-01", "forever", "1999-12-31"),
|
||||
("2000-01-01 00:00:00", "2000-01-01 00:00:01", "2000-01-01 00:00:02"),
|
||||
]
|
||||
|
||||
VALID_OPTIONS = [
|
||||
VALID_OPTIONS: list[tuple[str, OpensshCertificateOption]] = [
|
||||
(
|
||||
"force-command=/usr/bin/csh",
|
||||
OpensshCertificateOption("critical", "force-command", "/usr/bin/csh"),
|
||||
@@ -265,7 +267,7 @@ VALID_OPTIONS = [
|
||||
("extension:foo", OpensshCertificateOption("extension", "foo", "")),
|
||||
]
|
||||
|
||||
INVALID_OPTIONS = [
|
||||
INVALID_OPTIONS: list[str | list] = [
|
||||
"foobar",
|
||||
"foo=bar",
|
||||
"foo:bar=baz",
|
||||
@@ -273,7 +275,7 @@ INVALID_OPTIONS = [
|
||||
]
|
||||
|
||||
|
||||
def test_rsa_certificate(tmpdir):
|
||||
def test_rsa_certificate(tmpdir) -> None:
|
||||
cert_file = tmpdir / "id_rsa-cert.pub"
|
||||
cert_file.write(RSA_CERT_SIGNED_BY_DSA, mode="wb")
|
||||
|
||||
@@ -285,7 +287,7 @@ def test_rsa_certificate(tmpdir):
|
||||
assert cert.signing_key == DSA_FINGERPRINT
|
||||
|
||||
|
||||
def test_dsa_certificate(tmpdir):
|
||||
def test_dsa_certificate(tmpdir) -> None:
|
||||
cert_file = tmpdir / "id_dsa-cert.pub"
|
||||
cert_file.write(DSA_CERT_SIGNED_BY_ECDSA_NO_OPTS)
|
||||
|
||||
@@ -298,7 +300,7 @@ def test_dsa_certificate(tmpdir):
|
||||
assert cert.extensions == []
|
||||
|
||||
|
||||
def test_ecdsa_certificate(tmpdir):
|
||||
def test_ecdsa_certificate(tmpdir) -> None:
|
||||
cert_file = tmpdir / "id_ecdsa-cert.pub"
|
||||
cert_file.write(ECDSA_CERT_SIGNED_BY_ED25519_VALID_OPTS)
|
||||
|
||||
@@ -310,7 +312,7 @@ def test_ecdsa_certificate(tmpdir):
|
||||
assert cert.extensions == VALID_EXTENSIONS
|
||||
|
||||
|
||||
def test_ed25519_certificate(tmpdir):
|
||||
def test_ed25519_certificate(tmpdir) -> None:
|
||||
cert_file = tmpdir / "id_ed25519-cert.pub"
|
||||
cert_file.write(ED25519_CERT_SIGNED_BY_RSA_INVALID_OPTS)
|
||||
|
||||
@@ -322,7 +324,7 @@ def test_ed25519_certificate(tmpdir):
|
||||
assert cert.extensions == INVALID_EXTENSIONS
|
||||
|
||||
|
||||
def test_invalid_data(tmpdir):
|
||||
def test_invalid_data(tmpdir) -> None:
|
||||
result = False
|
||||
cert_file = tmpdir / "invalid-cert.pub"
|
||||
cert_file.write(INVALID_DATA)
|
||||
@@ -341,16 +343,16 @@ def test_invalid_data(tmpdir):
|
||||
VALID_TIME_PARAMETERS,
|
||||
)
|
||||
def test_valid_time_parameters(
|
||||
valid_from,
|
||||
valid_from_hr,
|
||||
valid_from_openssh,
|
||||
valid_from_timestamp,
|
||||
valid_to,
|
||||
valid_to_hr,
|
||||
valid_to_openssh,
|
||||
valid_to_timestamp,
|
||||
validity_string,
|
||||
):
|
||||
valid_from: int | str,
|
||||
valid_from_hr: int | str,
|
||||
valid_from_openssh: str,
|
||||
valid_from_timestamp: int,
|
||||
valid_to: int | str,
|
||||
valid_to_hr: str,
|
||||
valid_to_openssh: str,
|
||||
valid_to_timestamp: int,
|
||||
validity_string: str,
|
||||
) -> None:
|
||||
time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=valid_from, valid_to=valid_to
|
||||
)
|
||||
@@ -364,35 +366,37 @@ def test_valid_time_parameters(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("valid_from,valid_to", INVALID_TIME_PARAMETERS)
|
||||
def test_invalid_time_parameters(valid_from, valid_to):
|
||||
def test_invalid_time_parameters(valid_from: int | str, valid_to: int | str) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpensshCertificateTimeParameters(valid_from, valid_to)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("valid_from,valid_to,valid_at", VALID_VALIDITY_TEST)
|
||||
def test_valid_validity_test(valid_from, valid_to, valid_at):
|
||||
def test_valid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None:
|
||||
assert OpensshCertificateTimeParameters(valid_from, valid_to).within_range(valid_at)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("valid_from,valid_to,valid_at", INVALID_VALIDITY_TEST)
|
||||
def test_invalid_validity_test(valid_from, valid_to, valid_at):
|
||||
def test_invalid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None:
|
||||
assert not OpensshCertificateTimeParameters(valid_from, valid_to).within_range(
|
||||
valid_at
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("option_string,option_object", VALID_OPTIONS)
|
||||
def test_valid_options(option_string, option_object):
|
||||
def test_valid_options(
|
||||
option_string: str, option_object: OpensshCertificateOption
|
||||
) -> None:
|
||||
assert OpensshCertificateOption.from_string(option_string) == option_object
|
||||
|
||||
|
||||
@pytest.mark.parametrize("option_string", INVALID_OPTIONS)
|
||||
def test_invalid_options(option_string):
|
||||
def test_invalid_options(option_string: str) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpensshCertificateOption.from_string(option_string)
|
||||
|
||||
|
||||
def test_parse_option_list():
|
||||
def test_parse_option_list() -> None:
|
||||
critical_options, extensions = parse_option_list(["force-command=/usr/bin/csh"])
|
||||
|
||||
critical_option_objects = [
|
||||
@@ -411,7 +415,7 @@ def test_parse_option_list():
|
||||
assert set(extensions) == set(extension_objects)
|
||||
|
||||
|
||||
def test_parse_option_list_with_directives():
|
||||
def test_parse_option_list_with_directives() -> None:
|
||||
critical_options, extensions = parse_option_list(
|
||||
["clear", "no-pty", "permit-pty", "permit-user-rc"]
|
||||
)
|
||||
@@ -425,7 +429,7 @@ def test_parse_option_list_with_directives():
|
||||
assert set(extensions) == set(extension_objects)
|
||||
|
||||
|
||||
def test_parse_option_list_case_sensitivity():
|
||||
def test_parse_option_list_case_sensitivity() -> None:
|
||||
critical_options, extensions = parse_option_list(
|
||||
["CLEAR", "no-X11-forwarding", "permit-X11-forwarding"]
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os.path
|
||||
import typing as t
|
||||
from getpass import getuser
|
||||
from os import remove, rmdir
|
||||
from socket import gethostname
|
||||
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptogra
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_KEY_PARAMS = [
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptography import (
|
||||
KeyType,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
|
||||
(
|
||||
"rsa",
|
||||
None,
|
||||
@@ -50,7 +57,7 @@ DEFAULT_KEY_PARAMS = [
|
||||
),
|
||||
]
|
||||
|
||||
VALID_USER_KEY_PARAMS = [
|
||||
VALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
|
||||
(
|
||||
"rsa",
|
||||
8192,
|
||||
@@ -77,9 +84,9 @@ VALID_USER_KEY_PARAMS = [
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_USER_KEY_PARAMS = [
|
||||
INVALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
|
||||
(
|
||||
"dne",
|
||||
"dne", # type: ignore
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -87,18 +94,18 @@ INVALID_USER_KEY_PARAMS = [
|
||||
(
|
||||
"rsa",
|
||||
None,
|
||||
[1, 2, 3],
|
||||
[1, 2, 3], # type: ignore
|
||||
"comment",
|
||||
),
|
||||
(
|
||||
"ecdsa",
|
||||
None,
|
||||
None,
|
||||
[1, 2, 3],
|
||||
[1, 2, 3], # type: ignore
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_KEY_SIZES = [
|
||||
INVALID_KEY_SIZES: list[tuple[KeyType, int | None, bytes | None, str | None]] = [
|
||||
(
|
||||
"rsa",
|
||||
1023,
|
||||
@@ -134,7 +141,9 @@ INVALID_KEY_SIZES = [
|
||||
|
||||
@pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS)
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_default_key_params(keytype, size, passphrase, comment):
|
||||
def test_default_key_params(
|
||||
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
|
||||
) -> None:
|
||||
result = True
|
||||
|
||||
default_sizes = {
|
||||
@@ -163,7 +172,9 @@ def test_default_key_params(keytype, size, passphrase, comment):
|
||||
|
||||
@pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS)
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_valid_user_key_params(keytype, size, passphrase, comment):
|
||||
def test_valid_user_key_params(
|
||||
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
|
||||
) -> None:
|
||||
result = True
|
||||
|
||||
try:
|
||||
@@ -181,7 +192,9 @@ def test_valid_user_key_params(keytype, size, passphrase, comment):
|
||||
|
||||
@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS)
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_invalid_user_key_params(keytype, size, passphrase, comment):
|
||||
def test_invalid_user_key_params(
|
||||
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
|
||||
) -> None:
|
||||
result = False
|
||||
|
||||
try:
|
||||
@@ -199,7 +212,9 @@ def test_invalid_user_key_params(keytype, size, passphrase, comment):
|
||||
|
||||
@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES)
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_invalid_key_sizes(keytype, size, passphrase, comment):
|
||||
def test_invalid_key_sizes(
|
||||
keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None
|
||||
) -> None:
|
||||
result = False
|
||||
|
||||
try:
|
||||
@@ -216,7 +231,7 @@ def test_invalid_key_sizes(keytype, size, passphrase, comment):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_valid_comment_update():
|
||||
def test_valid_comment_update() -> None:
|
||||
|
||||
pair = OpensshKeypair.generate()
|
||||
new_comment = "comment"
|
||||
@@ -233,13 +248,13 @@ def test_valid_comment_update():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_invalid_comment_update():
|
||||
def test_invalid_comment_update() -> None:
|
||||
result = False
|
||||
|
||||
pair = OpensshKeypair.generate()
|
||||
new_comment = [1, 2, 3]
|
||||
try:
|
||||
pair.comment = new_comment
|
||||
pair.comment = new_comment # type: ignore
|
||||
except InvalidCommentError:
|
||||
result = True
|
||||
|
||||
@@ -247,7 +262,7 @@ def test_invalid_comment_update():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_valid_passphrase_update():
|
||||
def test_valid_passphrase_update() -> None:
|
||||
result = False
|
||||
|
||||
passphrase = "change_me".encode("UTF-8")
|
||||
@@ -281,13 +296,13 @@ def test_valid_passphrase_update():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_invalid_passphrase_update():
|
||||
def test_invalid_passphrase_update() -> None:
|
||||
result = False
|
||||
|
||||
passphrase = [1, 2, 3]
|
||||
pair = OpensshKeypair.generate()
|
||||
try:
|
||||
pair.update_passphrase(passphrase)
|
||||
pair.update_passphrase(passphrase) # type: ignore
|
||||
except InvalidPassphraseError:
|
||||
result = True
|
||||
|
||||
@@ -295,7 +310,7 @@ def test_invalid_passphrase_update():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_invalid_privatekey():
|
||||
def test_invalid_privatekey() -> None:
|
||||
result = False
|
||||
|
||||
try:
|
||||
@@ -325,7 +340,7 @@ def test_invalid_privatekey():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_mismatched_keypair():
|
||||
def test_mismatched_keypair() -> None:
|
||||
result = False
|
||||
|
||||
try:
|
||||
@@ -356,7 +371,7 @@ def test_mismatched_keypair():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography")
|
||||
def test_keypair_comparison():
|
||||
def test_keypair_comparison() -> None:
|
||||
assert OpensshKeypair.generate() != OpensshKeypair.generate()
|
||||
assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="dsa")
|
||||
assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="ed25519")
|
||||
@@ -366,7 +381,7 @@ def test_keypair_comparison():
|
||||
try:
|
||||
tmpdir = mkdtemp()
|
||||
|
||||
keys = {
|
||||
keys: dict[str, dict[str, t.Any]] = {
|
||||
"rsa": {
|
||||
"pair": OpensshKeypair.generate(),
|
||||
"filename": os.path.join(tmpdir, "id_rsa"),
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
|
||||
OpensshParser,
|
||||
@@ -15,36 +17,36 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.utils imp
|
||||
SSH_VERSION_STRING = "OpenSSH_7.9p1, OpenSSL 1.1.0i-fips 14 Aug 2018"
|
||||
SSH_VERSION_NUMBER = "7.9"
|
||||
|
||||
VALID_BOOLEAN = [True, False]
|
||||
INVALID_BOOLEAN = [0x02]
|
||||
VALID_UINT32 = [
|
||||
VALID_BOOLEAN: list[bool] = [True, False]
|
||||
INVALID_BOOLEAN: list[t.Any] = [0x02]
|
||||
VALID_UINT32: list[int] = [
|
||||
0x00,
|
||||
0x01,
|
||||
0x01234567,
|
||||
0xFFFFFFFF,
|
||||
]
|
||||
INVALID_UINT32 = [
|
||||
INVALID_UINT32: list[int] = [
|
||||
0xFFFFFFFFF,
|
||||
-1,
|
||||
]
|
||||
VALID_UINT64 = [
|
||||
VALID_UINT64: list[int] = [
|
||||
0x00,
|
||||
0x01,
|
||||
0x0123456789ABCDEF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
]
|
||||
INVALID_UINT64 = [
|
||||
INVALID_UINT64: list[int] = [
|
||||
0xFFFFFFFFFFFFFFFFF,
|
||||
-1,
|
||||
]
|
||||
VALID_STRING = [
|
||||
VALID_STRING: list[bytes] = [
|
||||
b"test string",
|
||||
]
|
||||
INVALID_STRING = [
|
||||
INVALID_STRING: list[t.Any] = [
|
||||
[],
|
||||
]
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for examples source
|
||||
VALID_MPINT = [
|
||||
VALID_MPINT: list[int] = [
|
||||
0x00,
|
||||
0x9A378F9B2E332A7,
|
||||
0x80,
|
||||
@@ -53,50 +55,50 @@ VALID_MPINT = [
|
||||
# Additional large int test
|
||||
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
|
||||
]
|
||||
INVALID_MPINT = [
|
||||
INVALID_MPINT: list[t.Any] = [
|
||||
[],
|
||||
]
|
||||
|
||||
|
||||
def test_parse_openssh_version():
|
||||
def test_parse_openssh_version() -> None:
|
||||
assert parse_openssh_version(SSH_VERSION_STRING) == SSH_VERSION_NUMBER
|
||||
|
||||
|
||||
@pytest.mark.parametrize("boolean", VALID_BOOLEAN)
|
||||
def test_valid_boolean(boolean):
|
||||
def test_valid_boolean(boolean: bool) -> None:
|
||||
assert OpensshParser(_OpensshWriter().boolean(boolean).bytes()).boolean() == boolean
|
||||
|
||||
|
||||
@pytest.mark.parametrize("boolean", INVALID_BOOLEAN)
|
||||
def test_invalid_boolean(boolean):
|
||||
def test_invalid_boolean(boolean: t.Any) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
_OpensshWriter().boolean(boolean)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("uint32", VALID_UINT32)
|
||||
def test_valid_uint32(uint32):
|
||||
def test_valid_uint32(uint32: int) -> None:
|
||||
assert OpensshParser(_OpensshWriter().uint32(uint32).bytes()).uint32() == uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize("uint32", INVALID_UINT32)
|
||||
def test_invalid_uint32(uint32):
|
||||
def test_invalid_uint32(uint32: int) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_OpensshWriter().uint32(uint32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("uint64", VALID_UINT64)
|
||||
def test_valid_uint64(uint64):
|
||||
def test_valid_uint64(uint64: int) -> None:
|
||||
assert OpensshParser(_OpensshWriter().uint64(uint64).bytes()).uint64() == uint64
|
||||
|
||||
|
||||
@pytest.mark.parametrize("uint64", INVALID_UINT64)
|
||||
def test_invalid_uint64(uint64):
|
||||
def test_invalid_uint64(uint64: int) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_OpensshWriter().uint64(uint64)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ssh_string", VALID_STRING)
|
||||
def test_valid_string(ssh_string):
|
||||
def test_valid_string(ssh_string: bytes) -> None:
|
||||
assert (
|
||||
OpensshParser(_OpensshWriter().string(ssh_string).bytes()).string()
|
||||
== ssh_string
|
||||
@@ -104,23 +106,23 @@ def test_valid_string(ssh_string):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ssh_string", INVALID_STRING)
|
||||
def test_invalid_string(ssh_string):
|
||||
def test_invalid_string(ssh_string: t.Any) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
_OpensshWriter().string(ssh_string)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mpint", VALID_MPINT)
|
||||
def test_valid_mpint(mpint):
|
||||
def test_valid_mpint(mpint: int) -> None:
|
||||
assert OpensshParser(_OpensshWriter().mpint(mpint).bytes()).mpint() == mpint
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mpint", INVALID_MPINT)
|
||||
def test_invalid_mpint(mpint):
|
||||
def test_invalid_mpint(mpint: t.Any) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
_OpensshWriter().mpint(mpint)
|
||||
|
||||
|
||||
def test_valid_seek():
|
||||
def test_valid_seek() -> None:
|
||||
buffer = bytearray(b"buffer")
|
||||
parser = OpensshParser(buffer)
|
||||
parser.seek(len(buffer))
|
||||
@@ -129,7 +131,7 @@ def test_valid_seek():
|
||||
assert parser.remaining_bytes() == len(buffer)
|
||||
|
||||
|
||||
def test_invalid_seek():
|
||||
def test_invalid_seek() -> None:
|
||||
buffer = b"buffer"
|
||||
parser = OpensshParser(buffer)
|
||||
|
||||
@@ -140,6 +142,6 @@ def test_invalid_seek():
|
||||
parser.seek(-1)
|
||||
|
||||
|
||||
def test_writer_bytes():
|
||||
def test_writer_bytes() -> None:
|
||||
buffer = bytearray(b"buffer")
|
||||
assert _OpensshWriter(buffer).bytes() == buffer
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
from ansible.module_utils.common.collections import is_sequence
|
||||
from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
UTC,
|
||||
add_or_remove_timezone,
|
||||
@@ -30,25 +30,27 @@ TIMEZONES = [
|
||||
]
|
||||
|
||||
|
||||
def cartesian_product(list1, list2):
|
||||
result = []
|
||||
if t.TYPE_CHECKING:
|
||||
_S = t.TypeVar("_S")
|
||||
_Ts = t.TypeVarTuple("_Ts")
|
||||
|
||||
|
||||
def cartesian_product(
|
||||
list1: list[_S], list2: "list[tuple[*_Ts]]"
|
||||
) -> "list[tuple[_S, *_Ts]]":
|
||||
result: "list[tuple[_S, *_Ts]]" = []
|
||||
for item1 in list1:
|
||||
if not is_sequence(item1):
|
||||
item1 = (item1,)
|
||||
elif not isinstance(item1, tuple):
|
||||
item1 = tuple(item1)
|
||||
item1_tuple = (item1,)
|
||||
for item2 in list2:
|
||||
if not is_sequence(item2):
|
||||
item2 = (item2,)
|
||||
elif not isinstance(item2, tuple):
|
||||
item2 = tuple(item2)
|
||||
result.append(item1 + item2)
|
||||
result.append(item1_tuple + item2)
|
||||
return result
|
||||
|
||||
|
||||
ONE_HOUR_PLUS = datetime.timezone(datetime.timedelta(hours=1))
|
||||
|
||||
TEST_REMOVE_TIMEZONE = cartesian_product(
|
||||
TEST_REMOVE_TIMEZONE: list[
|
||||
tuple[datetime.timedelta, datetime.datetime, datetime.datetime]
|
||||
] = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
@@ -66,7 +68,9 @@ TEST_REMOVE_TIMEZONE = cartesian_product(
|
||||
],
|
||||
)
|
||||
|
||||
TEST_UTC_TIMEZONE = cartesian_product(
|
||||
TEST_UTC_TIMEZONE: list[
|
||||
tuple[datetime.timedelta, datetime.datetime, datetime.datetime]
|
||||
] = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
@@ -84,48 +88,67 @@ TEST_UTC_TIMEZONE = cartesian_product(
|
||||
],
|
||||
)
|
||||
|
||||
TEST_EPOCH_SECONDS = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(0, dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0)),
|
||||
(
|
||||
1e-6,
|
||||
dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1),
|
||||
),
|
||||
(
|
||||
1e-3,
|
||||
dict(
|
||||
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1000
|
||||
TEST_EPOCH_SECONDS: list[tuple[datetime.timedelta, float, dict[str, int]]] = (
|
||||
cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
0,
|
||||
dict(
|
||||
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
3691.2,
|
||||
dict(
|
||||
year=1970,
|
||||
day=1,
|
||||
month=1,
|
||||
hour=1,
|
||||
minute=1,
|
||||
second=31,
|
||||
microsecond=200000,
|
||||
(
|
||||
1e-6,
|
||||
dict(
|
||||
year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
(
|
||||
1e-3,
|
||||
dict(
|
||||
year=1970,
|
||||
day=1,
|
||||
month=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=1000,
|
||||
),
|
||||
),
|
||||
(
|
||||
3691.2,
|
||||
dict(
|
||||
year=1970,
|
||||
day=1,
|
||||
month=1,
|
||||
hour=1,
|
||||
minute=1,
|
||||
second=31,
|
||||
microsecond=200000,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
TEST_EPOCH_TO_SECONDS = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62),
|
||||
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62),
|
||||
(
|
||||
datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS),
|
||||
62 - 3600,
|
||||
),
|
||||
],
|
||||
TEST_EPOCH_TO_SECONDS: list[tuple[datetime.timedelta, datetime.datetime, int]] = (
|
||||
cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62),
|
||||
(datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62),
|
||||
(
|
||||
datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS),
|
||||
62 - 3600,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product(
|
||||
TEST_CONVERT_RELATIVE_TO_DATETIME: list[
|
||||
tuple[datetime.timedelta, str, bool, datetime.datetime, datetime.datetime]
|
||||
] = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
@@ -167,7 +190,9 @@ TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product(
|
||||
],
|
||||
)
|
||||
|
||||
TEST_GET_RELATIVE_TIME_OPTION = cartesian_product(
|
||||
TEST_GET_RELATIVE_TIME_OPTION: list[
|
||||
tuple[datetime.timedelta, str, str, bool, datetime.datetime, datetime.datetime]
|
||||
] = cartesian_product(
|
||||
TIMEZONES,
|
||||
[
|
||||
(
|
||||
@@ -259,7 +284,9 @@ TEST_GET_RELATIVE_TIME_OPTION = cartesian_product(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, input, expected", TEST_REMOVE_TIMEZONE)
|
||||
def test_remove_timezone(timezone, input, expected):
|
||||
def test_remove_timezone(
|
||||
timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output_1 = remove_timezone(input)
|
||||
assert expected == output_1
|
||||
@@ -268,7 +295,9 @@ def test_remove_timezone(timezone, input, expected):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, input, expected", TEST_UTC_TIMEZONE)
|
||||
def test_utc_timezone(timezone, input, expected):
|
||||
def test_utc_timezone(
|
||||
timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output_1 = ensure_utc_timezone(input)
|
||||
assert expected == output_1
|
||||
@@ -280,7 +309,7 @@ def test_utc_timezone(timezone, input, expected):
|
||||
# Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553)
|
||||
# this only works with timezone = UTC
|
||||
@pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)])
|
||||
def test_get_now_datetime_w_timezone(timezone):
|
||||
def test_get_now_datetime_w_timezone(timezone: datetime.timedelta) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output_2 = get_now_datetime(with_timezone=True)
|
||||
assert output_2.tzinfo is not None
|
||||
@@ -289,7 +318,7 @@ def test_get_now_datetime_w_timezone(timezone):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone", TIMEZONES)
|
||||
def test_get_now_datetime_wo_timezone(timezone):
|
||||
def test_get_now_datetime_wo_timezone(timezone: datetime.timedelta) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output_1 = get_now_datetime(with_timezone=False)
|
||||
assert output_1.tzinfo is None
|
||||
@@ -297,13 +326,15 @@ def test_get_now_datetime_wo_timezone(timezone):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, seconds, timestamp", TEST_EPOCH_SECONDS)
|
||||
def test_epoch_seconds(timezone, seconds, timestamp):
|
||||
def test_epoch_seconds(
|
||||
timezone: datetime.timedelta, seconds: float, timestamp: dict[str, int]
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
ts_wo_tz = datetime.datetime(**timestamp)
|
||||
ts_wo_tz: datetime.datetime = datetime.datetime(**timestamp) # type: ignore
|
||||
assert seconds == get_epoch_seconds(ts_wo_tz)
|
||||
timestamp_w_tz = dict(timestamp)
|
||||
timestamp_w_tz: dict[str, t.Any] = dict(timestamp)
|
||||
timestamp_w_tz["tzinfo"] = UTC
|
||||
ts_w_tz = datetime.datetime(**timestamp_w_tz)
|
||||
ts_w_tz: datetime.datetime = datetime.datetime(**timestamp_w_tz) # type: ignore
|
||||
assert seconds == get_epoch_seconds(ts_w_tz)
|
||||
output_1 = from_epoch_seconds(seconds, with_timezone=False)
|
||||
assert ts_wo_tz == output_1
|
||||
@@ -312,7 +343,9 @@ def test_epoch_seconds(timezone, seconds, timestamp):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timezone, timestamp, expected_seconds", TEST_EPOCH_TO_SECONDS)
|
||||
def test_epoch_to_seconds(timezone, timestamp, expected_seconds):
|
||||
def test_epoch_to_seconds(
|
||||
timezone: datetime.timedelta, timestamp: datetime.datetime, expected_seconds: int
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
assert expected_seconds == get_epoch_seconds(timestamp)
|
||||
|
||||
@@ -322,8 +355,12 @@ def test_epoch_to_seconds(timezone, timestamp, expected_seconds):
|
||||
TEST_CONVERT_RELATIVE_TO_DATETIME,
|
||||
)
|
||||
def test_convert_relative_to_datetime(
|
||||
timezone, relative_time_string, with_timezone, now, expected
|
||||
):
|
||||
timezone: datetime.timedelta,
|
||||
relative_time_string: str,
|
||||
with_timezone: bool,
|
||||
now: datetime.datetime,
|
||||
expected: datetime.datetime,
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output = convert_relative_to_datetime(
|
||||
relative_time_string, with_timezone=with_timezone, now=now
|
||||
@@ -336,8 +373,13 @@ def test_convert_relative_to_datetime(
|
||||
TEST_GET_RELATIVE_TIME_OPTION,
|
||||
)
|
||||
def test_get_relative_time_option(
|
||||
timezone, input_string, input_name, with_timezone, now, expected
|
||||
):
|
||||
timezone: datetime.timedelta,
|
||||
input_string: str,
|
||||
input_name: str,
|
||||
with_timezone: bool,
|
||||
now: datetime.datetime,
|
||||
expected: datetime.datetime,
|
||||
) -> None:
|
||||
with freeze_time("2024-02-03 04:05:06", tz_offset=timezone):
|
||||
output = get_relative_time_option(
|
||||
input_string,
|
||||
|
||||
Reference in New Issue
Block a user