Source code for coredis.client.cluster

from __future__ import annotations

import asyncio
import contextlib
import contextvars
import functools
import inspect
import textwrap
from abc import ABCMeta
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, cast, overload

from deprecated.sphinx import versionadded

from coredis._utils import b, hash_slot
from coredis.cache import AbstractCache, SupportsClientTracking
from coredis.client.basic import Client, Redis
from coredis.commands._key_spec import KeySpec
from coredis.commands.constants import CommandName, NodeFlag
from coredis.commands.pubsub import ClusterPubSub, ShardedPubSub
from coredis.connection import RedisSSLContext
from coredis.exceptions import (
    AskError,
    BusyLoadingError,
    ClusterDownError,
    ClusterError,
    ConnectionError,
    MovedError,
    RedisClusterException,
    TimeoutError,
    TryAgainError,
    WatchError,
)
from coredis.globals import MODULE_GROUPS, READONLY_COMMANDS
from coredis.pool import ClusterConnectionPool
from coredis.pool.nodemanager import ManagedNode
from coredis.response._callbacks import AsyncPreProcessingCallback, NoopCallback
from coredis.retry import CompositeRetryPolicy, ConstantRetryPolicy, RetryPolicy
from coredis.typing import (
    AnyStr,
    AsyncIterator,
    Awaitable,
    Callable,
    ContextManager,
    Coroutine,
    Dict,
    Iterable,
    Iterator,
    List,
    Literal,
    Node,
    Optional,
    Parameters,
    ParamSpec,
    ResponseType,
    Set,
    StringT,
    Tuple,
    Type,
    TypeVar,
    ValueT,
)

P = ParamSpec("P")
R = TypeVar("R")

if TYPE_CHECKING:
    import coredis.pipeline


class ClusterMeta(ABCMeta):
    ROUTING_FLAGS: Dict[bytes, NodeFlag]
    SPLIT_FLAGS: Dict[bytes, NodeFlag]
    RESULT_CALLBACKS: Dict[bytes, Callable[..., ResponseType]]
    NODE_FLAG_DOC_MAPPING = {
        NodeFlag.PRIMARIES: "all primaries",
        NodeFlag.REPLICAS: "all replicas",
        NodeFlag.RANDOM: "a random node",
        NodeFlag.ALL: "all nodes",
        NodeFlag.SLOT_ID: "one or more nodes based on the slots provided",
    }

    def __new__(
        cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, object]
    ) -> ClusterMeta:
        kls = super().__new__(cls, name, bases, namespace)
        methods = dict(k for k in inspect.getmembers(kls) if inspect.isfunction(k[1]))
        for module in MODULE_GROUPS:
            methods.update(
                dict(
                    (f"{module.MODULE}.{k[0]}", k[1])
                    for k in inspect.getmembers(module)
                    if inspect.isfunction(k[1])
                )
            )
        for method_name, method in methods.items():
            doc_addition = ""
            cmd = getattr(method, "__coredis_command", None)
            if cmd:
                if cmd.cluster.route:
                    kls.ROUTING_FLAGS[cmd.command] = cmd.cluster.route
                    aggregate_note = ""
                    if cmd.cluster.multi_node:
                        if cmd.cluster.combine:
                            aggregate_note = (
                                f"and return {cmd.cluster.combine.response_policy}"
                            )
                        else:
                            aggregate_note = (
                                "and a mapping of nodes to results will be returned"
                            )
                    doc_addition = f"""
.. admonition:: Cluster note

   The command will be run on **{cls.NODE_FLAG_DOC_MAPPING[cmd.cluster.route]}** {aggregate_note}
                    """
                elif cmd.cluster.split and cmd.cluster.combine:
                    kls.SPLIT_FLAGS[cmd.command] = cmd.cluster.split
                    doc_addition = f"""
.. admonition:: Cluster note

   The command will be run on **{cls.NODE_FLAG_DOC_MAPPING[cmd.cluster.split]}**
   by distributing the keys to the appropriate nodes and return
   {cmd.cluster.combine.response_policy}.

   To disable this behavior set :paramref:`RedisCluster.non_atomic_cross_slot` to ``False``
                """
                if cmd.cluster.multi_node:
                    kls.RESULT_CALLBACKS[cmd.command] = cmd.cluster.combine
            if doc_addition and not hasattr(method, "__cluster_docs"):
                if not getattr(method, "__coredis_module", None):

                    def __w(
                        func: Callable[P, Awaitable[R]]
                    ) -> Callable[P, Awaitable[R]]:
                        @functools.wraps(func)
                        async def _w(*a: P.args, **k: P.kwargs) -> R:
                            return await func(*a, **k)

                        _w.__doc__ = f"""{textwrap.dedent(method.__doc__ or "")}
{doc_addition}
                    """
                        return _w

                    wrapped = __w(method)
                    setattr(wrapped, "__cluster_docs", doc_addition)
                    setattr(kls, method_name, wrapped)
                else:
                    method.__doc__ = f"""{textwrap.dedent(method.__doc__ or "")}
{doc_addition}
                    """
                    setattr(method, "__cluster_docs", doc_addition)
        return kls


