Source code for ducktape.cluster.remoteaccount

# Copyright 2014 Confluent Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
import logging
import os
from paramiko import SSHClient, SSHConfig, MissingHostKeyPolicy
from paramiko.ssh_exception import SSHException, NoValidConnectionsError
import shutil
import signal
import socket
import stat
import tempfile
import warnings

from ducktape.utils.http_utils import HttpMixin
from ducktape.utils.util import wait_until
from ducktape.errors import DucktapeError


def check_ssh(method):
    def wrapper(self, *args, **kwargs):
        try:
            return method(self, *args, **kwargs)
        except (SSHException, NoValidConnectionsError, socket.error) as e:
            if self._custom_ssh_exception_checks:
                self._log(logging.DEBUG, "caught ssh error", exc_info=True)
                self._log(logging.DEBUG, "starting ssh checks:")
                self._log(logging.DEBUG, "\n".join(repr(f) for f in self._custom_ssh_exception_checks))
                for func in self._custom_ssh_exception_checks:
                    func(e, self)
            raise
    return wrapper


class RemoteAccountSSHConfig(object):
    def __init__(self, host=None, hostname=None, user=None, port=None, password=None, identityfile=None,
                 connecttimeout=None, **kwargs):
        """Wrapper for ssh configs used by ducktape to connect to remote machines.

        The fields in this class are lowercase versions of a small selection of ssh config properties
        (see man page: "man ssh_config")
        """
        self.host = host
        self.hostname = hostname or 'localhost'
        self.user = user
        self.port = port or 22
        self.port = int(self.port)
        self.password = password
        self.identityfile = identityfile
        # None is default, and it means default TCP timeout will be used.
        self.connecttimeout = int(connecttimeout) if connecttimeout is not None else None

    @staticmethod
    def from_string(config_str):
        """Construct RemoteAccountSSHConfig object from a string that looks like

        Host the-host
            Hostname the-hostname
            Port 22
            User ubuntu
            IdentityFile /path/to/key
        """
        config = SSHConfig()
        config.parse(config_str.split("\n"))

        hostnames = config.get_hostnames()
        if '*' in hostnames:
            hostnames.remove('*')
        assert len(hostnames) == 1, "Expected hostnames to have single entry: %s" % hostnames
        host = hostnames.pop()

        config_dict = config.lookup(host)
        if config_dict.get("identityfile") is not None:
            # paramiko.SSHConfig parses this in as a list, but we only want a single string
            config_dict["identityfile"] = config_dict["identityfile"][0]

        return RemoteAccountSSHConfig(host, **config_dict)

    def to_json(self):
        return self.__dict__

    def __repr__(self):
        return str(self.to_json())

    def __eq__(self, other):
        return other and other.__dict__ == self.__dict__

    def __hash__(self):
        return hash(tuple(sorted(self.__dict__.items())))


class RemoteAccountError(DucktapeError):
    """This exception is raised when an attempted action on a remote node fails.
    """

    def __init__(self, account, msg):
        self.account_str = str(account)
        self.msg = msg

    def __str__(self):
        return "%s: %s" % (self.account_str, self.msg)


class RemoteCommandError(RemoteAccountError):
    """This exception is raised when a process run by ssh*() returns a non-zero exit status.
    """

    def __init__(self, account, cmd, exit_status, msg):
        self.account_str = str(account)
        self.exit_status = exit_status
        self.cmd = cmd
        self.msg = msg

    def __str__(self):
        msg = "%s: Command '%s' returned non-zero exit status %d." % (self.account_str, self.cmd, self.exit_status)
        if self.msg:
            msg += " Remote error message: %s" % self.msg
        return msg


