Source code for coredis.client.basic

from __future__ import annotations

import asyncio
import contextlib
import contextvars
import functools
import warnings
from collections import defaultdict
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, cast, overload

from deprecated.sphinx import versionadded
from packaging import version
from packaging.version import InvalidVersion, Version

from coredis._utils import EncodingInsensitiveDict, nativestr
from coredis.cache import AbstractCache, SupportsClientTracking
from coredis.commands._key_spec import KeySpec
from coredis.commands.constants import CommandFlag, CommandName
from coredis.commands.core import CoreCommands
from coredis.commands.function import Library
from coredis.commands.monitor import Monitor
from coredis.commands.pubsub import PubSub
from coredis.commands.script import Script
from coredis.commands.sentinel import SentinelCommands
from coredis.config import Config
from coredis.connection import (
    BaseConnection,
    RedisSSLContext,
    UnixDomainSocketConnection,
)
from coredis.exceptions import (
    ConnectionError,
    PersistenceError,
    RedisError,
    ReplicationError,
    TimeoutError,
    UnknownCommandError,
    WatchError,
)
from coredis.globals import COMMAND_FLAGS, READONLY_COMMANDS
from coredis.modules import ModuleMixin
from coredis.pool import ConnectionPool
from coredis.response._callbacks import (
    AsyncPreProcessingCallback,
    NoopCallback,
    ResponseCallback,
)
from coredis.response.types import ScoredMember
from coredis.retry import ConstantRetryPolicy, NoRetryPolicy, RetryPolicy
from coredis.typing import (
    AnyStr,
    AsyncGenerator,
    AsyncIterator,
    Callable,
    ContextManager,
    Coroutine,
    Dict,
    Generator,
    Generic,
    Iterator,
    KeyT,
    Literal,
    Optional,
    Parameters,
    ParamSpec,
    ResponseType,
    StringT,
    Tuple,
    Type,
    TypeVar,
    ValueT,
)

P = ParamSpec("P")
R = TypeVar("R")

if TYPE_CHECKING:
    import coredis.pipeline


ClientT = TypeVar("ClientT", bound="Client[Any]")
RedisT = TypeVar("RedisT", bound="Redis[Any]")
RedisStringT = TypeVar("RedisStringT", bound="Redis[str]")
RedisBytesT = TypeVar("RedisBytesT", bound="Redis[bytes]")