RedisClusterT = TypeVar("RedisClusterT", bound="RedisCluster[Any]")
RedisClusterStringT = TypeVar("RedisClusterStringT", bound="RedisCluster[str]")
RedisClusterBytesT = TypeVar("RedisClusterBytesT", bound="RedisCluster[bytes]")


[docs] class RedisCluster( Client[AnyStr], metaclass=ClusterMeta, ): MAX_RETRIES = 16 ROUTING_FLAGS: Dict[bytes, NodeFlag] = {} SPLIT_FLAGS: Dict[bytes, NodeFlag] = {} RESULT_CALLBACKS: Dict[bytes, Callable[..., Any]] = {} connection_pool: ClusterConnectionPool @overload def __init__( self: RedisCluster[bytes], host: Optional[str] = ..., port: Optional[int] = ..., *, startup_nodes: Optional[Iterable[Node]] = ..., stream_timeout: Optional[float] = ..., connect_timeout: Optional[float] = ..., ssl: bool = ..., ssl_context: Optional[SSLContext] = ..., ssl_keyfile: Optional[str] = ..., ssl_certfile: Optional[str] = ..., ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = ..., ssl_check_hostname: Optional[bool] = ..., ssl_ca_certs: Optional[str] = ..., max_connections: int = ..., max_connections_per_node: bool = ..., readonly: bool = ..., read_from_replicas: bool = ..., reinitialize_steps: Optional[int] = ..., skip_full_coverage_check: bool = ..., nodemanager_follow_cluster: bool = ..., decode_responses: Literal[False] = ..., connection_pool: Optional[ClusterConnectionPool] = ..., connection_pool_cls: Type[ClusterConnectionPool] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., non_atomic_cross_slot: bool = ..., cache: Optional[AbstractCache] = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> None: ... @overload def __init__( self: RedisCluster[str], host: Optional[str] = ..., port: Optional[int] = ..., *, startup_nodes: Optional[Iterable[Node]] = ..., stream_timeout: Optional[float] = ..., connect_timeout: Optional[float] = ..., ssl: bool = ..., ssl_context: Optional[SSLContext] = ..., ssl_keyfile: Optional[str] = ..., ssl_certfile: Optional[str] = ..., ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = ..., ssl_check_hostname: Optional[bool] = ..., ssl_ca_certs: Optional[str] = ..., max_connections: int = ..., max_connections_per_node: bool = ..., readonly: bool = ..., read_from_replicas: bool = ..., reinitialize_steps: Optional[int] = ..., skip_full_coverage_check: bool = ..., nodemanager_follow_cluster: bool = ..., decode_responses: Literal[True], connection_pool: Optional[ClusterConnectionPool] = ..., connection_pool_cls: Type[ClusterConnectionPool] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., non_atomic_cross_slot: bool = ..., cache: Optional[AbstractCache] = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> None: ... def __init__( self, host: Optional[str] = None, port: Optional[int] = None, *, startup_nodes: Optional[Iterable[Node]] = None, stream_timeout: Optional[float] = None, connect_timeout: Optional[float] = None, ssl: bool = False, ssl_context: Optional[SSLContext] = None, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = None, ssl_check_hostname: Optional[bool] = None, ssl_ca_certs: Optional[str] = None, max_connections: int = 32, max_connections_per_node: bool = False, readonly: bool = False, read_from_replicas: bool = False, reinitialize_steps: Optional[int] = None, skip_full_coverage_check: bool = False, nodemanager_follow_cluster: bool = True, decode_responses: bool = False, connection_pool: Optional[ClusterConnectionPool] = None, connection_pool_cls: Type[ClusterConnectionPool] = ClusterConnectionPool, protocol_version: Literal[2, 3] = 3, verify_version: bool = True, non_atomic_cross_slot: bool = True, cache: Optional[AbstractCache] = None, noreply: bool = False, noevict: bool = False, notouch: bool = False, retry_policy: RetryPolicy = CompositeRetryPolicy( ConstantRetryPolicy((ClusterDownError,), 2, 0.1), ConstantRetryPolicy( ( ConnectionError, TimeoutError, ), 2, 0.1, ), ), **kwargs: Any, ) -> None: """ Changes - .. versionadded:: 4.12.0 - :paramref:`retry_policy` - :paramref:`noevict` - :paramref:`notouch` - :meth:`RedisCluster.ensure_persistence` context manager - Redis Module support - RedisJSON: :attr:`RedisCluster.json` - RedisBloom: - BloomFilter: :attr:`RedisCluster.bf` - CuckooFilter: :attr:`RedisCluster.cf` - CountMinSketch: :attr:`RedisCluster.cms` - TopK: :attr:`RedisCluster.topk` - TDigest: :attr:`RedisCluster.tdigest` - RedisTimeSeries: :attr:`RedisCluster.timeseries` - RedisGraph: :attr:`RedisCluster.graph` - RediSearch: - Search & Aggregation: :attr:`RedisCluster.search` - Autocomplete: Added :attr:`RedisCluster.autocomplete` - .. versionchanged:: 4.4.0 - :paramref:`nodemanager_follow_cluster` now defaults to ``True`` - .. deprecated:: 4.4.0 - The :paramref:`readonly` argument is deprecated in favour of :paramref:`read_from_replicas` - .. versionadded:: 4.3.0 - Added :paramref:`connection_pool_cls` - .. versionchanged:: 4.0.0 - :paramref:`non_atomic_cross_slot` defaults to ``True`` - :paramref:`protocol_version`` defaults to ``3`` - .. versionadded:: 3.11.0 - Added :paramref:`noreply` - .. versionadded:: 3.10.0 - Synchronized ssl constructor parameters with :class:`coredis.Redis` - .. versionadded:: 3.9.0 - If :paramref:`cache` is provided the client will check & populate the cache for read only commands and invalidate it for commands that could change the key(s) in the request. - .. versionadded:: 3.6.0 - The :paramref:`non_atomic_cross_slot` parameter was added - .. versionchanged:: 3.5.0 - The :paramref:`verify_version` parameter now defaults to ``True`` - .. versionadded:: 3.1.0 - The :paramref:`protocol_version` and :paramref:`verify_version` parameters were added :param host: Can be used to point to a startup node :param port: Can be used to point to a startup node :param startup_nodes: List of nodes that initial bootstrapping can be done from :param stream_timeout: Timeout (seconds) when reading responses from the server :param connect_timeout: Timeout (seconds) for establishing a connection to the server :param ssl: Whether to use an SSL connection :param ssl_context: If provided the :class:`ssl.SSLContext` will be used when establishing the connection. Otherwise either the default context (if no other ssl related parameters are provided) or a custom context based on the other ``ssl_*`` parameters will be used. :param ssl_keyfile: Path of the private key to use :param ssl_certfile: Path to the certificate corresponding to :paramref:`ssl_keyfile` :param ssl_cert_reqs: Whether to try to verify the server's certificates and how to behave if verification fails (See :attr:`ssl.SSLContext.verify_mode`). :param ssl_check_hostname: Whether to enable hostname checking when establishing an ssl connection. :param ssl_ca_certs: Path to a concatenated certificate authority file or a directory containing several CA certifcates to use for validating the server's certificates when :paramref:`ssl_cert_reqs` is not ``"none"`` (See :meth:`ssl.SSLContext.load_verify_locations`). :param max_connections: Maximum number of connections that should be kept open at one time :param max_connections_per_node: :param read_from_replicas: If ``True`` the client will route readonly commands to replicas :param reinitialize_steps: Number of moved errors that result in a cluster topology refresh using the startup nodes provided :param skip_full_coverage_check: Skips the check of cluster-require-full-coverage config, useful for clusters without the CONFIG command (like aws) :param nodemanager_follow_cluster: The node manager will during initialization try the last set of nodes that it was operating on. This will allow the client to drift along side the cluster if the cluster nodes move around alot. :param decode_responses: If ``True`` string responses from the server will be decoded using :paramref:`encoding` before being returned. (See :ref:`handbook/encoding:encoding/decoding`) :param connection_pool: The connection pool instance to use. If not provided a new pool will be assigned to this client. :param connection_pool_cls: The connection pool class to use when constructing a connection pool for this instance. :param protocol_version: Whether to use the RESP (``2``) or RESP3 (``3``) protocol for parsing responses from the server (Default ``3``). (See :ref:`handbook/response:redis response`) :param verify_version: Validate redis server version against the documented version introduced before executing a command and raises a :exc:`CommandNotSupportedError` error if the required version is higher than the reported server version :param non_atomic_cross_slot: If ``True`` certain commands that can operate on multiple keys (cross slot) will be split across the relevant nodes by mapping the keys to the appropriate slot and the result merged before being returned. :param cache: If provided the cache will be used to avoid requests for read only commands if the client has already requested the data and it hasn't been invalidated. The cache is responsible for any mutations to the keys that happen outside of this client :param noreply: If ``True`` the client will not request a response for any commands sent to the server. :param noevict: Ensures that connections from the client will be excluded from the client eviction process even if we're above the configured client eviction threshold. :param notouch: Ensures that commands sent by the client will not alter the LRU/LFU of the keys they access. :param retry_policy: The retry policy to use when interacting with the cluster """ if "db" in kwargs: # noqa raise RedisClusterException( "Argument 'db' is not possible to use in cluster mode" ) if connection_pool: pool = connection_pool else: startup_nodes = [] if startup_nodes is None else list(startup_nodes) # Support host/port as argument if host: startup_nodes.append( Node( host=host, port=port if port else 7000, ) ) if ssl_context is not None: kwargs["ssl_context"] = ssl_context elif ssl: ssl_context = RedisSSLContext( ssl_keyfile, ssl_certfile, ssl_cert_reqs, ssl_ca_certs, ssl_check_hostname, ).get() kwargs["ssl_context"] = ssl_context pool = connection_pool_cls( startup_nodes=startup_nodes, max_connections=max_connections, reinitialize_steps=reinitialize_steps, max_connections_per_node=max_connections_per_node, skip_full_coverage_check=skip_full_coverage_check, nodemanager_follow_cluster=nodemanager_follow_cluster, read_from_replicas=readonly or read_from_replicas, decode_responses=decode_responses, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, stream_timeout=stream_timeout, connect_timeout=connect_timeout, **kwargs, ) super().__init__( stream_timeout=stream_timeout, connect_timeout=connect_timeout, connection_pool=pool, connection_pool_cls=connection_pool_cls, decode_responses=decode_responses, verify_version=verify_version, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, retry_policy=retry_policy, **kwargs, ) self.refresh_table_asap: bool = False self.route_flags: Dict[bytes, NodeFlag] = self.__class__.ROUTING_FLAGS.copy() self.split_flags: Dict[bytes, NodeFlag] = self.__class__.SPLIT_FLAGS.copy() self.result_callbacks: Dict[bytes, Callable[..., Any]] = ( self.__class__.RESULT_CALLBACKS.copy() ) self.non_atomic_cross_slot = non_atomic_cross_slot self.cache = cache self._decodecontext: contextvars.ContextVar[Optional[bool],] = ( contextvars.ContextVar("decode", default=None) ) self._encodingcontext: contextvars.ContextVar[Optional[str],] = ( contextvars.ContextVar("decode", default=None) ) @classmethod @overload def from_url( cls: Type[RedisClusterBytesT], url: str, *, db: Optional[int] = ..., skip_full_coverage_check: bool = ..., decode_responses: Literal[False] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> RedisClusterBytesT: ... @classmethod @overload def from_url( cls: Type[RedisClusterStringT], url: str, *, db: Optional[int] = ..., skip_full_coverage_check: bool = ..., decode_responses: Literal[True], protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> RedisClusterStringT: ...
[docs] @classmethod def from_url( cls: Type[RedisClusterT], url: str, *, db: Optional[int] = None, skip_full_coverage_check: bool = False, decode_responses: bool = False, protocol_version: Literal[2, 3] = 3, verify_version: bool = True, noreply: bool = False, noevict: bool = False, notouch: bool = False, retry_policy: RetryPolicy = CompositeRetryPolicy( ConstantRetryPolicy((ClusterDownError,), 2, 0.1), ConstantRetryPolicy( ( ConnectionError, TimeoutError, ), 2, 0.1, ), ), **kwargs: Any, ) -> RedisClusterT: """ Return a Cluster client object configured from the startup node in URL, which must use either the ``redis://`` scheme `<http://www.iana.org/assignments/uri-schemes/prov/redis>`_ For example: - ``redis://[:password]@localhost:6379`` - ``rediss://[:password]@localhost:6379`` :paramref:`url` and :paramref:`kwargs` are passed as is to the :func:`coredis.ConnectionPool.from_url`. """ if decode_responses: return cls( decode_responses=True, protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, connection_pool=ClusterConnectionPool.from_url( url, db=db, skip_full_coverage_check=skip_full_coverage_check, decode_responses=decode_responses, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, **kwargs, ), ) else: return cls( decode_responses=False, protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, connection_pool=ClusterConnectionPool.from_url( url, db=db, skip_full_coverage_check=skip_full_coverage_check, decode_responses=decode_responses, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, **kwargs, ), )
async def initialize(self) -> RedisCluster[AnyStr]: if self.refresh_table_asap: self.connection_pool.initialized = False await super().initialize() if self.cache: self.cache = await self.cache.initialize(self) self.refresh_table_asap = False return self def __repr__(self) -> str: servers = list( { "{}:{}".format(info.host, info.port) for info in self.connection_pool.nodes.startup_nodes } ) servers.sort() return "{}<{}>".format(type(self).__name__, ", ".join(servers)) @property def all_nodes(self) -> Iterator[Redis[AnyStr]]: """ """ for node in self.connection_pool.nodes.all_nodes(): yield cast( Redis[AnyStr], self.connection_pool.nodes.get_redis_link(node.host, node.port), ) @property def primaries(self) -> Iterator[Redis[AnyStr]]: """ """ for primary in self.connection_pool.nodes.all_primaries(): yield cast( Redis[AnyStr], self.connection_pool.nodes.get_redis_link(primary.host, primary.port), ) @property def replicas(self) -> Iterator[Redis[AnyStr]]: """ """ for replica in self.connection_pool.nodes.all_replicas(): yield cast( Redis[AnyStr], self.connection_pool.nodes.get_redis_link(replica.host, replica.port), ) @property def num_replicas_per_shard(self) -> int: """ Number of replicas per shard of the cluster determined by initial cluster topology discovery """ return self.connection_pool.nodes.replicas_per_shard async def _ensure_initialized(self) -> None: if not self.connection_pool.initialized or self.refresh_table_asap: await self def _determine_slots( self, command: bytes, *args: ValueT, **options: Optional[ValueT] ) -> Set[int]: """Determines the slots the command and args would touch""" keys = cast(Tuple[ValueT, ...], options.get("keys")) or KeySpec.extract_keys( command, *args, readonly_command=self.connection_pool.read_from_replicas ) if ( command in { CommandName.EVAL, CommandName.EVAL_RO, CommandName.EVALSHA, CommandName.EVALSHA_RO, CommandName.FCALL, CommandName.FCALL_RO, CommandName.PUBLISH, } and not keys ): return set() return {hash_slot(b(key)) for key in keys} def _merge_result( self, command: bytes, res: Dict[str, R], **kwargs: Optional[ValueT], ) -> R: assert command in self.result_callbacks return cast( R, self.result_callbacks[command]( res, version=self.protocol_version, **kwargs ), ) def determine_node( self, command: bytes, **kwargs: Optional[ValueT] ) -> Optional[List[ManagedNode]]: node_flag = self.route_flags.get(command) if command in self.split_flags and self.non_atomic_cross_slot: node_flag = self.split_flags[command] if node_flag == NodeFlag.RANDOM: return [self.connection_pool.nodes.random_node(primary=True)] elif node_flag == NodeFlag.PRIMARIES: return list(self.connection_pool.nodes.all_primaries()) elif node_flag == NodeFlag.ALL: return list(self.connection_pool.nodes.all_nodes()) elif node_flag == NodeFlag.SLOT_ID: slot_id: Optional[ValueT] = kwargs.get("slot_id") node_from_slot = ( self.connection_pool.nodes.node_from_slot(int(slot_id)) if slot_id is not None else None ) if node_from_slot: return [node_from_slot] return None async def on_connection_error(self, _: BaseException) -> None: self.connection_pool.disconnect() self.connection_pool.reset() self.refresh_table_asap = True async def on_cluster_down_error(self, _: BaseException) -> None: self.connection_pool.disconnect() self.connection_pool.reset() self.refresh_table_asap = True
[docs] async def execute_command( self, command: bytes, *args: ValueT, callback: Callable[..., R] = NoopCallback(), **kwargs: Optional[ValueT], ) -> R: """ Sends a command to one or many nodes in the cluster with retries based on :paramref:`RedisCluster.retry_policy` """ return await self.retry_policy.call_with_retries( lambda: self._execute_command(command, *args, callback=callback, **kwargs), failure_hook={ ConnectionError: self.on_connection_error, ClusterDownError: self.on_cluster_down_error, }, before_hook=self._ensure_initialized, )
async def _execute_command( self, command: bytes, *args: ValueT, callback: Callable[..., R] = NoopCallback(), **kwargs: Optional[ValueT], ) -> R: """ Sends a command to one or many nodes in the cluster """ nodes = self.determine_node(command, **kwargs) if nodes and len(nodes) > 1: tasks: Dict[str, Coroutine[Any, Any, R]] = {} node_arg_mapping = self._split_args_over_nodes(nodes, command, *args) node_name_map = {n.name: n for n in nodes} for node_name in node_arg_mapping: for portion, pargs in enumerate(node_arg_mapping[node_name]): tasks[f"{node_name}:{portion}"] = ( self._execute_command_on_single_node( command, *pargs, callback=callback, node=node_name_map[node_name], slots=None, **kwargs, ) ) results = await asyncio.gather(*tasks.values(), return_exceptions=True) if self.noreply: return None # type: ignore return cast( R, self._merge_result(command, dict(zip(tasks.keys(), results)), **kwargs), ) else: node = None slots = None if not nodes: slots = list(self._determine_slots(command, *args, **kwargs)) else: node = nodes.pop() return await self._execute_command_on_single_node( command, *args, callback=callback, node=node, slots=slots, **kwargs ) def _split_args_over_nodes( self, nodes: List[ManagedNode], command: bytes, *args: ValueT, ) -> Dict[str, List[Tuple[ValueT, ...]]]: if command in self.split_flags and self.non_atomic_cross_slot: keys = KeySpec.extract_keys(command, *args) node_arg_mapping: Dict[str, List[Tuple[ValueT, ...]]] = {} if keys: key_start: int = args.index(keys[0]) key_end: int = args.index(keys[-1]) assert ( args[key_start : 1 + key_end] == keys ), f"Unable to map {command.decode('latin-1')} by keys {keys}" for ( node_name, key_groups, ) in self.connection_pool.nodes.keys_to_nodes_by_slot(*keys).items(): for _, node_keys in key_groups.items(): node_arg_mapping.setdefault(node_name, []).append( ( *args[:key_start], *node_keys, # type: ignore *args[1 + key_end :], ) ) if self.cache and command not in READONLY_COMMANDS: self.cache.invalidate(*keys) return node_arg_mapping else: # This command is not meant to be split across nodes and each node # should be called with the same arguments return {node.name: [args] for node in nodes} async def _execute_command_on_single_node( self, command: bytes, *args: ValueT, callback: Callable[..., R] = NoopCallback(), node: Optional[ManagedNode] = None, slots: Optional[List[int]] = None, **kwargs: Optional[ValueT], ) -> R: redirect_addr = None asking = False if not node and not slots: try_random_node = True try_random_type = NodeFlag.PRIMARIES else: try_random_node = False try_random_type = NodeFlag.ALL remaining_attempts = int(self.MAX_RETRIES) while remaining_attempts > 0: remaining_attempts -= 1 if self.refresh_table_asap and not slots: await self if asking and redirect_addr: node = self.connection_pool.nodes.nodes[redirect_addr] r = await self.connection_pool.get_connection_by_node(node) elif try_random_node: r = await self.connection_pool.get_random_connection( primary=try_random_type == NodeFlag.PRIMARIES ) if slots: try_random_node = False elif node: r = await self.connection_pool.get_connection_by_node(node) elif slots: if self.refresh_table_asap: # MOVED node = self.connection_pool.get_primary_node_by_slots(slots) else: node = self.connection_pool.get_node_by_slots(slots) r = await self.connection_pool.get_connection_by_node(node) else: continue quick_release = self.should_quick_release(command) released = False try: if asking: request = await r.create_request( CommandName.ASKING, noreply=self.noreply, decode=False ) await request asking = False if ( isinstance(self.cache, AbstractCache) and isinstance(self.cache, SupportsClientTracking) and r.tracking_client_id != self.cache.get_client_id(r) ): self.cache.reset() await r.update_tracking_client(True, self.cache.get_client_id(r)) if self.cache and command not in READONLY_COMMANDS: self.cache.invalidate(*KeySpec.extract_keys(command, *args)) request = await r.create_request( command, *args, noreply=self.noreply, decode=kwargs.get("decode", self._decodecontext.get()), encoding=self._encodingcontext.get(), ) if quick_release and not (self.requires_wait or self.requires_waitaof): released = True self.connection_pool.release(r) reply = await request response = None maybe_wait = [ await self._ensure_wait(command, r), await self._ensure_persistence(command, r), ] if not self.noreply: if isinstance(callback, AsyncPreProcessingCallback): await callback.pre_process( self, reply, version=self.protocol_version, **kwargs ) response = callback( reply, version=self.protocol_version, **kwargs, ) await asyncio.gather(*maybe_wait) return response # type: ignore except (RedisClusterException, BusyLoadingError, asyncio.CancelledError): raise except MovedError as e: # Reinitialize on ever x number of MovedError. # This counter will increase faster when the same client object # is shared between multiple threads. To reduce the frequency you # can set the variable 'reinitialize_steps' in the constructor. self.refresh_table_asap = True await self.connection_pool.nodes.increment_reinitialize_counter() node = self.connection_pool.nodes.set_node( e.host, e.port, server_type="primary" ) try_random_node = False self.connection_pool.nodes.slots[e.slot_id][0] = node except TryAgainError: if remaining_attempts < self.MAX_RETRIES / 2: await asyncio.sleep(0.05) except AskError as e: redirect_addr, asking = f"{e.host}:{e.port}", True finally: self._ensure_server_version(r.server_version) if not released: self.connection_pool.release(r) raise ClusterError("Maximum retries exhausted.") @overload def decoding( self, mode: Literal[False], encoding: Optional[str] = None ) -> ContextManager[RedisCluster[bytes]]: ... @overload def decoding( self, mode: Literal[True], encoding: Optional[str] = None ) -> ContextManager[RedisCluster[str]]: ...
[docs] @contextlib.contextmanager @versionadded(version="4.8.0") def decoding( self, mode: bool, encoding: Optional[str] = None ) -> Iterator[RedisCluster[Any]]: """ Context manager to temporarily change the decoding behavior of the client :param mode: Whether to decode or not :param encoding: Optional encoding to use if decoding. If not provided the :paramref:`~coredis.RedisCluster.encoding` parameter provided to the client will be used. Example:: client = coredis.RedisCluster(decode_responses=True) await client.set("fubar", "baz") assert await client.get("fubar") == "baz" with client.decoding(False): assert await client.get("fubar") == b"baz" with client.decoding(True): assert await client.get("fubar") == "baz" """ prev_decode = self._decodecontext.get() prev_encoding = self._encodingcontext.get() self._decodecontext.set(mode) self._encodingcontext.set(encoding) try: yield self finally: self._decodecontext.set(prev_decode) self._encodingcontext.set(prev_encoding)
[docs] def pubsub( self, ignore_subscribe_messages: bool = False, retry_policy: Optional[RetryPolicy] = None, **kwargs: Any, ) -> ClusterPubSub[AnyStr]: """ Return a Pub/Sub instance that can be used to subscribe to channels or patterns in a redis cluster and receive messages that get published to them. :param ignore_subscribe_messages: Whether to skip subscription acknowledgement messages :param retry_policy: An explicit retry policy to use in the subscriber. """ return ClusterPubSub[AnyStr]( self.connection_pool, ignore_subscribe_messages=ignore_subscribe_messages, retry_policy=retry_policy, **kwargs, )
[docs] @versionadded(version="3.6.0") def sharded_pubsub( self, ignore_subscribe_messages: bool = False, read_from_replicas: bool = False, retry_policy: Optional[RetryPolicy] = None, **kwargs: Any, ) -> ShardedPubSub[AnyStr]: """ Return a Pub/Sub instance that can be used to subscribe to channels in a redis cluster and receive messages that get published to them. The implementation returned differs from that returned by :meth:`pubsub` as it uses the Sharded Pub/Sub implementation which routes messages to cluster nodes using the same algorithm used to assign keys to slots. This effectively restricts the propagation of messages to be within the shard of a cluster hence affording horizontally scaling the use of Pub/Sub with the cluster itself. :param ignore_subscribe_messages: Whether to skip subscription acknowledgement messages :param read_from_replicas: Whether to read messages from replica nodes :param retry_policy: An explicit retry policy to use in the subscriber. New in :redis-version:`7.0.0` """ return ShardedPubSub[AnyStr]( self.connection_pool, ignore_subscribe_messages=ignore_subscribe_messages, read_from_replicas=read_from_replicas, retry_policy=retry_policy, **kwargs, )
[docs] async def pipeline( self, transaction: Optional[bool] = None, watches: Optional[Parameters[StringT]] = None, timeout: Optional[float] = None, ) -> "coredis.pipeline.ClusterPipeline[AnyStr]": """ Returns a new pipeline object that can queue multiple commands for batch execution. Pipelines in cluster mode only provide a subset of the functionality of pipelines in standalone mode. Specifically: - Each command in the pipeline should only access keys on the same node - Transactions are disabled by default and are only supported if all watched keys route to the same node as where the commands in the multi/exec part of the pipeline. :param transaction: indicates whether all commands should be executed atomically. :param watches: If :paramref:`transaction` is True these keys are watched for external changes during the transaction. :param timeout: If specified this value will take precedence over :paramref:`RedisCluster.stream_timeout` """ await self.connection_pool.initialize() from coredis.pipeline import ClusterPipeline return ClusterPipeline[AnyStr].proxy( client=self, transaction=transaction, watches=watches, timeout=timeout, )
[docs] async def transaction( self, func: Callable[ ["coredis.pipeline.ClusterPipeline[AnyStr]"], Coroutine[Any, Any, Any], ], *watches: StringT, value_from_callable: bool = False, watch_delay: Optional[float] = None, **kwargs: Any, ) -> Any: """ Convenience method for executing the callable :paramref:`func` as a transaction while watching all keys specified in :paramref:`watches`. :param func: callable should expect a single argument which is a :class:`coredis.pipeline.ClusterPipeline` object retrieved by calling :meth:`~coredis.RedisCluster.pipeline`. :param watches: The keys to watch during the transaction. The keys should route to the same node as the keys touched by the commands in :paramref:`func` :param value_from_callable: Whether to return the result of transaction or the value returned from :paramref:`func` .. warning:: Cluster transactions can only be run with commands that route to the same slot. .. versionchanged:: 4.9.0 When the transaction is started with :paramref:`watches` the :class:`~coredis.pipeline.ClusterPipeline` instance passed to :paramref:`func` will not start queuing commands until a call to :meth:`~coredis.pipeline.ClusterPipeline.multi` is made. This makes the cluster implementation consistent with :meth:`coredis.Redis.transaction` """ async with await self.pipeline(True) as pipe: while True: try: if watches: await pipe.watch(*watches) func_value = await func(pipe) exec_value = await pipe.execute() return func_value if value_from_callable else exec_value except WatchError: if watch_delay is not None and watch_delay > 0: await asyncio.sleep(watch_delay) continue
[docs] async def scan_iter( self, match: Optional[StringT] = None, count: Optional[int] = None, type_: Optional[StringT] = None, ) -> AsyncIterator[AnyStr]: for node in self.primaries: cursor = None while cursor != 0: cursor, data = await node.scan(cursor or 0, match, count, type_) for item in data: yield item