Skip to content
Snippets Groups Projects
remote.py 5.82 KiB
Newer Older
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
"""
Utils to run commands on remote server

Copyright (C) 2010-2020 Cr@ns <roots@crans.org>
Authors : Daniel Stan <daniel.stan@crans.org>
          Vincent Le Gallic <legallic@crans.org>
          Alexandre Iooss <erdnaxe@crans.org>
SPDX-License-Identifier: GPL-3.0-or-later
"""

me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
import json
import logging
from functools import lru_cache
from getpass import getpass
from hashlib import sha1, sha256
from pathlib import Path
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed

from dns import flags, resolver
from paramiko.client import MissingHostKeyPolicy, SSHClient
from paramiko.config import SSHConfig
from paramiko.ssh_exception import (
    AuthenticationException,
    PasswordRequiredException,
    SSHException,
)
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed

from .locale import _

# Local logger
log = logging.getLogger(__name__)


AUTHENTICITY_MSG = """
The authenticity of host '{host}' can't be established.
{ktype} key fingerprint is SHA256:{kfp}.{dnsfp}
Are you sure you want to continue connecting (yes/no)?
"""

_key_algorithms = {
    "ssh-rsa": "1",
    "ssh-dss": "2",
    "ecdsa-sha2-nistp256": "3",
    "ecdsa-sha2-nistp384": "3",
    "ecdsa-sha2-nistp521": "3",
    "ssh-ed25519": "4",
}

_hash_funcs = {
    "1": sha1,
    "2": sha256,
}


class AskUserOrDNSPolicy(MissingHostKeyPolicy):
    """
    Policy for automatically trusting DNSSEC authenticated fingerprint
    or asking the user for unknown hostname & key. This is used by `.SSHClient`
    """

    def missing_host_key(self, client, hostname, key):
        kfp256 = sha256(key.asbytes()).digest()
        kfp = base64.b64encode(kfp256).decode("utf-8").replace("=", "")
        ktype = key.get_name().upper().replace("SSH-", "")
        kalg = _key_algorithms.get(key.get_name())
        kres = resolver.Resolver()
        kres.use_edns(True, flags.DO, 1280)
        dnssec = False
        found = False
        dnsfp = ""
        try:
            kans = kres.query(hostname, "SSHFP")
            if kans.response.flags & flags.AD:
                dnssec = True
            for rdata in kans:
                try:
                    alg, fptype, fp = rdata.to_text().split()
                except ValueError:  # Invalid SSHFP record format, don't care.
                    pass
                if alg != kalg:
                    continue
                if fptype not in _hash_funcs:
                    continue
                expected = _hash_funcs.get(fptype)(key.asbytes()).hexdigest()
                if expected == fp:
                    found = True
                    break
        except resolver.NoAnswer:  # Can't find SSHFP for this host.
            pass
        if found and dnssec:
            log.debug(
                "Authentic {} host key for {} found in DNS".format(
                    key.get_name(), hostname
                )
            )
            return
        if found:
            dnsfp = "\nMatching host key fingerprint found in DNS."
        inp = input(
            AUTHENTICITY_MSG.format(host=hostname, ktype=ktype, kfp=kfp, dnsfp=dnsfp)
        )
        if inp not in ["yes", "y", ""]:
            log.debug("Rejecting {} host key for {}: {}".format(ktype, hostname, kfp))
            raise SSHException(
                "Connection to {!r} rejected by the user".format(hostname)
            )
        log.debug("Accepting {} host key for {}: {}".format(ktype, hostname, kfp256))


me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
@lru_cache()
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
def create_ssh_client(host, password=None):
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    """
    Create a SSH client with paramiko module
    """
    # Create SSH client with system host keys and agent
    client = SSHClient()
    client.set_missing_host_key_policy(AskUserOrDNSPolicy)
    # Load config file and use the right username
    try:
        config = SSHConfig()
        config.parse(Path.home().joinpath(".ssh/config").open())
        username = config.lookup(host).get("user", None)
    except FileNotFoundError:
        username=None

    # Load system private keys
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    client.load_system_host_keys()
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    try:
        client.connect(host, username=username, password=password)
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    except PasswordRequiredException:
        password = getpass("SSH password: ")
        return create_ssh_client(host, password)
    except AuthenticationException:
        log.error(_("SSH authentication failed."))
        exit(1)
me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    except SSHException:
        log.error(_("An error occured during SSH connection, debug with -vv"))
        raise

    return client


def remote_command(options, command, arg=None, stdin_contents=None):
    """
    Execute remote command and return output
    """
    if "host" not in options.serverdata:
        log.error("Missing parameter `host` in active server configuration")
        exit(1)
    client = create_ssh_client(str(options.serverdata['host']))

    # Build command
    if "remote_cmd" not in options.serverdata:
        log.error("Missing parameter `remote_cmd` in active server configuration")
        exit(1)
    remote_cmd = options.serverdata['remote_cmd'] + " " + command
    if arg:
        remote_cmd += " " + arg

    # Run command and timeout after 10s
    log.info(_("Running command `%s`") % remote_cmd)
    stdin, stdout, stderr = client.exec_command(remote_cmd, timeout=10)

    # Write
    if stdin_contents is not None:
        log.info(_("Writing to stdin: %s") % stdin_contents)
        stdin.write(json.dumps(stdin_contents))
        stdin.flush()

me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    # If the server is not expected to exit, then exit now
    if stdin_contents:
        return

me5na7qbjqbrp's avatar
me5na7qbjqbrp committed
    # Return code == 0 if success
    ret = stdout.channel.recv_exit_status()
    if ret != 0:
        err = ""
        if stderr.channel.recv_stderr_ready():
            err = stderr.read()
        log.error(_("Wrong server return code %s, error is %s") % (ret, err))
        exit(ret)

    # Decode directly read buffer
    try:
        answer = json.load(stdout)
    except ValueError:
        log.error(_("Error while parsing JSON"))
        exit(42)

    log.debug(_("Server returned %s") % answer)
    return answer