# -*- coding:utf-8 -*-
import logging
import re
import hashlib
import time
import binascii
from tempfile import mkstemp
from os import unlink, system, sep, listdir, close, open as osopen, fdopen, O_WRONLY, O_CREAT
from os.path import basename, dirname, isdir, join, exists
from subprocess import getoutput, getstatusoutput
from pyeole.process import system_out
import subprocess

from creole import cert

from formencode import validators, foreach
from sqlalchemy.exc import IntegrityError
#from arv.config import id2sql_path
#from arv.config import bin2sql_path
from arv.config import ssl_dir
from arv.config import vpn_path
from arv.lib.logger import logger

cert.load_default_conf_if_needed()

def trace(hide_args=None, hide_kwargs=None):
    """This is a decorator which can be used to trace functions calls

    It can replace some positional and/or keyword arguments with some
    'X' if they are present.

    @param hide_args: List of positional argument indexes to replace if present
    @type hide_args: C{list}
    @param hide_kwargs: List of keyword argument names to replace if present
    @type hide_kwargs: C{list}
    """
    def tracedec(func):
        def newFunc(*args, **kwargs):
            # Do nothing if debug is not enabled
            if logger.isEnabledFor(logging.DEBUG):
                # Copy arguments
                args_list = list(args)
                args_dict = kwargs.copy()
                if hide_args is not None:
                    for index in hide_args:
                        if index < len(args_list):
                            args_list[index] = 'XXXXXXXX'
                if hide_kwargs is not None:
                    for keyname in hide_kwargs:
                        if keyname in args_dict:
                            args_dict[keyname] = 'XXXXXXXX'
                logger.debug( "-> entering %s(%s, %s)" % (func.__name__, str(args_list), str(args_dict)) )
            return func(*args, **kwargs)

        newFunc.__name__ = func.__name__
        newFunc.__doc__ = func.__doc__
        newFunc.__dict__.update(func.__dict__)
        return newFunc
    return tracedec

@trace()
def normalize_unicode(string):
    if not isinstance(string, str):
        raise TypeError( "unsupported encoding")
    return string

@trace()
def valid(value, typ):
    """
        formencode validation
    """
    if isinstance(value, bytes):
        value = value.decode()
    if typ == 'string':
        validator = validators.String()
    if typ == 'bool':
        validator = validators.StringBoolean()
    if typ == 'integer':
        validator = validators.Int()
    if typ == 'enum':
        #FIXME : nothing is done in this case by now
        #validator = foreach.ForEach()
        return value
    if typ == 'ip':
        validator = validators.IPAddress()
    val = validator.to_python(str(value))
    if typ == 'string' or typ == 'ip':
        val = normalize_unicode(val)
    return val
# ____________________________________________________________
@trace()
def try_unique_column(func_name, function, **args):
    try:
        return function(**args)
    except IntegrityError as e:
        if " column name is not unique " in str(e):
            raise Exception("Name should be unique")
        else:
            raise Exception("error in %s: %s" % (func_name, str(e)))
    except Exception as e:
        raise Exception("error in %s: %s" % (func_name, str(e)))

# ____________________________________________________________
#
@trace()
def get_scndline(lines):
    """parse a two lines (type, name) from a command output
    typically :

      type    encoding
      2,  X'736466736466'
    """
    logger.debug("scndline : " + lines)
    second_line = lines.split('\n')[1]
    keyid = second_line.split(',')[1]
    encoded_string = keyid.strip()
    #encoded_string =  encoded_string.replace("X'", "")
    #encoded_string = encoded_string.replace("'", "")
    if encoded_string.startswith("X'") and encoded_string.endswith("'"):
        return encoded_string[2:-1]
    else:
        raise TypeError("unexpected encoded string: {0}".format(encoded_string))

