Source code for coredis.patterns.pipeline

from __future__ import annotations

from contextlib import asynccontextmanager
from typing import Any, cast

from anyio import AsyncContextManagerMixin
from deprecated.sphinx import versionchanged

from coredis._telemetry import get_telemetry_provider
from coredis._utils import nativestr
from coredis.client import Client, RedisCluster
from coredis.cluster._node import ClusterNodeLocation
from coredis.commands import CommandRequest, CommandResponseT
from coredis.commands._routing import ExplicitSlotStrategy, RandomStrategy
from coredis.commands.constants import CommandName
from coredis.commands.script import Script
from coredis.connection._base import BaseConnection
from coredis.connection._request import Request
from coredis.exceptions import (
    AskError,
    ClusterCrossSlotError,
    ClusterDownError,
    ClusterTransactionError,
    ConnectionError,
    ExecAbortError,
    MovedError,
    RedisClusterError,
    RedisError,
    ResponseError,
    TryAgainError,
    WatchError,
)
from coredis.pool import ClusterConnectionPool
from coredis.response._callbacks import (
    AnyStrCallback,
    BoolsCallback,
    NoopCallback,
    SimpleStringCallback,
)
from coredis.retry import ConstantRetryPolicy, retryable
from coredis.typing import (
    AnyStr,
    AsyncGenerator,
    Awaitable,
    Callable,
    ExecutionParameters,
    Generator,
    Iterable,
    Key,
    KeyT,
    ParamSpec,
    RedisValueT,
    ResponseType,
    Self,
    T_co,
    TypeAdapter,
    TypeVar,
    ValueT,
)

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")

ERRORS_ALLOW_RETRY = (
    MovedError,
    AskError,
    TryAgainError,
)

UNWATCH_COMMANDS = {CommandName.DISCARD, CommandName.EXEC, CommandName.UNWATCH}


class PipelineResult(Awaitable[T]):
    __slots__ = ("_result",)

    def __init__(self, result: T) -> None:
        self._result = result

    @property
    def value(self) -> T:
        if isinstance(self._result, Exception):
            raise self._result
        else:
            return self._result

    def __await__(self) -> Generator[Any, None, T]:
        async def _coro() -> T:
            return self.value

        return _coro().__await__()


class PipelineCommandRequest(CommandRequest[CommandResponseT]):
    """
    Command request returned by a pipeline command
    """

    __slots__ = ()

    def __await__(self) -> Generator[None, None, CommandResponseT]:
        if hasattr(self, "_response"):
            return self._response.__await__()
        raise RuntimeError("You can't await a pipeline command before it completes executing")


class ClusterPipelineCommandRequest(CommandRequest[CommandResponseT]):
    """
    Command request for cluster pipelines, tracks position and result for cluster routing.
    """

    __slots__ = ("position", "result")

    def __init__(
        self,
        name: bytes,
        *arguments: ValueT | Key,
        callback: Callable[..., CommandResponseT],
        execution_parameters: ExecutionParameters,
        type_adapter: TypeAdapter,
    ) -> None:
        super().__init__(
            name,
            *arguments,
            callback=callback,
            execution_parameters=execution_parameters,
            type_adapter=type_adapter,
        )
        self.position: int = 0
        self.result: Any = None

    def __await__(self) -> Generator[None, None, CommandResponseT]:
        if hasattr(self, "_response"):
            return self._response.__await__()
        raise RuntimeError("You can't await a pipeline command before it completes executing")


