Source code for coredis.pool._basic

from __future__ import annotations

from contextlib import asynccontextmanager
from typing import Any, cast, overload

from anyio import (
    TASK_STATUS_IGNORED,
    CapacityLimiter,
    fail_after,
    move_on_after,
)
from anyio.abc import TaskStatus

from coredis._concurrency import Queue
from coredis._telemetry import TelemetryProvider, get_telemetry_provider
from coredis.connection._base import (
    BaseConnectionParams,
    ConnectionT,
    Location,
)
from coredis.connection._tcp import TCPConnection, TCPLocation
from coredis.connection._uds import UnixDomainSocketConnection, UnixDomainSocketLocation
from coredis.exceptions import ConnectionError, RedisError
from coredis.patterns.cache import AbstractCache, NodeTrackingCache, TrackingCache
from coredis.typing import (
    AsyncGenerator,
    Callable,
    ClassVar,
    NotRequired,
    Self,
    Unpack,
)

from ._base import BaseConnectionPool, BaseConnectionPoolParams


[docs] class ConnectionPoolParams(BaseConnectionPoolParams[ConnectionT]): """ Parameters accepted by :class:`coredis.pool.ConnectionPool` """ #: :meta private: _cache: NotRequired[AbstractCache | None]
[docs] class ConnectionPool(BaseConnectionPool[ConnectionT]): URL_QUERY_ARGUMENT_PARSERS: ClassVar[ dict[str, Callable[..., int | float | bool | str | None]] ] = { **BaseConnectionPool.URL_QUERY_ARGUMENT_PARSERS, } @overload def __init__( self: ConnectionPool[ConnectionT], *, connection_class: type[ConnectionT], location: Location | None = ..., max_connections: int | None = ..., timeout: float | None = ..., _cache: AbstractCache | None = ..., **connection_kwargs: Unpack[BaseConnectionParams], ): ... @overload def __init__( self: ConnectionPool[TCPConnection], *, connection_class: None = ..., location: TCPLocation, max_connections: int | None = ..., timeout: float | None = ..., _cache: AbstractCache | None = ..., **connection_kwargs: Unpack[BaseConnectionParams], ): ... @overload def __init__( self: ConnectionPool[UnixDomainSocketConnection], *, connection_class: None = ..., location: UnixDomainSocketLocation, max_connections: int | None = ..., timeout: float | None = ..., _cache: AbstractCache | None = ..., **connection_kwargs: Unpack[BaseConnectionParams], ): ... @overload def __init__( self: ConnectionPool[TCPConnection], *, connection_class: None = ..., location: None = ..., # Retained for backward compatibility host: str = ..., port: int = ..., max_connections: int | None = ..., timeout: float | None = ..., _cache: AbstractCache | None = ..., **connection_kwargs: Unpack[BaseConnectionParams], ): ... def __init__( self, *, connection_class: type[ConnectionT] | None = None, location: Location | None = None, max_connections: int | None = None, timeout: float | None = None, _cache: AbstractCache | None = None, # host/port retained for backward compatibility host: str | None = None, port: int | None = None, **connection_kwargs: Unpack[BaseConnectionParams], ) -> None: """ Blocking connection pool for single instance redis clients :param connection_class: The connection class to use when creating new connections :param max_connections: Maximum connections to grow the pool. Once the limit is reached clients will block to wait for a connection to be returned to the pool. :param timeout: Number of seconds to block when trying to obtain a connection. :param connection_kwargs: arguments to pass to the :paramref:`connection_class` constructor when creating a new connection """ if connection_class is None: if isinstance(location, TCPLocation): connection_class = cast(type[ConnectionT], TCPConnection) elif isinstance(location, UnixDomainSocketLocation): connection_class = cast(type[ConnectionT], UnixDomainSocketConnection) elif host is not None and port is not None: connection_class = cast(type[ConnectionT], TCPConnection) location = TCPLocation(host, port) if not connection_class: raise RuntimeError("Unable to initialize pool without a `connection_class`") super().__init__( connection_class=connection_class, location=location, max_connections=max_connections, timeout=timeout, **connection_kwargs, ) # TODO: Use the `max_failures` argument of tracking cache self.cache: TrackingCache[Any] | None = NodeTrackingCache(self, _cache) if _cache else None # The pool of available connections self._available_connections: Queue[ConnectionT] = Queue(self.max_connections) # All connections that are still active self._online_connections: set[ConnectionT] = set() # Used by the connection to limit concurrently entering # CPU hotspots to ensure fairness between connections in the pool. # The main observed scenario where this is useful is if the connection pool # is being used by multiple push message consumers that are constantly # receiving data in the read task. self._connection_processing_budget = CapacityLimiter(1) self.connection_kwargs["processing_budget"] = self._connection_processing_budget @overload @classmethod def from_url( cls: type[ConnectionPool[Any]], url: str, *, decode_components: bool = False, **kwargs: Unpack[ConnectionPoolParams[Any]], ) -> ConnectionPool[TCPConnection] | ConnectionPool[UnixDomainSocketConnection]: ... @overload @classmethod def from_url( cls: type[Self], url: str, *, decode_components: bool = False, **kwargs: Unpack[ConnectionPoolParams[Any]], ) -> Self: ...
[docs] @classmethod def from_url( cls: type[Self], url: str, *, decode_components: bool = False, **kwargs: Unpack[ConnectionPoolParams[Any]], ) -> Self | ConnectionPool[TCPConnection] | ConnectionPool[UnixDomainSocketConnection]: """ Returns a connection pool configured from the given URL. For example: - ``redis://[:password]@localhost:6379/0`` - ``rediss://[:password]@localhost:6379/0`` - ``unix://[:password]@/path/to/socket.sock?db=0`` Three URL schemes are supported: - `redis:// <http://www.iana.org/assignments/uri-schemes/prov/redis>`__ creates a normal TCP socket connection - `rediss:// <http://www.iana.org/assignments/uri-schemes/prov/rediss>`__ creates a SSL wrapped TCP socket connection - ``unix://`` creates a Unix Domain Socket connection There are several ways to specify a database number. The parse function will return the first specified option: - A ``db`` querystring option, e.g. ``redis://localhost?db=0`` - If using the ``redis://`` scheme, the path argument of the url, e.g. ``redis://localhost/0`` - The ``db`` argument to this function. If none of these options are specified, ``db=0`` is used. The :paramref:`decode_components` argument allows this function to work with percent-encoded URLs. If this argument is set to ``True`` all ``%xx`` escapes will be replaced by their single-character equivalents after the URL has been parsed. This only applies to the ``hostname``, ``path``, ``username`` & ``password`` components. Any additional querystring arguments and keyword arguments will be passed along to the class constructor. .. note:: In the case of conflicting arguments, querystring arguments always win. """ location, merged_options = cls._parse_url( url, decode_components, kwargs, ConnectionPoolParams ) if isinstance(location, UnixDomainSocketLocation): merged_options["connection_class"] = UnixDomainSocketConnection else: merged_options["connection_class"] = TCPConnection return cls( location=location, **merged_options, )
async def _initialize(self) -> None: if self.cache: with move_on_after(self.connect_timeout) as cancel_scope: await self.task_group.start(self.cache.run) if cancel_scope.cancelled_caught: raise ConnectionError( f"Unable to initialize cache within {self.connect_timeout} seconds" )
[docs] async def get_connection(self, **_: Any) -> ConnectionT: """ Gets or create a connection from the pool. Be careful to only release the connection AFTER all commands are sent, or race conditions are possible. """ with get_telemetry_provider().capture_connection_wait_time(self): with fail_after(self.timeout): # if stack has a connection, use that connection = await self._available_connections.get() if connection is None or not connection.reusable: # If the connection was in the pool but is "dirty" it should be # invalidated (to avoid leaking) and discarded. if connection: connection.invalidate() with get_telemetry_provider().capture_connection_create_time(self): connection = await self._construct_connection() if err := await self.task_group.start(self.__wrap_connection, connection): self._available_connections.append_nowait(None) raise err self.statistics.connection_created(connection) self._online_connections.add(connection) self.statistics.connection_leased(connection) return connection
[docs] @asynccontextmanager async def acquire(self, **_: Any) -> AsyncGenerator[ConnectionT]: """ Gets or creates a connection from the pool, then release it afterwards. Multiplexing is automatic if you exit the context manager before waiting for command results. .. caution:: Do not explicitly release connections acquired using this context manager. """ connection = await self.get_connection() yield connection self.release(connection)
[docs] def release(self, connection: ConnectionT) -> None: """ Checks connection for liveness and releases it back to the pool. """ self.statistics.connection_released(connection) if connection.usable: self._available_connections.put_nowait(connection)
[docs] def disconnect(self) -> None: """ Disconnect all active connections in the pool """ for connection in self._online_connections: connection.invalidate() self._online_connections.clear()
def _reset(self) -> None: # TODO: seems like something should be cleared? pass def telemetry_attributes(self, provider: TelemetryProvider) -> dict[str, str | int]: return {"db.client.connection.pool.name": str(self.location)} async def _construct_connection(self) -> ConnectionT: assert self.location return self.connection_class(self.location, **self.connection_kwargs) async def __wrap_connection( self, connection: ConnectionT, *, task_status: TaskStatus[None | Exception] = TASK_STATUS_IGNORED, ) -> None: try: await connection.run(task_status=task_status) except RedisError as error: # Only coredis.exception.RedisError is explictly caught and returned with the task status # As these are clear signals that an error case was handled by the connection task_status.started(error) finally: self._online_connections.discard(connection) if connection in self._available_connections: self._available_connections.remove(connection) self._available_connections.append_nowait(None)