@trace()
def get_lastline(lines):
    """parse lines (type, name) from a command output
    typically :

    writing RSA key
    parsed 2048 bits RSA private key.
    subjectPublicKeyInfo keyid: b0:71:fb:0a:62:f7:8d:7b:9d:35:d7:c9:4d:12:f5:d5:51:e6:db:da
    subjectPublicKey keyid:     0e:5d:33:43:15:b0:f2:a0:0b:b6:6b:f3:24:44:6a:f5:08:91:da:0d

    """
    logger.debug("lastline : " + lines)
    lines = lines.split('\n')
    last_line = lines[-1]
    lines.pop()
    penultimate = lines[-1]
    subjkey = penultimate.replace('subjectPublicKeyInfo keyid:', '')
    # impossible : so much semicolumnsi in the keyid itself
    # keyid = last_line.split(':')[1]
    keyid = last_line.replace('subjectPublicKey keyid:', '')
    return (subjkey.strip(), keyid.strip())

@trace()
def bin_encoding(clear):
    """encode in binary for strongswan database
    """
    clear = clear.encode()
    # WARNING : THIS DOESN'T REPLACE ID2SQL SCRIPT.
    # THIS IS JUST TO MAKE ARV DATABASE BUILDNING POSSIBLE WITHOUT ID2SQL
    # STRONGSWAN DATABASE MODE SHOULDN'T WORKS WITH THAT
    result = binascii.hexlify(clear)
    #encode and decode. this functionality already exists with the encodings library (which is built-in)
    # have to decode to put the normal string in a sqlalchemy binary type
    logger.debug("#-> bin_encoding result var: " + result.decode())
    return result

@trace()
def suppress_colon(keyid):
    """suppress colons in a string
    """
    if isinstance(keyid, bytes):
        keyid = keyid.decode()
    return bytes.fromhex(keyid.replace(':', ''))

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_in_certif(certif_name, passwd=None, certiftype='ca'):
    """takes the long name (ex: /var/lib/arv/CA/certs/CaCert.pem) of the certificate
    and returns the keyid
    """
    certif_path = dirname(certif_name)
    certif_name = basename(certif_name)
    certif_path = dirname(certif_path)
    certif_path = join(certif_path, 'private')
    if certiftype == 'ca':
        priv_certname = cert.ca_keyfile
    else:
        priv_certname = join(certif_path, 'priv-'+certif_name)
    return get_keyid_from_keyid_in_certif(priv_certname, passwd, mode='rsa', ca=certiftype=='ca')

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_from_keyid_in_certif(certif_name, passwd=None, mode='rsa', ca=False):
    """constructs a keyid when a keyid is not present in the certificate
    """
    if mode == 'rsa':
        if password_OK(certif_name, passwd):
            #Extract from private key
            sha1_cmd = ['openssl', 'sha1', '-c']
            rsa_cmd = ['openssl', 'rsa', '-in', certif_name,
                       '-passin', 'stdin', '-outform', 'DER']

            pub_info_cmd = rsa_cmd + ['-pubout']
            code, pub_info_der, stderr = system_out(pub_info_cmd, stdin=passwd, to_str=False)

            pub_key_cmd = rsa_cmd + ['-RSAPublicKey_out']
            code, pub_key_der, stderr = system_out(pub_key_cmd, stdin=passwd, to_str=False)

            code, pub_info_fpr, stderr = system_out(sha1_cmd, pub_info_der)
            code, pub_key_fpr, stderr = system_out(sha1_cmd, pub_key_der)

            subjkey = pub_info_fpr.split()[-1].strip()
            keyid = pub_key_fpr.split()[-1].strip()
            return (subjkey, keyid)
        else:
            raise Exception('Invalid password')
    elif mode == 'x509':
        #Extract from credential
        if ca:
            ext = 'subjectKeyIdentifier'
        else:
            ext = 'authorityKeyIdentifier'
        cmd = 'openssl {0} -in "{1}" -noout -ext {2} | tail -n 1'.format(mode, certif_name, ext)
        output = getoutput(cmd)
        keyid = output.strip()
        logger.debug('{0}'.format(keyid))
        subjkey = keyid
    return (subjkey, keyid)

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_from_certifstring(certif_string, passwd=None, mode='rsa', ca=False):
    """constructs a keyid when a keyid is not present in the certificate
    """
    try:
        fd, certif_name = mkstemp()
        fh = open(certif_name, 'wb')
        if isinstance(certif_string, str):
            certif_string = certif_string.encode()
        fh.write(certif_string)
        fh.close()
        subjkey, keyid = get_keyid_from_keyid_in_certif(certif_name, passwd, mode=mode, ca=ca)
        close(fd)
        unlink(certif_name)
        return (subjkey, keyid)
    except Exception as e:
        msg = 'Cannot generate keyid in certificate : %s'% str(e)
        logger.warning(msg)
        raise Exception(msg)