class NodeCommands(AsyncContextManagerMixin):
    """
    Helper for grouping and executing commands on a single cluster node, handling transactions if needed.
    """

    connection: BaseConnection

    def __init__(
        self,
        client: RedisCluster[AnyStr],
        node: ClusterNodeLocation,
        connection: BaseConnection | None = None,
        in_transaction: bool = False,
        timeout: float | None = None,
        raise_on_error: bool = True,
    ):
        self.client: RedisCluster[Any] = client
        self.node = node
        self._connection = connection
        self.commands: list[ClusterPipelineCommandRequest[Any]] = []
        self.in_transaction = in_transaction
        self.timeout = timeout
        self._raise_on_error = raise_on_error
        self.multi_cmd: Request | None = None
        self.exec_cmd: Request | None = None
        self.request_batch: Awaitable[list[ResponseType | BaseException | None]] | None = None

    def extend(self, c: list[ClusterPipelineCommandRequest[Any]]) -> None:
        self.commands.extend(c)

    def append(self, c: ClusterPipelineCommandRequest[Any]) -> None:
        self.commands.append(c)

    @asynccontextmanager
    async def __asynccontextmanager__(self) -> AsyncGenerator[None]:
        if not self._connection:
            async with self.client.connection_pool.acquire(node=self.node) as self.connection:
                yield
        else:
            self.connection = self._connection
            yield

    def write(self) -> None:
        connection = self.connection
        commands = self.commands

        # Reset results for all commands before writing.
        for c in commands:
            c.result = None

        # Batch all commands into a single request for efficiency.
        try:
            if self.in_transaction:
                self.multi_cmd = connection.create_request(CommandName.MULTI, timeout=self.timeout)
            self.request_batch = connection.create_request_batch(
                commands,
                timeout=self.timeout,
            )
            if self.in_transaction:
                self.exec_cmd = connection.create_request(CommandName.EXEC, timeout=self.timeout)
        except (ConnectionError, TimeoutError) as e:
            for c in commands:
                c.result = e

    async def read(self) -> None:
        success = True
        multi_result = None
        if self.multi_cmd:
            multi_result = await self.multi_cmd
            success = multi_result in {b"OK", "OK"}
        if self.request_batch:
            try:
                responses = await self.request_batch
                for command, response in zip(self.commands, responses):
                    command.result = response
                    command._response = PipelineResult(response)
                    if isinstance(response, (ConnectionError, TimeoutError, RedisError)):
                        success = False
            except ExecAbortError:
                raise
        if self.in_transaction and self.exec_cmd:
            if success:
                res = await self.exec_cmd
                if res:
                    transaction_result = cast(list[ResponseType], res)
                else:
                    raise WatchError("Watched variable changed.")
                for idx, c in enumerate(
                    [
                        _c
                        for _c in sorted(self.commands, key=lambda x: x.position)
                        if _c.name not in {CommandName.MULTI, CommandName.EXEC}
                    ]
                ):
                    c.result = c.callback(
                        transaction_result[idx],
                    )
                    c._response = PipelineResult(c.result)
            elif isinstance(multi_result, BaseException):
                raise multi_result


