#!/usr/bin/env python3
# -*- coding: utf-8 -*-

#
# Sample KeyTalk REST API client
# The following APIs are included:
# - RCDP (certificate retrieval API used by KeyTalk agents)
# - Admin API
# - Public API
# - CA API
#

# Requires Python 3.6 and higher

######################################################################################
# Typical RCDPv2 flow
# Goes over HTTPSs
#
# client ---------> hello -----------------> server
# client <--------- hello <----------------- server
# client ---------> handshake -------------> server
# client <--------- handshake <------------- server
# client ---------> auth-requirements -----> server
# client <--------- auth-requirements <----- server
# client ---------> authentication --------> server
# client <--------- auth-result <----------- server
# client ---------> last-messages ---------> server
# client <--------- last-messages <--------- server
# client ---------> cert ------------------> server
# client <--------- cert <------------------ server
# client ---------> eoc -------------------> server
# client <--------- eoc <------------------- server

# eoc (end-of-communication) and error can be sent by any party at any time
##########################################################################################

import http.client
import ssl
import re
import os
import urllib.request
import urllib.parse
import urllib.error
import json
import datetime
import base64
import pprint
import hashlib
import tarfile
import socket
import subprocess
import OpenSSL

import common_config as conf


###################################################################################

#
# Global settings
#
VERBOSE = False

# HACK to bypass checking KeyTalk server hostname during SSL handshake;
# useful for quick testing e.g. when KeyTalk server is IP address iso FQDN
BYPASS_HTTPS_VALIDATION = True

#
# These settings should come from RCCD
#
if BYPASS_HTTPS_VALIDATION:
    KEYTALK_SERVER = '10.100.0.200'
    # KEYTALK_SERVER = 'localhost'
else:
    KEYTALK_SERVER = 'demo.keytalkdemo.com'


SERVER_VERIFICATION_CA_CHAIN = [
    'commcacert.pem',
    'pcacert.pem'
]
LAST_MESSAGES_FROM_UTC = None

UMTS_USERNAME = 'UMTS_2_354162120787078'
GSM_USERNAME = 'GSM_2_354162120787078'

WEBUI_ADMIN_USERNAME = 'admin'
WEBUI_ADMIN_PASSWORD = 'change!'

##############################################################################


def debug(msg):
    if VERBOSE:
        print(msg)


def log(msg):
    print(msg)


def is_opensslv1():
    from OpenSSL import debug
    return debug._env_info.find(" OpenSSL: OpenSSL 1.") != -1


def is_opensslv3():
    from OpenSSL import debug
    return debug._env_info.find(" OpenSSL: OpenSSL 3.") != -1


def supported_cert_formats():
    formats = [conf.CERTKEY_FORMAT_PEM, ]
    if is_opensslv1():
        formats.append(conf.CERTKEY_FORMAT_P12)
    elif is_opensslv3():
        formats.append(conf.CERTKEY_FORMAT_P12_V2)
    else:
        raise Exception("Unsupported OpenSSL version")
    return formats