class Client(
    Generic[AnyStr],
    CoreCommands[AnyStr],
    ModuleMixin[AnyStr],
    SentinelCommands[AnyStr],
):
    cache: Optional[AbstractCache]
    connection_pool: ConnectionPool
    decode_responses: bool
    encoding: str
    protocol_version: Literal[2, 3]
    server_version: Optional[Version]
    callback_storage: Dict[Type[ResponseCallback[Any, Any, Any]], Dict[str, Any]]

    def __init__(
        self,
        host: Optional[str] = "localhost",
        port: Optional[int] = 6379,
        db: int = 0,
        username: Optional[str] = None,
        password: Optional[str] = None,
        stream_timeout: Optional[float] = None,
        connect_timeout: Optional[float] = None,
        connection_pool: Optional[ConnectionPool] = None,
        connection_pool_cls: Type[ConnectionPool] = ConnectionPool,
        unix_socket_path: Optional[str] = None,
        encoding: str = "utf-8",
        decode_responses: bool = False,
        ssl: bool = False,
        ssl_context: Optional[SSLContext] = None,
        ssl_keyfile: Optional[str] = None,
        ssl_certfile: Optional[str] = None,
        ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = None,
        ssl_check_hostname: Optional[bool] = None,
        ssl_ca_certs: Optional[str] = None,
        max_connections: Optional[int] = None,
        max_idle_time: float = 0,
        idle_check_interval: float = 1,
        client_name: Optional[str] = None,
        protocol_version: Literal[2, 3] = 3,
        verify_version: bool = True,
        noreply: bool = False,
        retry_policy: RetryPolicy = NoRetryPolicy(),
        noevict: bool = False,
        notouch: bool = False,
        **kwargs: Any,
    ):
        if not connection_pool:
            kwargs = {
                "db": db,
                "username": username,
                "password": password,
                "encoding": encoding,
                "stream_timeout": stream_timeout,
                "connect_timeout": connect_timeout,
                "max_connections": max_connections,
                "decode_responses": decode_responses,
                "max_idle_time": max_idle_time,
                "idle_check_interval": idle_check_interval,
                "client_name": client_name,
                "protocol_version": protocol_version,
                "noreply": noreply,
                "noevict": noevict,
                "notouch": notouch,
            }

            if unix_socket_path is not None:
                kwargs.update(
                    {
                        "path": unix_socket_path,
                        "connection_class": UnixDomainSocketConnection,
                    }
                )
            else:
                # TCP specific options
                kwargs.update({"host": host, "port": port})

                if ssl_context is not None:
                    kwargs["ssl_context"] = ssl_context
                elif ssl:
                    ssl_context = RedisSSLContext(
                        ssl_keyfile,
                        ssl_certfile,
                        ssl_cert_reqs,
                        ssl_ca_certs,
                        ssl_check_hostname,
                    ).get()
                    kwargs["ssl_context"] = ssl_context
            connection_pool = connection_pool_cls(**kwargs)

        self.connection_pool = connection_pool
        self.encoding = str(connection_pool.connection_kwargs.get("encoding", encoding))
        self.decode_responses = bool(
            connection_pool.connection_kwargs.get("decode_responses", decode_responses)
        )
        connection_protocol_version = (
            connection_pool.connection_kwargs.get("protocol_version")
            or protocol_version
        )
        assert connection_protocol_version in {
            2,
            3,
        }, "Protocol version can only be one of {2,3}"
        self.protocol_version = connection_protocol_version
        self.server_version: Optional[Version] = None
        self.verify_version = verify_version
        self.__noreply = noreply
        self._noreplycontext: contextvars.ContextVar[Optional[bool]] = (
            contextvars.ContextVar("noreply", default=None)
        )
        self._waitcontext: contextvars.ContextVar[Optional[Tuple[int, int]]] = (
            contextvars.ContextVar("wait", default=None)
        )
        self._waitaof_context: contextvars.ContextVar[
            Optional[Tuple[int, int, int]]
        ] = contextvars.ContextVar("waitaof", default=None)
        self.retry_policy = retry_policy
        self._module_info: Optional[Dict[str, version.Version]] = None
        self.callback_storage = defaultdict(lambda: {})

    @property
    def noreply(self) -> bool:
        if not hasattr(self, "_noreplycontext"):
            return False
        ctx = self._noreplycontext.get()
        if ctx is not None:
            return ctx
        return self.__noreply

    @property
    def requires_wait(self) -> bool:
        if not hasattr(self, "_waitcontext") or not self._waitcontext.get():
            return False
        return True

    @property
    def requires_waitaof(self) -> bool:
        if not hasattr(self, "_waitaof_context") or not self._waitaof_context.get():
            return False
        return True

    async def get_server_module_version(self, module: str) -> Optional[version.Version]:
        if self._module_info is None:
            await self._populate_module_versions()
        return (self._module_info or {}).get(module)

    def _ensure_server_version(self, version: Optional[str]) -> None:
        if not self.verify_version or Config.optimized:
            return
        if not version:
            return
        if not self.server_version and version:
            try:
                self.server_version = Version(nativestr(version))
            except InvalidVersion:
                warnings.warn(
                    (
                        f"Server reported an invalid version: {version}."
                        "If this is expected you can dismiss this warning by passing "
                        "verify_version=False to the client constructor"
                    ),
                    category=UserWarning,
                )
                self.verify_version = False
                self.server_version = None

    async def _ensure_wait(
        self, command: bytes, connection: BaseConnection
    ) -> asyncio.Future[None]:
        maybe_wait: asyncio.Future[None] = asyncio.get_running_loop().create_future()
        wait = self._waitcontext.get()
        if wait and wait[0] > 0:

            def check_wait(
                wait: Tuple[int, int], response: asyncio.Future[ResponseType]
            ) -> None:
                exc = response.exception()
                if exc:
                    maybe_wait.set_exception(exc)
                elif not cast(int, response.result()) >= wait[0]:
                    maybe_wait.set_exception(
                        ReplicationError(command, wait[0], wait[1])
                    )
                else:
                    maybe_wait.set_result(None)

            request = await connection.create_request(
                CommandName.WAIT, *wait, decode=False
            )
            request.add_done_callback(functools.partial(check_wait, wait))
        else:
            maybe_wait.set_result(None)
        return maybe_wait

    async def _ensure_persistence(
        self, command: bytes, connection: BaseConnection
    ) -> asyncio.Future[None]:
        maybe_wait: asyncio.Future[None] = asyncio.get_running_loop().create_future()
        waitaof = self._waitaof_context.get()
        if waitaof and waitaof[0] > 0:

            def check_wait(
                waitaof: Tuple[int, int, int], response: asyncio.Future[ResponseType]
            ) -> None:
                exc = response.exception()
                if exc:
                    maybe_wait.set_exception(exc)
                else:
                    res = cast(Tuple[int, int], response.result())
                    if not (res[0] >= waitaof[0] and res[1] >= waitaof[1]):
                        maybe_wait.set_exception(PersistenceError(command, *waitaof))
                    else:
                        maybe_wait.set_result(None)

            request = await connection.create_request(
                CommandName.WAITAOF, *waitaof, decode=False
            )
            request.add_done_callback(functools.partial(check_wait, waitaof))
        else:
            maybe_wait.set_result(None)
        return maybe_wait

    async def _populate_module_versions(self) -> None:
        if self.noreply:
            return
        try:
            modules = await self.module_list()
            self._module_info = defaultdict(lambda: version.Version("0"))
            for module in modules:
                mod = EncodingInsensitiveDict(module)
                name = nativestr(mod["name"])
                ver = mod["ver"]
                ver, patch = divmod(ver, 100)
                ver, minor = divmod(ver, 100)
                ver, major = divmod(ver, 100)
                self._module_info[name] = version.Version(f"{major}.{minor}.{patch}")
        except UnknownCommandError:
            self._module_info = {}

    async def initialize(self: ClientT) -> ClientT:
        await self.connection_pool.initialize()
        return self

    def __await__(self: ClientT) -> Generator[Any, None, ClientT]:
        return self.initialize().__await__()

    def __repr__(self) -> str:
        return f"{type(self).__name__}<{repr(self.connection_pool)}>"

    async def scan_iter(
        self,
        match: Optional[StringT] = None,
        count: Optional[int] = None,
        type_: Optional[StringT] = None,
    ) -> AsyncIterator[AnyStr]:
        """
        Make an iterator using the SCAN command so that the client doesn't
        need to remember the cursor position.
        """
        cursor = None

        while cursor != 0:
            cursor, data = await self.scan(
                cursor=cursor, match=match, count=count, type_=type_
            )

            for item in data:
                yield item

    async def sscan_iter(
        self,
        key: KeyT,
        match: Optional[StringT] = None,
        count: Optional[int] = None,
    ) -> AsyncIterator[AnyStr]:
        """
        Make an iterator using the SSCAN command so that the client doesn't
        need to remember the cursor position.
        """
        cursor = None

        while cursor != 0:
            cursor, data = await self.sscan(
                key, cursor=cursor, match=match, count=count
            )

            for item in data:
                yield item

    async def hscan_iter(
        self,
        key: KeyT,
        match: Optional[StringT] = None,
        count: Optional[int] = None,
    ) -> AsyncGenerator[Tuple[AnyStr, AnyStr], None]:
        """
        Make an iterator using the HSCAN command so that the client doesn't
        need to remember the cursor position.
        """
        cursor = None

        while cursor != 0:
            cursor, data = await self.hscan(
                key, cursor=cursor, match=match, count=count
            )

            for item in data.items():
                yield item

    async def zscan_iter(
        self,
        key: KeyT,
        match: Optional[StringT] = None,
        count: Optional[int] = None,
    ) -> AsyncIterator[ScoredMember]:
        """
        Make an iterator using the ZSCAN command so that the client doesn't
        need to remember the cursor position.
        """
        cursor = None

        while cursor != 0:
            cursor, data = await self.zscan(
                key,
                cursor=cursor,
                match=match,
                count=count,
            )

            for item in data:
                yield item

    def register_script(self, script: ValueT) -> Script[AnyStr]:
        """
        Registers a Lua :paramref:`script`

        :return: A :class:`coredis.commands.script.Script` instance that is
         callable and hides the complexity of dealing with scripts, keys, and
         shas.
        """
        return Script[AnyStr](self, script)  # type: ignore

    @versionadded(version="3.1.0")
    async def register_library(
        self, name: StringT, code: StringT, replace: bool = False
    ) -> Library[AnyStr]:
        """
        Register a new library

        :param name: name of the library
        :param code: raw code for the library
        :param replace: Whether to replace the library when intializing. If ``False``
         an exception will be raised if the library was already loaded in the target
         redis instance.
        """
        return await Library[AnyStr](self, name=name, code=code, replace=replace)

    @versionadded(version="3.1.0")
    async def get_library(self, name: StringT) -> Library[AnyStr]:
        """
        Fetch a pre registered library

        :param name: name of the library
        """
        return await Library[AnyStr](self, name)

    @contextlib.contextmanager
    def ignore_replies(self: ClientT) -> Iterator[ClientT]:
        """
        Context manager to run commands without waiting for a reply.

        Example::

            client = coredis.Redis()
            with client.ignore_replies():
                assert None == await client.set("fubar", 1), "noreply"
            assert True == await client.set("fubar", 1), "reply"
        """
        self._noreplycontext.set(True)
        try:
            yield self
        finally:
            self._noreplycontext.set(None)

    @contextlib.contextmanager
    def ensure_replication(
        self: ClientT, replicas: int = 1, timeout_ms: int = 100
    ) -> Iterator[ClientT]:
        """
        Context manager to ensure that commands executed within the context
        are replicated to atleast :paramref:`replicas` within
        :paramref:`timeout_ms` milliseconds.

        Internally this uses `WAIT <https://redis.io/commands/wait>`_ after
        each command executed within the context

        :raises: :exc:`coredis.exceptions.ReplicationError`

        Example::

            client = coredis.RedisCluster("localhost", 7000)
            with client.ensure_replication(1, 20):
                await client.set("fubar", 1)

        """
        self._waitcontext.set((replicas, timeout_ms))
        try:
            yield self
        finally:
            self._waitcontext.set(None)

    @versionadded(version="4.12.0")
    @contextlib.contextmanager
    def ensure_persistence(
        self: ClientT,
        local: Literal[0, 1] = 0,
        replicas: int = 0,
        timeout_ms: int = 100,
    ) -> Iterator[ClientT]:
        """
        Context manager to ensure that commands executed within the context
        are synced to the AOF of a :paramref:`local` host and/or :paramref:`replicas`
        within :paramref:`timeout_ms` milliseconds.

        Internally this uses `WAITAOF <https://redis.io/commands/waitaof>`_ after
        each command executed within the context

        :raises: :exc:`coredis.exceptions.PersistenceError`

        Example for standalone client::

            client = coredis.Redis()
            with client.ensure_persistence(1, 0, 20):
                await client.set("fubar", 1)

        Example for cluster::

            client = coredis.RedisCluster("localhost", 7000)
            with client.ensure_persistence(1, 1, 20):
                await client.set("fubar", 1)

        """
        self._waitaof_context.set((local, replicas, timeout_ms))
        try:
            yield self
        finally:
            self._waitaof_context.set(None)

    def should_quick_release(self, command: bytes) -> bool:
        return CommandFlag.BLOCKING not in COMMAND_FLAGS[command]


[docs] class Redis(Client[AnyStr]): connection_pool: ConnectionPool @overload def __init__( self: Redis[bytes], host: Optional[str] = ..., port: Optional[int] = ..., db: int = ..., *, username: Optional[str] = ..., password: Optional[str] = ..., stream_timeout: Optional[float] = ..., connect_timeout: Optional[float] = ..., connection_pool: Optional[ConnectionPool] = ..., connection_pool_cls: Type[ConnectionPool] = ..., unix_socket_path: Optional[str] = ..., encoding: str = ..., decode_responses: Literal[False] = ..., ssl: bool = ..., ssl_context: Optional[SSLContext] = ..., ssl_keyfile: Optional[str] = ..., ssl_certfile: Optional[str] = ..., ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = ..., ssl_check_hostname: Optional[bool] = ..., ssl_ca_certs: Optional[str] = ..., max_connections: Optional[int] = ..., max_idle_time: float = ..., idle_check_interval: float = ..., client_name: Optional[str] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., cache: Optional[AbstractCache] = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> None: ... @overload def __init__( self: Redis[str], host: Optional[str] = ..., port: Optional[int] = ..., db: int = ..., *, username: Optional[str] = ..., password: Optional[str] = ..., stream_timeout: Optional[float] = ..., connect_timeout: Optional[float] = ..., connection_pool: Optional[ConnectionPool] = ..., connection_pool_cls: Type[ConnectionPool] = ..., unix_socket_path: Optional[str] = ..., encoding: str = ..., decode_responses: Literal[True], ssl: bool = ..., ssl_context: Optional[SSLContext] = ..., ssl_keyfile: Optional[str] = ..., ssl_certfile: Optional[str] = ..., ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = ..., ssl_check_hostname: Optional[bool] = ..., ssl_ca_certs: Optional[str] = ..., max_connections: Optional[int] = ..., max_idle_time: float = ..., idle_check_interval: float = ..., client_name: Optional[str] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., cache: Optional[AbstractCache] = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> None: ... def __init__( self, host: Optional[str] = "localhost", port: Optional[int] = 6379, db: int = 0, *, username: Optional[str] = None, password: Optional[str] = None, stream_timeout: Optional[float] = None, connect_timeout: Optional[float] = None, connection_pool: Optional[ConnectionPool] = None, connection_pool_cls: Type[ConnectionPool] = ConnectionPool, unix_socket_path: Optional[str] = None, encoding: str = "utf-8", decode_responses: bool = False, ssl: bool = False, ssl_context: Optional[SSLContext] = None, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Optional[Literal["optional", "required", "none"]] = None, ssl_check_hostname: Optional[bool] = None, ssl_ca_certs: Optional[str] = None, max_connections: Optional[int] = None, max_idle_time: float = 0, idle_check_interval: float = 1, client_name: Optional[str] = None, protocol_version: Literal[2, 3] = 3, verify_version: bool = True, cache: Optional[AbstractCache] = None, noreply: bool = False, noevict: bool = False, notouch: bool = False, retry_policy: RetryPolicy = ConstantRetryPolicy( (ConnectionError, TimeoutError), 2, 0.01 ), **kwargs: Any, ) -> None: """ Changes - .. versionadded:: 4.12.0 - :paramref:`retry_policy` - :paramref:`noevict` - :paramref:`notouch` - :meth:`Redis.ensure_persistence` context manager - Redis Module support - RedisJSON: :attr:`Redis.json` - RedisBloom: - BloomFilter: :attr:`Redis.bf` - CuckooFilter: :attr:`Redis.cf` - CountMinSketch: :attr:`Redis.cms` - TopK: :attr:`Redis.topk` - TDigest: :attr:`Redis.tdigest` - RedisTimeSeries: :attr:`Redis.timeseries` - RedisGraph: :attr:`Redis.graph` - RediSearch: - Search & Aggregation: :attr:`Redis.search` - Autocomplete: Added :attr:`Redis.autocomplete` - .. versionchanged:: 4.12.0 - Removed :paramref:`retry_on_timeout` constructor argument. Use :paramref:`retry_policy` instead. - .. versionadded:: 4.3.0 - Added :paramref:`connection_pool_cls` - .. versionchanged:: 4.0.0 - :paramref:`non_atomic_cross_slot` defaults to ``True`` - :paramref:`protocol_version`` defaults to ``3`` - .. versionadded:: 3.11.0 - Added :paramref:`noreply` - .. versionadded:: 3.9.0 - If :paramref:`cache` is provided the client will check & populate the cache for read only commands and invalidate it for commands that could change the key(s) in the request. - .. versionchanged:: 3.5.0 - The :paramref:`verify_version` parameter now defaults to ``True`` - .. versionadded:: 3.1.0 - The :paramref:`protocol_version` and :paramref:`verify_version` :parameters were added :param host: The hostname of the redis server :param port: The port at which th redis server is listening on :param db: database number to switch to upon connection :param username: Username for authenticating with the redis server :param password: Password for authenticating with the redis server :param stream_timeout: Timeout (seconds) when reading responses from the server :param connect_timeout: Timeout (seconds) for establishing a connection to the server :param connection_pool: The connection pool instance to use. If not provided a new pool will be assigned to this client. :param connection_pool_cls: The connection pool class to use when constructing a connection pool for this instance. :param unix_socket_path: Path to the UDS which the redis server is listening at :param encoding: The codec to use to encode strings transmitted to redis and decode responses with. (See :ref:`handbook/encoding:encoding/decoding`) :param decode_responses: If ``True`` string responses from the server will be decoded using :paramref:`encoding` before being returned. (See :ref:`handbook/encoding:encoding/decoding`) :param ssl: Whether to use an SSL connection :param ssl_context: If provided the :class:`ssl.SSLContext` will be used when establishing the connection. Otherwise either the default context (if no other ssl related parameters are provided) or a custom context based on the other ``ssl_*`` parameters will be used. :param ssl_keyfile: Path of the private key to use :param ssl_certfile: Path to the certificate corresponding to :paramref:`ssl_keyfile` :param ssl_cert_reqs: Whether to try to verify the server's certificates and how to behave if verification fails (See :attr:`ssl.SSLContext.verify_mode`). :param ssl_check_hostname: Whether to enable hostname checking when establishing an ssl connection. :param ssl_ca_certs: Path to a concatenated certificate authority file or a directory containing several CA certifcates to use for validating the server's certificates when :paramref:`ssl_cert_reqs` is not ``"none"`` (See :meth:`ssl.SSLContext.load_verify_locations`). :param max_connections: Maximum capacity of the connection pool (Ignored if :paramref:`connection_pool` is not ``None``. :param max_idle_time: Maximum number of a seconds an unused connection is cached before it is disconnected. :param idle_check_interval: Periodicity of idle checks (seconds) to release idle connections. :param client_name: The client name to identifiy with the redis server :param protocol_version: Whether to use the RESP (``2``) or RESP3 (``3``) protocol for parsing responses from the server (Default ``3``). (See :ref:`handbook/response:redis response`) :param verify_version: Validate redis server version against the documented version introduced before executing a command and raises a :exc:`CommandNotSupportedError` error if the required version is higher than the reported server version :param cache: If provided the cache will be used to avoid requests for read only commands if the client has already requested the data and it hasn't been invalidated. The cache is responsible for any mutations to the keys that happen outside of this client :param noreply: If ``True`` the client will not request a response for any commands sent to the server. :param noevict: Ensures that connections from the client will be excluded from the client eviction process even if we're above the configured client eviction threshold. :param notouch: Ensures that commands sent by the client will not alter the LRU/LFU of the keys they access. :param retry_policy: The retry policy to use when interacting with the redis server """ super().__init__( host=host, port=port, db=db, username=username, password=password, stream_timeout=stream_timeout, connect_timeout=connect_timeout, connection_pool=connection_pool, connection_pool_cls=connection_pool_cls, unix_socket_path=unix_socket_path, encoding=encoding, decode_responses=decode_responses, ssl=ssl, ssl_context=ssl_context, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, ssh_check_hostname=ssl_check_hostname, ssl_ca_certs=ssl_ca_certs, max_connections=max_connections, max_idle_time=max_idle_time, idle_check_interval=idle_check_interval, client_name=client_name, protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, noevict=noevict, notouch=notouch, retry_policy=retry_policy, **kwargs, ) self.cache = cache self._decodecontext: contextvars.ContextVar[Optional[bool],] = ( contextvars.ContextVar("decode", default=None) ) self._encodingcontext: contextvars.ContextVar[Optional[str],] = ( contextvars.ContextVar("decode", default=None) ) @classmethod @overload def from_url( cls: Type[RedisBytesT], url: str, db: Optional[int] = ..., *, decode_responses: Literal[False] = ..., protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> RedisBytesT: ... @classmethod @overload def from_url( cls: Type[RedisStringT], url: str, db: Optional[int] = ..., *, decode_responses: Literal[True], protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., notouch: bool = ..., retry_policy: RetryPolicy = ..., **kwargs: Any, ) -> RedisStringT: ...
[docs] @classmethod def from_url( cls: Type[RedisT], url: str, db: Optional[int] = None, *, decode_responses: bool = False, protocol_version: Literal[2, 3] = 3, verify_version: bool = True, noreply: bool = False, noevict: bool = False, notouch: bool = False, retry_policy: RetryPolicy = ConstantRetryPolicy( (ConnectionError, TimeoutError), 2, 0.01 ), **kwargs: Any, ) -> RedisT: """ Return a Redis client object configured from the given URL, which must use either the `redis:// scheme <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ for RESP connections or the ``unix://`` scheme for Unix domain sockets. For example: - ``redis://[:password]@localhost:6379/0`` - ``rediss://[:password]@localhost:6379/0`` - ``unix://[:password]@/path/to/socket.sock?db=0`` :paramref:`url` and :paramref:`kwargs` are passed as is to the :func:`coredis.ConnectionPool.from_url`. """ if decode_responses: return cls( decode_responses=True, protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, connection_pool=ConnectionPool.from_url( url, db=db, decode_responses=decode_responses, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, **kwargs, ), ) else: return cls( decode_responses=False, protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, connection_pool=ConnectionPool.from_url( url, db=db, decode_responses=decode_responses, protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, **kwargs, ), )
async def initialize(self) -> Redis[AnyStr]: if not self.connection_pool.initialized: await super().initialize() if self.cache: self.cache = await self.cache.initialize(self) return self
[docs] async def execute_command( self, command: bytes, *args: ValueT, callback: Callable[..., R] = NoopCallback(), **options: Optional[ValueT], ) -> R: """ Executes a command with configured retries and returns the parsed response """ return await self.retry_policy.call_with_retries( lambda: self._execute_command(command, *args, callback=callback, **options), before_hook=self.initialize, )
async def _execute_command( self, command: bytes, *args: ValueT, callback: Callable[..., R] = NoopCallback(), **options: Optional[ValueT], ) -> R: pool = self.connection_pool quick_release = self.should_quick_release(command) connection = await pool.get_connection( command, *args, acquire=not quick_release or self.requires_wait or self.requires_waitaof, ) if ( self.cache and isinstance(self.cache, SupportsClientTracking) and connection.tracking_client_id != self.cache.get_client_id(connection) ): self.cache.reset() await connection.update_tracking_client( True, self.cache.get_client_id(connection) ) try: if self.cache and command not in READONLY_COMMANDS: self.cache.invalidate(*KeySpec.extract_keys(command, *args)) request = await connection.create_request( command, *args, noreply=self.noreply, decode=options.get("decode", self._decodecontext.get()), encoding=self._encodingcontext.get(), ) maybe_wait = [ await self._ensure_wait(command, connection), await self._ensure_persistence(command, connection), ] reply = await request await asyncio.gather(*maybe_wait) if self.noreply: return None # type: ignore if isinstance(callback, AsyncPreProcessingCallback): await callback.pre_process( self, reply, version=self.protocol_version, **options ) return callback( reply, version=self.protocol_version, **options, ) except RedisError: connection.disconnect() raise finally: self._ensure_server_version(connection.server_version) if not quick_release or self.requires_wait or self.requires_waitaof: pool.release(connection) @overload def decoding( self, mode: Literal[False], encoding: Optional[str] = None ) -> ContextManager[Redis[bytes]]: ... @overload def decoding( self, mode: Literal[True], encoding: Optional[str] = None ) -> ContextManager[Redis[str]]: ...
[docs] @contextlib.contextmanager @versionadded(version="4.8.0") def decoding( self, mode: bool, encoding: Optional[str] = None ) -> Iterator[Redis[Any]]: """ Context manager to temporarily change the decoding behavior of the client :param mode: Whether to decode or not :param encoding: Optional encoding to use if decoding. If not provided the :paramref:`~coredis.Redis.encoding` parameter provided to the client will be used. Example:: client = coredis.Redis(decode_responses=True) await client.set("fubar", "baz") assert await client.get("fubar") == "baz" with client.decoding(False): assert await client.get("fubar") == b"baz" with client.decoding(True): assert await client.get("fubar") == "baz" """ prev_decode = self._decodecontext.get() prev_encoding = self._encodingcontext.get() self._decodecontext.set(mode) self._encodingcontext.set(encoding) try: yield self finally: self._decodecontext.set(prev_decode) self._encodingcontext.set(prev_encoding)
[docs] def monitor(self) -> Monitor[AnyStr]: """ Return an instance of a :class:`~coredis.commands.monitor.Monitor` The monitor can be used as an async iterator or individual commands can be fetched via :meth:`~coredis.commands.monitor.Monitor.get_command`. """ return Monitor[AnyStr](self)
[docs] def pubsub( self, ignore_subscribe_messages: bool = False, retry_policy: Optional[RetryPolicy] = None, **kwargs: Any, ) -> PubSub[AnyStr]: """ Return a Pub/Sub instance that can be used to subscribe to channels and patterns and receive messages that get published to them. :param ignore_subscribe_messages: Whether to skip subscription acknowledgement messages :param retry_policy: An explicit retry policy to use in the subscriber. """ return PubSub[AnyStr]( self.connection_pool, ignore_subscribe_messages=ignore_subscribe_messages, retry_policy=retry_policy, **kwargs, )
[docs] async def pipeline( self, transaction: Optional[bool] = True, watches: Optional[Parameters[KeyT]] = None, timeout: Optional[float] = None, ) -> "coredis.pipeline.Pipeline[AnyStr]": """ Returns a new pipeline object that can queue multiple commands for batch execution. :param transaction: indicates whether all commands should be executed atomically. :param watches: If :paramref:`transaction` is True these keys are watched for external changes during the transaction. :param timeout: If specified this value will take precedence over :paramref:`Redis.stream_timeout` """ from coredis.pipeline import Pipeline return Pipeline[AnyStr].proxy(self, transaction, watches, timeout)
[docs] async def transaction( self, func: Callable[["coredis.pipeline.Pipeline[AnyStr]"], Coroutine[Any, Any, Any]], *watches: KeyT, value_from_callable: bool = False, watch_delay: Optional[float] = None, **kwargs: Any, ) -> Optional[Any]: """ Convenience method for executing the callable :paramref:`func` as a transaction while watching all keys specified in :paramref:`watches`. :param func: callable should expect a single argument which is a :class:`coredis.pipeline.Pipeline` object retrieved by calling :meth:`~coredis.Redis.pipeline`. :param watches: The keys to watch during the transaction :param value_from_callable: Whether to return the result of transaction or the value returned from :paramref:`func` """ async with await self.pipeline(True) as pipe: while True: try: if watches: await pipe.watch(*watches) func_value = await func(pipe) exec_value = await pipe.execute() return func_value if value_from_callable else exec_value except WatchError: if watch_delay is not None and watch_delay > 0: await asyncio.sleep(watch_delay) continue