summaryrefslogtreecommitdiffstats
path: root/filter_plugins/openshift_version.py
blob: df8f565f0b4e3ccff5125cc017adf0da12da76ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/python

# -*- coding: utf-8 -*-
# vim: expandtab:tabstop=4:shiftwidth=4
"""
Custom version comparison filters for use in openshift-ansible
"""

# pylint can't locate distutils.version within virtualenv
# https://github.com/PyCQA/pylint/issues/73
# pylint: disable=no-name-in-module, import-error
from distutils.version import LooseVersion


def gte_function_builder(name, versions):
    """
    Build and return a version comparison function.

    Ex: name = 'oo_version_gte_3_1_or_1_1'
        versions = {'enterprise': '3.1', 'origin': '1.1'}

        returns oo_version_gte_3_1_or_1_1, a function which based on the
        version and deployment type will return true if the provided
        version is greater than or equal to the function's version
    """
    enterprise_version = versions['enterprise']
    origin_version = versions['origin']

    def _gte_function(version, deployment_type):
        """
        Dynamic function created by gte_function_builder.

        Ex: version = '3.1'
            deployment_type = 'openshift-enterprise'
            returns True/False
        """
        version_gte = False
        if 'enterprise' in deployment_type:
            if str(version) >= LooseVersion(enterprise_version):
                version_gte = True
        elif 'origin' in deployment_type:
            if str(version) >= LooseVersion(origin_version):
                version_gte = True
        return version_gte
    _gte_function.__name__ = name
    return _gte_function


# pylint: disable=too-few-public-methods
class FilterModule(object):
    """
    Filters for version checking.
    """
    #: The major versions to start incrementing. (enterprise, origin)
    majors = [(3, 1)]

    #: The minor version to start incrementing
    minor = 3
    #: The number of iterations to increment
    minor_iterations = 10

    def __init__(self):
        """
        Creates a new FilterModule for ose version checking.
        """
        self._filters = {}
        # For each major version
        for enterprise, origin in self.majors:
            # For each minor version in the range
            for minor in range(self.minor, self.minor_iterations):
                # Create the function name
                func_name = 'oo_version_gte_{}_{}_or_{}_{}'.format(
                    enterprise, minor, origin, minor)
                # Create the function with the builder
                func = gte_function_builder(
                    func_name, {
                        'enterprise': '{}.{}.0'.format(enterprise, minor),
                        'origin': '{}.{}.0'.format(origin, minor)
                    })
                # Add the function to the mapping
                self._filters[func_name] = func

        # Create filters with special versioning requirements
        self._filters['oo_version_gte_3_1_or_1_1'] = gte_function_builder('oo_version_gte_3_1_or_1_1',
                                                                          {'enterprise': '3.0.2.905',
                                                                           'origin': '1.1.0'})
        self._filters['oo_version_gte_3_1_1_or_1_1_1'] = gte_function_builder('oo_version_gte_3_1_or_1_1',
                                                                              {'enterprise': '3.1.1',
                                                                               'origin': '1.1.1'})
        self._filters['oo_version_gte_3_2_or_1_2'] = gte_function_builder('oo_version_gte_3_2_or_1_2',
                                                                          {'enterprise': '3.1.1.901',
                                                                           'origin': '1.2.0'})

    def filters(self):
        """
        Return the filters mapping.
        """
        return self._filters