@trace(hide_args=[1], hide_kwargs=['password'])
def password_OK(private_key, password):
    """Test private_key password validity
    """
    if isinstance(password, str):
        password = password.encode()
    if password:
        cmd = ["openssl", "rsa",  "-in", private_key, "-passin", "stdin"]
    else:
        cmd = ["openssl", "rsa",  "-in", private_key]
    if password:
        logger.debug("PASSWORD : {0}".format(password.decode()))
    else:
        logger.debug("PAS DE PASSWORD")
    process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
    if password:
        process.communicate(input=password)
    retcode = process.wait()
    if retcode == 0:
        return True
    else:
        return False

@trace(hide_args=[1], hide_kwargs=['passwd'])
def decrypt_privkey(privkey_string, passwd):
    """Suppress password from private key
    """
    fd, privkey_filename = mkstemp()
    fh = open(privkey_filename, 'wb')
    if isinstance(privkey_string, str):
        privkey_string = privkey_string.encode()
    fh.write(privkey_string)
    fh.close()
    close(fd)
    if isinstance(passwd, str):
        passwd = passwd.encode()
    cmd = ["openssl", "rsa",  "-in", privkey_filename, "-passin", "stdin"]
    process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
    openssl_output = process.communicate(input=passwd)[0]
    retcode = process.wait()
    unlink(privkey_filename)
    if retcode:
        raise Exception('Unable to decrypt private key, check password')
    #lines = openssl_output.decode().split('\n')
    #decrypted_key = "\n".join(lines)
    #return decrypted_key
    return openssl_output

def get_req_archive(name):
    """Download credential request to send to CA
    """
    reqname = name +".p10"
    privkeyname = "priv-"+name+".pem"
    reqfilename = join(ssl_dir, "req",  reqname)
    privkeyfilename = join(ssl_dir, "private", privkeyname)
    system("""cd {0}
mv {1} {0}
mv {2} {0}
tar -czf {3}.tgz {4} {5}
rm {4} {5}
""".format(ssl_dir, reqfilename, privkeyfilename, name, reqname, privkeyname))
    tarfname = join(ssl_dir, name+".tgz")
    fh = open(tarfname, 'rb')
    content = fh.read()
    fh.close()
    unlink(tarfname)
    return content

@trace()
def gen_archive_name(uai, name):
    """
    """
    if isinstance(name, bytes):
        name = name.decode()
    amonpath = vpn_path + sep + str(uai)
    archivename = str(uai) + "-" + name + ".tar.gz"
    return amonpath, archivename

@trace()
def der_to_pem(certificate):
    """
    Convert certificate to PEM format
    Returns tuple : (cert_type, pem_cert)
    """
    if isinstance(certificate, str):
        certificate = certificate.encode()
    cert_types = ["x509", "pkcs7"]
    cert_formats = ["DER", "PEM"]
    for cert_type in cert_types:
        for cert_format in cert_formats:
            cmd = ["openssl", cert_type, "-inform", cert_format]
            with subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) as process:
                output, err = process.communicate(input=certificate)
                if process.returncode == 0:
                    return cert_type, output
    raise Exception(f'unable to convert certificat in PEM format: {output}, {err}')

