from __future__ import annotations
import warnings
from contextlib import asynccontextmanager
from typing import Any
from anyio import TASK_STATUS_IGNORED, CapacityLimiter, fail_after, sleep
from anyio.abc import TaskStatus
from coredis._concurrency import Queue, QueueEmpty, QueueFull
from coredis._telemetry import TelemetryProvider, get_telemetry_provider
from coredis._utils import logger, query_param_to_bool
from coredis.cluster._discovery import DiscoveryService
from coredis.cluster._layout import ClusterLayout
from coredis.cluster._node import ClusterNodeLocation
from coredis.connection import (
BaseConnectionParams,
ClusterConnection,
)
from coredis.connection._tcp import TCPLocation
from coredis.exceptions import RedisError
from coredis.patterns.cache import AbstractCache, ClusterTrackingCache
from coredis.pool._basic import ConnectionPoolParams
from coredis.typing import (
AsyncGenerator,
Callable,
ClassVar,
Iterable,
Node,
NotRequired,
Self,
Unpack,
)
from ._base import BaseConnectionPool
[docs]
class ClusterConnectionPoolParams(ConnectionPoolParams[ClusterConnection]):
"""
Parameters accepted by :class:`coredis.pool.ClusterConnectionPool`
"""
#: The initial collection of nodes to use to map the cluster solts to individual primary & replica nodes.
startup_nodes: NotRequired[Iterable[Node | TCPLocation]]
#: Skips the check of cluster-require-full-coverage config, useful for clusters
#: without the :rediscommand:`CONFIG` command (For example with AWS Elasticache)
skip_full_coverage_check: NotRequired[bool]
#: Whether to use the value of ``max_connections``
#: on a per node basis or cluster wide. If ``False`` the per-node connection pools will have
#: a maximum size of :paramref:`max_connections` divided by the number of nodes in the cluster.
max_connections_per_node: NotRequired[bool]
#: If ``True`` connections to replicas will be returned for readonly commands
read_from_replicas: NotRequired[bool]
#: Interval (in seconds) for performing a cleanup of the pool to
#: remove any connections that are no longer in the cluster layout.
gc_interval: NotRequired[int]
[docs]
class ClusterConnectionPool(BaseConnectionPool[ClusterConnection]):
"""
Custom connection pool for :class:`~coredis.RedisCluster` client
"""
URL_QUERY_ARGUMENT_PARSERS: ClassVar[
dict[str, Callable[..., int | float | bool | str | None]]
] = {
**BaseConnectionPool.URL_QUERY_ARGUMENT_PARSERS,
"max_connections_per_node": query_param_to_bool,
"reinitialize_steps": int,
"skip_full_coverage_check": query_param_to_bool,
"read_from_replicas": query_param_to_bool,
"gc_interval": int,
}
connection_class: type[ClusterConnection]
_cluster_available_connections: dict[TCPLocation, Queue[ClusterConnection]]
_online_connections: set[ClusterConnection]
def __init__(
self,
startup_nodes: Iterable[Node | TCPLocation],
*,
connection_class: type[ClusterConnection] = ClusterConnection,
max_connections: int | None = None,
max_connections_per_node: bool = False,
reinitialize_steps: int | None = None,
skip_full_coverage_check: bool = False,
nodemanager_follow_cluster: bool = True,
readonly: bool = False,
read_from_replicas: bool = False,
timeout: float | None = None,
gc_interval: int = 30,
_cache: AbstractCache | None = None,
**connection_kwargs: Unpack[BaseConnectionParams],
):
"""
Cluster aware connection pool that tracks and manages sub pools for each node
in the redis cluster
Changes
- .. deprecated:: 6.2.0
- :paramref:`startup_nodes` should not be passed as dictionaries
and instead migrate to use instances of :class:`~coredis.connection.TCPLocation`
- .. versionchanged:: 4.4.0
- :paramref:`nodemanager_follow_cluster` now defaults to ``True``
- .. deprecated:: 4.4.0
- :paramref:`readonly` renamed to :paramref:`read_from_replicas`
:param startup_nodes: The initial collection of nodes to use to map the cluster
solts to individual primary & replica nodes.
:param connection_class: The connection class to use when creating new connections
:param max_connections: Maximum number of connections to allow concurrently from this
client. If the value is ``None`` it will default to 64.
:param max_connections_per_node: Whether to use the value of :paramref:`max_connections`
on a per node basis or cluster wide. If ``False`` the per-node connection pools will have
a maximum size of :paramref:`max_connections` divided by the number of nodes in the cluster.
:param timeout: Number of seconds to block when trying to obtain a connection.
:param skip_full_coverage_check:
Skips the check of cluster-require-full-coverage config, useful for clusters
without the :rediscommand:`CONFIG` command (For example with AWS Elasticache)
:param nodemanager_follow_cluster: The node manager will during initialization try
the last set of nodes that it was operating on. This will allow the client to drift
along side the cluster if the cluster nodes move around alot.
:param read_from_replicas: If ``True`` connections to replicas will be returned for readonly
commands
:param gc_interval: Interval (in seconds) for performing a cleanup of the pool to
remove any connections that are no longer in the cluster layout.
:param connection_kwargs: arguments to pass to the :paramref:`connection_class`
constructor when creating a new connection
"""
super().__init__(
connection_class=connection_class,
**connection_kwargs,
)
self._initialized = False
self._gc_interval = gc_interval
self.timeout = timeout
self.max_connections = max_connections or 64
self.max_connections_per_node = max_connections_per_node
# TODO: Remove support for Node
if any(isinstance(node, dict) for node in startup_nodes):
warnings.warn(
"Use coredis.connection.TCPLocation to specify startup nodes",
DeprecationWarning,
stacklevel=2,
)
self.startup_nodes = [
node if isinstance(node, TCPLocation) else TCPLocation(node["host"], node["port"])
for node in startup_nodes
]
self.cluster_layout = ClusterLayout(
DiscoveryService(
startup_nodes=self.startup_nodes,
skip_full_coverage_check=skip_full_coverage_check,
follow_cluster=nodemanager_follow_cluster,
**connection_kwargs,
),
error_threshold=reinitialize_steps or 2,
)
# Used by the connection to limit concurrently entering
# CPU hotspots to ensure fairness between connections in the pool.
# The main observed scenario where this is useful is if the connection pool
# is being used by multiple push message consumers that are constantly
# receiving data in the read task.
self._connection_processing_budget = CapacityLimiter(1)
self.connection_kwargs["processing_budget"] = self._connection_processing_budget
self.read_from_replicas = read_from_replicas or readonly
# TODO: Use the `max_failures` argument of tracking cache
self.cache = ClusterTrackingCache(self, _cache) if _cache else None
self._reset()
if "stream_timeout" not in self.connection_kwargs:
self.connection_kwargs["stream_timeout"] = None
[docs]
@classmethod
def from_url(
cls: type[Self],
url: str,
*,
decode_components: bool = False,
**kwargs: Unpack[ClusterConnectionPoolParams],
) -> Self:
"""
Returns a cluster connection pool configured from the given URL.
"""
location, merged_options = cls._parse_url(
url, decode_components, kwargs, ClusterConnectionPoolParams
)
if "startup_nodes" not in merged_options and location:
assert isinstance(location, TCPLocation)
merged_options["startup_nodes"] = [Node(host=location.host, port=location.port)]
merged_options["connection_class"] = ClusterConnection
return cls(
**merged_options,
)
def __repr__(self) -> str:
"""
Returns a string with all unique ip:port combinations that this pool
is connected to
"""
return "{}<{}>".format(
type(self).__name__,
", ".join([node.name for node in list(self.cluster_layout.nodes)]),
)
async def _initialize(self) -> None:
await self.cluster_layout.initialize()
total_nodes = len(self.cluster_layout.nodes)
if not self.max_connections_per_node and self.max_connections < total_nodes:
warnings.warn(
f"The value of max_connections={self.max_connections} "
"should be atleast equal to the number of nodes "
f"({total_nodes}) in the cluster and has been increased by "
f"{total_nodes - self.max_connections} connections."
)
self.max_connections = total_nodes
await self.task_group.start(self.cluster_layout.monitor)
await self.task_group.start(self._cleanup)
if self.cache:
# TODO: handle cache failure so that the pool doesn't die
# if the cache fails.
await self.task_group.start(self.cache.run)
[docs]
async def get_connection(
self, node: ClusterNodeLocation | None = None, primary: bool = True, **options: Any
) -> ClusterConnection:
"""
Acquires a connection from the cluster pool. If no node
is specified a random node is picked. The connection
must be returned back to the pool using the :meth:`release`
method.
:param node: The node for which to get a connection from
:param primary: If False a connection from the replica will be returned
"""
connection: ClusterConnection
with get_telemetry_provider().capture_connection_wait_time(self):
if node:
connection = await self._get_connection_by_node(node)
else:
connection = await self._get_random_connection(primary=primary)
self.statistics.connection_leased(connection)
return connection
[docs]
@asynccontextmanager
async def acquire(
self,
node: ClusterNodeLocation | None = None,
primary: bool = True,
**options: Any,
) -> AsyncGenerator[ClusterConnection]:
"""
Acquires a connection from the cluster pool. If no node
is specified a random node is picked. The connection
will be automatically released back to the pool when
the context manager exits.
:param node: The node for which to get a connection from
:param primary: If False a connection from the replica will be returned
"""
connection = await self.get_connection(node=node, primary=primary, **options)
yield connection
self.release(connection)
[docs]
def release(self, connection: ClusterConnection) -> None:
"""Releases the connection back to the pool"""
assert isinstance(connection, ClusterConnection)
self.statistics.connection_released(connection)
try:
if connection.usable:
self._node_pool(connection.location).put_nowait(connection)
except QueueFull:
pass
[docs]
def disconnect(self) -> None:
for connection in self._online_connections:
connection.invalidate()
self._online_connections.clear()
def _reset(self) -> None:
"""Resets the connection pool back to a clean state"""
self._cluster_available_connections = {}
self._online_connections = set()
def telemetry_attributes(self, provider: TelemetryProvider) -> dict[str, str | int]:
return {
"db.client.connection.pool.name": ",".join(
[node.name for node in list(self.cluster_layout.nodes)]
)
}
async def _wrap_connection(
self,
connection: ClusterConnection,
*,
task_status: TaskStatus[None | Exception] = TASK_STATUS_IGNORED,
) -> None:
try:
await connection.run(task_status=task_status)
except RedisError as error:
# Only coredis.exception.RedisError is explictly caught and returned with the task status
# As these are clear signals that an error case was handled by the connection
task_status.started(error)
finally:
node_pool = self._node_pool(connection.location)
self._online_connections.discard(connection)
if connection in node_pool:
node_pool.remove(connection)
node_pool.append_nowait(None)
async def _make_node_connection(self, node: ClusterNodeLocation) -> ClusterConnection:
"""Creates a new connection to a node"""
location = TCPLocation(node.host, node.port)
with get_telemetry_provider().capture_connection_create_time(self):
connection = self.connection_class(
location=location,
read_from_replicas=self.read_from_replicas and node.server_type == "replica",
**self.connection_kwargs,
)
if err := await self.task_group.start(self._wrap_connection, connection):
raise err
self.statistics.connection_created(connection)
self._online_connections.add(connection)
return connection
def _node_pool(self, location: TCPLocation) -> Queue[ClusterConnection]:
if self._cluster_available_connections.get(location) is None:
self._cluster_available_connections[location] = self._default_node_queue()
return self._cluster_available_connections[location]
def _default_node_queue(
self,
) -> Queue[ClusterConnection]:
q_size = max(
1,
self.max_connections
if self.max_connections_per_node
else self.max_connections // len(list(self.cluster_layout.nodes)),
)
return Queue[ClusterConnection](q_size)
async def _get_random_connection(self, primary: bool = False) -> ClusterConnection:
return await self._get_connection_by_node(self.cluster_layout.random_node(primary))
async def _get_connection_by_node(self, node: ClusterNodeLocation) -> ClusterConnection:
location = TCPLocation(node.host, node.port)
with fail_after(self.timeout):
connection = await self._node_pool(location).get()
if not connection or not connection.reusable:
# If the connection was in the pool but is "dirty" it should be
# invalidated (to avoid leaking) and discarded.
if connection:
connection.invalidate()
try:
connection = await self._make_node_connection(node)
except Exception:
self._node_pool(location).append_nowait(None)
raise
return connection
async def _cleanup(self, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
task_status.started()
while True:
if self.cluster_layout.nodes:
for location, pool in list(self._cluster_available_connections.items()):
if not self.cluster_layout.node_for_location(location):
dead_queue = self._cluster_available_connections.pop(location)
connections = []
try:
while True:
if c := dead_queue.get_nowait():
connections.append(c)
except QueueEmpty:
pass
logger.info(
f"Node for {location} is no longer in cluster layout, releasing from connection pool (connections: {len(connections)}"
)
for connection in connections:
connection.invalidate()
await sleep(self._gc_interval)