#!/usr/bin/env python2 ''' Fetch and combine multiple ec2 account settings into a single json hash. ''' # vim: expandtab:tabstop=4:shiftwidth=4 from time import time import argparse import yaml import os import subprocess import json import errno import fcntl import tempfile import copy CONFIG_FILE_NAME = 'multi_ec2.yaml' DEFAULT_CACHE_PATH = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache') class MultiEc2(object): ''' MultiEc2 class: Opens a yaml config file and reads aws credentials. Stores a json hash of resources in result. ''' def __init__(self, args=None): # Allow args to be passed when called as a library if not args: self.args = {} else: self.args = args self.cache_path = DEFAULT_CACHE_PATH self.config = None self.all_ec2_results = {} self.result = {} self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__))) same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME) etc_dir_config_file = os.path.join(os.path.sep, 'etc', 'ansible', CONFIG_FILE_NAME) # Prefer a file in the same directory, fall back to a file in etc if os.path.isfile(same_dir_config_file): self.config_file = same_dir_config_file elif os.path.isfile(etc_dir_config_file): self.config_file = etc_dir_config_file else: self.config_file = None # expect env vars def run(self): '''This method checks to see if the local cache is valid for the inventory. if the cache is valid; return cache else the credentials are loaded from multi_ec2.yaml or from the env and we attempt to get the inventory from the provider specified. ''' # load yaml if self.config_file and os.path.isfile(self.config_file): self.config = self.load_yaml_config() elif os.environ.has_key("AWS_ACCESS_KEY_ID") and \ os.environ.has_key("AWS_SECRET_ACCESS_KEY"): # Build a default config self.config = {} self.config['accounts'] = [ { 'name': 'default', 'cache_location': DEFAULT_CACHE_PATH, 'provider': 'aws/hosts/ec2.py', 'env_vars': { 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"], 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"], } }, ] self.config['cache_max_age'] = 0 else: raise RuntimeError("Could not find valid ec2 credentials in the environment.") if self.config.has_key('cache_location'): self.cache_path = self.config['cache_location'] if self.args.get('refresh_cache', None): self.get_inventory() self.write_to_cache() # if its a host query, fetch and do not cache elif self.args.get('host', None): self.get_inventory() elif not self.is_cache_valid(): # go fetch the inventories and cache them if cache is expired self.get_inventory() self.write_to_cache() else: # get data from disk self.get_inventory_from_cache() def load_yaml_config(self, conf_file=None): """Load a yaml config file with credentials to query the respective cloud for inventory. """ config = None if not conf_file: conf_file = self.config_file with open(conf_file) as conf: config = yaml.safe_load(conf) return config def get_provider_tags(self, provider, env=None): """Call and query all of the tags that are usuable by ansible. If environment is empty use the default env. """ if not env: env = os.environ # Allow relatively path'd providers in config file if os.path.isfile(os.path.join(self.file_path, provider)): provider = os.path.join(self.file_path, provider) # check to see if provider exists if not os.path.isfile(provider) or not os.access(provider, os.X_OK): raise RuntimeError("Problem with the provider. Please check path " \ "and that it is executable. (%s)" % provider) cmds = [provider] if self.args.get('host', None): cmds.append("--host") cmds.append(self.args.get('host', None)) else: cmds.append('--list') cmds.append('--refresh-cache') return subprocess.Popen(cmds, stderr=subprocess.PIPE, \ stdout=subprocess.PIPE, env=env) @staticmethod def generate_config(config_data): """Generate the ec2.ini file in as a secure temp file. Once generated, pass it to the ec2.py as an environment variable. """ fildes, tmp_file_path = tempfile.mkstemp(prefix='multi_ec2.ini.') for section, values in config_data.items(): os.write(fildes, "[%s]\n" % section) for option, value in values.items(): os.write(fildes, "%s = %s\n" % (option, value)) os.close(fildes) return tmp_file_path def run_provider(self): '''Setup the provider call with proper variables and call self.get_provider_tags. ''' try: all_results = [] tmp_file_paths = [] processes = {} for account in self.config['accounts']: env = account['env_vars'] if account.has_key('provider_config'): tmp_file_paths.append(MultiEc2.generate_config(account['provider_config'])) env['EC2_INI_PATH'] = tmp_file_paths[-1] name = account['name'] provider = account['provider'] processes[name] = self.get_provider_tags(provider, env) # for each process collect stdout when its available for name, process in processes.items(): out, err = process.communicate() all_results.append({ "name": name, "out": out.strip(), "err": err.strip(), "code": process.returncode }) finally: # Clean up the mkstemp file for tmp_file in tmp_file_paths: os.unlink(tmp_file) return all_results def get_inventory(self): """Create the subprocess to fetch tags from a provider. Host query: Query to return a specific host. If > 1 queries have results then fail. List query: Query all of the different accounts for their tags. Once completed store all of their results into one merged updated hash. """ provider_results = self.run_provider() # process --host results # For any 0 result, return it if self.args.get('host', None): count = 0 for results in provider_results: if results['code'] == 0 and results['err'] == '' and results['out'] != '{}': self.result = json.loads(results['out']) count += 1 if count > 1: raise RuntimeError("Found > 1 results for --host %s. \ This is an invalid state." % self.args.get('host', None)) # process --list results else: # For any non-zero, raise an error on it for result in provider_results: if result['code'] != 0: err_msg = ['\nProblem fetching account: {name}', 'Error Code: {code}', 'StdErr: {err}', 'Stdout: {out}', ] raise RuntimeError('\n'.join(err_msg).format(**result)) else: self.all_ec2_results[result['name']] = json.loads(result['out']) # Check if user wants extra vars in yaml by # having hostvars and all_group defined for acc_config in self.config['accounts']: self.apply_account_config(acc_config) # Build results by merging all dictionaries values = self.all_ec2_results.values() values.insert(0, self.result) for result in values: MultiEc2.merge_destructively(self.result, result) def apply_account_config(self, acc_config): ''' Apply account config settings ''' if not acc_config.has_key('hostvars') and not acc_config.has_key('all_group'): return results = self.all_ec2_results[acc_config['name']] # Update each hostvar with the newly desired key: value for host_property, value in acc_config['hostvars'].items(): # Verify the account results look sane # by checking for these keys ('_meta' and 'hostvars' exist) if results.has_key('_meta') and results['_meta'].has_key('hostvars'): for data in results['_meta']['hostvars'].values(): data[str(host_property)] = str(value) # Add this group if results.has_key(acc_config['all_group']): results["%s_%s" % (host_property, value)] = \ copy.copy(results[acc_config['all_group']]) # store the results back into all_ec2_results self.all_ec2_results[acc_config['name']] = results @staticmethod def merge_destructively(input_a, input_b): "merges b into input_a" for key in input_b: if key in input_a: if isinstance(input_a[key], dict) and isinstance(input_b[key], dict): MultiEc2.merge_destructively(input_a[key], input_b[key]) elif input_a[key] == input_b[key]: pass # same leaf value # both lists so add each element in b to a if it does ! exist elif isinstance(input_a[key], list) and isinstance(input_b[key], list): for result in input_b[key]: if result not in input_a[key]: input_a[key].append(result) # a is a list and not b elif isinstance(input_a[key], list): if input_b[key] not in input_a[key]: input_a[key].append(input_b[key]) elif isinstance(input_b[key], list): input_a[key] = [input_a[key]] + [k for k in input_b[key] if k != input_a[key]] else: input_a[key] = [input_a[key], input_b[key]] else: input_a[key] = input_b[key] return input_a def is_cache_valid(self): ''' Determines if the cache files have expired, or if it is still valid ''' if os.path.isfile(self.cache_path): mod_time = os.path.getmtime(self.cache_path) current_time = time() if (mod_time + self.config['cache_max_age']) > current_time: return True return False def parse_cli_args(self): ''' Command line argument processing ''' parser = argparse.ArgumentParser( description='Produce an Ansible Inventory file based on a provider') parser.add_argument('--refresh-cache', action='store_true', default=False, help='Fetch cached only instances (default: False)') parser.add_argument('--list', action='store_true', default=True, help='List instances (default: True)') parser.add_argument('--host', action='store', default=False, help='Get all the variables about a specific instance') self.args = parser.parse_args().__dict__ def write_to_cache(self): ''' Writes data in JSON format to a file ''' # if it does not exist, try and create it. if not os.path.isfile(self.cache_path): path = os.path.dirname(self.cache_path) try: os.makedirs(path) except OSError as exc: if exc.errno != errno.EEXIST or not os.path.isdir(path): raise json_data = MultiEc2.json_format_dict(self.result, True) with open(self.cache_path, 'w') as cache: try: fcntl.flock(cache, fcntl.LOCK_EX) cache.write(json_data) finally: fcntl.flock(cache, fcntl.LOCK_UN) def get_inventory_from_cache(self): ''' Reads the inventory from the cache file and returns it as a JSON object ''' if not os.path.isfile(self.cache_path): return None with open(self.cache_path, 'r') as cache: self.result = json.loads(cache.read()) return True @classmethod def json_format_dict(cls, data, pretty=False): ''' Converts a dict to a JSON object and dumps it as a formatted string ''' if pretty: return json.dumps(data, sort_keys=True, indent=2) else: return json.dumps(data) def result_str(self): '''Return cache string stored in self.result''' return self.json_format_dict(self.result, True) if __name__ == "__main__": MEC2 = MultiEc2() MEC2.parse_cli_args() MEC2.run() print MEC2.result_str()