@trace()
def is_ca(certificate):
    """
    certificate is in PEM format
    Returns True if certificate is a CA, False if not
    """
    if isinstance(certificate, str):
        certificate = certificate.encode()
    cmd = ["openssl", "x509", "-noout", "-text"]
    with subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) as process:
        decoded_certificate, err = process.communicate(input=certificate)
        if process.returncode !=0:
            raise Exception(f"Error in execution 'openssl x509' command: {err}")
    match =  re.match(rb'.*CA:TRUE.*', decoded_certificate, re.S)
    return match != None

class CertChain:
    def __init__(self):
        self.signed_by = {}
        self.certs = {}

    def add_cert(self, certif):
        cert_subject = cert.get_subject(certif)[1]
        if cert_subject not in self.certs:
            self.certs[cert_subject] = certif
            cert_issuer = cert.get_issuer_subject(certif)[1]
            self.signed_by[cert_subject] = cert_issuer
        
    def add_certs(self, certifs):
        for certif in certifs:
            self.add_cert(certif)

    def get_chain_by_subject(self, leaf_certif_subject, subject_only=False):
        if not leaf_certif_subject in self.certs:
            raise Exception('Unknown certificate')
        chain = [leaf_certif_subject]
        subject = leaf_certif_subject

        while self.signed_by.get(subject, None):
            issuer = self.signed_by[subject]
            if subject == issuer:
                break
            if self.certs.get(issuer, None):
                chain.append(self.signed_by[subject])
            else:
                break
            subject = issuer
        chain.reverse()
        if not subject_only:
            chain = [self.certs[subject] for subject in chain]
        return chain

    def get_chain_by_b64(self, leaf_certif, subject_only=False):
        subject = cert.get_subject(certif)[1]
        return self.get_chain_by_subject(subject, subject_only=subject_only)

    def get_leaf_certificate(self):
        leaf_certificates = [c for c in self.certs if c not in self.signed_by.values()]
        if len(leaf_certificates) != 1:
            raise Exception('Ambiguous or gapped certificate chain')
        return leaf_certificates[0]


@trace()
def split_pkcs7(pkcs7_cred):
    """Split pkcs7 string in two strings ca_cred and cred
        returns :
         * ca certificates into a string list in PEM format
           sorted first root CA certificate to intermediates CAs
         * end-user certificate
    """
    try:
        cmd = ["openssl", "pkcs7", "-print_certs"]
        if isinstance(pkcs7_cred, str):
            pkcs7_cred = pkcs7_cred.encode()
        if not pkcs7_cred.startswith(b'-----BEGIN PKCS7-----'):
            cmd.extend(['-inform', 'der'])
        with subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) as process:
            creds_chain, err = process.communicate(input=pkcs7_cred)
            if process.returncode != 0:
                raise Exception(f"Error in execution 'openssl pkcs7' command: {err}")
        # Append cert_list
        pem_cert_re = re.compile(r'^(-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----)', re.M|re.S)
        cert_list = pem_cert_re.findall(creds_chain.decode())
        cert_chain = CertChain()
        cert_chain.add_certs(cert_list)
        chain = cert_chain.get_chain_by_subject(cert_chain.get_leaf_certificate())
        if is_ca(chain[-1]):
            ca_chain = chain
            machine_certificate = None
        else:
            ca_chain = chain[:-1]
            machine_certificate = chain[-1]
        return ca_chain, machine_certificate
    except UnicodeDecodeError as e:
        raise Exception("Problème de lecture du fichier : %s" % e)
    except Exception as e:
        logger.error('Cannot read pkcs7: {0}'.format(e))
        raise Exception('Cannot read pkcs7: {0}'.format(e))

