iptables_state: get rid of temporary files (#11258)

Get rid of temporary files.
This commit is contained in:
Felix Fontein
2025-12-06 13:40:59 +01:00
committed by GitHub
parent 3d25aac978
commit 0ef3eac0f4
2 changed files with 47 additions and 46 deletions

View File

@@ -0,0 +1,2 @@
bugfixes:
- "iptables_state - refactor code to avoid writing unnecessary temporary files (https://github.com/ansible-collections/community.general/pull/11258)."

View File

@@ -225,9 +225,6 @@ tables:
import re import re
import os import os
import time import time
import tempfile
import filecmp
import shutil
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -260,18 +257,28 @@ def read_state(b_path):
return [t for t in text.splitlines() if t != ""] return [t for t in text.splitlines() if t != ""]
def write_state(b_path, lines, changed): def get_file_contents(b_path: bytes) -> bytes | None:
try:
with open(b_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def write_state(module: AnsibleModule, b_path: bytes, lines: list[str], changed: bool) -> bool:
""" """
Write given contents to the given path, and return changed status. Write given contents to the given path, and return changed status.
""" """
# Populate a temporary file joined_lines = "\n".join(lines)
tmpfd, tmpfile = tempfile.mkstemp() content = f"{joined_lines}\n".encode("utf-8")
with os.fdopen(tmpfd, "w") as f:
joined_lines = "\n".join(lines)
f.write(f"{joined_lines}\n")
# Prepare to copy temporary file to the final destination existing_contents = get_file_contents(b_path)
if not os.path.exists(b_path): if existing_contents == content:
return changed
changed = True
if existing_contents is None:
b_destdir = os.path.dirname(b_path) b_destdir = os.path.dirname(b_path)
destdir = to_native(b_destdir, errors="surrogate_or_strict") destdir = to_native(b_destdir, errors="surrogate_or_strict")
if b_destdir and not os.path.exists(b_destdir) and not module.check_mode: if b_destdir and not os.path.exists(b_destdir) and not module.check_mode:
@@ -279,15 +286,11 @@ def write_state(b_path, lines, changed):
os.makedirs(b_destdir) os.makedirs(b_destdir)
except Exception as err: except Exception as err:
module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines) module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines)
changed = True
elif not filecmp.cmp(tmpfile, b_path): if not module.check_mode:
changed = True
# Do it
if changed and not module.check_mode:
try: try:
shutil.copyfile(tmpfile, b_path) with open(b_path, "wb") as f:
f.write(content)
except Exception as err: except Exception as err:
path = to_native(b_path, errors="surrogate_or_strict") path = to_native(b_path, errors="surrogate_or_strict")
module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines) module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines)
@@ -295,7 +298,7 @@ def write_state(b_path, lines, changed):
return changed return changed
def initialize_from_null_state(initializer, initcommand, fallbackcmd, table): def initialize_from_null_state(module: AnsibleModule, initializer, initcommand, fallbackcmd, table):
""" """
This ensures iptables-state output is suitable for iptables-restore to roll This ensures iptables-state output is suitable for iptables-restore to roll
back to it, i.e. iptables-save output is not empty. This also works for the back to it, i.e. iptables-save output is not empty. This also works for the
@@ -317,7 +320,7 @@ def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
return rc, out, err return rc, out, err
def filter_and_format_state(string): def filter_and_format_state(module: AnsibleModule, string: str) -> list[str]:
""" """
Remove timestamps to ensure idempotence between runs. Also remove counters Remove timestamps to ensure idempotence between runs. Also remove counters
by default. And return the result as a list. by default. And return the result as a list.
@@ -329,15 +332,15 @@ def filter_and_format_state(string):
return lines return lines
def parse_per_table_state(all_states_dump): def parse_per_table_state(module: AnsibleModule, all_states_dump) -> dict[str, list[str]]:
""" """
Convert raw iptables-save output into usable datastructure, for reliable Convert raw iptables-save output into usable datastructure, for reliable
comparisons between initial and final states. comparisons between initial and final states.
""" """
lines = filter_and_format_state(all_states_dump) lines = filter_and_format_state(module, all_states_dump)
tables = dict() tables: dict[str, list[str]] = {}
current_table = "" current_table = ""
current_list = list() current_list: list[str] = []
for line in lines: for line in lines:
if re.match(r"^[*](filter|mangle|nat|raw|security)$", line): if re.match(r"^[*](filter|mangle|nat|raw|security)$", line):
current_table = line[1:] current_table = line[1:]
@@ -345,7 +348,7 @@ def parse_per_table_state(all_states_dump):
if line == "COMMIT": if line == "COMMIT":
tables[current_table] = current_list tables[current_table] = current_list
current_table = "" current_table = ""
current_list = list() current_list = []
continue continue
if line.startswith("# "): if line.startswith("# "):
continue continue
@@ -353,9 +356,7 @@ def parse_per_table_state(all_states_dump):
return tables return tables
def main(): def main() -> None:
global module
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path", required=True), path=dict(type="path", required=True),
@@ -459,28 +460,30 @@ def main():
for t in TABLES: for t in TABLES:
if f"*{t}" in state_to_restore: if f"*{t}" in state_to_restore:
if len(stdout) == 0 or f"*{t}" not in stdout.splitlines(): if len(stdout) == 0 or f"*{t}" not in stdout.splitlines():
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, t) (rc, stdout, stderr) = initialize_from_null_state(
module, INITIALIZER, INITCOMMAND, FALLBACKCMD, t
)
elif len(stdout) == 0: elif len(stdout) == 0:
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter") (rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
elif state == "restored" and f"*{table}" not in state_to_restore: elif state == "restored" and f"*{table}" not in state_to_restore:
module.fail_json(msg=f"Table {table} to restore not defined in {path}") module.fail_json(msg=f"Table {table} to restore not defined in {path}")
elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines(): elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines():
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, table) (rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
initial_state = filter_and_format_state(stdout) initial_state = filter_and_format_state(module, stdout)
if initial_state is None: if initial_state is None:
module.fail_json(msg="Unable to initialize firewall from NULL state.") module.fail_json(msg="Unable to initialize firewall from NULL state.")
# Depending on the value of 'table', initref_state may differ from # Depending on the value of 'table', initref_state may differ from
# initial_state. # initial_state.
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
tables_before = parse_per_table_state(stdout) tables_before = parse_per_table_state(module, stdout)
initref_state = filter_and_format_state(stdout) initref_state = filter_and_format_state(module, stdout)
if state == "saved": if state == "saved":
changed = write_state(b_path, initref_state, changed) changed = write_state(module, b_path, initref_state, changed)
module.exit_json( module.exit_json(
changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state
) )
@@ -497,7 +500,7 @@ def main():
if _back is not None: if _back is not None:
b_back = to_bytes(_back, errors="surrogate_or_strict") b_back = to_bytes(_back, errors="surrogate_or_strict")
dummy = write_state(b_back, initref_state, changed) dummy = write_state(module, b_back, initref_state, changed)
BACKCOMMAND = list(MAINCOMMAND) BACKCOMMAND = list(MAINCOMMAND)
BACKCOMMAND.append(_back) BACKCOMMAND.append(_back)
@@ -536,12 +539,8 @@ def main():
) )
if module.check_mode: if module.check_mode:
tmpfd, tmpfile = tempfile.mkstemp() joined_initial_state = "\n".join(initial_state)
with os.fdopen(tmpfd, "w") as f: if get_file_contents(b_path) == f"{joined_initial_state}\n".encode("utf-8"):
joined_initial_state = "\n".join(initial_state)
f.write(f"{joined_initial_state}\n")
if filecmp.cmp(tmpfile, b_path):
restored_state = initial_state restored_state = initial_state
else: else:
restored_state = state_to_restore restored_state = state_to_restore
@@ -572,8 +571,8 @@ def main():
) )
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
restored_state = filter_and_format_state(stdout) restored_state = filter_and_format_state(module, stdout)
tables_after = parse_per_table_state("\n".join(restored_state)) tables_after = parse_per_table_state(module, "\n".join(restored_state))
if restored_state not in (initref_state, initial_state): if restored_state not in (initref_state, initial_state):
for table_name, table_content in tables_after.items(): for table_name, table_content in tables_after.items():
if table_name not in tables_before: if table_name not in tables_before:
@@ -608,7 +607,7 @@ def main():
# timeout # timeout
# * task attribute 'poll' equals 0 # * task attribute 'poll' equals 0
# #
for dummy in range(_timeout): for dummy2 in range(_timeout):
if os.path.exists(b_back): if os.path.exists(b_back):
time.sleep(1) time.sleep(1)
continue continue
@@ -628,7 +627,7 @@ def main():
os.remove(b_back) os.remove(b_back)
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
tables_rollback = parse_per_table_state(stdout) tables_rollback = parse_per_table_state(module, stdout)
msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state." msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."