[docs] @versionchanged( version="6.0.0", reason="Pipelines are no longer awaitable. They support the async context manager protocol and must always be used as such", ) class Pipeline(Client[AnyStr]): """ Pipeline for batching multiple commands to a Redis server. Supports transactions and command stacking. All commands executed within a pipeline are wrapped with MULTI and EXEC calls when :paramref:`transaction` is ``True``. Any command raising an exception does **not** halt the execution of subsequent commands in the pipeline, however the first exception encountered will be raised when exiting the pipeline if :paramref:`raise_on_error` is ``True``. If not the exception is caught and will be returned when awaiting the command that failed. """ QUEUED_RESPONSES = {b"QUEUED", "QUEUED"} def __init__( self, client: Client[AnyStr], transaction: bool | None, raise_on_error: bool = True, timeout: float | None = None, ) -> None: """ :param transaction: Whether to wrap the commands in the pipeline in a ``MULTI``, ``EXEC`` :param raise_on_error: Whether to raise the first error encounterd in the pipeline after executing it :param timeout: Time in seconds to wait for the pipeline results to return """ self.client: Client[AnyStr] = client self._connection: BaseConnection | None = None self._transaction = transaction self._raise_on_error = raise_on_error self.watching = False self.command_stack: list[PipelineCommandRequest[Any]] = [] self.watches: list[KeyT] = [] self.cache = None self.explicit_transaction = False self.scripts: set[Script[AnyStr]] = set() self.timeout = timeout self.type_adapter = client.type_adapter self._results: tuple[Any] | None = None @asynccontextmanager async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: yield self await self._execute() if self._connection: self.client.connection_pool.release(self._connection) def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self._connection)}>" def create_request( self, name: bytes, *arguments: ValueT | Key, callback: Callable[..., T_co], execution_parameters: ExecutionParameters | None = None, ) -> CommandRequest[T_co]: """ :meta private: """ command = PipelineCommandRequest( name, *arguments, callback=callback, execution_parameters=execution_parameters or {}, type_adapter=self.type_adapter, ) self.command_stack.append(command) return command
[docs] @asynccontextmanager async def watch(self, *keys: KeyT) -> AsyncGenerator[None]: """ The given keys will be watched for changes within this context and the commands stacked within the context will be automatically executed when the context exits. """ if self.command_stack: raise WatchError("Unable to add a watch after pipeline commands have been added") if not self._connection: self._connection = await self.client.connection_pool.get_connection() self.watches.extend(keys) await self._immediate_execute_command( self.client.create_request( CommandName.WATCH, *[Key(watch) for watch in self.watches], callback=SimpleStringCallback(), ) ) self.explicit_transaction = True yield await self._execute()
@property def results(self) -> tuple[Any, ...] | None: """ The results of the pipeline execution which can be accessed after the pipeline has completed. """ if self.command_stack: raise RuntimeError("Pipeline results are not available before it completes execution") return self._results def execute_command( self, command: CommandRequest[R], ) -> Awaitable[R]: raise NotImplementedError async def _clear(self) -> None: """ Clear the pipeline and reset state. """ self.command_stack.clear() self.scripts.clear() # Reset connection state if we were watching something. if self.watches and self._connection: await self._connection.create_request(CommandName.UNWATCH, decode=False) self.watches.clear() self.explicit_transaction = False async def _immediate_execute_command( self, command: CommandRequest[R], ) -> R: """ Executes a command immediately, but don't auto-retry on a ConnectionError if we're already WATCHing a variable. Used when issuing WATCH or subsequent commands retrieving their values but before MULTI is called. :meta private: """ assert self._connection request = self._connection.create_request( command.name, *command.serialized_arguments, decode=command.execution_parameters.get("decode"), ) return command.callback(await request) async def _execute_transaction( self, connection: BaseConnection, commands: list[PipelineCommandRequest[Any]], ) -> None: multi_request = connection.create_request(CommandName.MULTI, timeout=self.timeout) queued_batch = connection.create_request_batch(commands, timeout=self.timeout) exec_request = connection.create_request(CommandName.EXEC, timeout=self.timeout) errors: list[tuple[int, RedisError | TimeoutError | None]] = [] # parse off the response for MULTI # NOTE: we need to handle ResponseErrors here and continue # so that we read all the additional command messages from # the socket try: await multi_request except (RedisError, TimeoutError) as e: errors.append((0, e)) # and all the other commands for i, (cmd, queued_response) in enumerate(zip(commands, await queued_batch)): if isinstance(queued_response, (RedisError, TimeoutError)): self._annotate_exception(queued_response, i + 1, cmd.name, cmd.serialized_arguments) errors.append((i + 1, queued_response)) if isinstance(queued_response, BaseException): raise queued_response if queued_response not in self.QUEUED_RESPONSES: raise Exception( f"Abnormal response in pipeline for command {cmd.name!r}: {queued_response!r}" ) try: response = cast(list[ResponseType] | None, await exec_request) except (ExecAbortError, ResponseError, TimeoutError) as e: if errors and errors[0][1]: raise errors[0][1] from e raise if response is None: raise WatchError("Watched variable changed.") # put any parse errors into the response for i, e in errors: # type: ignore response.insert(i, cast(ResponseType, e)) if len(response) != len(commands): raise ResponseError("Wrong number of response items from pipeline execution") # We have to run response callbacks manually data: list[Any] = [] for r, cmd in zip(response, commands): if not isinstance(r, Exception): r = cmd.callback(r, **cmd.execution_parameters) cmd._response = PipelineResult(r) data.append(r) self._results = tuple(data) # find any errors in the response and raise if necessary if self._raise_on_error: self._raise_first_error(commands, response) async def _execute_pipeline( self, connection: BaseConnection, commands: list[PipelineCommandRequest[Any]] ) -> None: request_batch = connection.create_request_batch(commands, timeout=self.timeout) results: list[Any] = [] for cmd, response in zip(commands, await request_batch): try: if isinstance(response, BaseException): raise response resp = cmd.callback( response, **cmd.execution_parameters, ) cmd._response = PipelineResult(resp) results.append(resp) except (ResponseError, TimeoutError) as re: cmd._response = PipelineResult(re) results.append(re) self._results = tuple(results) if self._raise_on_error: self._raise_first_error(commands, results) def _raise_first_error( self, commands: list[PipelineCommandRequest[Any]], response: ResponseType ) -> None: assert isinstance(response, list) for i, r in enumerate(response): if isinstance(r, (RedisError, TimeoutError)): self._annotate_exception( r, i + 1, commands[i].name, commands[i].serialized_arguments ) raise r def _annotate_exception( self, exception: RedisError | TimeoutError | None, number: int, command: bytes, args: Iterable[RedisValueT], ) -> None: if exception: cmd = command.decode("latin-1") args = " ".join(map(str, args)) msg = f"Command # {number} ({cmd} {args}) of pipeline caused error: {str(exception.args[0])}" exception.args = (msg,) + exception.args[1:] async def _load_scripts(self) -> None: # make sure all scripts that are about to be run on this pipeline exist scripts = list(self.scripts) shas = [s.sha for s in scripts] exists = await self._immediate_execute_command( self.client.create_request(CommandName.SCRIPT_EXISTS, *shas, callback=BoolsCallback()) ) if not all(exists): for s, exist in zip(scripts, exists): if not exist: s.sha = await self._immediate_execute_command( self.client.create_request( CommandName.SCRIPT_LOAD, s.script, callback=AnyStrCallback[AnyStr](), ) ) async def _execute(self) -> None: """ Execute all queued commands in the pipeline. """ if not self.command_stack: return None if not self._connection: self._connection = await self.client.connection_pool.get_connection() with get_telemetry_provider().start_span( tuple(self.command_stack), self._connection, name="MULTI" if self._transaction else "PIPELINE", ): if self.scripts: await self._load_scripts() if self._transaction or self.explicit_transaction: exec = self._execute_transaction else: exec = self._execute_pipeline try: return await exec(self._connection, self.command_stack) except (ConnectionError, TimeoutError) as e: if self.watches: raise WatchError( "A connection error occurred while watching one or more keys" ) from e raise finally: await self._clear()
[docs] @versionchanged( version="6.0.0", reason="Cluster Pipelines are no longer awaitable. They support the async context manager protocol and must always be used as such", ) class ClusterPipeline(Client[AnyStr]): """ Pipeline for batching multiple commands to a Redis Cluster Supports transactions only when all keys map to the same shard, and therefore :paramref:`transactions` is set to ``False`` by default due to the limited scope. Any command raising an exception does **not** halt the execution of subsequent commands in the pipeline, however the first exception encountered will be raised when exiting the pipeline if :paramref:`raise_on_error` is ``True``. If not the exception is caught and will be returned when awaiting the command that failed. """ client: RedisCluster[AnyStr] connection_pool: ClusterConnectionPool command_stack: list[ClusterPipelineCommandRequest[Any]] def __init__( self, client: RedisCluster[AnyStr], raise_on_error: bool = True, transaction: bool = False, timeout: float | None = None, ) -> None: """ :param transaction: Whether to wrap the commands in the pipeline in a ``MULTI``, ``EXEC`` :param raise_on_error: Whether to raise the first error encounterd in the pipeline after executing it :param timeout: Time in seconds to wait for the pipeline results to return """ self.command_stack = [] self.client = client self.connection_pool = client.connection_pool self._raise_on_error = raise_on_error self._transaction = transaction self._watched_node: ClusterNodeLocation | None = None self._watched_connection: BaseConnection | None = None self.watches: list[KeyT] = [] self.explicit_transaction = False self.cache = None self.scripts: set[Script[AnyStr]] = set() self.timeout = timeout self.type_adapter = client.type_adapter self._results: tuple[Any] | None = None @asynccontextmanager async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: yield self await self._execute() def create_request( self, name: bytes, *arguments: ValueT | Key, callback: Callable[..., T_co], execution_parameters: ExecutionParameters | None = None, ) -> CommandRequest[T_co]: """ :meta private: """ command = ClusterPipelineCommandRequest( name, *arguments, callback=callback, execution_parameters=execution_parameters or {}, type_adapter=self.type_adapter, ) command.position = len(self.command_stack) self.command_stack.append(command) return command
[docs] @asynccontextmanager async def watch(self, *keys: KeyT) -> AsyncGenerator[None]: """ The given keys will be watched for changes within this context and the commands stacked within the context will be automatically executed when the context exits. """ if self.command_stack: raise WatchError("Unable to add a watch after pipeline commands have been added") self._watched_node = self.connection_pool.cluster_layout.node_for_request( self.client.create_request( b"WATCH", *[Key(key) for key in keys], callback=NoopCallback() ) ) self.watches.extend(keys) async with self.connection_pool.acquire( node=self._watched_node ) as self._watched_connection: await self._watched_connection.create_request(CommandName.WATCH, *keys) self.explicit_transaction = True yield await self._execute() await self._watched_connection.create_request(CommandName.UNWATCH, decode=False)
@property def results(self) -> tuple[Any, ...] | None: """ The results of the pipeline execution which can be accessed after the pipeline has completed. """ if self.command_stack: raise RuntimeError("Pipeline results are not available before it completes") return self._results def execute_command( self, command: CommandRequest[R], ) -> Awaitable[R]: raise NotImplementedError async def _clear(self) -> None: """ Clear the pipeline and reset state. """ self.command_stack = [] self.scripts.clear() self.watches.clear() self.explicit_transaction = False def _raise_first_error(self) -> None: for c in self.command_stack: r = c.result if isinstance(r, (RedisError, TimeoutError)): self._annotate_exception(r, c.position + 1, c.name, c.serialized_arguments) raise r def _annotate_exception( self, exception: RedisError | TimeoutError | None, number: int, command: bytes, args: Iterable[RedisValueT], ) -> None: if exception: cmd = command.decode("latin-1") args = " ".join(str(x) for x in args) msg = f"Command # {number} ({cmd} {args}) of pipeline caused error: {exception.args[0]}" exception.args = (msg,) + exception.args[1:] async def _execute(self) -> None: """ Execute all queued commands in the cluster pipeline. Returns a tuple of results. """ if not self.command_stack: return with get_telemetry_provider().start_span( tuple(self.command_stack), self.client.connection_pool, name="MULTI" if self._transaction else "PIPELINE", ): if self.scripts: await self._load_scripts() use_primary = not ( self.client.connection_pool.read_from_replicas and all(cmd.readonly for cmd in self.command_stack) ) if self._transaction or self.explicit_transaction: execute = self._send_cluster_transaction else: execute = self._send_cluster_commands try: await execute(self._raise_on_error, use_primary) finally: await self._clear() def _get_slot_for_command(self, command: ClusterPipelineCommandRequest[Any]) -> int | None: affected_slots = command.affected_slots if len(affected_slots) == 1: return affected_slots[0] match command.routing_strategy: case ExplicitSlotStrategy(): return command.routing_strategy.slot case RandomStrategy(): return None if not affected_slots: raise RedisClusterError( f"No way to dispatch {nativestr(command.name)} to Redis Cluster. Missing key" ) else: raise ClusterCrossSlotError(command=command.name, keys=command.keys) @retryable(policy=ConstantRetryPolicy((ClusterDownError,), retries=3, delay=0.1)) async def _send_cluster_transaction( self, raise_on_error: bool = True, use_primary: bool = False ) -> None: """ :meta private: """ attempt = sorted(self.command_stack, key=lambda x: x.position) slots: set[int] = set() for c in attempt: if (slot := self._get_slot_for_command(c)) is not None: slots.add(slot) if len(slots) > 1: raise ClusterTransactionError("Multiple slots involved in transaction") if not slots: raise ClusterTransactionError("No slots found for transaction") node = self.connection_pool.cluster_layout.node_for_slot(slots.pop(), use_primary) if self._watched_node and node != self._watched_node: raise ClusterTransactionError("Multiple slots involved in transaction") node_commands = NodeCommands( self.client, node, in_transaction=True, timeout=self.timeout, connection=self._watched_connection, raise_on_error=self._raise_on_error, ) node_commands.extend(attempt) async with node_commands: node_commands.write() try: await node_commands.read() except ExecAbortError: await node_commands.connection.create_request(CommandName.DISCARD) # If at least one watched key is modified before EXEC, the transaction aborts and EXEC returns null. if node_commands.exec_cmd: exec_result = await node_commands.exec_cmd if exec_result is None: raise WatchError("Watched variable changed.") self._results = tuple( n.result for n in node_commands.commands if n.name not in {CommandName.MULTI, CommandName.EXEC} ) if raise_on_error: self._raise_first_error() @retryable(policy=ConstantRetryPolicy((ClusterDownError,), retries=3, delay=0.1)) async def _send_cluster_commands( self, raise_on_error: bool = True, use_primary: bool = False ) -> None: """ Execute all queued commands in the cluster pipeline, handling redirections and retries as needed. :meta private: """ attempt: dict[ClusterPipelineCommandRequest[Any], ClusterNodeLocation] = {} nodes: dict[str, NodeCommands] = {} for c in sorted(self.command_stack, key=lambda x: x.position): if (slot := self._get_slot_for_command(c)) is not None: node = self.connection_pool.cluster_layout.node_for_slot(slot, use_primary) else: node = node or self.connection_pool.cluster_layout.random_node(use_primary) if node.name not in nodes: nodes[node.name] = NodeCommands( self.client, node, timeout=self.timeout, raise_on_error=self._raise_on_error, ) nodes[node.name].append(c) attempt[c] = node # Write to all nodes, then read from all nodes in sequence. for n in nodes.values(): async with n: n.write() await n.read() # Retry MOVED/ASK/connection errors one by one if allowed. attempt = dict( sorted( (c for c in attempt.items() if isinstance(c[0].result, ERRORS_ALLOW_RETRY)), key=lambda x: x[0].position, ) ) if attempt: for c, node in attempt.items(): self.connection_pool.cluster_layout.report_errors(node, c.result) try: c.result = await self.client.execute_command(c.raw()) except (RedisError, TimeoutError) as e: c.result = e # Flatten results to match the original command order. response = [] for c in sorted(self.command_stack, key=lambda x: x.position): r = c.result if not isinstance(c.result, (RedisError, TimeoutError)): r = c.callback(c.result) c._response = PipelineResult(r) response.append(r) self._results = tuple(response) if raise_on_error: self._raise_first_error() async def _load_scripts(self) -> None: shas = [s.sha for s in self.scripts] exists = await self.client.script_exists(shas) if not all(exists): for s, exist in zip(self.scripts, exists): if not exist: s.sha = await self.client.script_load(s.script)