def run_cmd(cmd, decode=True):
    result = subprocess.run(cmd,
                            shell=True,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
    if decode:
        result.stdout = result.stdout.decode("utf-8")
        result.stderr = result.stderr.decode("utf-8")
    return result


def run_cmd_checked(cmd, decode=True):
    result = run_cmd(cmd, decode)
    if result.returncode == 0:
        return result
    else:
        msg = "Command '{}' finished with code {}. stdout: '{}'. stderr: '{}'".format(
            cmd, result.returncode, result.stdout, result.stderr)
        raise Exception(msg)


def fetch_url(url):
    debug("Fetching URL " + url)
    try:
        with urllib.request.urlopen(url) as response:
            payload = response.read()
            return payload
    except Exception as e:
        log("Error opening URL {}. {}".format(url, e))
        return None


def is_true(d, key, strict=False):
    if strict and key not in d:
        raise Exception("{} not found in {}".format(key, d))
    return (key in d) and (d[key].lower() == 'true')


def is_false(d, key, strict=False):
    if strict and key not in d:
        raise Exception("{} not found in {}".format(key, d))
    return (key in d) and (d[key].lower() == 'false')


def get_cert_san(x509cert):
    san = ''
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if 'subjectAltName' in str(ext.get_short_name()):
            san = ext.__str__()
    return san


def random_string():
    import random
    import string
    letters = string.ascii_letters
    length = 8
    return ''.join(random.choice(letters) for i in range(length))


def isP12(fmt):
    return fmt in [conf.CERTKEY_FORMAT_P12, conf.CERTKEY_FORMAT_P12_V2]


class BadRequestError(Exception):
    pass


def server_verification_ca_chain():
    cadata = ''
    for file in SERVER_VERIFICATION_CA_CHAIN:
        cadata += open(file).read()
    return cadata


class CertRetrievalApi(object):

    def __init__(self):
        self.version = conf.SERVER_SUPPORTED_RCDPV2_VERSIONS[-1]  # the latest-greatest
        self.conn = None
        self.cookie = None

    #
    # Private API
    #

    @staticmethod
    def _parse_rcdp_response(conn, request_name, expected_status):
        response = conn.getresponse()
        response_payload = response.read().decode()
        if response.status != 200:
            raise Exception(
                'Unexpected response HTTP status {} received on {} request.'.format(
                    response.status, request_name))

        payload = json.loads(response_payload)
        debug("{} -> {} {}.\n{}".format(request_name, response.status,
                                        response.reason, pprint.pformat(payload)))

        status = payload[conf.RCDPV2_RESPONSE_PARAM_NAME_STATUS]
        if status != expected_status:
            if status == conf.RCDPV2_RESPONSE_ERROR:
                code = int(payload[conf.RCDPV2_RESPONSE_PARAM_NAME_ERROR_CODE])
                info = payload[conf.RCDPV2_RESPONSE_PARAM_NAME_ERROR_DESCRIPTION]
                if code == int(conf.RCDPV2_ERR_CODE["ErrTimeOutOfSync"]):
                    delta_seconds = int(info)
                    if delta_seconds > 0:
                        raise Exception(
                            'Client time is {} seconds ahead the server time'.format(delta_seconds))
                    else:
                        raise Exception(
                            'Client time is {} seconds behind the server time'.format(-delta_seconds))
                else:
                    raise Exception(
                        'Received error {} for response on {}. Extra info: {}'.format(
                            code, request_name, info))
            else:
                raise Exception(
                    'Expected {} response on {} request but received {} instead'.format(
                        expected_status, request_name, status))

        cookie = response.getheader('set-cookie', None)
        return (payload, cookie)

    @staticmethod
    def _calc_hwsig(formula):
        # @todo implement for real for the given formula
        return "HWSIG-123456"

    @staticmethod
    def _get_system_hwdescription():
        # @todo implement for real
        return "Windows 7, BIOS s/n 1234567890"

    @staticmethod
    def _resolve_host(uri):
        hostname = urllib.parse.urlparse(uri).hostname
        log("Resolving " + hostname)
        ips = []
        try:
            for addr in socket.getaddrinfo(
                    hostname,
                    port=None,
                    family=socket.AF_INET,
                    type=socket.SOCK_STREAM):
                ips.append(addr[4][0])
        except Exception:
            pass
        try:
            for addr in socket.getaddrinfo(
                    hostname,
                    port=None,
                    family=socket.AF_INET6,
                    type=socket.SOCK_STREAM):
                ips.append(addr[4][0])
        except Exception:
            pass
        if not ips:
            raise Exception("Failed to resolve " + hostname)
        return ips

    @staticmethod
    def _calc_digest(path):
        with open(path, 'rb') as f:
            return hashlib.sha256(f.read()).hexdigest()

    @staticmethod
    def is_cr_authentication(auth_requirements):
        return conf.CRED_RESPONSE in auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_CRED_TYPES]

    @staticmethod
    def is_mfa_authentication(auth_requirements):
        return conf.CRED_OTP_MFA in auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_CRED_TYPES]

    @staticmethod
    def is_password_expiring(password_validity_sec):
        """quick&dirty but good enough for tests; see ta::SyfInfo::isUserPasswordExpiring() for proper implementation"""
        seconds_in_day = 24 * 60 * 60
        max_validity_days = 7
        return password_validity_sec >= 0 and\
            password_validity_sec < max_validity_days * seconds_in_day

    @staticmethod
    def decrypt_seat_auth_challenge(encrypted_challenge, service, username):
        signer = "{}.{}.certkey.pem".format(service, username)
        with open("./encrypted_challenge", 'wb') as f:
            f.write(encrypted_challenge)
        result = run_cmd_checked(
            "openssl cms -decrypt -in ./encrypted_challenge -signer {} -inkey {} -inform DER -outform DER -print".format(
                signer, signer))
        os.remove("./encrypted_challenge")
        plain_encoded_challenge = result.stdout
        # challenge is encoded as <body-size>#<body>
        sep_pos = plain_encoded_challenge.find('#')
        if sep_pos == -1:
            raise Exception("Ill-formed challenge '{}'".format(plain_encoded_challenge))
        length = int(plain_encoded_challenge[0:sep_pos])
        plain_challenge = plain_encoded_challenge[sep_pos + 1:sep_pos + length + 1]
        return plain_challenge

    @staticmethod
    def calc_responses(username, challenges, response_names):
        if username == UMTS_USERNAME:
            # expect only 2 challenges and 3 responses
            responses = [
                {
                    conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[0],
                    conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "02020202020202020202020202020202"
                },
                {
                    conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[1],
                    conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "03030303030303030303030303030303"
                },
                {
                    conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[2],
                    conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "04040404040404040404040404040404"
                }
            ]
        elif username == GSM_USERNAME:
            # 3 rounds of request-response
            if challenges['GSM RANDOM'] == '101112131415161718191a1b1c1d1e1f':
                responses = [
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[0],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "d1d2d3d4"
                    },
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[1],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "a0a1a2a3a4a5a6a7"
                    }
                ]
            elif challenges['GSM RANDOM'] == '202122232425262728292a2b2c2d2e2f':
                responses = [
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[0],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "e1e2e3e4"
                    },
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[1],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "b0b1b2b3b4b5b6b7"
                    }
                ]
            elif challenges['GSM RANDOM'] == '303132333435363738393a3b3c3d3e3f':
                responses = [
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[0],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "f1f2f3f4"
                    },
                    {
                        conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_names[1],
                        conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: "c0c1c2c3c4c5c6c7"
                    }
                ]
        else:
            # expect only challenge and one response;
            challenge_value = next(iter(challenges.values()))
            response_name = response_names[0]
            response_value = hashlib.sha1(
                (username + challenge_value).encode()).hexdigest()[:8].upper()
            responses = [
                {
                    conf.RCDPV2_REQUEST_PARAM_NAME_NAME: response_name,
                    conf.RCDPV2_REQUEST_PARAM_NAME_VALUE: response_value
                }
            ]

        return {conf.RCDPV2_REQUEST_PARAM_NAME_RESPONSES: json.dumps(responses)}

    @staticmethod
    def request_auth_credentials(auth_requirements, username, password=None, pincode=None):
        required_cred_types = auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_CRED_TYPES]
        creds = {}
        if conf.CRED_USERID in required_cred_types:
            creds[conf.CRED_USERID] = username
        if conf.CRED_PASSWD in required_cred_types:
            creds[conf.CRED_PASSWD] = password
        if conf.CRED_PIN in required_cred_types:
            creds[conf.CRED_PIN] = pincode
        if conf.CRED_HWSIG in required_cred_types:
            creds[conf.CRED_HWSIG] = CertRetrievalApi._calc_hwsig(
                auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_HWSIG_FORMULA])

        service_uris = auth_requirements.get(conf.RCDPV2_RESPONSE_PARAM_NAME_SERVICE_URIS, None)

        if is_true(auth_requirements,
                   conf.RCDPV2_RESPONSE_PARAM_NAME_RESOLVE_SERVICE_URIS):
            ips = []
            for service_uri in service_uris:
                ips.append({conf.RCDPV2_REQUEST_PARAM_NAME_URI: service_uri,
                            conf.RCDPV2_REQUEST_PARAM_NAME_IPS: CertRetrievalApi._resolve_host(service_uri)})
            creds[conf.RCDPV2_REQUEST_PARAM_NAME_RESOLVED] = json.dumps(ips)

        if is_true(auth_requirements,
                   conf.RCDPV2_RESPONSE_PARAM_NAME_CALC_SERVICE_URIS_DIGEST):
            digests = []
            for service_uri in service_uris:
                digests.append({conf.RCDPV2_REQUEST_PARAM_NAME_URI: service_uri,
                                conf.RCDPV2_REQUEST_PARAM_NAME_DIGEST: CertRetrievalApi._calc_digest(service_uri)})
            creds[conf.RCDPV2_REQUEST_PARAM_NAME_DIGESTS] = json.dumps(digests)

        return creds

    @staticmethod
    def gen_csr(requirements):
        key_size = int(requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_KEY_SIZE])
        signing_algo = requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_SIGNING_ALGO]
        subject = requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_SUBJECT]
        san = requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_SAN] if conf.RCDPV2_RESPONSE_PARAM_NAME_SAN in requirements else []

        log("Generating {}-bit RSA keypair".format(key_size))
        keypair = OpenSSL.crypto.PKey()
        keypair.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
        log("Creating CSR with subject {}, SAN {} and signed by {}".format(subject, san, signing_algo))
        req = OpenSSL.crypto.X509Req()
        CertRetrievalApi._set_subject_on_req(req, subject)
        CertRetrievalApi._set_san_on_req(req, san)
        req.set_pubkey(keypair)
        req.sign(keypair, signing_algo)
        pkcs10_req = OpenSSL.crypto.dump_certificate_request(OpenSSL.crypto.FILETYPE_PEM, req)
        return pkcs10_req

    @staticmethod
    def _set_subject_on_req(req, subject):
        subj = req.get_subject()
        for key, value in subject.items():
            if key == "cn":
                setattr(subj, "CN", value)
            if key == "c":
                setattr(subj, "C", value)
            if key == "st":
                setattr(subj, "ST", value)
            if key == "l":
                setattr(subj, "L", value)
            if key == "o":
                setattr(subj, "O", value)
            if key == "ous":
                setattr(subj, "OU", value[0] if value else "")
            if key == "e":
                setattr(subj, "emailAddress", value)

    @staticmethod
    def _set_san_on_req(req, san):
        if san:
            ext = OpenSSL.crypto.X509Extension(
                b'subjectAltName', False, bytes(
                    ','.join(san), 'utf-8'))
            req.add_extensions([ext])

    @staticmethod
    def _save_cert(cert, passphrase, fmt):
        pass_path = 'certkey-password.txt'
        if fmt == conf.CERTKEY_FORMAT_PEM:
            cert_path = 'cert.pem'
        elif isP12(fmt):
            cert_path = 'cert.pfx'
        else:
            raise Exception('Unexpected certificate format ' + fmt)

        with open(cert_path, 'wb') as cert_file:
            cert_file.write(cert)
        with open(pass_path, 'w') as pass_file:
            pass_file.write(passphrase)

        log("The certificate has been saved to {}, passphrase has been saved to {}".format(
            cert_path, pass_path))
        if fmt == conf.CERTKEY_FORMAT_PEM:
            log("Use the following command if you wish to decrypt the key:\nopenssl rsa -in {} -out key.pem -passin pass:{}".format(
                cert_path, passphrase))
        elif isP12(fmt):
            log("Use the following command if you wish to decrypt the PKCS#12 package:\nopenssl pkcs12 -nodes -in {} -out certkey.pem -passin pass:{}".format(
                cert_path, passphrase))
            # check Pfx key encryption
            expected_key_encryption = "AES-256-CBC" if fmt == conf.CERTKEY_FORMAT_P12_V2 else "TripleDES-CBC"
            result = run_cmd_checked(
                "openssl pkcs12 -info -in {} -passin pass:{} -noout".format(
                    cert_path, passphrase))
            if result.stderr.find(expected_key_encryption) == -1:
                raise Exception(
                    "{} not found in {}".format(
                        expected_key_encryption,
                        result.stderr))

        return cert_path, pass_path

    @staticmethod
    def _save_pem_cert_only(cert):
        cert_path = 'cert.pem'
        with open(cert_path, 'wb') as cert_file:
            cert_file.write(cert)
        log("The certificate has been saved to {}".format(cert_path))
        return cert_path

    def _request(self, action, params={}, method='GET', send_cookie=True):
        if VERBOSE:
            self.conn.set_debuglevel(1)

        url = "/{}/{}/{}".format(conf.RCDPV2_HTTP_REQUEST_URI_PREFIX,
                                 self.version,
                                 action)
        headers = {}
        body = None

        if method == 'GET':
            # HTTP GET params are sent in URL
            if params:
                url += '?' + urllib.parse.urlencode(params)
        elif method == 'POST':
            # HTTP POST params are sent in body
            body = urllib.parse.urlencode(params)
            headers["Content-type"] = "application/x-www-form-urlencoded"
        else:
            raise Exception('Unsupported HTTP request method {}'.format(method))

        if send_cookie:
            headers["Cookie"] = self.cookie

        self.conn.request(method, url, body, headers)

    def _get_cert_passphrase(self):
        parsed_cookie = self.cookie.split('=')
        if len(parsed_cookie) != 2 or parsed_cookie[0] != conf.RCDPV2_HTTP_SID_COOKIE_NAME:
            raise Exception('Cannot parse RCDP session cookie from ' + self.cookie)
        sid = parsed_cookie[1]
        passphrase = sid[:conf.RCDPV2_PACKAGED_CERT_EXPORT_PASSWDSIZE]
        return passphrase

    #
    # Public API
    #

    def eoc(self):
        self._request(conf.RCDPV2_REQUEST_EOC)

    def hello(self):
        log("Connecting RCDP to KeyTalk server at " + KEYTALK_SERVER + "...")

        ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

        if BYPASS_HTTPS_VALIDATION:
            ssl_ctx.check_hostname = False
            ssl_ctx.verify_mode = ssl.CERT_NONE
        else:
            ssl_ctx.verify_mode = ssl.CERT_REQUIRED
            ssl_ctx.load_verify_locations(cadata=server_verification_ca_chain())

        conn = http.client.HTTPSConnection(
            KEYTALK_SERVER,
            conf.RCDP_AND_PUBLIC_API_LISTEN_SSL_PORT,
            context=ssl_ctx)

        self.conn = conn
        request_params = {
            conf.RCDPV2_REQUEST_PARAM_NAME_CALLER_APP_DESCRIPTION: 'Test KeyTalk Python client'}
        self._request(conf.RCDPV2_REQUEST_HELLO, request_params, send_cookie=False)
        response_payload, cookie = CertRetrievalApi._parse_rcdp_response(
            conn, conf.RCDPV2_REQUEST_HELLO, conf.RCDPV2_RESPONSE_HELLO)
        if response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_VERSION] != self.version:
            raise Exception(
                'Unexpected response received on hello request: {}'.format(response_payload))
        self.cookie = cookie

    def handshake(self):
        request_params = {conf.RCDPV2_REQUEST_PARAM_NAME_CALLER_UTC:
                          datetime.datetime.utcnow().isoformat() + 'Z'}
        self._request(conf.RCDPV2_REQUEST_HANDSHAKE, request_params)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_HANDSHAKE, conf.RCDPV2_RESPONSE_HANDSHAKE)
        return response_payload

    def get_service_auth_requirements(self, service):
        request_params = {conf.RCDPV2_REQUEST_PARAM_NAME_SERVICE: service}
        self._request(conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, request_params)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, conf.RCDPV2_RESPONSE_AUTH_REQUIREMENTS)
        return response_payload

    def get_seat_auth_requirements_step1(self, service, userid, computer_name):
        request_params = {conf.RCDPV2_REQUEST_PARAM_NAME_SERVICE: service}
        request_params[conf.CRED_USERID] = userid
        request_params[conf.RCDPV2_REQUEST_PARAM_NAME_COMPUTER_NAME] = computer_name
        self._request(conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, request_params)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, conf.RCDPV2_RESPONSE_AUTH_REQUIREMENTS)
        return response_payload

    def get_seat_auth_requirements_step2(self, hwsig):
        request_params = {conf.CRED_HWSIG: hwsig}
        self._request(conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, request_params)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_AUTH_REQUIREMENTS, conf.RCDPV2_RESPONSE_AUTH_REQUIREMENTS)
        return response_payload

    def authenticate(self, creds, service=None):
        """
        authenticate against TEMPLATE with the given set of credentials
        return password-validity-in-seconds when ok is received or -1 if the password is expired
        return (challenges, response-names) when challenge is received
        raise exception otherwise
        """
        if service is not None:
            request_params = {
                conf.RCDPV2_REQUEST_PARAM_NAME_SERVICE: service,
                conf.RCDPV2_REQUEST_PARAM_NAME_CALLER_HW_DESCRIPTION: CertRetrievalApi._get_system_hwdescription(),
            }
            request_params.update(creds)
        else:
            # normally this means that we have already submitted service name on the first authentication
            # request and now we are on the challenge phase
            request_params = creds

        debug("Sending authentication request: " + pprint.pformat(request_params))

        self._request(conf.RCDPV2_REQUEST_AUTHENTICATION, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_AUTHENTICATION, conf.RCDPV2_RESPONSE_AUTH_RESULT)
        auth_status = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_AUTH_STATUS]
        if auth_status == conf.AUTH_RESULT_OK:
            if conf.RCDPV2_RESPONSE_PARAM_NAME_PASSWORD_VALIDITY in response_payload:
                password_validity = int(
                    response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_PASSWORD_VALIDITY])
            else:
                password_validity = None
            log("Authenticated successfully")
            return password_validity
        elif auth_status == conf.AUTH_RESULT_CHALLENGE:
            log("Challenge received")
            challenges = {}
            for challenge in response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CHALLENGES]:
                challenges[challenge[conf.RCDPV2_RESPONSE_PARAM_NAME_NAME]] = challenge[
                    conf.RCDPV2_RESPONSE_PARAM_NAME_VALUE]
            response_names = response_payload.get(
                conf.RCDPV2_RESPONSE_PARAM_NAME_RESPONSE_NAMES, None)
            return challenges, response_names
        elif auth_status == conf.AUTH_RESULT_EXPIRED:
            log("Password expired")
            return -1
        else:
            raise Exception(
                'Got {} trying to authenticate against service {}.'.format(auth_status, service))

    def authenticate_seat(self, service, username, response):
        """
        authenticate against seat with the supplied response
        """
        request_params = {
            conf.CRED_RESPONSE: response,
            conf.RCDPV2_REQUEST_PARAM_NAME_CALLER_HW_DESCRIPTION: CertRetrievalApi._get_system_hwdescription(),
        }

        debug("Sending seat authentication request for user " + username +
              " and service " + service + ": " + pprint.pformat(request_params))

        self._request(conf.RCDPV2_REQUEST_AUTHENTICATION, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_AUTHENTICATION, conf.RCDPV2_RESPONSE_AUTH_RESULT)
        auth_status = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_AUTH_STATUS]
        if auth_status == conf.AUTH_RESULT_OK:
            log("Authenticated successfully")
            return
        else:
            raise Exception(
                'Got {} trying to authenticate seat with user {} of service {}.'.format(
                    auth_status, username, service))

    def get_smb_certs(self, fmt, out_of_band=False):
        request_params = {
            conf.RCDPV2_REQUEST_PARAM_NAME_CERTKEY_FORMAT: fmt,
            conf.RCDPV2_REQUEST_PARAM_NAME_CERT_OUT_OF_BAND: out_of_band,
        }

        self._request(conf.RCDPV2_REQUEST_SMBCERTS, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_SMBCERTS, conf.RCDPV2_RESPONSE_SMBCERTS)

        smbcerts = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_SMBCERTS]
        if out_of_band:
            # download OOB certs given the URLss
            for smbcert in smbcerts:
                if conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_URL_TEMPL in smbcert:
                    cert_url_templ = smbcert[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_URL_TEMPL]
                    cert_url = cert_url_templ.replace(
                        "$(" + conf.CERT_DOWNLOAD_URL_HOST_PLACEHOLDER + ")", KEYTALK_SERVER)
                    cert = fetch_url(cert_url)
                    smbcert[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT] = cert
                    assert not fetch_url(
                        cert_url), "the given certificate can only be downloaded once"
                if conf.RCDPV2_RESPONSE_PARAM_NAME_HISTORICAL_CERTS_URL_TEMPL in smbcert:
                    historical_certs_url_templ = smbcert[conf.RCDPV2_RESPONSE_PARAM_NAME_HISTORICAL_CERTS_URL_TEMPL]
                    historical_certs_url = historical_certs_url_templ.replace(
                        "$(" + conf.CERT_DOWNLOAD_URL_HOST_PLACEHOLDER + ")", KEYTALK_SERVER)
                    historical_certs = fetch_url(historical_certs_url)
                    smbcert[conf.RCDPV2_RESPONSE_PARAM_NAME_HISTORICAL_CERTS] = historical_certs
                    assert not fetch_url(
                        historical_certs_url), "the given historical certificates can only be downloaded once"

        return smbcerts

    def change_password(self, old_password, new_password):
        request_params = {
            conf.RCDPV2_REQUEST_PARAM_NAME_OLD_PASSWORD: old_password,
            conf.RCDPV2_REQUEST_PARAM_NAME_NEW_PASSWORD: new_password
        }
        debug("Changing user password")
        self._request(conf.RCDPV2_REQUEST_CHANGE_PASSWORD, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_CHANGE_PASSWORD, conf.RCDPV2_RESPONSE_AUTH_RESULT)
        auth_status = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_AUTH_STATUS]
        if auth_status == conf.AUTH_RESULT_OK:
            log("Password successfully changed")
        else:
            raise Exception(
                'Got {} trying to change user password.'.format(auth_status))

    def get_last_messages(self):
        request_params = {}
        if LAST_MESSAGES_FROM_UTC is not None:
            request_params[
                conf.RCDPV2_REQUEST_PARAM_NAME_LAST_MESSAGES_FROM_UTC] = LAST_MESSAGES_FROM_UTC
        self._request(conf.RCDPV2_REQUEST_LAST_MESSAGES, request_params)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_LAST_MESSAGES, conf.RCDPV2_RESPONSE_LAST_MESSAGES)
        messages = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_MESSAGES] or []
        if len(messages) > 0:
            log("Received {} user messages:\n{}".format(len(messages), pprint.pformat(messages)))
        return response_payload

    def get_csr_requirements(self):
        self._request(conf.RCDPV2_REQUEST_CSR_REQUIREMENTS)
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_CSR_REQUIREMENTS, conf.RCDPV2_RESPONSE_CSR_REQUIREMENTS)
        return response_payload

    def get_cert(self, fmt, out_of_band=False):
        request_params = {
            conf.RCDPV2_REQUEST_PARAM_NAME_CERTKEY_FORMAT: fmt,
            conf.RCDPV2_REQUEST_PARAM_NAME_CERT_OUT_OF_BAND: out_of_band,
        }

        self._request(conf.RCDPV2_REQUEST_CERT, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_CERT, conf.RCDPV2_RESPONSE_CERT)

        if out_of_band:
            cert_url_templ = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_URL_TEMPL]
            cert_url = cert_url_templ.replace(
                "$(" + conf.CERT_DOWNLOAD_URL_HOST_PLACEHOLDER + ")", KEYTALK_SERVER)
            cert = fetch_url(cert_url)
            assert not fetch_url(cert_url), "the given certificate can only be downloaded once"
            if conf.RCDPV2_RESPONSE_PARAM_NAME_HISTORICAL_CERTS_URL_TEMPL in response_payload:
                historical_certs_url_templ = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_HISTORICAL_CERTS_URL_TEMPL]
                historical_certs_url = historical_certs_url_templ.replace(
                    "$(" + conf.CERT_DOWNLOAD_URL_HOST_PLACEHOLDER + ")", KEYTALK_SERVER)
                cert += fetch_url(historical_certs_url)
                assert not fetch_url(
                    historical_certs_url), "the given historical certificates can only be downloaded once"
        else:
            cert = bytes(response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT], 'utf-8')
            if isP12(fmt):
                cert = base64.b64decode(cert)

        store_cert_to_system = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_STORE_TO_SYSTEM] if conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_STORE_TO_SYSTEM in response_payload else False

        cert_passphrase = self._get_cert_passphrase()
        log("Successfully received {} certificate, {} store to system store".format(
            fmt, "should" if store_cert_to_system else "shouldn't"))
        cert_path, cert_pass_path = CertRetrievalApi._save_cert(cert, cert_passphrase, fmt)
        return cert_path, cert_pass_path

    def sign_csr(self, csr, out_of_band=False):
        request_params = {
            conf.RCDPV2_REQUEST_PARAM_NAME_CSR: csr,
            conf.RCDPV2_REQUEST_PARAM_NAME_CERT_OUT_OF_BAND: out_of_band,
        }

        self._request(conf.RCDPV2_REQUEST_CERT, request_params, method='POST')
        response_payload, _ = CertRetrievalApi._parse_rcdp_response(
            self.conn, conf.RCDPV2_REQUEST_CERT, conf.RCDPV2_RESPONSE_CERT)

        if out_of_band:
            cert_url_templ = response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT_URL_TEMPL]
            cert_url = cert_url_templ.replace(
                "$(" + conf.CERT_DOWNLOAD_URL_HOST_PLACEHOLDER + ")", KEYTALK_SERVER)
            cert = fetch_url(cert_url)
            assert not fetch_url(cert_url), "the given certificate can only be downloaded once"
        else:
            cert = bytes(response_payload[conf.RCDPV2_RESPONSE_PARAM_NAME_CERT], 'utf-8')

        log("Successfully generated PEM certificate from client CSR")
        cert_path = CertRetrievalApi._save_pem_cert_only(cert)
        return cert_path

    def reset_user_password(self, password):
        pass


