"""
Utils to run commands on remote server

Copyright (C) 2010-2020 Cr@ns <roots@crans.org>
Authors : Daniel Stan <daniel.stan@crans.org>
          Vincent Le Gallic <legallic@crans.org>
          Alexandre Iooss <erdnaxe@crans.org>
SPDX-License-Identifier: GPL-3.0-or-later
"""

import base64
import json
import logging
from functools import lru_cache
from getpass import getpass
from hashlib import sha1, sha256
from pathlib import Path

from dns import flags, resolver
from paramiko.client import MissingHostKeyPolicy, SSHClient
from paramiko.config import SSHConfig
from paramiko.ssh_exception import (
    AuthenticationException,
    PasswordRequiredException,
    SSHException,
)

from .locale import _

# Local logger
log = logging.getLogger(__name__)


AUTHENTICITY_MSG = """
The authenticity of host '{host}' can't be established.
{ktype} key fingerprint is SHA256:{kfp}.{dnsfp}
Are you sure you want to continue connecting (yes/no)?
"""

_key_algorithms = {
    "ssh-rsa": "1",
    "ssh-dss": "2",
    "ecdsa-sha2-nistp256": "3",
    "ecdsa-sha2-nistp384": "3",
    "ecdsa-sha2-nistp521": "3",
    "ssh-ed25519": "4",
}

_hash_funcs = {
    "1": sha1,
    "2": sha256,
}


class AskUserOrDNSPolicy(MissingHostKeyPolicy):
    """
    Policy for automatically trusting DNSSEC authenticated fingerprint
    or asking the user for unknown hostname & key. This is used by `.SSHClient`
    """

    def missing_host_key(self, client, hostname, key):
        kfp256 = sha256(key.asbytes()).digest()
        kfp = base64.b64encode(kfp256).decode("utf-8").replace("=", "")
        ktype = key.get_name().upper().replace("SSH-", "")
        kalg = _key_algorithms.get(key.get_name())
        kres = resolver.Resolver()
        kres.use_edns(True, flags.DO, 1280)
        dnssec = False
        found = False
        dnsfp = ""
        try:
            kans = kres.query(hostname, "SSHFP")
            if kans.response.flags & flags.AD:
                dnssec = True
            for rdata in kans:
                try:
                    alg, fptype, fp = rdata.to_text().split()
                except ValueError:  # Invalid SSHFP record format, don't care.
                    pass
                if alg != kalg:
                    continue
                if fptype not in _hash_funcs:
                    continue
                expected = _hash_funcs.get(fptype)(key.asbytes()).hexdigest()
                if expected == fp:
                    found = True
                    break
        except resolver.NoAnswer:  # Can't find SSHFP for this host.
            pass
        if found and dnssec:
            log.debug(
                "Authentic {} host key for {} found in DNS".format(
                    key.get_name(), hostname
                )
            )
            return
        if found:
            dnsfp = "\nMatching host key fingerprint found in DNS."
        inp = input(
            AUTHENTICITY_MSG.format(host=hostname, ktype=ktype, kfp=kfp, dnsfp=dnsfp)
        )
        if inp not in ["yes", "y", ""]:
            log.debug("Rejecting {} host key for {}: {}".format(ktype, hostname, kfp))
            raise SSHException(
                "Connection to {!r} rejected by the user".format(hostname)
            )
        log.debug("Accepting {} host key for {}: {}".format(ktype, hostname, kfp256))


@lru_cache()
def create_ssh_client(host, password=None):
    """
    Create a SSH client with paramiko module
    """
    # Create SSH client with system host keys and agent
    client = SSHClient()
    client.set_missing_host_key_policy(AskUserOrDNSPolicy)

    # Load config file and use the right username
    try:
        config = SSHConfig()
        config.parse(Path.home().joinpath(".ssh/config").open())
        username = config.lookup(host).get("user", None)
    except FileNotFoundError:
        username=None

    # Load system private keys
    client.load_system_host_keys()

    # Connect
    try:
        client.connect(host, username=username, password=password)
    except PasswordRequiredException:
        password = getpass("SSH password: ")
        return create_ssh_client(host, password)
    except AuthenticationException:
        log.error(_("SSH authentication failed."))
        exit(1)
    except SSHException:
        log.error(_("An error occured during SSH connection, debug with -vv"))
        raise

    return client


def remote_command(options, command, arg=None, stdin_contents=None):
    """
    Execute remote command and return output
    """
    if "host" not in options.serverdata:
        log.error("Missing parameter `host` in active server configuration")
        exit(1)
    client = create_ssh_client(str(options.serverdata['host']))

    # Build command
    if "remote_cmd" not in options.serverdata:
        log.error("Missing parameter `remote_cmd` in active server configuration")
        exit(1)
    remote_cmd = options.serverdata['remote_cmd'] + " " + command
    if arg:
        remote_cmd += " " + arg

    # Run command and timeout after 10s
    log.info(_("Running command `%s`") % remote_cmd)
    stdin, stdout, stderr = client.exec_command(remote_cmd, timeout=10)

    # Write
    if stdin_contents is not None:
        log.info(_("Writing to stdin: %s") % stdin_contents)
        stdin.write(json.dumps(stdin_contents))
        stdin.flush()

    # If the server is not expected to exit, then exit now
    if stdin_contents:
        return

    # Return code == 0 if success
    ret = stdout.channel.recv_exit_status()
    if ret != 0:
        err = ""
        if stderr.channel.recv_stderr_ready():
            err = stderr.read()
        log.error(_("Wrong server return code %s, error is %s") % (ret, err))
        exit(ret)

    # Decode directly read buffer
    try:
        answer = json.load(stdout)
    except ValueError:
        log.error(_("Error while parsing JSON"))
        exit(42)

    log.debug(_("Server returned %s") % answer)
    return answer