@trace(hide_args=[1, 2], hide_kwargs=['pkcs12_password', 'key_passphrase'])
def split_pkcs12(pkcs12_certificate, pkcs12_password, key_passphrase):
    """returns PKCS7 certificate and encrypted private key from a PKCS12 certificate
    """
    try:
        pkcs12_pwd_file = "/root/p12_pwd"
        if exists(pkcs12_pwd_file):
            unlink(pkcs12_pwd_file)
        fh = fdopen(osopen(pkcs12_pwd_file, O_WRONLY|O_CREAT, 0o400), 'w')
        if isinstance(pkcs12_password, bytes):
            pkcs12_password = pkcs12_password.decode()
        fh.write(pkcs12_password)
        fh.close()
        key_pwd_file = "/root/key_pwd"
        if exists(key_pwd_file):
            unlink(key_pwd_file)
        fh = fdopen(osopen(key_pwd_file, O_WRONLY|O_CREAT, 0o400), 'w')
        if isinstance(key_passphrase, bytes):
            key_passphrase = key_passphrase.decode()
        fh.write(key_passphrase)
        fh.close()
        cmd = 'openssl pkcs12 -passin file:{0} -passout file:{1}'.format(pkcs12_pwd_file, key_pwd_file).split()
        process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
        process.stdin.write(pkcs12_certificate)
        ret_code = process.wait()
        if ret_code != 0:
            raise Exception("Error in execution 'openssl pkcs12' command")
        pkcs12_chain = process.stdout.read()
        if isinstance(pkcs12_chain, bytes):
            pkcs12_chain = pkcs12_chain.decode()
        logger.debug('PKCS12 command result' + pkcs12_chain)
        unlink(pkcs12_pwd_file)
        unlink(key_pwd_file)
        # isolate private key from pkcs12_chain
        privkey_pattern = re.compile ( '-----BEGIN ENCRYPTED PRIVATE KEY-----' '.*?' '-----END ENCRYPTED PRIVATE KEY-----', re.DOTALL)
        private_key = privkey_pattern.findall(pkcs12_chain)[0]
        # isolate certificates from pkcs12_chain
        certs_pattern = re.compile ( '-----BEGIN CERTIFICATE-----' '.*?' '-----END CERTIFICATE-----', re.DOTALL)
        certs_chain = '\n'.join(certs_pattern.findall(pkcs12_chain))
        # generate pkcs7 certificate
        fd, tmp_certs_chain_file = mkstemp()
        fh = open(tmp_certs_chain_file, 'w')
        if isinstance(certs_chain, bytes):
            certs_chain = certs_chain.decode()
        fh.write(certs_chain)
        fh.close()
        cmd = 'openssl crl2pkcs7 -nocrl -certfile {0}'.format(tmp_certs_chain_file).split()
        process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
        ret_code = process.wait()
        if ret_code != 0:
            raise Exception("Error in execution 'openssl crl2pkcs7' command")
        pkcs7_certs = process.stdout.read()
        unlink(tmp_certs_chain_file)
        return private_key, pkcs7_certs
    except Exception as e:
        raise Exception('Cannot read pkcs12: {0}'.format(e))

@trace()
def extract_crls_from_certifstring(certifstring):
    """Extract crl from certif string
    """
    try:
        cmd = 'echo "{0}"|openssl x509 -text -noout|grep "URI"'.format(certifstring)
        errcode, output = getstatusoutput(cmd)
        output = output.split('\n')
        crl = []
        crlbegin = False
        for line in output:
            if "CRL Distribution Points:" in line:
                crlbegin = True
                continue
            if crlbegin:
                try:
                    crl.append(line.split('URI:')[1])
                except:
                    break
        return crl
    except:
        logger.warning("No crl found in credential")
        pass

def extract_CN_from_certifstring(certifstring):
    """Extract CommonName from certif string
    """
    return cert.get_subject(certifstring)[1]

@trace()
def ipsec_running():
    """
    """
    cmd = 'ipsec status > /dev/null'
    errcode, output = getstatusoutput(cmd)
    if errcode == 0:
        return True
    else:
        return False

@trace()
def ipsec_stop():
    """
    """
    ipsec_stop = "service strongswan-starter stop > /dev/null"
    errcode, output = getstatusoutput(ipsec_stop)
    return errcode

@trace()
def ipsec_start():
    """
    """
    ipsec_start = "service strongswan-starter start > /dev/null"
    errcode, output = getstatusoutput(ipsec_start)
    return errcode