class CaApi(object):

    def __init__(self):
        self.port = conf.CA_API_AND_CERT_DOWNLOAD_NOSSL_LISTEN_PORT
        self.script = conf.CA_API_REQUEST_SCRIPT_NAME
        self.version = conf.SERVER_SUPPORTED_CA_API_VERSIONS[-1]  # the latest-greatest

    def _url(self, ca_name):
        return "http://{}:{}/{}/{}/{}".format(KEYTALK_SERVER,
                                              self.port,
                                              self.script,
                                              self.version,
                                              ca_name)

    def fetch_ca(self, ca_name):
        url = self._url(ca_name)
        return fetch_url(url)


class PublicApi(object):

    def __init__(self):
        self.conn = None

    def _request(self, action, params={}, method='GET'):
        log("Connecting public API to KeyTalk server at " + KEYTALK_SERVER + "...")

        ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

        if BYPASS_HTTPS_VALIDATION:
            ssl_ctx.check_hostname = False
            ssl_ctx.verify_mode = ssl.CERT_NONE
        else:
            ssl_ctx.verify_mode = ssl.CERT_REQUIRED
            ssl_ctx.load_verify_locations(cadata=server_verification_ca_chain())

        self.conn = http.client.HTTPSConnection(
            KEYTALK_SERVER,
            conf.RCDP_AND_PUBLIC_API_LISTEN_SSL_PORT,
            context=ssl_ctx)

        if VERBOSE:
            self.conn.set_debuglevel(1)

        url = "/{}/{}".format(conf.PUBLIC_API_REQUEST_SCRIPT_NAME, action)
        headers = {}
        body = None

        if method == 'GET':
            # HTTP GET params are sent in URL
            if params:
                url += '?' + urllib.parse.urlencode(params)
        elif method == 'POST':
            # HTTP POST params are sent in body
            body = urllib.parse.urlencode(params)
            headers["Content-type"] = "application/x-www-form-urlencoded"
        else:
            raise Exception('Unsupported HTTP request method {}'.format(method))

        self.conn.request(method, url, body, headers)

    @staticmethod
    def _parse_response(conn, request_name, expected_status):
        '''
        raise BadRequestError for request errors
        '''
        response = conn.getresponse()
        response_payload = response.read().decode()
        if response.status not in (200, 400):
            raise Exception(
                'Unexpected response HTTP status {} received on {} request.'.format(
                    response.status, request_name))

        payload = json.loads(response_payload)
        debug("{} -> {} {}.\n{}".format(request_name, response.status,
                                        response.reason, pprint.pformat(payload)))

        status = payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_STATUS]
        if response.status == 400:
            if status != conf.PUBLIC_API_RESPONSE_ERROR:
                raise Exception(
                    "Unexpected status {} received in HTTP 400 response received on {} request".format(
                        status, request_name))
            error_msg = payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_ERROR]
            raise BadRequestError(error_msg)

        if status != expected_status:
            raise Exception(
                'Expected {} response on {} request but received {} instead'.format(
                    expected_status, request_name, status))

        return payload

    def fetch_ktagent_unattended_popup_msg_template(self):
        self._request(conf.PUBLIC_API_REQUEST_KTAGENT_UNATTENDED_POPUP_MSG_TEMPLATE,
                      {},
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_KTAGENT_UNATTENDED_POPUP_MSG_TEMPLATE,
            conf.PUBLIC_API_RESPONSE_KTAGENT_UNATTENDED_POPUP_MSG_TEMPLATE)
        msg_template = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_MSG_TEMPLATE]
        return msg_template

    def is_self_service_available(self, cert):
        self._request(conf.PUBLIC_API_REQUEST_SELF_SERVICE_AVAILABILITY,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_CERT: cert},
                      method='POST')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_SELF_SERVICE_AVAILABILITY,
            conf.PUBLIC_API_RESPONSE_SELF_SERVICE_AVAILABILITY)
        if is_true(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return True
        elif is_false(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return False
        else:
            raise Exception("Failed to parse boolean from {} key of the response {}".format(
                conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE, response_payload))

    def fetch_keep_alive_interval(self, tempate_name):
        self._request(conf.PUBLIC_API_REQUEST_KEEP_ALIVE_INTERVAL,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: tempate_name},
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_KEEP_ALIVE_INTERVAL,
            conf.PUBLIC_API_RESPONSE_KEEP_ALIVE_INTERVAL)
        intvl = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_INTERVAL]
        return intvl

    def send_keep_alive(self, tempate_name, user_name, computer_name):
        self._request(conf.PUBLIC_API_REQUEST_I_AM_ALIVE,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: tempate_name,
                       conf.PUBLIC_API_REQUEST_PARAM_NAME_USER: user_name,
                       conf.PUBLIC_API_REQUEST_PARAM_NAME_COMPUTER_NAME: computer_name,
                       },
                      method='POST')
        PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_I_AM_ALIVE,
            conf.PUBLIC_API_RESPONSE_SUCCESS)

    def is_self_service_available(self, cert):
        self._request(conf.PUBLIC_API_REQUEST_SELF_SERVICE_AVAILABILITY,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_CERT: cert},
                      method='POST')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_SELF_SERVICE_AVAILABILITY,
            conf.PUBLIC_API_RESPONSE_SELF_SERVICE_AVAILABILITY)
        if is_true(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return True
        elif is_false(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return False
        else:
            raise Exception("Failed to parse boolean from {} key of the response {}".format(
                conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE, response_payload))

    def is_smime_cert_enrollment_available(self, cert):
        """ return (true, sms-required) or (false, reason) """
        self._request(conf.PUBLIC_API_REQUEST_SMIME_CERT_ENROLLMENT_AVAILABILITY,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_CERT: cert},
                      method='POST')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_SMIME_CERT_ENROLLMENT_AVAILABILITY,
            conf.PUBLIC_API_RESPONSE_SMIME_CERT_ENROLLMENT_AVAILABILITY)
        if is_true(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return (
                True,
                is_true(
                    response_payload,
                    conf.PUBLIC_API_RESPONSE_PARAM_NAME_MOBILE_REQUIRED,
                    strict=True))
        elif is_false(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE):
            return (False, response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_REASON])
        else:
            raise Exception("Failed to parse boolean from {} key of the response {}".format(
                conf.PUBLIC_API_RESPONSE_PARAM_NAME_AVAILABLE, response_payload))

    def retrieve_address_books(self, service):
        self._request(conf.PUBLIC_API_REQUEST_ADDRESS_BOOK_LIST,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_SERVICE: service},
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_ADDRESS_BOOK_LIST,
            conf.PUBLIC_API_RESPONSE_ADDRESS_BOOK_LIST)
        books = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_ADDRESS_BOOKS] or []
        apply_address_books = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_APPLY_ADDRESS_BOOKS] == "true"
        return books, apply_address_books

    def should_cert_go_to_system_store(self, service):
        self._request(conf.PUBLIC_API_REQUEST_SHOULD_CERT_GO_TO_SYSTEM_STORE,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_SERVICE: service},
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_SHOULD_CERT_GO_TO_SYSTEM_STORE,
            conf.PUBLIC_API_RESPONSE_SHOULD_CERT_GO_TO_SYSTEM_STORE)
        if is_true(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_SYSTEM_STORE):
            return True
        elif is_false(response_payload, conf.PUBLIC_API_RESPONSE_PARAM_NAME_SYSTEM_STORE):
            return False
        else:
            raise Exception("Failed to parse boolean from {} key of the response {}".format(
                conf.PUBLIC_API_RESPONSE_PARAM_NAME_SYSTEM_STORE, response_payload))

    def retrieve_cert_expiry_margin(self, service, user="", computer_name=""):
        self._request(conf.PUBLIC_API_REQUEST_CERT_EXPIRATION_MARGIN,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_SERVICE: service,
                       conf.PUBLIC_API_REQUEST_PARAM_NAME_USER: user,
                       conf.PUBLIC_API_REQUEST_PARAM_NAME_COMPUTER_NAME: computer_name,
                       },
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_PARAM_NAME_SERVICE,
            conf.PUBLIC_API_RESPONSE_CERT_EXPIRATION_MARGIN)
        threshold = int(response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_THRESHOLD_SECONDS])
        return threshold

    def check_server_health(self):
        self._request(conf.PUBLIC_API_REQUEST_HEALTH_CHECK)
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_HEALTH_CHECK,
            conf.PUBLIC_API_RESPONSE_HEALTH_CHECK)
        result = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_CHECK_RESULT]
        return result

    def query_template_names_to_autorenew_seat_certs(self, templates):
        self._request(conf.PUBLIC_API_RESPONSE_TEMPLATES_TO_AUTO_RENEW_SEAT_CERTS,
                      {conf.PUBLIC_API_REQUEST_PARAM_NAME_TEMPLATE_NAMES: json.dumps(templates)},
                      method='POST')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_RESPONSE_TEMPLATES_TO_AUTO_RENEW_SEAT_CERTS,
            conf.PUBLIC_API_RESPONSE_TEMPLATES_TO_AUTO_RENEW_SEAT_CERTS)
        return response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_TEMPLATE_NAMES] or []

    def retrieve_server_version(self):
        self._request(conf.PUBLIC_API_REQUEST_VERSION,
                      {},
                      method='GET')
        response_payload = PublicApi._parse_response(
            self.conn,
            conf.PUBLIC_API_REQUEST_VERSION,
            conf.PUBLIC_API_RESPONSE_VERSION)
        version = response_payload[conf.PUBLIC_API_RESPONSE_PARAM_NAME_VERSION]
        return version


