from __future__ import annotations
import asyncio
import functools
import inspect
import sys
import textwrap
from abc import ABCMeta
from concurrent.futures import CancelledError
from dataclasses import dataclass, field
from itertools import chain
from types import TracebackType
from typing import Any, cast
from wrapt import ObjectProxy # type: ignore
from coredis._utils import b, hash_slot
from coredis.client import Client, Redis, RedisCluster
from coredis.commands._key_spec import KeySpec
from coredis.commands.constants import CommandName, NodeFlag
from coredis.commands.script import Script
from coredis.connection import BaseConnection, ClusterConnection, CommandInvocation
from coredis.exceptions import (
AskError,
ClusterCrossSlotError,
ClusterDownError,
ClusterTransactionError,
ConnectionError,
ExecAbortError,
MovedError,
RedisClusterException,
RedisError,
ResponseError,
TimeoutError,
TryAgainError,
WatchError,
)
from coredis.pool import ClusterConnectionPool, ConnectionPool
from coredis.pool.nodemanager import ManagedNode
from coredis.response._callbacks import (
AnyStrCallback,
AsyncPreProcessingCallback,
BoolCallback,
BoolsCallback,
NoopCallback,
SimpleStringCallback,
)
from coredis.retry import ConstantRetryPolicy, retryable
from coredis.typing import (
AnyStr,
Callable,
Coroutine,
Dict,
Generic,
Iterable,
KeyT,
List,
Optional,
Parameters,
ParamSpec,
ResponseType,
Set,
StringT,
Tuple,
Type,
TypeVar,
ValueT,
)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
ERRORS_ALLOW_RETRY = (
MovedError,
AskError,
TryAgainError,
)
UNWATCH_COMMANDS = {CommandName.DISCARD, CommandName.EXEC, CommandName.UNWATCH}
def wrap_pipeline_method(
kls: PipelineMeta, func: Callable[P, Coroutine[Any, Any, R]]
) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return await func(*args, **kwargs)
wrapper.__annotations__ = wrapper.__annotations__.copy()
wrapper.__annotations__["return"] = kls
wrapper.__doc__ = textwrap.dedent(wrapper.__doc__ or "")
wrapper.__doc__ = f"""
Pipeline variant of :meth:`coredis.Redis.{func.__name__}` that does not execute
immediately and instead pushes the command into a stack for batch send
and returns the instance of :class:`{kls.__name__}` itself.
To fetch the return values call :meth:`{kls.__name__}.execute` to process the pipeline
and retrieve responses for the commands executed in the pipeline.
{wrapper.__doc__}
"""
return wrapper
@dataclass
class PipelineCommand:
command: bytes
args: Tuple[ValueT, ...]
callback: Callable[..., Any] = NoopCallback() # type: ignore
options: Dict[str, Optional[ValueT]] = field(default_factory=dict)
request: Optional[asyncio.Future[ResponseType]] = None
@dataclass
class ClusterPipelineCommand(PipelineCommand):
position: int = 0
result: Optional[Any] = None # type: ignore
asking: bool = False
class NodeCommands:
def __init__(
self,
client: RedisCluster[AnyStr],
connection: ClusterConnection,
in_transaction: bool = False,
timeout: Optional[float] = None,
):
self.client = client
self.connection = connection
self.commands: List[ClusterPipelineCommand] = []
self.in_transaction = in_transaction
self.timeout = timeout
def extend(self, c: List[ClusterPipelineCommand]) -> None:
self.commands.extend(c)
def append(self, c: ClusterPipelineCommand) -> None:
self.commands.append(c)
async def write(self) -> None:
connection = self.connection
commands = self.commands
# We are going to clobber the commands with the write, so go ahead
# and ensure that nothing is sitting there from a previous run.
for c in commands:
c.result = None
# build up all commands into a single request to increase network perf
# send all the commands and catch connection and timeout errors.
try:
requests = await connection.create_requests(
[
CommandInvocation(
cmd.command,
cmd.args,
(
bool(cmd.options.get("decode"))
if cmd.options.get("decode")
else None
),
None,
)
for cmd in commands
],
timeout=self.timeout,
)
for i, cmd in enumerate(commands):
cmd.request = requests[i]
except (ConnectionError, TimeoutError) as e:
for c in commands:
c.result = e
async def read(self) -> None:
connection = self.connection
success = True
for c in self.commands:
if c.result is None:
try:
c.result = await c.request if c.request else None
except ExecAbortError:
raise
except (ConnectionError, TimeoutError, RedisError) as e:
success = False
c.result = e
if self.in_transaction:
transaction_result = []
if success:
for c in self.commands:
if c.command == CommandName.EXEC:
if c.result:
transaction_result = cast(List[ResponseType], c.result)
else:
raise WatchError("Watched variable changed.")
for idx, c in enumerate(
[
_c
for _c in sorted(self.commands, key=lambda x: x.position)
if _c.command not in {CommandName.MULTI, CommandName.EXEC}
]
):
if isinstance(c.callback, AsyncPreProcessingCallback):
await c.callback.pre_process(
self.client, transaction_result[idx], **c.options
)
c.result = c.callback(
transaction_result[idx],
version=connection.protocol_version,
**c.options,
)
elif isinstance(self.commands[0].result, BaseException):
raise self.commands[0].result
class PipelineMeta(ABCMeta):
RESULT_CALLBACKS: Dict[str, Callable[..., Any]]
NODES_FLAGS: Dict[str, NodeFlag]
def __new__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, object]):
kls = super().__new__(cls, name, bases, namespace)
for name, method in PipelineMeta.get_methods(kls).items():
if getattr(method, "__coredis_command", None):
setattr(kls, name, wrap_pipeline_method(kls, method))
return kls
@staticmethod
def get_methods(kls: PipelineMeta) -> Dict[str, Callable[..., Any]]:
return dict(k for k in inspect.getmembers(kls) if inspect.isfunction(k[1]))
class ClusterPipelineMeta(PipelineMeta):
def __new__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, object]):
kls = super().__new__(cls, name, bases, namespace)
for name, method in ClusterPipelineMeta.get_methods(kls).items():
cmd = getattr(method, "__coredis_command", None)
if cmd:
if cmd.cluster.route:
kls.NODES_FLAGS[cmd.command] = cmd.cluster.route
if cmd.cluster.multi_node:
kls.RESULT_CALLBACKS[cmd.command] = cmd.cluster.combine or (
lambda r, **_: r
)
else:
kls.RESULT_CALLBACKS[cmd.command] = lambda response, **_: list(
response.values()
).pop()
return kls
class PipelineImpl(Client[AnyStr], metaclass=PipelineMeta):
"""Pipeline for the Redis class"""
"""
Pipelines provide a way to transmit multiple commands to the Redis server
in one transmission. This is convenient for batch processing, such as
saving all the values in a list to Redis.
All commands executed within a pipeline are wrapped with MULTI and EXEC
calls. This guarantees all commands executed in the pipeline will be
executed atomically.
Any command raising an exception does *not* halt the execution of
subsequent commands in the pipeline. Instead, the exception is caught
and its instance is placed into the response list returned by execute().
Code iterating over the response list should be able to deal with an
instance of an exception as a potential value. In general, these will be
ResponseError exceptions, such as those raised when issuing a command
on a key of a different datatype.
"""
command_stack: List[PipelineCommand]
connection_pool: ConnectionPool
def __init__(
self,
client: Client[AnyStr],
transaction: Optional[bool],
watches: Optional[Parameters[KeyT]] = None,
timeout: Optional[float] = None,
) -> None:
self.client = client
self.connection_pool = client.connection_pool
self.connection = None
self._transaction = transaction
self.watching = False
self.watches: Optional[Parameters[KeyT]] = watches or None
self.command_stack = []
self.cache = None # not implemented.
self.explicit_transaction = False
self.scripts: Set[Script[AnyStr]] = set()
self.timeout = timeout
async def __aenter__(self) -> "PipelineImpl[AnyStr]":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.reset_pipeline()
def __len__(self) -> int:
return len(self.command_stack)
def __bool__(self) -> bool:
return True
async def reset_pipeline(self) -> None:
self.command_stack.clear()
self.scripts: Set[Script[AnyStr]] = set()
# make sure to reset the connection state in the event that we were
# watching something
if self.watching and self.connection:
try:
# call this manually since our unwatch or
# immediate_execute_command methods can call reset_pipeline()
request = await self.connection.create_request(
CommandName.UNWATCH, decode=False
)
await request
except ConnectionError:
# disconnect will also remove any previous WATCHes
self.connection.disconnect()
# clean up the other instance attributes
self.watching = False
self.watches = []
self.explicit_transaction = False
# we can safely return the connection to the pool here since we're
# sure we're no longer WATCHing anything
if self.connection:
self.connection_pool.release(self.connection)
self.connection = None
def multi(self) -> None:
"""
Starts a transactional block of the pipeline after WATCH commands
are issued. End the transactional block with `execute`.
"""
if self.explicit_transaction:
raise RedisError("Cannot issue nested calls to MULTI")
if self.command_stack:
raise RedisError(
"Commands without an initial WATCH have already been issued"
)
self.explicit_transaction = True
async def execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any] = NoopCallback(), # type: ignore
**options: Optional[ValueT],
) -> PipelineImpl[AnyStr]: # type: ignore
if (
self.watching or command == CommandName.WATCH
) and not self.explicit_transaction:
return await self.immediate_execute_command(
command, *args, callback=callback, **options
) # type: ignore
return self.pipeline_execute_command(
command, *args, callback=callback, **options
)
async def immediate_execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any] = NoopCallback(), # type: ignore
**kwargs: Optional[ValueT],
) -> Any: # type: ignore
"""
Executes a command immediately, but don't auto-retry on a
ConnectionError if we're already WATCHing a variable. Used when
issuing WATCH or subsequent commands retrieving their values but before
MULTI is called.
:meta private:
"""
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = await self.connection_pool.get_connection()
self.connection = conn
try:
request = await conn.create_request(
command, *args, decode=kwargs.get("decode")
)
return callback(
await request,
version=conn.protocol_version,
**kwargs,
)
except (ConnectionError, TimeoutError):
conn.disconnect()
# if we're not already watching, we can safely retry the command
try:
if not self.watching:
request = await conn.create_request(
command, *args, decode=kwargs.get("decode")
)
return callback(
await request, version=conn.protocol_version, **kwargs
)
except ConnectionError:
# the retry failed so cleanup.
conn.disconnect()
await self.reset_pipeline()
raise
finally:
if command in UNWATCH_COMMANDS:
self.watching = False
elif command == CommandName.WATCH:
self.watching = True
def pipeline_execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any],
**options: Optional[ValueT],
) -> PipelineImpl[AnyStr]:
"""
Stages a command to be executed next execute() invocation
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
:meta private:
"""
self.command_stack.append(
PipelineCommand(
command=command, args=args, options=options, callback=callback
)
)
return self
async def _execute_transaction(
self,
connection: BaseConnection,
commands: List[PipelineCommand],
raise_on_error: bool,
) -> Tuple[Any, ...]:
cmds = list(
chain(
[
PipelineCommand(
command=CommandName.MULTI,
args=(),
)
],
commands,
[
PipelineCommand(
command=CommandName.EXEC,
args=(),
)
],
)
)
if self.watches:
await self.watch(*self.watches)
requests = await connection.create_requests(
[
CommandInvocation(
cmd.command,
cmd.args,
(
bool(cmd.options.get("decode"))
if cmd.options.get("decode")
else None
),
None,
)
for cmd in cmds
],
timeout=self.timeout,
)
for i, cmd in enumerate(cmds):
cmd.request = requests[i]
errors: List[Tuple[int, Optional[RedisError]]] = []
multi_failed = False
# parse off the response for MULTI
# NOTE: we need to handle ResponseErrors here and continue
# so that we read all the additional command messages from
# the socket
try:
await cmds[0].request if cmds[0].request else None
except RedisError:
multi_failed = True
errors.append((0, cast(RedisError, sys.exc_info()[1])))
# and all the other commands
for i, cmd in enumerate(cmds[1:-1]):
try:
if cmd.request:
assert (await cmd.request) in {b"QUEUED", "QUEUED"}
except RedisError:
ex = cast(RedisError, sys.exc_info()[1])
self.annotate_exception(ex, i + 1, cmd.command, cmd.args)
errors.append((i, ex))
response: List[ResponseType]
try:
response = cast(
List[ResponseType],
await cmds[-1].request if cmds[-1].request else None,
)
except (ExecAbortError, ResponseError):
if self.explicit_transaction and not multi_failed:
await self.immediate_execute_command(
CommandName.DISCARD, callback=BoolCallback()
)
if errors and errors[0][1]:
raise errors[0][1]
raise
if response is None:
raise WatchError("Watched variable changed.")
# put any parse errors into the response
for i, e in errors:
response.insert(i, cast(ResponseType, e))
if len(response) != len(commands):
if self.connection:
self.connection.disconnect()
raise ResponseError(
"Wrong number of response items from pipeline execution"
)
# find any errors in the response and raise if necessary
if raise_on_error:
self.raise_first_error(commands, response)
# We have to run response callbacks manually
data: List[Any] = []
for r, cmd in zip(response, commands):
if not isinstance(r, Exception):
if isinstance(cmd.callback, AsyncPreProcessingCallback):
await cmd.callback.pre_process(self.client, r, **cmd.options)
r = cmd.callback(r, version=connection.protocol_version, **cmd.options)
data.append(r)
return tuple(data)
async def _execute_pipeline(
self,
connection: BaseConnection,
commands: List[PipelineCommand],
raise_on_error: bool,
) -> Tuple[Any, ...]:
# build up all commands into a single request to increase network perf
requests = await connection.create_requests(
[
CommandInvocation(
cmd.command,
cmd.args,
(
bool(cmd.options.get("decode"))
if cmd.options.get("decode")
else None
),
None,
)
for cmd in commands
],
timeout=self.timeout,
)
for i, cmd in enumerate(commands):
cmd.request = requests[i]
response: List[Any] = []
for cmd in commands:
try:
res = await cmd.request if cmd.request else None
if isinstance(cmd.callback, AsyncPreProcessingCallback):
await cmd.callback.pre_process(self.client, res, **cmd.options)
response.append(
cmd.callback(
res,
version=connection.protocol_version,
**cmd.options,
)
)
except ResponseError:
response.append(sys.exc_info()[1])
if raise_on_error:
self.raise_first_error(commands, response)
return tuple(response)
def raise_first_error(
self, commands: List[PipelineCommand], response: ResponseType
) -> None:
assert isinstance(response, list)
for i, r in enumerate(response):
if isinstance(r, RedisError):
self.annotate_exception(r, i + 1, commands[i].command, commands[i].args)
raise r
def annotate_exception(
self,
exception: Optional[RedisError],
number: int,
command: bytes,
args: Iterable[ValueT],
) -> None:
if exception:
cmd = command.decode("latin-1")
args = " ".join(map(str, args))
msg = "Command # {} ({} {}) of pipeline caused error: {}".format(
number,
cmd,
args,
str(exception.args[0]),
)
exception.args = (msg,) + exception.args[1:]
async def load_scripts(self):
# make sure all scripts that are about to be run on this pipeline exist
scripts = list(self.scripts)
immediate = self.immediate_execute_command
shas = [s.sha for s in scripts]
# we can't use the normal script_* methods because they would just
# get buffered in the pipeline.
exists = await immediate(
CommandName.SCRIPT_EXISTS, *shas, callback=BoolsCallback()
)
if not all(exists):
for s, exist in zip(scripts, exists):
if not exist:
s.sha = await immediate(
CommandName.SCRIPT_LOAD,
s.script,
callback=AnyStrCallback[AnyStr](),
)
async def execute(self, raise_on_error: bool = True) -> Tuple[Any, ...]:
"""Executes all the commands in the current pipeline"""
stack = self.command_stack
if not stack:
return ()
if self.scripts:
await self.load_scripts()
if self._transaction or self.explicit_transaction:
exec = self._execute_transaction
else:
exec = self._execute_pipeline
conn = self.connection
if not conn:
conn = await self.connection_pool.get_connection()
# assign to self.connection so reset_pipeline() releases the connection
# back to the pool after we're done
self.connection = conn
try:
return await exec(conn, stack, raise_on_error)
except (ConnectionError, TimeoutError, CancelledError):
conn.disconnect()
# if we were watching a variable, the watch is no longer valid
# since this connection has died. raise a WatchError, which
# indicates the user should retry his transaction. If this is more
# than a temporary failure, the WATCH that the user next issues
# will fail, propegating the real ConnectionError
if self.watching:
raise WatchError(
"A ConnectionError occured on while watching one or more keys"
)
# otherwise, it's safe to retry since the transaction isn't
# predicated on any state
return await exec(conn, stack, raise_on_error)
finally:
await self.reset_pipeline()
async def watch(self, *keys: KeyT) -> bool:
"""
Watches the values at keys ``keys``
"""
if self.explicit_transaction:
raise RedisError("Cannot issue a WATCH after a MULTI")
return await self.immediate_execute_command(
CommandName.WATCH, *keys, callback=SimpleStringCallback()
)
async def unwatch(self) -> bool:
"""Unwatches all previously specified keys"""
return (
await self.immediate_execute_command(
CommandName.UNWATCH, callback=SimpleStringCallback()
)
if self.watching
else True
)
class ClusterPipelineImpl(Client[AnyStr], metaclass=ClusterPipelineMeta):
client: RedisCluster[AnyStr]
connection_pool: ClusterConnectionPool
command_stack: List[ClusterPipelineCommand]
RESULT_CALLBACKS: Dict[str, Callable[..., Any]] = {}
NODES_FLAGS: Dict[str, NodeFlag] = {}
def __init__(
self,
client: RedisCluster[AnyStr],
transaction: Optional[bool] = False,
watches: Optional[Parameters[KeyT]] = None,
timeout: Optional[float] = None,
) -> None:
self.command_stack = []
self.refresh_table_asap = False
self.client = client
self.connection_pool = client.connection_pool
self.result_callbacks = client.result_callbacks
self._transaction = transaction
self._watched_node: Optional[ManagedNode] = None
self._watched_connection: Optional[ClusterConnection] = None
self.watches: Optional[Parameters[KeyT]] = watches or None
self.watching = False
self.explicit_transaction = False
self.cache = None # not implemented.
self.timeout = timeout
async def watch(self, *keys: KeyT) -> bool:
if self.explicit_transaction:
raise RedisError("Cannot issue a WATCH after a MULTI")
return await self.immediate_execute_command(
CommandName.WATCH, *keys, callback=SimpleStringCallback()
)
async def unwatch(self) -> bool:
if self._watched_connection:
try:
return await self._unwatch(self._watched_connection)
finally:
if self._watched_connection:
self.connection_pool.release(self._watched_connection)
self.watching = False
self._watched_node = None
self._watched_connection = None
return True
def __repr__(self):
return f"{type(self).__name__}"
def __del__(self):
if self._watched_connection:
self.connection_pool.release(self._watched_connection)
def __len__(self):
return len(self.command_stack)
def __bool__(self) -> bool:
return True
async def __aenter__(self) -> "ClusterPipelineImpl[AnyStr]":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.reset_pipeline()
async def execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any] = NoopCallback(), # type: ignore
**options: Optional[ValueT],
) -> ClusterPipelineImpl[AnyStr]: # type: ignore
if (
self.watching or command == CommandName.WATCH
) and not self.explicit_transaction:
return await self.immediate_execute_command(
command, *args, callback=callback, **options
) # type: ignore
return self.pipeline_execute_command(
command, *args, callback=callback, **options
)
def pipeline_execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any],
**options: Optional[ValueT],
) -> ClusterPipelineImpl[AnyStr]:
self.command_stack.append(
ClusterPipelineCommand(
command=command,
args=args,
options=options,
callback=callback,
position=len(self.command_stack),
)
)
return self
def raise_first_error(self) -> None:
for c in self.command_stack:
r = c.result
if isinstance(r, RedisError):
self.annotate_exception(r, c.position + 1, c.command, c.args)
raise r
def annotate_exception(
self,
exception: Optional[RedisError],
number: int,
command: bytes,
args: Iterable[ValueT],
) -> None:
if exception:
cmd = command.decode("latin-1")
args = " ".join(str(x) for x in args)
msg = "Command # {} ({} {}) of pipeline caused error: {}".format(
number, cmd, args, exception.args[0]
)
exception.args = (msg,) + exception.args[1:]
async def execute(self, raise_on_error: bool = True) -> Tuple[object, ...]:
await self.connection_pool.initialize()
if not self.command_stack:
return ()
if self._transaction or self.explicit_transaction:
execute = self.send_cluster_transaction
else:
execute = self.send_cluster_commands
try:
return await execute(raise_on_error)
finally:
await self.reset_pipeline()
async def reset_pipeline(self):
"""Empties pipeline"""
self.command_stack = []
self.scripts: Set[Script[AnyStr]] = set()
# clean up the other instance attributes
self.watching = False
self.explicit_transaction = False
self._watched_node = None
if self._watched_connection:
self.connection_pool.release(self._watched_connection)
self._watched_connection = None
@retryable(policy=ConstantRetryPolicy((ClusterDownError,), 3, 0.1))
async def send_cluster_transaction(
self, raise_on_error: bool = True
) -> Tuple[object, ...]:
attempt = sorted(self.command_stack, key=lambda x: x.position)
slots: Set[int] = set()
for c in attempt:
slot = self._determine_slot(c.command, *c.args, **c.options)
if slot:
slots.add(slot)
if len(slots) > 1:
raise ClusterTransactionError("Multiple nodes involved in transaction")
if not slots:
raise ClusterTransactionError("No slots found for transaction")
node = self.connection_pool.get_node_by_slot(slots.pop())
if self._watched_node and node.name != self._watched_node.name:
raise ClusterTransactionError("Multiple nodes involved in transaction")
conn = (
self._watched_connection
or await self.connection_pool.get_connection_by_node(node)
)
if self.watches:
await self._watch(node, conn, self.watches)
node_commands = NodeCommands(
self.client, conn, in_transaction=True, timeout=self.timeout
)
node_commands.append(ClusterPipelineCommand(CommandName.MULTI, ()))
node_commands.extend(attempt)
node_commands.append(ClusterPipelineCommand(CommandName.EXEC, ()))
self.explicit_transaction = True
await node_commands.write()
try:
await node_commands.read()
except ExecAbortError:
if self.explicit_transaction:
request = await conn.create_request(CommandName.DISCARD)
await request
# If at least one watched key is modified before the EXEC command,
# the whole transaction aborts,
# and EXEC returns a Null reply to notify that the transaction failed.
if node_commands.commands[-1].result is None:
raise WatchError
self.connection_pool.release(conn)
if self.watching:
await self._unwatch(conn)
if raise_on_error:
self.raise_first_error()
return tuple(
n.result
for n in node_commands.commands
if n.command not in {CommandName.MULTI, CommandName.EXEC}
)
@retryable(policy=ConstantRetryPolicy((ClusterDownError,), 3, 0.1))
async def send_cluster_commands(
self, raise_on_error: bool = True, allow_redirections: bool = True
) -> Tuple[object, ...]:
"""
Sends a bunch of cluster commands to the redis cluster.
`allow_redirections` If the pipeline should follow `ASK` & `MOVED` responses
automatically. If set to false it will raise RedisClusterException.
"""
# the first time sending the commands we send all of the commands that were queued up.
# if we have to run through it again, we only retry the commands that failed.
attempt = sorted(self.command_stack, key=lambda x: x.position)
protocol_version: int = 3
# build a list of node objects based on node names we need to
nodes: Dict[str, NodeCommands] = {}
# as we move through each command that still needs to be processed,
# we figure out the slot number that command maps to, then from the slot determine the node.
for c in attempt:
# refer to our internal node -> slot table that tells us where a given
# command should route to.
slot = self._determine_slot(c.command, *c.args)
node = self.connection_pool.get_node_by_slot(slot)
if node.name not in nodes:
nodes[node.name] = NodeCommands(
self.client,
await self.connection_pool.get_connection_by_node(node),
timeout=self.timeout,
)
nodes[node.name].append(c)
# send the commands in sequence.
# we write to all the open sockets for each node first, before reading anything
# this allows us to flush all the requests out across the network essentially in parallel
# so that we can read them all in parallel as they come back.
# we dont' multiplex on the sockets as they come available, but that shouldn't make
# too much difference.
node_commands = nodes.values()
for n in node_commands:
await n.write()
for n in node_commands:
await n.read()
# release all of the redis connections we allocated earlier back into the connection pool.
# we used to do this step as part of a try/finally block, but it is really dangerous to
# release connections back into the pool if for some reason the socket has data still left
# in it from a previous operation. The write and read operations already have try/catch
# around them for all known types of errors including connection and socket level errors.
# So if we hit an exception, something really bad happened and putting any of
# these connections back into the pool is a very bad idea.
# the socket might have unread buffer still sitting in it, and then the
# next time we read from it we pass the buffered result back from a previous
# command and every single request after to that connection will always get
# a mismatched result. (not just theoretical, I saw this happen on production x.x).
for n in nodes.values():
protocol_version = n.connection.protocol_version
self.connection_pool.release(n.connection)
# if the response isn't an exception it is a valid response from the node
# we're all done with that command, YAY!
# if we have more commands to attempt, we've run into problems.
# collect all the commands we are allowed to retry.
# (MOVED, ASK, or connection errors or timeout errors)
attempt = sorted(
(c for c in attempt if isinstance(c.result, ERRORS_ALLOW_RETRY)),
key=lambda x: x.position,
)
if attempt and allow_redirections:
# RETRY MAGIC HAPPENS HERE!
# send these remaing comamnds one at a time using `execute_command`
# in the main client. This keeps our retry logic in one place mostly,
# and allows us to be more confident in correctness of behavior.
# at this point any speed gains from pipelining have been lost
# anyway, so we might as well make the best attempt to get the correct
# behavior.
#
# The client command will handle retries for each individual command
# sequentially as we pass each one into `execute_command`. Any exceptions
# that bubble out should only appear once all retries have been exhausted.
#
# If a lot of commands have failed, we'll be setting the
# flag to rebuild the slots table from scratch. So MOVED errors should
# correct .commandsthemselves fairly quickly.
await self.connection_pool.nodes.increment_reinitialize_counter(
len(attempt)
)
for c in attempt:
try:
# send each command individually like we do in the main client.
c.result = await self.client.execute_command(
c.command, *c.args, **c.options
)
except RedisError as e:
c.result = e
# turn the response back into a simple flat array that corresponds
# to the sequence of commands issued in the stack in pipeline.execute()
response = []
for c in sorted(self.command_stack, key=lambda x: x.position):
r = c.result
if not isinstance(c.result, RedisError):
if isinstance(c.callback, AsyncPreProcessingCallback):
await c.callback.pre_process(self.client, c.result, **c.options)
r = c.callback(c.result, version=protocol_version, **c.options)
response.append(r)
if raise_on_error:
self.raise_first_error()
return tuple(response)
def _determine_slot(self, command: bytes, *args: ValueT, **options: ValueT) -> int:
"""Figure out what slot based on command and args"""
keys: Tuple[ValueT, ...] = cast(
Tuple[ValueT, ...], options.get("keys")
) or KeySpec.extract_keys(command, *args)
if not keys:
raise RedisClusterException(
f"No way to dispatch {command} to Redis Cluster. Missing key"
)
slots = {hash_slot(b(key)) for key in keys}
if len(slots) != 1:
raise ClusterCrossSlotError(command=command, keys=keys)
return slots.pop()
def _fail_on_redirect(self, allow_redirections: bool) -> None:
if not allow_redirections:
raise RedisClusterException(
"ASK & MOVED redirection not allowed in this pipeline"
)
def multi(self) -> None:
if self.explicit_transaction:
raise RedisError("Cannot issue nested calls to MULTI")
if self.command_stack:
raise RedisError(
"Commands without an initial WATCH have already been issued"
)
self.explicit_transaction = True
async def immediate_execute_command(
self,
command: bytes,
*args: ValueT,
callback: Callable[..., Any] = NoopCallback(),
**kwargs: Optional[ValueT],
) -> Any:
slot = self._determine_slot(command, *args)
node = self.connection_pool.get_node_by_slot(slot)
if command == CommandName.WATCH:
if self._watched_node and node.name != self._watched_node.name:
raise ClusterTransactionError(
"Cannot issue a watch on a different node in the same transaction"
)
else:
self._watched_node = node
self._watched_connection = conn = (
self._watched_connection
or await self.connection_pool.get_connection_by_node(node)
)
else:
conn = await self.connection_pool.get_connection_by_node(node)
try:
request = await conn.create_request(
command, *args, decode=kwargs.get("decode")
)
return callback(
await request,
version=conn.protocol_version,
**kwargs,
)
except (ConnectionError, TimeoutError):
conn.disconnect()
try:
if not self.watching:
request = await conn.create_request(
command, *args, decode=kwargs.get("decode")
)
return callback(
await request, version=conn.protocol_version, **kwargs
)
except ConnectionError:
# the retry failed so cleanup.
conn.disconnect()
await self.reset_pipeline()
raise
finally:
if command in UNWATCH_COMMANDS:
self.watching = False
elif command == CommandName.WATCH:
self.watching = True
# don't release the connection if the command was a watch
return
self.connection_pool.release(conn)
def load_scripts(self):
raise RedisClusterException("method load_scripts() is not implemented")
async def _watch(
self, node: ManagedNode, conn: BaseConnection, keys: Parameters[KeyT]
) -> bool:
"Watches the values at keys ``keys``"
for key in keys:
slot = self._determine_slot(CommandName.WATCH, key)
dist_node = self.connection_pool.get_node_by_slot(slot)
if node.name != dist_node.name:
raise ClusterTransactionError(
"Keys in request don't hash to the same node"
)
if self.explicit_transaction:
raise RedisError("Cannot issue a WATCH after a MULTI")
request = await conn.create_request(CommandName.WATCH, *keys)
return SimpleStringCallback()(
cast(StringT, await request),
version=conn.protocol_version,
)
async def _unwatch(self, conn: BaseConnection) -> bool:
"""Unwatches all previously specified keys"""
request = await conn.create_request(CommandName.UNWATCH, decode=False)
res = cast(StringT, await request)
return res == b"OK" if self.watching else True
[docs]
class Pipeline(ObjectProxy, Generic[AnyStr]): # type: ignore
"""
Class returned by :meth:`coredis.Redis.pipeline`
The class exposes the redis command methods available in
:class:`~coredis.Redis`, however each of those methods returns
the instance itself and the results of the batched commands
can be retrieved by calling :meth:`execute`.
"""
__wrapped__: PipelineImpl[AnyStr]
async def __aenter__(self) -> Pipeline[AnyStr]:
return cast(Pipeline[AnyStr], await self.__wrapped__.__aenter__())
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
@classmethod
def proxy(
cls,
client: Redis[AnyStr],
transaction: Optional[bool] = None,
watches: Optional[Parameters[KeyT]] = None,
timeout: Optional[float] = None,
) -> Pipeline[AnyStr]:
return cls(
PipelineImpl(
client,
transaction=transaction,
watches=watches,
timeout=timeout,
)
)
[docs]
def multi(self) -> None:
"""
Starts a transactional block of the pipeline after WATCH commands
are issued. End the transactional block with :meth:`execute`
"""
self.__wrapped__.multi() # Only here for documentation purposes.
[docs]
async def watch(self, *keys: KeyT) -> bool: # noqa
"""
Watches the values at keys ``keys``
"""
return await self.__wrapped__.watch(
*keys
) # Only here for documentation purposes.
[docs]
async def unwatch(self) -> bool: # noqa
"""
Unwatches all previously specified keys
"""
return await self.__wrapped__.unwatch() # Only here for documentation purposes.
[docs]
async def execute(self, raise_on_error: bool = True) -> Tuple[object, ...]:
"""
Executes all the commands in the current pipeline
and return the results of the individual batched commands
"""
# Only here for documentation purposes.
return await self.__wrapped__.execute(raise_on_error=raise_on_error)
[docs]
async def reset(self) -> None:
"""
Resets the command stack and releases any connections acquired from the
pool
"""
await self.__wrapped__.reset_pipeline()
[docs]
class ClusterPipeline(ObjectProxy, Generic[AnyStr]): # type: ignore
"""
Class returned by :meth:`coredis.RedisCluster.pipeline`
The class exposes the redis command methods available in
:class:`~coredis.Redis`, however each of those methods returns
the instance itself and the results of the batched commands
can be retrieved by calling :meth:`execute`.
"""
__wrapped__: ClusterPipelineImpl[AnyStr]
async def __aenter__(self) -> ClusterPipeline[AnyStr]:
return cast(ClusterPipeline[AnyStr], await self.__wrapped__.__aenter__())
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
@classmethod
def proxy(
cls,
client: RedisCluster[AnyStr],
transaction: Optional[bool] = False,
watches: Optional[Parameters[KeyT]] = None,
timeout: Optional[float] = None,
) -> ClusterPipeline[AnyStr]:
return cls(
ClusterPipelineImpl(
client,
transaction=transaction,
watches=watches,
timeout=timeout,
)
)
[docs]
def multi(self) -> None:
"""
Starts a transactional block of the pipeline after WATCH commands
are issued. End the transactional block with :meth:`execute`
"""
self.__wrapped__.multi() # Only here for documentation purposes.
[docs]
async def watch(self, *keys: KeyT) -> bool: # noqa
"""
Watches the values at keys ``keys``
:raises: :exc:`~coredis.exceptions.ClusterTransactionError`
if a watch is issued on a key that resides on a different
cluster node than a previous watch.
"""
return await self.__wrapped__.watch(
*keys
) # Only here for documentation purposes.
[docs]
async def unwatch(self) -> bool: # noqa
"""
Unwatches all previously specified keys
"""
return await self.__wrapped__.unwatch() # Only here for documentation purposes.
[docs]
async def execute(self, raise_on_error: bool = True) -> Tuple[object, ...]:
"""
Executes all the commands in the current pipeline
and return the results of the individual batched commands
"""
# Only here for documentation purposes.
return await self.__wrapped__.execute(raise_on_error=raise_on_error)
[docs]
async def reset(self) -> None:
"""
Resets the command stack and releases any connections acquired from the
pool
"""
await self.__wrapped__.reset_pipeline()