[docs]class RemoteAccount(HttpMixin): """RemoteAccount is the heart of interaction with cluster nodes, and every allocated cluster node has a reference to an instance of RemoteAccount. It wraps metadata such as ssh configs, and provides methods for file system manipulation and shell commands. Each operating system has its own RemoteAccount implementation. """
[docs] def __init__(self, ssh_config, externally_routable_ip=None, logger=None, ssh_exception_checks=[]): # Instance of RemoteAccountSSHConfig - use this instead of a dict, because we need the entire object to # be hashable self.ssh_config = ssh_config # We don't want to rely on the hostname (e.g. 'worker1') having been added to the driver host's /etc/hosts file. # But that means we need to distinguish between the hostname and the value of hostname we use for SSH commands. # We try to satisfy all use cases and keep things simple by # a) storing the hostname the user probably expects (the "Host" value in .ssh/config) # b) saving the real value we use for running the SSH command self.hostname = ssh_config.host self.ssh_hostname = ssh_config.hostname self.user = ssh_config.user self.externally_routable_ip = externally_routable_ip self._logger = logger self.os = None self._ssh_client = None self._sftp_client = None self._custom_ssh_exception_checks = ssh_exception_checks
@property def operating_system(self): return self.os @property def logger(self): if self._logger: return self._logger else: return logging.getLogger(__name__) @logger.setter def logger(self, logger): self._logger = logger def _log(self, level, msg, *args, **kwargs): msg = "%s: %s" % (str(self), msg) self.logger.log(level, msg, *args, **kwargs) @check_ssh def _set_ssh_client(self): client = SSHClient() client.set_missing_host_key_policy(IgnoreMissingHostKeyPolicy()) self._log(logging.DEBUG, "ssh_config: %s" % str(self.ssh_config)) client.connect( hostname=self.ssh_config.hostname, port=self.ssh_config.port, username=self.ssh_config.user, password=self.ssh_config.password, key_filename=self.ssh_config.identityfile, look_for_keys=False, timeout=self.ssh_config.connecttimeout) if self._ssh_client: self._ssh_client.close() self._ssh_client = client self._set_sftp_client() @property def ssh_client(self): if (self._ssh_client and self._ssh_client.get_transport() and self._ssh_client.get_transport().is_active()): try: transport = self._ssh_client.get_transport() transport.send_ignore() except Exception as e: self._log(logging.DEBUG, "exception getting ssh_client (creating new client): %s" % str(e)) self._set_ssh_client() else: self._set_ssh_client() return self._ssh_client def _set_sftp_client(self): if self._sftp_client: self._sftp_client.close() self._sftp_client = self.ssh_client.open_sftp() @property def sftp_client(self): if not self._sftp_client: self._set_sftp_client() else: self.ssh_client # test connection return self._sftp_client
[docs] def close(self): """Close/release any outstanding network connections to remote account.""" if self._ssh_client: self._ssh_client.close() self._ssh_client = None if self._sftp_client: self._sftp_client.close() self._sftp_client = None
def __str__(self): r = "" if self.user: r += self.user + "@" r += self.hostname return r def __repr__(self): return str(self.__dict__) def __eq__(self, other): return other is not None and self.__dict__ == other.__dict__ def __hash__(self): return hash(tuple(sorted(self.__dict__.items())))
[docs] def wait_for_http_service(self, port, headers, timeout=20, path='/'): """Wait until this service node is available/awake.""" url = "http://%s:%s%s" % (self.externally_routable_ip, str(port), path) err_msg = "Timed out trying to contact service on %s. " % url + \ "Either the service failed to start, or there is a problem with the url." wait_until(lambda: self._can_ping_url(url, headers), timeout_sec=timeout, backoff_sec=.25, err_msg=err_msg)
def _can_ping_url(self, url, headers): """See if we can successfully issue a GET request to the given url.""" try: self.http_request(url, "GET", None, headers, timeout=.75) return True except Exception: return False def available(self): # TODO: https://github.com/confluentinc/ducktape/issues/339 # try: # self.ssh_client # except Exception: # return False # else: # return True # finally: # self.close() return True @check_ssh def ssh(self, cmd, allow_fail=False): """Run the given command on the remote host, and block until the command has finished running. :param cmd: The remote ssh command :param allow_fail: If True, ignore nonzero exit status of the remote command, else raise an ``RemoteCommandError`` :return: The exit status of the command. :raise RemoteCommandError: If allow_fail is False and the command returns a non-zero exit status """ self._log(logging.DEBUG, "Running ssh command: %s" % cmd) client = self.ssh_client stdin, stdout, stderr = client.exec_command(cmd) # Unfortunately we need to read over the channel to ensure that recv_exit_status won't hang. See: # http://docs.paramiko.org/en/2.0/api/channel.html#paramiko.channel.Channel.recv_exit_status stdout.read() exit_status = stdout.channel.recv_exit_status() try: if exit_status != 0: if not allow_fail: raise RemoteCommandError(self, cmd, exit_status, stderr.read()) else: self._log(logging.DEBUG, "Running ssh command '%s' exited with status %d and message: %s" % (cmd, exit_status, stderr.read())) finally: stdin.close() stdout.close() stderr.close() return exit_status @check_ssh def ssh_capture(self, cmd, allow_fail=False, callback=None, combine_stderr=True, timeout_sec=None): """Run the given command asynchronously via ssh, and return an SSHOutputIter object. Does *not* block :param cmd: The remote ssh command :param allow_fail: If True, ignore nonzero exit status of the remote command, else raise an ``RemoteCommandError`` :param callback: If set, the iterator returns ``callback(line)`` for each line of output instead of the raw output :param combine_stderr: If True, return output from both stderr and stdout of the remote process. :param timeout_sec: Set timeout on blocking reads/writes. Default None. For more details see http://docs.paramiko.org/en/2.0/api/channel.html#paramiko.channel.Channel.settimeout :return SSHOutputIter: object which allows iteration through each line of output. :raise RemoteCommandError: If ``allow_fail`` is False and the command returns a non-zero exit status """ self._log(logging.DEBUG, "Running ssh command: %s" % cmd) client = self.ssh_client chan = client.get_transport().open_session(timeout=timeout_sec) chan.settimeout(timeout_sec) chan.exec_command(cmd) chan.set_combine_stderr(combine_stderr) stdin = chan.makefile('wb', -1) # set bufsize to -1 stdout = chan.makefile('r', -1) stderr = chan.makefile_stderr('r', -1) def output_generator(): for line in iter(stdout.readline, ''): if callback is None: yield line else: yield callback(line) try: exit_status = stdout.channel.recv_exit_status() if exit_status != 0: if not allow_fail: raise RemoteCommandError(self, cmd, exit_status, stderr.read()) else: self._log(logging.DEBUG, "Running ssh command '%s' exited with status %d and message: %s" % (cmd, exit_status, stderr.read())) finally: stdin.close() stdout.close() stderr.close() return SSHOutputIter(output_generator, stdout) @check_ssh def ssh_output(self, cmd, allow_fail=False, combine_stderr=True, timeout_sec=None): """Runs the command via SSH and captures the output, returning it as a string. :param cmd: The remote ssh command. :param allow_fail: If True, ignore nonzero exit status of the remote command, else raise an ``RemoteCommandError`` :param combine_stderr: If True, return output from both stderr and stdout of the remote process. :param timeout_sec: Set timeout on blocking reads/writes. Default None. For more details see http://docs.paramiko.org/en/2.0/api/channel.html#paramiko.channel.Channel.settimeout :return: The stdout output from the ssh command. :raise RemoteCommandError: If ``allow_fail`` is False and the command returns a non-zero exit status """ self._log(logging.DEBUG, "Running ssh command: %s" % cmd) client = self.ssh_client chan = client.get_transport().open_session(timeout=timeout_sec) chan.settimeout(timeout_sec) chan.exec_command(cmd) chan.set_combine_stderr(combine_stderr) stdin = chan.makefile('wb', -1) # set bufsize to -1 stdout = chan.makefile('r', -1) stderr = chan.makefile_stderr('r', -1) try: stdoutdata = stdout.read() exit_status = stdin.channel.recv_exit_status() if exit_status != 0: if not allow_fail: raise RemoteCommandError(self, cmd, exit_status, stderr.read()) else: self._log(logging.DEBUG, "Running ssh command '%s' exited with status %d and message: %s" % (cmd, exit_status, stderr.read())) finally: stdin.close() stdout.close() stderr.close() self._log(logging.DEBUG, "Returning ssh command output:\n%s" % stdoutdata) return stdoutdata
[docs] def alive(self, pid): """Return True if and only if process with given pid is alive.""" try: self.ssh("kill -0 %s" % str(pid), allow_fail=False) return True except Exception: return False
def signal(self, pid, sig, allow_fail=False): cmd = "kill -%d %s" % (int(sig), str(pid)) self.ssh(cmd, allow_fail=allow_fail) def kill_process(self, process_grep_str, clean_shutdown=True, allow_fail=False): cmd = """ps ax | grep -i """ + process_grep_str + """ | grep -v grep | awk '{print $1}'""" pids = [pid for pid in self.ssh_capture(cmd, allow_fail=True)] if clean_shutdown: sig = signal.SIGTERM else: sig = signal.SIGKILL for pid in pids: self.signal(pid, sig, allow_fail=allow_fail)
[docs] def java_pids(self, match): """ Get all the Java process IDs matching 'match'. :param match: The AWK expression to match """ cmd = """jcmd | awk '/%s/ { print $1 }'""" % match return [int(pid) for pid in self.ssh_capture(cmd, allow_fail=True)]
[docs] def kill_java_processes(self, match, clean_shutdown=True, allow_fail=False): """ Kill all the java processes matching 'match'. :param match: The AWK expression to match :param clean_shutdown: True if we should shut down cleanly with SIGTERM; false if we should shut down with SIGKILL. :param allow_fail: True if we should throw exceptions if the ssh commands fail. """ cmd = """jcmd | awk '/%s/ { print $1 }'""" % match pids = [pid for pid in self.ssh_capture(cmd, allow_fail=True)] if clean_shutdown: sig = signal.SIGTERM else: sig = signal.SIGKILL for pid in pids: self.signal(pid, sig, allow_fail=allow_fail)
[docs] def copy_between(self, src, dest, dest_node): """Copy src to dest on dest_node :param src: Path to the file or directory we want to copy :param dest: The destination path :param dest_node: The node to which we want to copy the file/directory Note that if src is a directory, this will automatically copy recursively. """ # TODO: if dest is an existing file, what is the behavior? temp_dir = tempfile.mkdtemp() try: # TODO: deal with very unlikely case that src_name matches temp_dir name? # TODO: I think this actually works local_dest = self._re_anchor_basename(src, temp_dir) self.copy_from(src, local_dest) dest_node.account.copy_to(local_dest, dest) finally: if os.path.isdir(temp_dir): shutil.rmtree(temp_dir)
def scp_from(self, src, dest, recursive=False): warnings.warn("scp_from is now deprecated. Please use copy_from") self.copy_from(src, dest) def _re_anchor_basename(self, path, directory): """Anchor the basename of path onto the given directory Helper for the various copy_* methods. :param path: Path to a file or directory. Could be on the driver machine or a worker machine. :param directory: Path to a directory. Could be on the driver machine or a worker machine. Example:: path/to/the_basename, another/path/ -> another/path/the_basename """ path_basename = path # trim off path separator from end of path # this is necessary because os.path.basename of a path ending in a separator is an empty string # For example: # os.path.basename("the/path/") == "" # os.path.basename("the/path") == "path" if path_basename.endswith(os.path.sep): path_basename = path_basename[:-len(os.path.sep)] path_basename = os.path.basename(path_basename) return os.path.join(directory, path_basename) @check_ssh def copy_from(self, src, dest): if os.path.isdir(dest): # dest is an existing directory, so assuming src looks like path/to/src_name, # in this case we'll copy as: # path/to/src_name -> dest/src_name dest = self._re_anchor_basename(src, dest) if self.isfile(src): self.sftp_client.get(src, dest) elif self.isdir(src): # we can now assume dest path looks like: path_that_exists/new_directory os.mkdir(dest) # for obj in `ls src`, if it's a file, copy with copy_file_from, elif its a directory, call again for obj in self.sftp_client.listdir(src): obj_path = os.path.join(src, obj) if self.isfile(obj_path) or self.isdir(obj_path): self.copy_from(obj_path, dest) else: # TODO what about uncopyable file types? pass def scp_to(self, src, dest, recursive=False): warnings.warn("scp_to is now deprecated. Please use copy_to") self.copy_to(src, dest) @check_ssh def copy_to(self, src, dest): if self.isdir(dest): # dest is an existing directory, so assuming src looks like path/to/src_name, # in this case we'll copy as: # path/to/src_name -> dest/src_name dest = self._re_anchor_basename(src, dest) if os.path.isfile(src): # local to remote self.sftp_client.put(src, dest) elif os.path.isdir(src): # we can now assume dest path looks like: path_that_exists/new_directory self.mkdir(dest) # for obj in `ls src`, if it's a file, copy with copy_file_from, elif its a directory, call again for obj in os.listdir(src): obj_path = os.path.join(src, obj) if os.path.isfile(obj_path) or os.path.isdir(obj_path): self.copy_to(obj_path, dest) else: # TODO what about uncopyable file types? pass @check_ssh def islink(self, path): try: # stat should follow symlinks path_stat = self.sftp_client.lstat(path) return stat.S_ISLNK(path_stat.st_mode) except Exception: return False @check_ssh def isdir(self, path): try: # stat should follow symlinks path_stat = self.sftp_client.stat(path) return stat.S_ISDIR(path_stat.st_mode) except Exception: return False @check_ssh def exists(self, path): """Test that the path exists, but don't follow symlinks.""" try: # stat follows symlinks and tries to stat the actual file self.sftp_client.lstat(path) return True except IOError: return False @check_ssh def isfile(self, path): """Imitates semantics of os.path.isfile :param path: Path to the thing to check :return: True if path is a file or a symlink to a file, else False. Note False can mean path does not exist. """ try: # stat should follow symlinks path_stat = self.sftp_client.stat(path) return stat.S_ISREG(path_stat.st_mode) except Exception: return False def open(self, path, mode='r'): return self.sftp_client.open(path, mode) @check_ssh def create_file(self, path, contents): """Create file at path, with the given contents. If the path already exists, it will be overwritten. """ # TODO: what should semantics be if path exists? what actually happens if it already exists? # TODO: what happens if the base part of the path does not exist? with self.sftp_client.open(path, "w") as f: f.write(contents) _DEFAULT_PERMISSIONS = int('755', 8) @check_ssh def mkdir(self, path, mode=_DEFAULT_PERMISSIONS): self.sftp_client.mkdir(path, mode) def mkdirs(self, path, mode=_DEFAULT_PERMISSIONS): self.ssh("mkdir -p %s && chmod %o %s" % (path, mode, path))
[docs] def remove(self, path, allow_fail=False): """Remove the given file or directory""" if allow_fail: cmd = "rm -rf %s" % path else: cmd = "rm -r %s" % path self.ssh(cmd, allow_fail=allow_fail)
[docs] @contextmanager def monitor_log(self, log): """ Context manager that returns an object that helps you wait for events to occur in a log. This checks the size of the log at the beginning of the block and makes a helper object available with convenience methods for checking or waiting for a pattern to appear in the log. This will commonly be used to start a process, then wait for a log message indicating the process is in a ready state. See ``LogMonitor`` for more usage information. """ try: offset = int(self.ssh_output("wc -c %s" % log).split()[0]) except Exception: offset = 0 yield LogMonitor(self, log, offset)
class SSHOutputIter(object): """Helper class that wraps around an iterable object to provide has_next() in addition to next() """ def __init__(self, iter_obj_func, channel_file=None): """ :param iter_obj_func: A generator that returns an iterator over stdout from the remote process :param channel_file: A paramiko ``ChannelFile`` object """ self.iter_obj_func = iter_obj_func self.iter_obj = iter_obj_func() self.channel_file = channel_file # sentinel is used as an indicator that there is currently nothing cached # If self.cached is self.sentinel, then next object from ier_obj is not yet cached. self.sentinel = object() self.cached = self.sentinel def __iter__(self): return self def next(self): if self.cached is self.sentinel: return next(self.iter_obj) next_obj = self.cached self.cached = self.sentinel return next_obj __next__ = next def has_next(self, timeout_sec=None): """Return True if next(iter_obj) would return another object within timeout_sec, else False. If timeout_sec is None, next(iter_obj) may block indefinitely. """ assert timeout_sec is None or self.channel_file is not None, "should have descriptor to enforce timeout" prev_timeout = None if self.cached is self.sentinel: if self.channel_file is not None: prev_timeout = self.channel_file.channel.gettimeout() # when timeout_sec is None, next(iter_obj) will block indefinitely self.channel_file.channel.settimeout(timeout_sec) try: self.cached = next(self.iter_obj, self.sentinel) except socket.timeout: self.iter_obj = self.iter_obj_func() self.cached = self.sentinel finally: if self.channel_file is not None: # restore preexisting timeout self.channel_file.channel.settimeout(prev_timeout) return self.cached is not self.sentinel
[docs]class LogMonitor(object): """ Helper class returned by monitor_log. Should be used as:: with remote_account.monitor_log("/path/to/log") as monitor: remote_account.ssh("/command/to/start") monitor.wait_until("pattern.*to.*grep.*for", timeout_sec=5) to run the command and then wait for the pattern to appear in the log. """
[docs] def __init__(self, acct, log, offset): self.acct = acct self.log = log self.offset = offset
[docs] def wait_until(self, pattern, **kwargs): """ Wait until the specified pattern is found in the log, after the initial offset recorded when the LogMonitor was created. Additional keyword args are passed directly to ``ducktape.utils.util.wait_until`` """ return wait_until(lambda: self.acct.ssh("tail -c +%d %s | grep '%s'" % (self.offset + 1, self.log, pattern), allow_fail=True) == 0, **kwargs)
class IgnoreMissingHostKeyPolicy(MissingHostKeyPolicy): """Policy for ignoring missing host keys. Many examples show use of AutoAddPolicy, but this clutters up the known_hosts file unnecessarily. """ def missing_host_key(self, client, hostname, key): return