class AdminApi(object):

    def __init__(self):
        self.conn = None

    @staticmethod
    def gen_csr(requirements):
        key_size = int(requirements[conf.ADMIN_API_RESPONSE_PARAM_NAME_KEY_SIZE])
        signing_algo = requirements[conf.ADMIN_API_RESPONSE_PARAM_NAME_SIGNING_ALGO]
        subject = requirements[conf.ADMIN_API_RESPONSE_PARAM_NAME_SUBJECT]
        san = requirements[conf.ADMIN_API_RESPONSE_PARAM_NAME_SAN] if conf.ADMIN_API_RESPONSE_PARAM_NAME_SAN in requirements else []

        log("Generating {}-bit RSA keypair".format(key_size))
        keypair = OpenSSL.crypto.PKey()
        keypair.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
        log("Creating CSR with subject {}, SAN {} and signed by {}".format(subject, san, signing_algo))
        req = OpenSSL.crypto.X509Req()
        CertRetrievalApi._set_subject_on_req(req, subject)
        CertRetrievalApi._set_san_on_req(req, san)
        req.set_pubkey(keypair)
        req.sign(keypair, signing_algo)
        pkcs10_req = OpenSSL.crypto.dump_certificate_request(OpenSSL.crypto.FILETYPE_PEM, req)
        return pkcs10_req

    def _request(self, action, params={}, method='POST'):
        log("Connecting admin API to KeyTalk server at " + KEYTALK_SERVER + "...")

        ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

        if BYPASS_HTTPS_VALIDATION:
            ssl_ctx.check_hostname = False
            ssl_ctx.verify_mode = ssl.CERT_NONE
        else:
            ssl_ctx.verify_mode = ssl.CERT_REQUIRED
            ssl_ctx.load_verify_locations(cadata=server_verification_ca_chain())

        self.conn = http.client.HTTPSConnection(
            KEYTALK_SERVER,
            conf.WEBUI_AND_ADMIN_API_LISTEN_PORT,
            context=ssl_ctx)

        if VERBOSE:
            self.conn.set_debuglevel(1)

        url = "/{}/{}".format(conf.ADMIN_API_REQUEST_SCRIPT_NAME, action)
        headers = {}
        body = None

        if method == 'GET':
            # HTTP GET params are sent in URL
            if params:
                url += '?' + urllib.parse.urlencode(params)
        elif method == 'POST':
            # HTTP POST params are sent in body
            body = urllib.parse.urlencode(params)
            headers["Content-type"] = "application/x-www-form-urlencoded"
        else:
            raise Exception('Unsupported HTTP request method {}'.format(method))

        self.conn.request(method, url, body, headers)

    @staticmethod
    def _parse_response(conn, request_name, expected_status):
        '''
        raise BadRequestError for request errors
        '''
        response = conn.getresponse()
        response_payload = response.read().decode()
        if response.status not in (200, 400):
            raise Exception(
                'Unexpected response HTTP status {} received on {} request.'.format(
                    response.status, request_name))

        payload = json.loads(response_payload)
        debug("{} -> {} {}.\n{}".format(request_name, response.status,
                                        response.reason, pprint.pformat(payload)))

        status = payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_STATUS]
        if response.status == 400:
            if status != conf.ADMIN_API_RESPONSE_PARAM_NAME_ERROR:
                raise Exception(
                    "Unexpected status {} received in HTTP 400 response received on {} request".format(
                        status, request_name))
            error_msg = payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_ERROR]
            raise BadRequestError(error_msg)

        if status != expected_status:
            raise Exception(
                'Expected {} response on {} request but received {} instead'.format(
                    expected_status, request_name, status))

        return payload

    @staticmethod
    def _parse_download_response(conn, request_name):
        '''
        raise BadRequestError for request errors
        '''
        response = conn.getresponse()

        # The download size might be quite large, which might cause up http.client to hiccup (BUG?). Hence the patch.
        # @todo try to change to Python requests library iso maintaining this lousy workaround.
        response_payload = b''
        while True:
            try:
                response_payload += response.read()
                break
            except http.client.IncompleteRead as icread:
                response_payload += icread.partial
                continue

        if response.status != 200:
            raise Exception(
                'Unexpected response HTTP status {} received on {} request.'.format(
                    response.status, request_name))
        debug("{} -> {} {}".format(request_name, response.status, response.reason))
        return response_payload

    def enroll_cert(self, service, user, san, is_new_seat):
        request_params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
            conf.ADMIN_API_REQUEST_PARAM_NAME_SERVICE_NAME: service,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DEVID_USERNAME: user,
        }
        if san:
            request_params[conf.ADMIN_API_REQUEST_PARAM_NAME_SAN] = json.dumps(san)

        self._request(conf.ADMIN_API_REQUEST_CERT_ENROLLMENT, request_params)

        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CERT_ENROLLMENT,
            conf.ADMIN_API_RESPONSE_CERT_ENROLLMENT)
        cert = bytes(response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_CERT], 'utf-8')

        parsed_cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
        assert parsed_cert.get_subject().commonName.lower() == user.lower()
        if san:
            for i in range(len(san)):
                if san[i].startswith("IP:"):
                    san[i] = san[i].replace("IP:", "IP Address:")
            assert get_cert_san(parsed_cert) == ', '.join(san)
        if is_new_seat:
            created_user_auth_password = response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_CREATED_USER_AUTH_PASSWORD]
            assert created_user_auth_password != ""
        log("Enrolled certificate with SHA1 {}".format(
            parsed_cert.digest("sha1").decode().lower().replace(':', '')))

    def get_csr_enrollment_requirements(self, service, user):
        request_params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
            conf.ADMIN_API_REQUEST_PARAM_NAME_SERVICE_NAME: service,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DEVID_USERNAME: user,
        }

        self._request(conf.ADMIN_API_REQUEST_CSR_ENROLLMENT_REQUIREMENTS, request_params)

        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CSR_ENROLLMENT_REQUIREMENTS,
            conf.ADMIN_API_RESPONSE_CSR_ENROLLMENT_REQUIREMENTS)
        return response_payload

    def enroll_cert_for_csr(self, service, user, csr, is_new_seat):
        request_params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
            conf.ADMIN_API_REQUEST_PARAM_NAME_SERVICE_NAME: service,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DEVID_USERNAME: user,
            conf.ADMIN_API_REQUEST_PARAM_NAME_CSR: csr,
        }

        self._request(conf.ADMIN_API_REQUEST_CERT_ENROLLMENT_FOR_CSR, request_params)

        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CERT_ENROLLMENT_FOR_CSR,
            conf.ADMIN_API_RESPONSE_CERT_ENROLLMENT)
        cert = bytes(response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_CERT], 'utf-8')

        parsed_cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
        assert parsed_cert.get_subject().commonName.lower() == user.lower()
        if is_new_seat:
            created_user_auth_password = response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_CREATED_USER_AUTH_PASSWORD]
            assert created_user_auth_password != ""
        log("Enrolled certificate with SHA1 {}".format(
            parsed_cert.digest("sha1").decode().lower().replace(':', '')))

    def revoke_certs(self, service, user):
        self._request(conf.ADMIN_API_REQUEST_CERT_REVOCATION,
                      {
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_SERVICE_NAME: service,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_DEVID_USERNAME: user,
                      },
                      )
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CERT_REVOCATION,
            conf.ADMIN_API_RESPONSE_CERT_REVOCATION)
        num_revoked_certs = int(
            response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_NUM_REVOKED_CERTS])
        log("{} certificates revoked".format(num_revoked_certs))
        return num_revoked_certs

    def create_internal_ra_user(self, service, username):
        self._request(conf.ADMIN_API_REQUEST_CREATE_INTERNAL_RA_USER,
                      {
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: service,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_USER_NAME: username,
                      },
                      )
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CREATE_INTERNAL_RA_USER,
            conf.ADMIN_API_RESPONSE_SUCCESS)
        debug("Successfully created internal RA user {} under service {}".format(username, service))

    def download_server_settings(
            self,
            include_shared_settings,
            include_db_connection_settings,
            include_hsm_connection_settings,
            include_keytalk_cert_tree,
            include_webapi_server_settings,
            include_aad_certwriter_settings):
        self._request(
            conf.ADMIN_API_REQUEST_KEYTALK_SETTINGS,
            {
                conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
                conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_SHARED_SETTINGS: include_shared_settings,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_DB_CONNECTION_SETTINGS: include_db_connection_settings,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_HSM_CONNECTION_SETTINGS: include_hsm_connection_settings,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_KEYTALK_CERT_TREE: include_keytalk_cert_tree,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_WEBAPI_SERVER_SETTINGS: include_webapi_server_settings,
                conf.ADMIN_API_REQUEST_PARAM_NAME_INCLUDE_AAD_CERTWRITER_SETTINGS: include_aad_certwriter_settings,
            },
        )
        response_payload = AdminApi._parse_download_response(
            self.conn,
            conf.ADMIN_API_REQUEST_KEYTALK_SETTINGS)
        return response_payload

    def copy_template(self, src_template, new_template, digicert_settings={}):
        params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
            conf.ADMIN_API_REQUEST_PARAM_NAME_SRC_TEMPLATE_NAME: src_template,
            conf.ADMIN_API_REQUEST_PARAM_NAME_NEW_TEMPLATE_NAME: new_template,
        }
        if digicert_settings:
            params[conf.ADMIN_API_REQUEST_PARAM_NAME_DIGICERT_CENTRAL_SETTINGS] = json.dumps(
                digicert_settings)
        self._request(conf.ADMIN_API_REQUEST_COPY_TEMPLATE, params)
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_COPY_TEMPLATE,
            conf.ADMIN_API_RESPONSE_SUCCESS)
        debug("Successfully copied template {} to {}".format(src_template, new_template))

    def create_digicert_dv_acme_template(
            self,
            new_template,
            digicert_account_region,
            digicert_product,
            digicert_api_key,
            mail_fetch_proto,
            cert_validity_months,
            azure_client_id,
            azure_client_secret,
            azure_tenant_id,
            approver_email):
        params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
            conf.ADMIN_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: new_template,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DIGICERT_ACCOUNT_REGION: digicert_account_region,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DIGICERT_PRODUCT: digicert_product,
            conf.ADMIN_API_REQUEST_PARAM_NAME_DIGICERT_API_KEY: digicert_api_key,
            conf.ADMIN_API_REQUEST_PARAM_NAME_MAIL_FETCH_PROTO: mail_fetch_proto,
            conf.ADMIN_API_REQUEST_PARAM_NAME_CERT_VALIDITY_MONTHS: cert_validity_months,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AZURE_CLIENT_ID: azure_client_id,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AZURE_CLIENT_SECRET: azure_client_secret,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AZURE_TENANT_ID: azure_tenant_id,
            conf.ADMIN_API_REQUEST_PARAM_NAME_APPROVER_EMAIL: approver_email,
        }
        self._request(conf.ADMIN_API_REQUEST_CREATE_DIGICERT_DV_ACME_TEMPLATE, params)
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CREATE_DIGICERT_DV_ACME_TEMPLATE,
            conf.ADMIN_API_RESPONSE_SUCCESS)
        acme_url = response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_ACME_DIRECTORY_URL]
        debug(
            f"Successfully created DigiCert DV template for ACME {new_template}. ACME URL: {acme_url}")
        return acme_url

    def list_templates(self):
        params = {
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
            conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
        }
        self._request(conf.ADMIN_API_REQUEST_LIST_TEMPLATES, params)
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_COPY_TEMPLATE,
            conf.ADMIN_API_RESPONSE_SUCCESS)
        templates = response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_TEMPLATE_NAMES]
        debug("Successfully got templates list: {}".format(templates))
        return templates

    def remove_seat(self, service, seat_name):
        self._request(conf.ADMIN_API_REQUEST_REMOVE_SEAT,
                      {
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: service,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_SEAT_NAME: seat_name,
                      },
                      )
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_REMOVE_SEAT,
            conf.ADMIN_API_RESPONSE_REMOVE_SEAT)
        seat_removed = is_true(response_payload, conf.ADMIN_API_RESPONSE_PARAM_NAME_REMOVED)
        debug("Seat {} of service {} {} removed" .format(
            seat_name, service, "successfully" if seat_removed else "not"))
        return seat_removed

    def create_seat(self, service, seat_name, cn, san):
        self._request(conf.ADMIN_API_REQUEST_CREATE_SEAT,
                      {
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_USERNAME: WEBUI_ADMIN_USERNAME,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_AUTH_PASSWORD: WEBUI_ADMIN_PASSWORD,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_TEMPLATE_NAME: service,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_SEAT_NAME: seat_name,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_CN: cn,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_SAN: json.dumps(san),
                      },
                      )
        response_payload = AdminApi._parse_response(
            self.conn,
            conf.ADMIN_API_REQUEST_CREATE_SEAT,
            conf.ADMIN_API_RESPONSE_SUCCESS)
        result = response_payload[conf.ADMIN_API_RESPONSE_PARAM_NAME_RESULT]
        debug("Seat {} of service {} successfully {}".format(seat_name, service, result))
        return result


