#!/usr/bin/python3
"""Test client for ipa-custodia

The test script is expected to be executed on an IPA server with existing
Custodia server keys.
"""
from __future__ import print_function
import argparse
import logging
import os
import platform
import socket
import warnings

from custodia.message.kem import KEY_USAGE_SIG, KEY_USAGE_ENC, KEY_USAGE_MAP

from jwcrypto.common import json_decode
from jwcrypto.jwk import JWK

from ipalib import api
from ipaplatform.paths import paths
import ipapython.version
from ipaserver.install.installutils import is_ipa_configured

try:
    # FreeIPA >= 4.5
    from ipaserver.secrets.client import CustodiaClient
except ImportError:
    # FreeIPA <= 4.4
    from ipapython.secrets.client import CustodiaClient

# Ignore security warning from vendored and non-vendored urllib3
try:
    from urllib3.exceptions import SecurityWarning
except ImportError:
    SecurityWarning = None
else:
    warnings.simplefilter("ignore", SecurityWarning)

try:
    from requests.packages.urllib3.exceptions import SecurityWarning
except ImportError:
    SecurityWarning = None
else:
    warnings.simplefilter("ignore", SecurityWarning)


KEYS = [
    'dm/DMHash',
    'ra/ipaCert',
    'ca/auditSigningCert cert-pki-ca',
    'ca/caSigningCert cert-pki-ca',
    'ca/ocspSigningCert cert-pki-ca',
    'ca/subsystemCert cert-pki-ca',
]

IPA_CUSTODIA_KEYFILE = os.path.join(paths.IPA_CUSTODIA_CONF_DIR,
                                    'server.keys')


logger = logging.getLogger('ipa-custodia-tester')


parser = argparse.ArgumentParser(
    "IPA Custodia check",
)
# --store is dangerous and therefore hidden! Don't use it unless you really
# know what you are doing! Keep in mind that it might destroy your NSSDB
# unless it uses sqlite format.
parser.add_argument(
    "--store", action='store_true', dest='store',
    help=argparse.SUPPRESS
)
parser.add_argument(
    "--debug", action='store_true',
    help="Debug mode"
)
parser.add_argument(
    "--verbose", action='store_true',
    help='Verbose mode'
)
parser.add_argument(
    "server",
    help="FQDN of a IPA server (can be own FQDN for self-test)"
)
parser.add_argument(
    'keys', nargs='*', default=KEYS,
    help="Remote key ({})".format(', '.join(KEYS))
)


