This document describes the current stable version of Celery (5.2). For development docs, go here.
Source code for celery.contrib.testing.manager
"""Integration testing utilities."""
import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple
from kombu.utils.functional import retry_over_time
from celery import states
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, ResultSet
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
humanize_seconds = partial(_humanize_seconds, microseconds=True)
[docs]class ManagerMixin:
"""Mixin that adds :class:`Manager` capabilities."""
def _init_manager(self,
block_timeout=30 * 60.0, no_join=False,
stdout=None, stderr=None):
# type: (float, bool, TextIO, TextIO) -> None
self.stdout = sys.stdout if stdout is None else stdout
self.stderr = sys.stderr if stderr is None else stderr
self.connerrors = self.app.connection().recoverable_connection_errors
self.block_timeout = block_timeout
self.no_join = no_join
[docs] def remark(self, s, sep='-'):
# type: (str, str) -> None
print(f'{sep}{s}', file=self.stdout)
[docs] def missing_results(self, r):
# type: (Sequence[AsyncResult]) -> Sequence[str]
return [res.id for res in r if res.id not in res.backend._cache]
[docs] def wait_for(
self,
fun, # type: Callable
catch, # type: Sequence[Any]
desc="thing", # type: str
args=(), # type: Tuple
kwargs=None, # type: Dict
errback=None, # type: Callable
max_retries=10, # type: int
interval_start=0.1, # type: float
interval_step=0.5, # type: float
interval_max=5.0, # type: float
emit_warning=False, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Wait for event to happen.
The `catch` argument specifies the exception that means the event
has not happened yet.
"""
kwargs = {} if not kwargs else kwargs
def on_error(exc, intervals, retries):
interval = next(intervals)
if emit_warning:
self.warn(E_STILL_WAITING.format(
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
))
if errback:
errback(exc, interval, retries)
return interval
return self.retry_over_time(
fun, catch,
args=args, kwargs=kwargs,
errback=on_error, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
**options
)
[docs] def ensure_not_for_a_while(self, fun, catch,
desc='thing', max_retries=20,
interval_start=0.1, interval_step=0.02,
interval_max=1.0, emit_warning=False,
**options):
"""Make sure something does not happen (at least for a while)."""
try:
return self.wait_for(
fun, catch, desc=desc, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
interval_max=interval_max, emit_warning=emit_warning,
)
except catch:
pass
else:
raise AssertionError(f'Should not have happened: {desc}')
[docs] def join(self, r, propagate=False, max_retries=10, **kwargs):
if self.no_join:
return
if not isinstance(r, ResultSet):
r = self.app.ResultSet([r])
received = []
def on_result(task_id, value):
received.append(task_id)
for i in range(max_retries) if max_retries else count(0):
received[:] = []
try:
return r.get(callback=on_result, propagate=propagate, **kwargs)
except (socket.timeout, TimeoutError) as exc:
waiting_for = self.missing_results(r)
self.remark(
'Still waiting for {}/{}: [{}]: {!r}'.format(
len(r) - len(received), len(r),
truncate(', '.join(waiting_for)), exc), '!',
)
except self.connerrors as exc:
self.remark(f'join: connection lost: {exc!r}', '!')
raise AssertionError('Test failed: Missing task results')
[docs] def query_tasks(self, ids, timeout=0.5):
tasks = self.inspect(timeout).query_task(*ids) or {}
yield from tasks.items()
[docs] def query_task_states(self, ids, timeout=0.5):
states = defaultdict(set)
for hostname, reply in self.query_tasks(ids, timeout=timeout):
for task_id, (state, _) in reply.items():
states[state].add(task_id)
return states
[docs] def assert_accepted(self, ids, interval=0.5,
desc='waiting for tasks to be accepted', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
[docs] def assert_received(self, ids, interval=0.5,
desc='waiting for tasks to be received', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
[docs] def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)
[docs] def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)
[docs] @staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)
[docs] def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
(Sentinel,), **policy
)
[docs] def is_received(self, ids, **kwargs):
return self._ids_matches_state(
['reserved', 'active', 'ready'], ids, **kwargs)
[docs] def is_accepted(self, ids, **kwargs):
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
states = self.query_task_states(ids, timeout=timeout)
return all(
any(t in s for s in [states[k] for k in expected_states])
for t in ids
)
[docs] def true_or_raise(self, fun, *args, **kwargs):
res = fun(*args, **kwargs)
if not res:
raise Sentinel()
return res
[docs]class Manager(ManagerMixin):
"""Test helpers for task integration tests."""
def __init__(self, app, **kwargs):
self.app = app
self._init_manager(**kwargs)