#
# Test cases
#


def request_cert_with_password_authentication(cert_format):
    service = "CUST_PASSWD_INTERNAL_TESTUI"
    username = 'DemoUser'
    password = 'secret'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert not CertRetrievalApi.is_cr_authentication(
        auth_requirements), "Non-CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(auth_requirements, username, password)
    api.authenticate(creds, service)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format)
    # close connection
    api.eoc()


def request_cert_from_csr_with_password_authentication():
    service = "CUST_PASSWD_INTERNAL"
    username = 'DemoUser'
    password = 'secret'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert not CertRetrievalApi.is_cr_authentication(
        auth_requirements), "Non-CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(auth_requirements, username, password)
    api.authenticate(creds, service)
    # get service
    api.get_last_messages()
    csr_requirements = api.get_csr_requirements()
    csr = CertRetrievalApi.gen_csr(csr_requirements)
    api.sign_csr(csr)
    # close connection
    api.eoc()


def request_out_of_band_cert_with_password_authentication(cert_format):
    service = "CUST_PASSWD_INTERNAL_TESTUI"
    username = 'DemoUser'
    password = 'secret'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert not CertRetrievalApi.is_cr_authentication(
        auth_requirements), "Non-CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(auth_requirements, username, password)
    api.authenticate(creds, service)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format, out_of_band=True)
    # close connection
    api.eoc()


