diff --git a/ansible_testing/modules.py b/ansible_testing/modules.py index 01fddbc956..8eeb710d2b 100644 --- a/ansible_testing/modules.py +++ b/ansible_testing/modules.py @@ -10,9 +10,6 @@ import re import sys import traceback -# We only use StringIO, since we cannot setattr on cStringIO -from StringIO import StringIO - from distutils.version import StrictVersion from fnmatch import fnmatch @@ -22,7 +19,7 @@ from ansible.module_utils import basic as module_utils_basic from ansible.plugins import module_loader from ansible.utils.module_docs import BLACKLIST_MODULES, get_docstring -from utils import find_globals +from utils import CaptureStd, find_globals import yaml @@ -359,24 +356,17 @@ class ModuleValidator(Validator): except AttributeError: self.errors.append('No DOCUMENTATION provided') else: - sys_stdout = sys.stdout - sys_stderr = sys.stderr - sys.stdout = sys.stderr = buf = StringIO() - # instead of adding noqa to the above, do something with buf - assert buf - setattr(sys.stdout, 'encoding', sys_stdout.encoding) - setattr(sys.stderr, 'encoding', sys_stderr.encoding) - try: - get_docstring(self.path, verbose=True) - except AssertionError: - fragment = doc['extends_documentation_fragment'] - self.errors.append('DOCUMENTATION fragment missing: %s' % fragment) - except Exception as e: - self.traces.append(e) - self.errors.append('Unknown DOCUMENTATION error, see TRACE') - finally: - sys.stdout = sys_stdout - sys.stderr = sys_stderr + with CaptureStd(): + try: + get_docstring(self.path, verbose=True) + except AssertionError: + fragment = doc['extends_documentation_fragment'] + self.errors.append('DOCUMENTATION fragment missing: %s' % + fragment) + except Exception as e: + self.traces.append(e) + self.errors.append('Unknown DOCUMENTATION error, see ' + 'TRACE') self._check_version_added(doc) self._check_for_new_args(doc) @@ -435,30 +425,21 @@ class ModuleValidator(Validator): if self._is_new_module(): return - sys_stdout = sys.stdout - sys_stderr = sys.stderr - sys.stdout = sys.stderr = buf = StringIO() - # instead of adding noqa to the above, do something with buf - assert buf - setattr(sys.stdout, 'encoding', sys_stdout.encoding) - setattr(sys.stderr, 'encoding', sys_stderr.encoding) - try: - existing = module_loader.find_plugin(self.name, mod_type='.py') - existing_doc, _, _ = get_docstring(existing, verbose=True) - existing_options = existing_doc.get('options', {}) - except AssertionError: - fragment = doc['extends_documentation_fragment'] - self.errors.append('Existing DOCUMENTATION fragment missing: %s' % - fragment) - return - except Exception as e: - self.traces.append(e) - self.errors.append('Unknown existing DOCUMENTATION error, see ' - 'TRACE') - return - finally: - sys.stdout = sys_stdout - sys.stderr = sys_stderr + with CaptureStd(): + try: + existing = module_loader.find_plugin(self.name, mod_type='.py') + existing_doc, _, _ = get_docstring(existing, verbose=True) + existing_options = existing_doc.get('options', {}) + except AssertionError: + fragment = doc['extends_documentation_fragment'] + self.errors.append('Existing DOCUMENTATION fragment missing: ' + '%s' % fragment) + return + except Exception as e: + self.traces.append(e) + self.errors.append('Unknown existing DOCUMENTATION error, see ' + 'TRACE') + return options = doc.get('options', {}) diff --git a/ansible_testing/utils.py b/ansible_testing/utils.py index dc5a84e95a..1962040a97 100644 --- a/ansible_testing/utils.py +++ b/ansible_testing/utils.py @@ -1,4 +1,8 @@ import ast +import sys + +# We only use StringIO, since we cannot setattr on cStringIO +from StringIO import StringIO def find_globals(g, tree): @@ -22,3 +26,25 @@ def find_globals(g, tree): if g_name == '*': continue g.add(g_name) + + +class CaptureStd(): + """Context manager to handle capturing stderr and stdout""" + + def __enter__(self): + self.sys_stdout = sys.stdout + self.sys_stderr = sys.stderr + sys.stdout = self.stdout = StringIO() + sys.stderr = self.stderr = StringIO() + setattr(sys.stdout, 'encoding', self.sys_stdout.encoding) + setattr(sys.stderr, 'encoding', self.sys_stderr.encoding) + return self + + def __exit__(self, exc_type, exc_value, traceback): + sys.stdout = self.sys_stdout + sys.stderr = self.sys_stderr + + def get(self): + """Return ``(stdout, stderr)``""" + + return self.stdout.getvalue(), self.stderr.getvalue()