@trace()
def ipsec_restart():
    """
    """
    errcode = ipsec_stop()
    if errcode != 0:
        return errcode
    errcode = ipsec_start()
    return errcode

@trace()
def ipsec_down(connstring):
    """
    """
    cmd = 'ipsec down "connstring" >/dev/null'
    errcode, output = getstatusoutput(cmd)

@trace()
def ipsec_up(connstring):
    """
    """
    cmd = 'ipsec up "connstring" >/dev/null'
    errcode, output = getstatusoutput(cmd)

@trace()
def purge_file(filename):
    """
    """
    fh = open(filename,"w")
    fh.write('')
    fh.close()

@trace()
def fill_file(from_filename, to_filename):
    """
    Fill content of from_filename into to_filename without suppress to_filename
    """
    from_fd = open(from_filename,"rb")
    content = from_fd.read()
    purge_file(to_filename)
    to_fd = open(to_filename, "w")
    if isinstance(content, bytes):
        content = content.decode()
    to_fd.write(content)
    from_fd.close()
    to_fd.close()

@trace()
def md5(filename, sw_database_mode='True'):
    """Compute md5 hash of the specified file"""
    if sw_database_mode == 'True':
        logger.debug('database mode True')
        cmd = '/usr/bin/sqlite3 "%s" ".dump"' % filename
        errcode, content = getstatusoutput(cmd)
        if errcode != 0:
            logger.debug("Unable to open the file in readmode:" + filename)
            return
    else:
        logger.debug('database mode False')
        fd = open(filename, "rb")
        content = fd.read()
        fd.close()
    m = hashlib.md5()
    m.update(content)
    return m.hexdigest()

@trace(hide_args=[2], hide_kwargs=['passwd'])
def valid_priv_and_cred(private_key, credential, passwd):
    """Test private key and credential compatibility
    """
    fd, privkey_filename = mkstemp()
    fh = open(privkey_filename, 'wb')
    fh.write(private_key)
    fh.close()
    close(fd)
    if isinstance(passwd, str):
        passwd = passwd.encode()
    cmd = ["openssl", "rsa",  "-in", privkey_filename, "-noout", "-modulus", "-passin", "stdin"]
    process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
    privkey_modulus = process.communicate(input=passwd)[0]
    retcode = process.wait()
    unlink(privkey_filename)
    if retcode == 0:
        fd, certif_filename = mkstemp()
        fh = open(certif_filename, 'wb')
        if isinstance(credential, str):
            credential = credential.encode()
        fh.write(credential)
        fh.close()
        close(fd)
        cmd = ["openssl", "x509",  "-in", certif_filename, "-noout", "-modulus"]
        process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
        cred_modulus = process.communicate()[0]
        retcode = process.wait()
        unlink(certif_filename)
        if retcode == 0:
            return privkey_modulus == cred_modulus
    return False

@trace()
def cred_end_validity_date(credential):
    """Return credential end validity date
        from "Mar 22 09:32:39 2015 GMT" format
        to "22/03/2015" format
    """
    cred_cmd = 'echo "{0}" | openssl x509 -noout -dates|grep ^"notAfter="'.format(credential.decode())
    errcode, not_after_date = getstatusoutput(cred_cmd)
    if errcode == 0:
        try:
            exp_date =  not_after_date.split("=")[1]
            conv = time.strptime(exp_date, "%b %d %H:%M:%S %Y GMT")
            return time.strftime("%Y/%m/%d", conv)
        except:
            return not_after_date
    return None

def escape_special_characters(text, characters='\\\'"'):
    """Escape special characters
    """
    if isinstance(text, bytes):
        text = text.decode()
    for character in characters:
        text = text.replace( character, '\\' + character )
    return text

def clean_directory(d):
    """Recursively cleans directory without deleting tree
    """
    for f in listdir(d):
        full_path = join(d, f)
        if isdir(full_path):
            clean_directory(full_path)
        else:
            unlink(full_path)

def remove_special_characters(text, characters='\\\'"'):
    """Remove special characters
    """
    for character in characters:
        text = text.replace(character, '')
    return text