def request_out_of_band_smb_certs_with_seat_authentication(cert_format):
    service = "CUST_SMB_SMIME"
    username = "DemoUser"

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_seat_auth_requirements_step1(
        service, username, "demouser.test.com")
    hwsig = CertRetrievalApi._calc_hwsig(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_HWSIG_FORMULA])
    auth_requirements = api.get_seat_auth_requirements_step2(hwsig)
    encrypted_challenge = base64.b64decode(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_ENCRYPTED_CHALLENGE])
    response = CertRetrievalApi.decrypt_seat_auth_challenge(encrypted_challenge, service, username)
    api.authenticate_seat(service, username, response)
    # get SMBs
    smbs = api.get_smb_certs(cert_format, out_of_band=True)
    assert len(smbs) == 2, "Received {} SMB entries for service {} and username {}, expected 2".format(
        len(smbs), service, username)
    # @todo check certs
    # close connection
    api.eoc()


def request_cert_with_password_and_pincode_authentication(cert_format):
    service = "CUST_PIN_PASSWD_INTERNAL_TESTUI"
    username = 'DemoUser'
    password = 'secret'
    pincode = '1234'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert not CertRetrievalApi.is_cr_authentication(
        auth_requirements), "Non-CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(
        auth_requirements, username, password, pincode)
    api.authenticate(creds, service)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format)
    # close connection
    api.eoc()


def request_cert_with_challenge_response_authentication(cert_format):
    service = "CUST_CR_INTERNAL_TESTUI"
    username = 'DemoUser'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert CertRetrievalApi.is_cr_authentication(
        auth_requirements), "CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(auth_requirements, username)
    challenges, response_names = api.authenticate(creds, service)
    creds = CertRetrievalApi.calc_responses(username, challenges, response_names)
    api.authenticate(creds)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format)
    # close connection
    api.eoc()


def request_smb_certs_with_seat_authentication(cert_format):
    service = "CUST_SMB_SMIME"
    username = "DemoUser"

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_seat_auth_requirements_step1(service, username, "test.keytalk.com")
    hwsig = CertRetrievalApi._calc_hwsig(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_HWSIG_FORMULA])
    auth_requirements = api.get_seat_auth_requirements_step2(hwsig)
    encrypted_challenge = base64.b64decode(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_ENCRYPTED_CHALLENGE])
    response = CertRetrievalApi.decrypt_seat_auth_challenge(encrypted_challenge, service, username)
    api.authenticate_seat(service, username, response)
    # get SMBs
    smbs = api.get_smb_certs(cert_format, out_of_band=False)
    assert len(smbs) == 2, "Received {} SMB entries for service {} and user {}, expected 2".format(
        len(smbs), service, username)
    # @todo check certs
    # close connection
    api.eoc()


