Source code for ducktape.services.background_thread

# Copyright 2015 Confluent Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ducktape.services.service import Service

import threading
import traceback

from six import itervalues


[docs]class BackgroundThreadService(Service):
[docs] def __init__(self, context, num_nodes=None, cluster_spec=None, *args, **kwargs): super(BackgroundThreadService, self).__init__(context, num_nodes, cluster_spec, *args, **kwargs) self.worker_threads = {} self.worker_errors = {} self.errors = '' self.lock = threading.RLock()
def _protected_worker(self, idx, node): """Protected worker captures exceptions and makes them available to the main thread. This gives us the ability to propagate exceptions thrown in background threads, if desired. """ try: self._worker(idx, node) except BaseException: with self.lock: self.logger.info("BackgroundThreadService threw exception: ") tb = traceback.format_exc() self.logger.info(tb) self.worker_errors[threading.currentThread().name] = tb if self.errors: self.errors += "\n" self.errors += "%s: %s" % (threading.currentThread().name, tb) raise def start_node(self, node): idx = self.idx(node) if idx in self.worker_threads and self.worker_threads[idx].is_alive(): raise RuntimeError("Cannot restart node since previous thread is still alive") self.logger.info("Running %s node %d on %s", self.service_id, idx, node.account.hostname) worker = threading.Thread( name=self.service_id + "-worker-" + str(idx), target=self._protected_worker, args=(idx, node) ) worker.daemon = True worker.start() self.worker_threads[idx] = worker
[docs] def wait(self, timeout_sec=600): """Wait no more than timeout_sec for all worker threads to finish. raise TimeoutException if all worker threads do not finish within timeout_sec """ super(BackgroundThreadService, self).wait(timeout_sec) self._propagate_exceptions()
def stop(self): alive_workers = [worker for worker in itervalues(self.worker_threads) if worker.is_alive()] if len(alive_workers) > 0: self.logger.debug( "Called stop with at least one worker thread is still running: " + str(alive_workers)) self.logger.debug("%s" % str(self.worker_threads)) super(BackgroundThreadService, self).stop() self._propagate_exceptions() def wait_node(self, node, timeout_sec=600): idx = self.idx(node) worker_thread = self.worker_threads.get(idx) # worker thread can be absent if this node has never been started if worker_thread: worker_thread.join(timeout_sec) return not (worker_thread.is_alive()) else: self.logger.debug(f"Worker thread not found for {self.who_am_i(node)}") return True def _propagate_exceptions(self): """ Propagate exceptions thrown in background threads """ with self.lock: if len(self.worker_errors) > 0: raise Exception(self.errors)