class IPACustodiaTester(object):
    files = [
        paths.IPA_DEFAULT_CONF,
        paths.KRB5_KEYTAB,
        paths.IPA_CUSTODIA_CONF,
        IPA_CUSTODIA_KEYFILE
    ]

    def __init__(self, parser, args):
        self.parser = parser
        self.args = args
        if not api.isdone('bootstrap'):
            # bootstrap to initialize api.env
            api.bootstrap()
            self.debug("IPA API bootstrapped")
        self.realm = api.env.realm
        self.host = api.env.host
        self.host_spn = 'host/{}@{}'.format(self.host, self.realm)
        self.server_spn = 'host/{}@{}'.format(self.args.server, self.realm)
        self.client = None
        self._errors = []

    def error(self, msg, fatal=False):
        self._errors.append(msg)
        logger.error(msg, exc_info=self.args.verbose)
        if fatal:
            self.exit()

    def exit(self):
        if self._errors:
            self.parser.exit(1, "[ERROR] One or more tests have failed.\n")
        else:
            self.parser.exit(0, "All tests have passed successfully.\n")

    def warning(self, msg):
        logger.warning(msg)

    def info(self, msg):
        logger.info(msg)

    def debug(self, msg):
        logger.debug(msg)

    def check(self):
        self.status()
        self.check_fqdn()
        self.check_files()
        self.check_client()
        self.check_jwk()
        self.check_keys()

    def status(self):
        self.info("Platform: {}".format(platform.platform()))
        self.info("IPA version: {}".format(
            ipapython.version.VERSION
        ))
        self.info("IPA vendor version: {}".format(
            ipapython.version.VENDOR_VERSION
        ))
        self.info("Realm: {}".format(self.realm))
        self.info("Host: {}".format(self.host))
        self.info("Remote server: {}".format(self.args.server))
        if self.host == self.args.server:
            self.warning("Performing self-test only.")

    def check_fqdn(self):
        fqdn = socket.getfqdn()
        if self.host != fqdn:
            self.warning(
                "socket.getfqdn() reports hostname '{}'".format(fqdn)
            )

    def check_files(self):
        for filename in self.files:
            if not os.path.isfile(filename):
                self.error("File '{0}' is missing.".format(filename))
            else:
                self.info("File '{0}' exists.".format(filename))

    def check_client(self):
        try:
            self.client = CustodiaClient(
                server=self.args.server,
                client_service='host@{}'.format(self.host),
                keyfile=IPA_CUSTODIA_KEYFILE,
                keytab=paths.KRB5_KEYTAB,
                realm=self.realm,
            )
        except Exception as e:
            self.error("Failed to create client: {}".format(e), fatal=True)
        else:
            self.info("Custodia client created.")

    def _check_jwk_single(self, usage_id):
        usage = KEY_USAGE_MAP[usage_id]
        with open(IPA_CUSTODIA_KEYFILE) as f:
            dictkeys = json_decode(f.read())

        try:
            pkey = JWK(**dictkeys[usage_id])
            local_pubkey = json_decode(pkey.export_public())
        except Exception:
            raise self.error(
                "Failed to load and parse local JWK.", fatal=True
            )
        else:
            self.info("Loaded key for usage '{}' from '{}'.".format(
                usage, IPA_CUSTODIA_KEYFILE
            ))

        if pkey.key_id != self.host_spn:
            raise self.error(
                "KID '{}' != host service principal name '{}' "
                "(usage: {})".format(pkey.key_id, self.host_spn, usage),
                fatal=True
            )
        else:
            self.info(
                "JWK KID matches host's service principal name '{}'.".format(
                    self.host_spn
                ))

        # LDAP doesn't contain KID
        local_pubkey.pop("kid", None)
        find_key = self.client.ikk.find_key
        try:
            host_pubkey = json_decode(find_key(self.host_spn, usage_id))
        except Exception:
            raise self.error(
                "Fetching host keys {} (usage: {}) failed.".format(
                    self.host_spn, usage),
                fatal=True
            )
        else:
            self.info("Checked host LDAP keys '{}' for usage {}.".format(
                self.host_spn, usage
            ))

        if host_pubkey != local_pubkey:
            self.debug("LDAP: '{}'".format(host_pubkey))
            self.debug("Local: '{}'".format(local_pubkey))
            raise self.error(
                "Host key in LDAP does not match local key.",
                fatal=True
            )
        else:
            self.info(
                "Local key for usage '{}' matches key in LDAP.".format(usage)
            )

        try:
            server_pubkey = json_decode(find_key(self.server_spn, usage_id))
        except Exception:
            raise self.error(
                "Fetching server keys {} (usage: {}) failed.".format(
                    self.server_spn, usage),
                fatal=True
            )
        else:
            self.info("Checked server LDAP keys '{}' for usage {}.".format(
                self.server_spn, usage
            ))

        return local_pubkey, host_pubkey, server_pubkey

    def check_jwk(self):
        self._check_jwk_single(KEY_USAGE_SIG)
        self._check_jwk_single(KEY_USAGE_ENC)

    def check_keys(self):
        for key in self.args.keys:
            try:
                result = self.client.fetch_key(key, store=self.args.store)
            except Exception as e:
                self.error("Failed to retrieve key '{}': {}.".format(
                    key, e
                ))
            else:
                self.info("Successfully retrieved '{}'.".format(key))
                if not self.args.store:
                    self.debug(result)


def main():
    args = parser.parse_args()
    if args.debug:
        args.verbose = True

    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.INFO,
        format='[%(asctime)s %(name)s] <%(levelname)s>: %(message)s',
        datefmt='%Y-%m-%dT%H:%M:%S',
    )
    if not is_ipa_configured():
        parser.error("IPA is not configured on this system.\n")
    if os.geteuid() != 0:
        parser.error("Script must be executed as root.\n")

    tester = IPACustodiaTester(parser, args)
    tester.check()
    tester.exit()


if __name__ == '__main__':
    main()