def renew_cert_with_seat_authentication(cert_format):
    service = "CUST_SMB_SMIME"
    username = "DemoUser"

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_seat_auth_requirements_step1(service, username, "test.keytalk.com")
    hwsig = CertRetrievalApi._calc_hwsig(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_HWSIG_FORMULA])
    auth_requirements = api.get_seat_auth_requirements_step2(hwsig)
    encrypted_challenge = base64.b64decode(
        auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_ENCRYPTED_CHALLENGE])
    response = CertRetrievalApi.decrypt_seat_auth_challenge(encrypted_challenge, service, username)
    api.authenticate_seat(service, username, response)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format)
    # close connection
    api.eoc()


def change_password_and_request_cert(cert_format):
    log("SKIPPIING AD password change test because our test AD is not configured to allow that")
    return

    service = "CUST_PASSWD_AD"
    username = 'TestUser'
    old_password = 'Sioux2010'
    new_password = 'Sioux2011'

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # authenticate
    auth_requirements = api.get_service_auth_requirements(service)
    assert not CertRetrievalApi.is_cr_authentication(
        auth_requirements), "Non-CR authentication is expected for service {}".format(service)
    creds = CertRetrievalApi.request_auth_credentials(auth_requirements, username, old_password)
    password_validity_sec = api.authenticate(creds, service)
    assert CertRetrievalApi.is_password_expiring(
        password_validity_sec), "Password for user {} and service {} is not yet expiring (still valid for {} seconds)".format(username, service, password_validity_sec)
    api.change_password(old_password, new_password)
    creds[conf.CRED_PASSWD] = new_password
    api.authenticate(creds, service)
    # get service
    api.get_last_messages()
    api.get_cert(cert_format)
    # reset password back
    api.change_password(new_password, old_password)
    creds[conf.CRED_PASSWD] = old_password
    api.authenticate(creds, service)
    # close connection
    api.eoc()


def fetch_ca_certs():
    api = CaApi()

    cert = api.fetch_ca(conf.SIGNING_CA)
    debug(
        "Fetched Signing CA. Subject: {}".format(
            OpenSSL.crypto.load_certificate(
                OpenSSL.crypto.FILETYPE_PEM,
                cert).get_subject()))

    cert = api.fetch_ca(conf.COMMUNICATION_CA)
    debug(
        "Fetched Communication CA. Subject: {}".format(
            OpenSSL.crypto.load_certificate(
                OpenSSL.crypto.FILETYPE_PEM,
                cert).get_subject()))

    cert = api.fetch_ca(conf.PRIMARY_CA)
    debug(
        "Fetched Primary CA. Subject: {}".format(
            OpenSSL.crypto.load_certificate(
                OpenSSL.crypto.FILETYPE_PEM,
                cert).get_subject()))

    assert not api.fetch_ca(conf.ROOT_CA), "Root CA is not expected to be present"


def check_self_service_availability():

    api = PublicApi()

    # given, get the cert for which self-service should be available
    request_cert_with_password_authentication(conf.CERTKEY_FORMAT_PEM)
    cert = open('cert.pem').read()
    # when-then
    assert api.is_self_service_available(cert)

    # given, cert that doesn't correspond to a seat hence no self-service
    cert = open(SERVER_VERIFICATION_CA_CHAIN[0]).read()
    # when-then
    assert not api.is_self_service_available(cert)

    try:
        api.is_self_service_available("invalid-cert")
        assert False, "Invalid certificate is not reported as bad request"
    except BadRequestError:
        pass


def retrieve_address_books():

    api = PublicApi()

    address_books, apply_address_books = api.retrieve_address_books(service="CUST_ANO_INTERNAL")
    assert address_books == []
    assert apply_address_books == False
    address_books, apply_address_books = api.retrieve_address_books(service="CUST_PASSWD_AD")
    assert address_books == [
        {
            'ldap_svr_url': 'ldap://dc01-dev.keytalkdemo.local',
            'search_base': 'dc=keytalkdemo,dc=local'
        },
        {
            'ldap_svr_url': 'ldap://addressbook.example.com',
            'search_base': 'people,dc=example,dc=com'
        }
    ]
    assert apply_address_books

    try:
        api.retrieve_address_books(service="NON_EXISTING_SERVICE")
        assert False, "Non-existing service is not reported as bad request"
    except BadRequestError:
        pass


def retrieve_server_version():

    api = PublicApi()

    version = api.retrieve_server_version()
    assert re.match(r"(^\d.\d{1,2}\.\d{1,2})$", version) is not None


def check_smime_cert_enrollment_availability():

    api = PublicApi()

    # given, get the cert for which self-service should be available
    request_cert_with_password_authentication(conf.CERTKEY_FORMAT_PEM)
    cert = open('cert.pem').read()
    # when-then (default KeyTalk build lacks S/MIME cert enrollment setup and
    # that's pretty worksome to set it up by hand here, so just test for
    # error)
    available, reason = api.is_smime_cert_enrollment_available(cert)
    assert not available
    assert reason
    debug(reason)

    try:
        api.is_smime_cert_enrollment_available("invalid-cert")
        assert False, "Invalid certificate is not reported as bad request"
    except BadRequestError:
        pass


def should_cert_go_to_system_store():

    api = PublicApi()

    to_system_store = api.should_cert_go_to_system_store(service="CUST_ANO_INTERNAL")
    assert not to_system_store
    to_system_store = api.should_cert_go_to_system_store(service="CUST_ANO_REST")
    assert to_system_store

    try:
        api.should_cert_go_to_system_store("NON_EXISTING_SERVICE")
        assert False, "Non-existing service is not reported as bad request"
    except BadRequestError:
        pass


def retrieve_cert_expiry_margin():

    api = PublicApi()

    seconds_in_day = 24 * 60 * 60

    threshold_seconds = api.retrieve_cert_expiry_margin(service="CUST_ANO_REST")
    assert threshold_seconds == 7 * seconds_in_day, "Actual: {}".format(threshold_seconds)
    threshold_seconds = api.retrieve_cert_expiry_margin(service="CUST_ANO_REST", user="DemoUser")
    assert threshold_seconds == 14 * seconds_in_day, "Actual: {}".format(threshold_seconds)

    try:
        api.retrieve_cert_expiry_margin("NON_EXISTING_SERVICE")
        assert False, "Non-existing service is not reported as bad request"
    except BadRequestError:
        pass


def check_server_health():

    api = PublicApi()

    result = api.check_server_health()
    assert result == conf.PUBLIC_API_HEALTH_CHECK_RESULT_OPERATIONAL, result


def create_internal_ra_user():

    api = AdminApi()

    service = "CUST_ANO_INTERNAL"
    username = random_string()
    api.create_internal_ra_user(service, username)


def query_template_names_to_autorenew_seat_certs():

    api = PublicApi()

    templates = api.query_template_names_to_autorenew_seat_certs([
        "CUST_ANO_INTERNAL",
        "CUST_ANO_INTERNAL_GLOBALSIGN_ATLAS",
        "CUST_PASSWD_AZURE",
        "CUST_PASSWD_INTERNAL"
    ])
    assert templates == ["CUST_ANO_INTERNAL_GLOBALSIGN_ATLAS", "CUST_PASSWD_AZURE"]

    templates = api.query_template_names_to_autorenew_seat_certs([])
    assert templates == []

    try:
        api.query_template_names_to_autorenew_seat_certs(
            ["CUST_ANO_INTERNAL", "NON_EXISTING_SERVICE"])
        assert False, "Non-existing template is not reported as bad request"
    except BadRequestError:
        pass


def enroll_and_revoke_cert():

    service = "CUST_ANO_INTERNAL"
    existing_seat = "DemoUser"
    # add a random string to make sure the seat is fresh for each test run
    new_seat = "enroll_and_revoke_cert_test-" + random_string()

    api = AdminApi()

    # start clean
    api.revoke_certs(service, existing_seat)
    api.revoke_certs(service, new_seat)

    api.enroll_cert(service, existing_seat, san=[], is_new_seat=False)
    num_revoked_certs = api.revoke_certs(service, existing_seat)
    assert num_revoked_certs == 1, "Error revoking seat {} of service {}".format(
        existing_seat, service)
    num_revoked_certs = api.revoke_certs(service, existing_seat)
    assert num_revoked_certs == 0, "Error repeat revoking seat {} of service {}".format(
        existing_seat, service)

    api.enroll_cert(
        service,
        new_seat,
        san=[
            "DNS:test-keytalk-client.server.com",
            "IP:10.100.0.200"],
        is_new_seat=True)
    num_revoked_certs = api.revoke_certs(service, new_seat)
    assert num_revoked_certs == 1, "Error revoking seat {} of service {}".format(new_seat, service)
    num_revoked_certs = api.revoke_certs(service, new_seat)
    assert num_revoked_certs == 0, "Error repeat revoking seat {} of service {}".format(
        new_seat, service)


