mirror of
https://github.com/ansible-collections/community.general.git
synced 2026-05-06 13:22:48 +00:00
Allow AnsibleModules to be instantiated more than once in a module
Fix SELINUX monkeypatch in test_basic
This commit is contained in:
@@ -223,6 +223,12 @@ from ansible import __version__
|
||||
# Backwards compat. New code should just import and use __version__
|
||||
ANSIBLE_VERSION = __version__
|
||||
|
||||
# Internal global holding passed in params and constants. This is consulted
|
||||
# in case multiple AnsibleModules are created. Otherwise each AnsibleModule
|
||||
# would attempt to read from stdin. Other code should not use this directly
|
||||
# as it is an internal implementation detail
|
||||
_ANSIBLE_ARGS = None
|
||||
|
||||
FILE_COMMON_ARGUMENTS=dict(
|
||||
src = dict(),
|
||||
mode = dict(type='raw'),
|
||||
@@ -1457,23 +1463,28 @@ class AnsibleModule(object):
|
||||
''' read the input and set the params attribute. Sets the constants as well.'''
|
||||
# debug overrides to read args from file or cmdline
|
||||
|
||||
# Avoid tracebacks when locale is non-utf8
|
||||
# We control the args and we pass them as utf8
|
||||
if len(sys.argv) > 1:
|
||||
if os.path.isfile(sys.argv[1]):
|
||||
fd = open(sys.argv[1], 'rb')
|
||||
buffer = fd.read()
|
||||
fd.close()
|
||||
else:
|
||||
buffer = sys.argv[1]
|
||||
if sys.version_info >= (3,):
|
||||
buffer = buffer.encode('utf-8', errors='surrogateescape')
|
||||
# default case, read from stdin
|
||||
global _ANSIBLE_ARGS
|
||||
if _ANSIBLE_ARGS is not None:
|
||||
buffer = _ANSIBLE_ARGS
|
||||
else:
|
||||
if sys.version_info < (3,):
|
||||
buffer = sys.stdin.read()
|
||||
# Avoid tracebacks when locale is non-utf8
|
||||
# We control the args and we pass them as utf8
|
||||
if len(sys.argv) > 1:
|
||||
if os.path.isfile(sys.argv[1]):
|
||||
fd = open(sys.argv[1], 'rb')
|
||||
buffer = fd.read()
|
||||
fd.close()
|
||||
else:
|
||||
buffer = sys.argv[1]
|
||||
if sys.version_info >= (3,):
|
||||
buffer = buffer.encode('utf-8', errors='surrogateescape')
|
||||
# default case, read from stdin
|
||||
else:
|
||||
buffer = sys.stdin.buffer.read()
|
||||
if sys.version_info < (3,):
|
||||
buffer = sys.stdin.read()
|
||||
else:
|
||||
buffer = sys.stdin.buffer.read()
|
||||
_ANSIBLE_ARGS = buffer
|
||||
|
||||
try:
|
||||
params = json.loads(buffer.decode('utf-8'))
|
||||
|
||||
@@ -45,6 +45,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
|
||||
self.stdout_swap_ctx = swap_stdout()
|
||||
self.fake_stream = self.stdout_swap_ctx.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.module = basic.AnsibleModule(argument_spec=dict())
|
||||
|
||||
def tearDown(self):
|
||||
@@ -125,6 +126,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
|
||||
params = json.dumps(params)
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=params):
|
||||
reload(basic)
|
||||
with swap_stdout():
|
||||
module = basic.AnsibleModule(
|
||||
argument_spec = dict(
|
||||
@@ -146,6 +148,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
|
||||
params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={})
|
||||
params = json.dumps(params)
|
||||
with swap_stdin_and_argv(stdin_data=params):
|
||||
reload(basic)
|
||||
with swap_stdout():
|
||||
module = basic.AnsibleModule(
|
||||
argument_spec = dict(
|
||||
|
||||
@@ -49,6 +49,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
|
||||
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
|
||||
self.stdin_swap.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
@@ -85,6 +86,7 @@ class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase):
|
||||
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
|
||||
self.stdin_swap.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
@@ -132,6 +134,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase):
|
||||
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
|
||||
self.stdin_swap.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
@@ -192,6 +195,7 @@ class TestAnsibleModuleLogJournal(unittest.TestCase):
|
||||
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
|
||||
self.stdin_swap.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
|
||||
@@ -67,6 +67,7 @@ class TestAnsibleModuleRunCommand(unittest.TestCase):
|
||||
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
|
||||
self.stdin_swap.__enter__()
|
||||
|
||||
reload(basic)
|
||||
self.module = AnsibleModule(argument_spec=dict())
|
||||
self.module.fail_json = MagicMock(side_effect=SystemExit)
|
||||
|
||||
|
||||
@@ -26,6 +26,12 @@ import json
|
||||
from ansible.compat.tests import unittest
|
||||
from units.mock.procenv import swap_stdin_and_argv
|
||||
|
||||
try:
|
||||
from importlib import reload
|
||||
except:
|
||||
# Py2 has reload as a builtin
|
||||
pass
|
||||
|
||||
class TestAnsibleModuleExitJson(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_safe_eval(self):
|
||||
@@ -34,6 +40,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
reload(basic)
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec=dict(),
|
||||
)
|
||||
|
||||
@@ -31,6 +31,12 @@ try:
|
||||
except ImportError:
|
||||
import __builtin__ as builtins
|
||||
|
||||
try:
|
||||
from importlib import reload
|
||||
except:
|
||||
# Py2 has reload as a builtin
|
||||
pass
|
||||
|
||||
from units.mock.procenv import swap_stdin_and_argv
|
||||
|
||||
from ansible.compat.tests import unittest
|
||||
@@ -291,6 +297,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
reload(basic)
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = arg_spec,
|
||||
mutually_exclusive = mut_ex,
|
||||
@@ -307,6 +314,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
reload(basic)
|
||||
self.assertRaises(
|
||||
SystemExit,
|
||||
basic.AnsibleModule,
|
||||
@@ -353,6 +361,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_load_file_common_arguments(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -401,6 +410,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_selinux_mls_enabled(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -420,6 +430,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_selinux_initial_context(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -433,6 +444,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_selinux_enabled(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -464,6 +476,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_selinux_default_context(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -499,6 +512,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_selinux_context(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -540,6 +554,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}))
|
||||
|
||||
@@ -584,6 +599,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_to_filesystem_str(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -608,6 +624,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_find_mount_point(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -631,18 +648,19 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_set_context_if_different(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
|
||||
basic.HAS_SELINUX = False
|
||||
basic.HAVE_SELINUX = False
|
||||
|
||||
am.selinux_enabled = MagicMock(return_value=False)
|
||||
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True), True)
|
||||
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False), False)
|
||||
|
||||
basic.HAS_SELINUX = True
|
||||
basic.HAVE_SELINUX = True
|
||||
|
||||
am.selinux_enabled = MagicMock(return_value=True)
|
||||
am.selinux_context = MagicMock(return_value=['bar_u', 'bar_r', None, None])
|
||||
@@ -675,6 +693,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_set_owner_if_different(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -713,6 +732,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_set_group_if_different(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -751,6 +771,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
|
||||
def test_module_utils_basic_ansible_module_set_mode_if_different(self):
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -838,6 +859,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
):
|
||||
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
@@ -1015,6 +1037,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||
def test_module_utils_basic_ansible_module__symbolic_mode_to_octal(self):
|
||||
|
||||
from ansible.module_utils import basic
|
||||
reload(basic)
|
||||
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
|
||||
Reference in New Issue
Block a user