diff --git a/lib/ansible/module_utils/kubevirt.py b/lib/ansible/module_utils/kubevirt.py index 9c8db11ed7..4b6f4523bc 100644 --- a/lib/ansible/module_utils/kubevirt.py +++ b/lib/ansible/module_utils/kubevirt.py @@ -7,6 +7,7 @@ from collections import defaultdict from distutils.version import Version +from ansible.module_utils.common._collections_compat import Sequence from ansible.module_utils.k8s.common import list_dict_str from ansible.module_utils.k8s.raw import KubernetesRawModule @@ -127,28 +128,25 @@ class KubeVirtRawModule(KubernetesRawModule): super(KubeVirtRawModule, self).__init__(*args, **kwargs) @staticmethod - def merge_dicts(x, yy): + def merge_dicts(base_dict, merging_dicts): + """This function merges a base dictionary with one or more other dictionaries. + The base dictionary takes precedence when there is a key collision. + merging_dicts can be a dict or a list or tuple of dicts. In the latter case, the + dictionaries at the front of the list have higher precedence over the ones at the end. """ - This function merge two dictionaries, where the first dict has - higher priority in merging two same keys. - """ - if not yy: - yy = {} + if not merging_dicts: + merging_dicts = ({},) - if not isinstance(yy, list): - yy = [yy] + if not isinstance(merging_dicts, Sequence): + merging_dicts = (merging_dicts,) - for y in yy: - for k in set(x.keys()).union(y.keys()): - if k in x and k in y: - if isinstance(x[k], dict) and isinstance(y[k], dict): - yield (k, dict(KubeVirtRawModule.merge_dicts(x[k], y[k]))) - else: - yield (k, x[k]) - elif k in x: - yield (k, x[k]) - else: - yield (k, y[k]) + new_dict = {} + for d in reversed(merging_dicts): + new_dict.update(d) + + new_dict.update(base_dict) + + return new_dict def get_resource(self, resource): try: @@ -238,7 +236,7 @@ class KubeVirtRawModule(KubernetesRawModule): spec_interfaces = [] for i in interfaces: spec_interfaces.append( - dict(self.merge_dicts(dict((k, v) for k, v in i.items() if k != 'network'), defaults['interfaces'])) + self.merge_dicts(dict((k, v) for k, v in i.items() if k != 'network'), defaults['interfaces']) ) if 'interfaces' not in template_spec['domain']['devices']: template_spec['domain']['devices']['interfaces'] = [] @@ -249,7 +247,7 @@ class KubeVirtRawModule(KubernetesRawModule): for i in interfaces: net = i['network'] net['name'] = i['name'] - spec_networks.append(dict(self.merge_dicts(net, defaults['networks']))) + spec_networks.append(self.merge_dicts(net, defaults['networks'])) if 'networks' not in template_spec: template_spec['networks'] = [] template_spec['networks'].extend(spec_networks) @@ -269,7 +267,7 @@ class KubeVirtRawModule(KubernetesRawModule): spec_disks = [] for d in disks: spec_disks.append( - dict(self.merge_dicts(dict((k, v) for k, v in d.items() if k != 'volume'), defaults['disks'])) + self.merge_dicts(dict((k, v) for k, v in d.items() if k != 'volume'), defaults['disks']) ) if 'disks' not in template_spec['domain']['devices']: template_spec['domain']['devices']['disks'] = [] @@ -280,7 +278,7 @@ class KubeVirtRawModule(KubernetesRawModule): for d in disks: volume = d['volume'] volume['name'] = d['name'] - spec_volumes.append(dict(self.merge_dicts(volume, defaults['volumes']))) + spec_volumes.append(self.merge_dicts(volume, defaults['volumes'])) if 'volumes' not in template_spec: template_spec['volumes'] = [] template_spec['volumes'].extend(spec_volumes) @@ -350,7 +348,7 @@ class KubeVirtRawModule(KubernetesRawModule): template_spec['domain']['cpu']['model'] = cpu_model if labels: - template['metadata']['labels'] = dict(self.merge_dicts(labels, template['metadata']['labels'])) + template['metadata']['labels'] = self.merge_dicts(labels, template['metadata']['labels']) if machine_type: template_spec['domain']['machine']['type'] = machine_type @@ -378,7 +376,7 @@ class KubeVirtRawModule(KubernetesRawModule): # Define datavolumes: self._define_datavolumes(datavolumes, definition['spec']) - return dict(self.merge_dicts(definition, self.resource_definitions[0])) + return self.merge_dicts(definition, self.resource_definitions[0]) def construct_vm_definition(self, kind, definition, template, defaults=None): definition = self._construct_vm_definition(kind, definition, template, self.params, defaults) diff --git a/lib/ansible/modules/cloud/kubevirt/kubevirt_vm.py b/lib/ansible/modules/cloud/kubevirt/kubevirt_vm.py index 0fa6b0c31c..78737c36ba 100644 --- a/lib/ansible/modules/cloud/kubevirt/kubevirt_vm.py +++ b/lib/ansible/modules/cloud/kubevirt/kubevirt_vm.py @@ -380,7 +380,7 @@ class KubeVirtVM(KubeVirtRawModule): template['metadata']['labels']['vm.cnv.io/name'] = self.params.get('name') dummy, definition = self.construct_vm_definition(kind, definition, template, defaults) - return dict(self.merge_dicts(definition, processedtemplate)) + return self.merge_dicts(definition, processedtemplate) def execute_module(self): # Parse parameters specific to this module: