mirror of
https://github.com/ansible-collections/community.general.git
synced 2026-03-26 21:33:12 +00:00
iptables_state: get rid of temporary files (#11258)
Get rid of temporary files.
This commit is contained in:
2
changelogs/fragments/11258-iptables_state.yml
Normal file
2
changelogs/fragments/11258-iptables_state.yml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
bugfixes:
|
||||||
|
- "iptables_state - refactor code to avoid writing unnecessary temporary files (https://github.com/ansible-collections/community.general/pull/11258)."
|
||||||
@@ -225,9 +225,6 @@ tables:
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import tempfile
|
|
||||||
import filecmp
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
from ansible.module_utils.basic import AnsibleModule
|
from ansible.module_utils.basic import AnsibleModule
|
||||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||||
@@ -260,18 +257,28 @@ def read_state(b_path):
|
|||||||
return [t for t in text.splitlines() if t != ""]
|
return [t for t in text.splitlines() if t != ""]
|
||||||
|
|
||||||
|
|
||||||
def write_state(b_path, lines, changed):
|
def get_file_contents(b_path: bytes) -> bytes | None:
|
||||||
|
try:
|
||||||
|
with open(b_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def write_state(module: AnsibleModule, b_path: bytes, lines: list[str], changed: bool) -> bool:
|
||||||
"""
|
"""
|
||||||
Write given contents to the given path, and return changed status.
|
Write given contents to the given path, and return changed status.
|
||||||
"""
|
"""
|
||||||
# Populate a temporary file
|
joined_lines = "\n".join(lines)
|
||||||
tmpfd, tmpfile = tempfile.mkstemp()
|
content = f"{joined_lines}\n".encode("utf-8")
|
||||||
with os.fdopen(tmpfd, "w") as f:
|
|
||||||
joined_lines = "\n".join(lines)
|
|
||||||
f.write(f"{joined_lines}\n")
|
|
||||||
|
|
||||||
# Prepare to copy temporary file to the final destination
|
existing_contents = get_file_contents(b_path)
|
||||||
if not os.path.exists(b_path):
|
if existing_contents == content:
|
||||||
|
return changed
|
||||||
|
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if existing_contents is None:
|
||||||
b_destdir = os.path.dirname(b_path)
|
b_destdir = os.path.dirname(b_path)
|
||||||
destdir = to_native(b_destdir, errors="surrogate_or_strict")
|
destdir = to_native(b_destdir, errors="surrogate_or_strict")
|
||||||
if b_destdir and not os.path.exists(b_destdir) and not module.check_mode:
|
if b_destdir and not os.path.exists(b_destdir) and not module.check_mode:
|
||||||
@@ -279,15 +286,11 @@ def write_state(b_path, lines, changed):
|
|||||||
os.makedirs(b_destdir)
|
os.makedirs(b_destdir)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines)
|
module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines)
|
||||||
changed = True
|
|
||||||
|
|
||||||
elif not filecmp.cmp(tmpfile, b_path):
|
if not module.check_mode:
|
||||||
changed = True
|
|
||||||
|
|
||||||
# Do it
|
|
||||||
if changed and not module.check_mode:
|
|
||||||
try:
|
try:
|
||||||
shutil.copyfile(tmpfile, b_path)
|
with open(b_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
path = to_native(b_path, errors="surrogate_or_strict")
|
path = to_native(b_path, errors="surrogate_or_strict")
|
||||||
module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines)
|
module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines)
|
||||||
@@ -295,7 +298,7 @@ def write_state(b_path, lines, changed):
|
|||||||
return changed
|
return changed
|
||||||
|
|
||||||
|
|
||||||
def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
|
def initialize_from_null_state(module: AnsibleModule, initializer, initcommand, fallbackcmd, table):
|
||||||
"""
|
"""
|
||||||
This ensures iptables-state output is suitable for iptables-restore to roll
|
This ensures iptables-state output is suitable for iptables-restore to roll
|
||||||
back to it, i.e. iptables-save output is not empty. This also works for the
|
back to it, i.e. iptables-save output is not empty. This also works for the
|
||||||
@@ -317,7 +320,7 @@ def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
|
|||||||
return rc, out, err
|
return rc, out, err
|
||||||
|
|
||||||
|
|
||||||
def filter_and_format_state(string):
|
def filter_and_format_state(module: AnsibleModule, string: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Remove timestamps to ensure idempotence between runs. Also remove counters
|
Remove timestamps to ensure idempotence between runs. Also remove counters
|
||||||
by default. And return the result as a list.
|
by default. And return the result as a list.
|
||||||
@@ -329,15 +332,15 @@ def filter_and_format_state(string):
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def parse_per_table_state(all_states_dump):
|
def parse_per_table_state(module: AnsibleModule, all_states_dump) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Convert raw iptables-save output into usable datastructure, for reliable
|
Convert raw iptables-save output into usable datastructure, for reliable
|
||||||
comparisons between initial and final states.
|
comparisons between initial and final states.
|
||||||
"""
|
"""
|
||||||
lines = filter_and_format_state(all_states_dump)
|
lines = filter_and_format_state(module, all_states_dump)
|
||||||
tables = dict()
|
tables: dict[str, list[str]] = {}
|
||||||
current_table = ""
|
current_table = ""
|
||||||
current_list = list()
|
current_list: list[str] = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if re.match(r"^[*](filter|mangle|nat|raw|security)$", line):
|
if re.match(r"^[*](filter|mangle|nat|raw|security)$", line):
|
||||||
current_table = line[1:]
|
current_table = line[1:]
|
||||||
@@ -345,7 +348,7 @@ def parse_per_table_state(all_states_dump):
|
|||||||
if line == "COMMIT":
|
if line == "COMMIT":
|
||||||
tables[current_table] = current_list
|
tables[current_table] = current_list
|
||||||
current_table = ""
|
current_table = ""
|
||||||
current_list = list()
|
current_list = []
|
||||||
continue
|
continue
|
||||||
if line.startswith("# "):
|
if line.startswith("# "):
|
||||||
continue
|
continue
|
||||||
@@ -353,9 +356,7 @@ def parse_per_table_state(all_states_dump):
|
|||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
global module
|
|
||||||
|
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec=dict(
|
argument_spec=dict(
|
||||||
path=dict(type="path", required=True),
|
path=dict(type="path", required=True),
|
||||||
@@ -459,28 +460,30 @@ def main():
|
|||||||
for t in TABLES:
|
for t in TABLES:
|
||||||
if f"*{t}" in state_to_restore:
|
if f"*{t}" in state_to_restore:
|
||||||
if len(stdout) == 0 or f"*{t}" not in stdout.splitlines():
|
if len(stdout) == 0 or f"*{t}" not in stdout.splitlines():
|
||||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, t)
|
(rc, stdout, stderr) = initialize_from_null_state(
|
||||||
|
module, INITIALIZER, INITCOMMAND, FALLBACKCMD, t
|
||||||
|
)
|
||||||
elif len(stdout) == 0:
|
elif len(stdout) == 0:
|
||||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
|
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
|
||||||
|
|
||||||
elif state == "restored" and f"*{table}" not in state_to_restore:
|
elif state == "restored" and f"*{table}" not in state_to_restore:
|
||||||
module.fail_json(msg=f"Table {table} to restore not defined in {path}")
|
module.fail_json(msg=f"Table {table} to restore not defined in {path}")
|
||||||
|
|
||||||
elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines():
|
elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines():
|
||||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
|
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
|
||||||
|
|
||||||
initial_state = filter_and_format_state(stdout)
|
initial_state = filter_and_format_state(module, stdout)
|
||||||
if initial_state is None:
|
if initial_state is None:
|
||||||
module.fail_json(msg="Unable to initialize firewall from NULL state.")
|
module.fail_json(msg="Unable to initialize firewall from NULL state.")
|
||||||
|
|
||||||
# Depending on the value of 'table', initref_state may differ from
|
# Depending on the value of 'table', initref_state may differ from
|
||||||
# initial_state.
|
# initial_state.
|
||||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||||
tables_before = parse_per_table_state(stdout)
|
tables_before = parse_per_table_state(module, stdout)
|
||||||
initref_state = filter_and_format_state(stdout)
|
initref_state = filter_and_format_state(module, stdout)
|
||||||
|
|
||||||
if state == "saved":
|
if state == "saved":
|
||||||
changed = write_state(b_path, initref_state, changed)
|
changed = write_state(module, b_path, initref_state, changed)
|
||||||
module.exit_json(
|
module.exit_json(
|
||||||
changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state
|
changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state
|
||||||
)
|
)
|
||||||
@@ -497,7 +500,7 @@ def main():
|
|||||||
|
|
||||||
if _back is not None:
|
if _back is not None:
|
||||||
b_back = to_bytes(_back, errors="surrogate_or_strict")
|
b_back = to_bytes(_back, errors="surrogate_or_strict")
|
||||||
dummy = write_state(b_back, initref_state, changed)
|
dummy = write_state(module, b_back, initref_state, changed)
|
||||||
BACKCOMMAND = list(MAINCOMMAND)
|
BACKCOMMAND = list(MAINCOMMAND)
|
||||||
BACKCOMMAND.append(_back)
|
BACKCOMMAND.append(_back)
|
||||||
|
|
||||||
@@ -536,12 +539,8 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if module.check_mode:
|
if module.check_mode:
|
||||||
tmpfd, tmpfile = tempfile.mkstemp()
|
joined_initial_state = "\n".join(initial_state)
|
||||||
with os.fdopen(tmpfd, "w") as f:
|
if get_file_contents(b_path) == f"{joined_initial_state}\n".encode("utf-8"):
|
||||||
joined_initial_state = "\n".join(initial_state)
|
|
||||||
f.write(f"{joined_initial_state}\n")
|
|
||||||
|
|
||||||
if filecmp.cmp(tmpfile, b_path):
|
|
||||||
restored_state = initial_state
|
restored_state = initial_state
|
||||||
else:
|
else:
|
||||||
restored_state = state_to_restore
|
restored_state = state_to_restore
|
||||||
@@ -572,8 +571,8 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||||
restored_state = filter_and_format_state(stdout)
|
restored_state = filter_and_format_state(module, stdout)
|
||||||
tables_after = parse_per_table_state("\n".join(restored_state))
|
tables_after = parse_per_table_state(module, "\n".join(restored_state))
|
||||||
if restored_state not in (initref_state, initial_state):
|
if restored_state not in (initref_state, initial_state):
|
||||||
for table_name, table_content in tables_after.items():
|
for table_name, table_content in tables_after.items():
|
||||||
if table_name not in tables_before:
|
if table_name not in tables_before:
|
||||||
@@ -608,7 +607,7 @@ def main():
|
|||||||
# timeout
|
# timeout
|
||||||
# * task attribute 'poll' equals 0
|
# * task attribute 'poll' equals 0
|
||||||
#
|
#
|
||||||
for dummy in range(_timeout):
|
for dummy2 in range(_timeout):
|
||||||
if os.path.exists(b_back):
|
if os.path.exists(b_back):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
@@ -628,7 +627,7 @@ def main():
|
|||||||
os.remove(b_back)
|
os.remove(b_back)
|
||||||
|
|
||||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||||
tables_rollback = parse_per_table_state(stdout)
|
tables_rollback = parse_per_table_state(module, stdout)
|
||||||
|
|
||||||
msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."
|
msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user