def enroll_and_revoke_cert_from_csr():

    service = "CUST_ANO_INTERNAL"
    existing_seat = "DemoUser"
    # add a random string to make sure the seat is fresh for each test run
    new_seat = "enroll_and_revoke_cert_from_csr_test-" + random_string()

    api = AdminApi()

    # start clean
    api.revoke_certs(service, existing_seat)
    api.revoke_certs(service, new_seat)

    csr_requirements = api.get_csr_enrollment_requirements(service, existing_seat)
    csr = AdminApi.gen_csr(csr_requirements)
    api.enroll_cert_for_csr(service, existing_seat, csr, is_new_seat=False)
    num_revoked_certs = api.revoke_certs(service, existing_seat)
    assert num_revoked_certs == 1
    num_revoked_certs = api.revoke_certs(service, existing_seat)
    assert num_revoked_certs == 0

    csr_requirements = api.get_csr_enrollment_requirements(service, new_seat)
    csr = AdminApi.gen_csr(csr_requirements)
    api.enroll_cert_for_csr(service, new_seat, csr, is_new_seat=True)
    num_revoked_certs = api.revoke_certs(service, new_seat)
    assert num_revoked_certs == 1
    num_revoked_certs = api.revoke_certs(service, new_seat)
    assert num_revoked_certs == 0


def download_server_settings():
    api = AdminApi()

    settings = api.download_server_settings(include_shared_settings=True,
                                            include_db_connection_settings=True,
                                            include_hsm_connection_settings=True,
                                            include_keytalk_cert_tree=True,
                                            include_webapi_server_settings=True,
                                            include_aad_certwriter_settings=True)
    with open('settings.dat', 'wb') as f:
        f.write(settings)
    assert tarfile.is_tarfile('settings.dat')


def copy_template():
    api = AdminApi()

    new_template_name = "CUST_ANO_INTERNAL-" + random_string()
    assert new_template_name not in api.list_templates()
    api.copy_template("CUST_ANO_INTERNAL", new_template_name)
    assert new_template_name in api.list_templates()

    new_template_name = "CUST_ANO_INTERNAL_DIGICERT_CERTCENTRAL-" + random_string()
    assert new_template_name not in api.list_templates()
    api.copy_template("CUST_ANO_INTERNAL_DIGICERT_CERTCENTRAL",
                      new_template_name,
                      digicert_settings={
                          conf.ADMIN_API_REQUEST_PARAM_NAME_PRODUCT: "ssl_ev_basic",
                          conf.ADMIN_API_REQUEST_PARAM_NAME_API_KEY: "AABBCCDD",
                          conf.ADMIN_API_REQUEST_PARAM_NAME_CERT_VALIDITY_MONTHS: 12,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_ORGANIZATION_ID: 123456,
                          conf.ADMIN_API_REQUEST_PARAM_NAME_APPROVER_USER_ID: 234567,
                      }
                      )
    assert new_template_name in api.list_templates()


def create_digicert_dv_acme_template():
    api = AdminApi()
    new_template_name = "CUST_DIGICERT_DV_ACME_-" + random_string()
    assert new_template_name not in api.list_templates()
    acme_url = api.create_digicert_dv_acme_template(
        new_template_name,
        digicert_account_region="US",
        digicert_product="ssl_dv_thawte",
        digicert_api_key="123456",
        mail_fetch_proto="o365-imap",
        cert_validity_months=6,
        azure_client_id="111222",
        azure_client_secret="secret",
        azure_tenant_id="333444",
        approver_email="approver@mydomain.com")
    assert acme_url == f"http://demo.keytalkdemo.com/acme/directory?keytalk-template-name={new_template_name}"
    assert new_template_name in api.list_templates()


def list_templates():
    api = AdminApi()
    templates = api.list_templates()
    assert len(templates) > 20
    assert "CUST_ANO_INTERNAL" in templates
    assert "CUST_PASSWD_INTERNAL" in templates


def create_and_remove_seat():
    # given, create new seat by enrolling a cert
    service = "CUST_ANO_INTERNAL"
    api = AdminApi()
    new_seat = "NewSeat-" + random_string()
    assert api.create_seat(
        service,
        new_seat,
        cn='original-cn',
        san=[
            'DNS:original.test.com',
            'DNS:original2.test.com']) == "created"
    assert api.create_seat(
        service,
        new_seat,
        cn='updated-cn',
        san=[
            'DNS:updated.test.com',
            'DNS:updated2.test.com']) == "updated"
    assert api.remove_seat(service, new_seat)
    assert not api.remove_seat(service, new_seat)


def fetch_mfa_web_popups_settings():
    service = "MFA_WEB_POPUPS"
    client_id = "ba6c5e63-1e29-4037-a209-a7a74724ad68"
    authority = "https://login.microsoftonline.com/2fe4e8d6-472e-4a31-9269-3c7fa639c9d1"
    scopes = ["User.Read.All", "Mail.Read"]
    redirect_uri = "https://demo.keytalkdemo.com/provisioncert?template-name=3093ac9b97e53a8044d1f994fbb8ded53befbcd63d21992b92d74cf8f85a4eaf"
    post_log_out_uri = "https://demo.keytalkdemo.com/provisioncert?template-name=3093ac9b97e53a8044d1f994fbb8ded53befbcd63d21992b92d74cf8f85a4eaf"

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # fetch auth requirements
    auth_requirements = api.get_service_auth_requirements(service)
    # Check MFA settings
    assert CertRetrievalApi.is_mfa_authentication(
        auth_requirements), "OTP/MFA authentication is expected for service {}".format(service)
    mfa = auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_SETTINGS]
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_CLIENT_ID] == client_id
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_KIND] == conf.MFA_KIND_POPUPS
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_AUTHORITY] == authority

    for scope in scopes:
        assert scope in mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_SCOPES]

    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_REDIRECT_URI] == redirect_uri
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_POST_LOG_OUT_URI] == post_log_out_uri


def fetch_mfa_web_redirection_settings():
    service = "MFA_WEB_REDIRECTION"
    client_id = "ba6c5e63-1e29-4037-a209-a7a74724ad68"
    authority = "https://login.microsoftonline.com/2fe4e8d6-472e-4a31-9269-3c7fa639c9d1"
    scopes = ["User.Read.All", "Mail.Read"]
    redirect_uri = "https://demo.keytalkdemo.com/provisioncert?template-name=de1d87ed866da6f5a4dbace085ab8d09418d3b88a5fb6d4ae8c8a87664d5942b"
    post_log_out_uri = "https://demo.keytalkdemo.com/provisioncert?template-name=de1d87ed866da6f5a4dbace085ab8d09418d3b88a5fb6d4ae8c8a87664d5942b"

    api = CertRetrievalApi()
    # handshake
    api.hello()
    api.handshake()
    # fetch auth requirements
    auth_requirements = api.get_service_auth_requirements(service)
    # Check MFA settings
    assert CertRetrievalApi.is_mfa_authentication(
        auth_requirements), "OTP/MFA authentication is expected for service {}".format(service)
    mfa = auth_requirements[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_SETTINGS]
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_CLIENT_ID] == client_id
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_KIND] == conf.MFA_KIND_REDIRECTION
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_AUTHORITY] == authority

    for scope in scopes:
        assert scope in mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_SCOPES]

    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_REDIRECT_URI] == redirect_uri
    assert mfa[conf.RCDPV2_RESPONSE_PARAM_NAME_MFA_POST_LOG_OUT_URI] == post_log_out_uri


def fetch_ktagent_unattended_popup_msg_template():
    api = PublicApi()
    # when
    msg_template = api.fetch_ktagent_unattended_popup_msg_template()
    # then
    assert msg_template


def fetch_keep_alive_interval():
    api = PublicApi()
    # when
    intvl = api.fetch_keep_alive_interval("CUST_ANO_INTERNAL")
    # then
    assert intvl == "3600s"

    # when
    intvl = api.fetch_keep_alive_interval("CUST_PASSWD_INTERNAL")
    # then
    assert intvl == ""


def send_keep_alive():
    api = PublicApi()
    # when-then
    api.send_keep_alive("CUST_ANO_INTERNAL", "DemoUser", "demouser.test.com")

    # when-then
    try:
        api.send_keep_alive("CUST_ANO_INTERNAL", "non-existing-user", "demouser.test.com")
        assert False, "Sending keep-alive for non-existing user is not reported as bad request"
    except BadRequestError:
        pass


#
# Entry point
#
if __name__ == "__main__":
    for cert_format in supported_cert_formats():
        request_cert_with_password_authentication(cert_format)
        request_cert_with_password_and_pincode_authentication(cert_format)
        request_cert_with_challenge_response_authentication(cert_format)
        request_smb_certs_with_seat_authentication(cert_format)
        renew_cert_with_seat_authentication(cert_format)
        request_out_of_band_cert_with_password_authentication(cert_format)
        request_out_of_band_smb_certs_with_seat_authentication(cert_format)
        change_password_and_request_cert(cert_format)

    request_cert_from_csr_with_password_authentication()
    fetch_ca_certs()
    check_self_service_availability()
    retrieve_address_books()
    retrieve_server_version()
    check_smime_cert_enrollment_availability()
    should_cert_go_to_system_store()
    query_template_names_to_autorenew_seat_certs()
    enroll_and_revoke_cert()
    enroll_and_revoke_cert_from_csr()
    download_server_settings()
    list_templates()
    copy_template()
    create_digicert_dv_acme_template()
    retrieve_cert_expiry_margin()
    check_server_health()
    create_internal_ra_user()
    create_and_remove_seat()
    fetch_mfa_web_popups_settings()
    fetch_mfa_web_redirection_settings()
    fetch_ktagent_unattended_popup_msg_template()
    fetch_keep_alive_interval()
    send_keep_alive()
