Source code for coredis.pool.basic

from __future__ import annotations

import asyncio
import os
import threading
import time
import warnings
from itertools import chain
from ssl import SSLContext, VerifyMode
from typing import Any, cast
from urllib.parse import parse_qs, unquote, urlparse

import async_timeout

from coredis.connection import (
    BaseConnection,
    Connection,
    RedisSSLContext,
    UnixDomainSocketConnection,
)
from coredis.exceptions import ConnectionError
from coredis.typing import (
    Callable,
    ClassVar,
    Dict,
    List,
    Optional,
    Set,
    StringT,
    Type,
    TypeVar,
    Union,
    ValueT,
)

FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")


def to_bool(value: Optional[StringT]) -> Optional[bool]:
    if value is None or value == "":
        return None

    if isinstance(value, str) and value.upper() in FALSE_STRINGS:
        return False

    return bool(value)


_CPT = TypeVar("_CPT", bound="ConnectionPool")


[docs] class ConnectionPool: """Generic connection pool""" #: Mapping of querystring arguments to their parser functions URL_QUERY_ARGUMENT_PARSERS: ClassVar[ Dict[str, Callable[..., Optional[Union[int, float, bool, str]]]] ] = { "client_name": str, "stream_timeout": float, "connect_timeout": float, "max_connections": int, "max_idle_time": int, "protocol_version": int, "idle_check_interval": int, "noreply": bool, "noevict": bool, "notouch": bool, }
[docs] @classmethod def from_url( cls: Type[_CPT], url: str, db: Optional[int] = None, decode_components: bool = False, **kwargs: Any, ) -> _CPT: """ Returns a connection pool configured from the given URL. For example: - ``redis://[:password]@localhost:6379/0`` - ``rediss://[:password]@localhost:6379/0`` - ``unix://[:password]@/path/to/socket.sock?db=0`` Three URL schemes are supported: - `redis:// <http://www.iana.org/assignments/uri-schemes/prov/redis>`__ creates a normal TCP socket connection - `rediss:// <http://www.iana.org/assignments/uri-schemes/prov/rediss>`__ creates a SSL wrapped TCP socket connection - ``unix://`` creates a Unix Domain Socket connection There are several ways to specify a database number. The parse function will return the first specified option: - A ``db`` querystring option, e.g. ``redis://localhost?db=0`` - If using the ``redis://`` scheme, the path argument of the url, e.g. ``redis://localhost/0`` - The ``db`` argument to this function. If none of these options are specified, ``db=0`` is used. The :paramref:`decode_components` argument allows this function to work with percent-encoded URLs. If this argument is set to ``True`` all ``%xx`` escapes will be replaced by their single-character equivalents after the URL has been parsed. This only applies to the ``hostname``, ``path``, and ``password`` components. See :attr:`URL_QUERY_ARGUMENT_PARSERS` for a comprehensive mapping of querystring parameters to how they are parsed. Any additional querystring arguments and keyword arguments will be passed along to the class constructor. The querystring arguments ``connect_timeout`` and ``stream_timeout`` if supplied are parsed as float values. .. note:: In the case of conflicting arguments, querystring arguments always win. """ parsed_url = urlparse(url) qs = parsed_url.query url_options: Dict[ str, Optional[Union[int, float, bool, str, Type[BaseConnection], SSLContext]], ] = {} for name, value in iter(parse_qs(qs).items()): if value and len(value) > 0: parser = cls.URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: url_options[name] = parser(value[0]) except (TypeError, ValueError): warnings.warn( UserWarning( "Invalid value for `%s` in connection URL." % name ) ) else: url_options[name] = value[0] username: Optional[str] = parsed_url.username password: Optional[str] = parsed_url.password path: Optional[str] = parsed_url.path hostname: Optional[str] = parsed_url.hostname if decode_components: username = unquote(username) if username else None password = unquote(password) if password else None path = unquote(path) if path else None hostname = unquote(hostname) if hostname else None # We only support redis:// and unix:// schemes. if parsed_url.scheme == "unix": url_options.update( { "username": username, "password": password, "path": path, "connection_class": UnixDomainSocketConnection, } ) else: url_options.update( { "host": hostname, "port": int(parsed_url.port or 6379), "username": username, "password": password, } ) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified if "db" not in url_options and path: try: url_options["db"] = int(path.replace("/", "")) except (AttributeError, ValueError): pass if parsed_url.scheme == "rediss": keyfile = cast(Optional[str], url_options.pop("ssl_keyfile", None)) certfile = cast(Optional[str], url_options.pop("ssl_certfile", None)) check_hostname = cast( Optional[bool], url_options.pop("ssl_check_hostname", None) ) cert_reqs = cast( Optional[Union[str, VerifyMode]], url_options.pop("ssl_cert_reqs", None), ) ca_certs = cast(Optional[str], url_options.pop("ssl_ca_certs", None)) url_options["ssl_context"] = RedisSSLContext( keyfile, certfile, cert_reqs, ca_certs, check_hostname ).get() # last shot at the db value _db = url_options.get("db", db or 0) assert isinstance(_db, (int, str, bytes)) url_options["db"] = int(_db) # update the arguments from the URL values kwargs.update(url_options) return cls(**kwargs)
def __init__( self, *, connection_class: Optional[Type[Connection]] = None, max_connections: Optional[int] = None, max_idle_time: int = 0, idle_check_interval: int = 1, **connection_kwargs: Optional[Any], ) -> None: """ Creates a connection pool. If :paramref:`max_connections` is set, then this object raises :class:`~coredis.ConnectionError` when the pool's limit is reached. By default, TCP connections are created :paramref:`connection_class` is specified. Use :class:`~coredis.UnixDomainSocketConnection` for unix sockets. Any additional keyword arguments are passed to the constructor of connection_class. """ self.connection_class = connection_class or Connection self.connection_kwargs = connection_kwargs self.max_connections = max_connections or 2**31 self.max_idle_time = max_idle_time self.idle_check_interval = idle_check_interval self.initialized = False self.reset() async def initialize(self) -> None: self.initialized = True def __repr__(self) -> str: return "{}<{}>".format( type(self).__name__, self.connection_class.describe(self.connection_kwargs) ) def __del__(self) -> None: self.disconnect() async def disconnect_on_idle_time_exceeded(self, connection: Connection) -> None: while True: if ( time.time() - connection.last_active_at > self.max_idle_time and not connection.requests_pending ): connection.disconnect() if connection in self._available_connections: self._available_connections.remove(connection) self._created_connections -= 1 break await asyncio.sleep(self.idle_check_interval) def reset(self) -> None: self.pid = os.getpid() self._created_connections = 0 self._available_connections: List[Connection] = [] self._in_use_connections: Set[Connection] = set() self._check_lock = threading.Lock() def checkpid(self) -> None: # noqa if self.pid != os.getpid(): with self._check_lock: # Double check if self.pid == os.getpid(): return self.disconnect() self.reset() def peek_available(self) -> Optional[BaseConnection]: return self._available_connections[0] if self._available_connections else None
[docs] async def get_connection( self, command_name: Optional[bytes] = None, *args: ValueT, acquire: bool = True, **kwargs: Optional[ValueT], ) -> Connection: """Gets a connection from the pool""" self.checkpid() try: connection = self._available_connections.pop() if connection.is_connected and connection.needs_handshake: await connection.perform_handshake() except IndexError: if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") connection = self._make_connection(**kwargs) if acquire: self._in_use_connections.add(connection) else: self._available_connections.append(connection) return connection
[docs] def release(self, connection: Connection) -> None: """ Releases the :paramref:`connection` back to the pool """ self.checkpid() if connection.pid == self.pid: self._in_use_connections.remove(connection) self._available_connections.append(connection)
[docs] def disconnect(self) -> None: """Closes all connections in the pool""" all_conns = chain(self._available_connections, self._in_use_connections) for connection in all_conns: connection.disconnect() self._created_connections -= 1
def _make_connection(self, **options: Optional[ValueT]) -> Connection: """ Creates a new connection """ self._created_connections += 1 connection = self.connection_class( **self.connection_kwargs, # type: ignore ) if self.max_idle_time > self.idle_check_interval > 0: # do not await the future asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) return connection
[docs] class BlockingConnectionPool(ConnectionPool): """ Blocking connection pool:: >>> from coredis import Redis >>> client = Redis(connection_pool=BlockingConnectionPool()) It performs the same function as the default :class:`~coredis.ConnectionPool`, in that, it maintains a pool of reusable connections that can be shared by multiple redis clients. The difference is that, in the event that a client tries to get a connection from the pool when all of the connections are in use, rather than raising a :exc:`~coredis.ConnectionError` (as the default :class:`~coredis.ConnectionPool` implementation does), it makes the client blocks for a specified number of seconds until a connection becomes available. Use :paramref:`max_connections` to increase / decrease the pool size:: >>> pool = BlockingConnectionPool(max_connections=10) Use :paramref:`timeout` to tell it either how many seconds to wait for a connection to become available, or to block forever:: >>> # Block forever. >>> pool = BlockingConnectionPool(timeout=None) >>> # Raise a ``ConnectionError`` after five seconds if a connection is >>> # not available. >>> pool = BlockingConnectionPool(timeout=5) """ def __init__( self, connection_class: Optional[Type[Connection]] = None, queue_class: Type[asyncio.Queue[Optional[Connection]]] = asyncio.LifoQueue, max_connections: Optional[int] = None, timeout: int = 20, max_idle_time: int = 0, idle_check_interval: int = 1, **connection_kwargs: Optional[ValueT], ): self.timeout = timeout self.queue_class = queue_class self.total_wait = 0 self.total_allocated = 0 max_connections = max_connections or 50 super().__init__( connection_class=connection_class or Connection, max_connections=max_connections, max_idle_time=max_idle_time, idle_check_interval=idle_check_interval, **connection_kwargs, ) async def disconnect_on_idle_time_exceeded(self, connection: Connection) -> None: while True: if time.time() - connection.last_active_at > self.max_idle_time: # Unlike the non blocking pool, we don't free the connection object, # but always reuse it connection.disconnect() break await asyncio.sleep(self.idle_check_interval) def reset(self) -> None: self._pool: asyncio.Queue[Optional[Connection]] = self.queue_class( self.max_connections ) while True: try: self._pool.put_nowait(None) except asyncio.QueueFull: break super().reset() def peek_available(self) -> Optional[BaseConnection]: return ( self._pool._queue[-1] # type: ignore if (self._pool and not self._pool.empty()) else None )
[docs] async def get_connection( self, command_name: Optional[bytes] = None, *args: ValueT, acquire: bool = True, **kwargs: Optional[ValueT], ) -> Connection: """Gets a connection from the pool""" self.checkpid() try: async with async_timeout.timeout(self.timeout): connection = await self._pool.get() if connection and connection.is_connected and connection.needs_handshake: await connection.perform_handshake() except asyncio.TimeoutError: raise ConnectionError("No connection available.") if connection is None: connection = self._make_connection() if acquire: self._in_use_connections.add(connection) else: self._pool.put_nowait(connection) return connection
[docs] def release(self, connection: Connection) -> None: """Releases the connection back to the pool""" _connection: Optional[Connection] = connection self.checkpid() if _connection and _connection.pid == self.pid: self._in_use_connections.remove(_connection) try: self._pool.put_nowait(_connection) except asyncio.QueueFull: _connection.disconnect()
[docs] def disconnect(self) -> None: """Closes all connections in the pool""" pooled_connections: List[Optional[Connection]] = [] while True: try: pooled_connections.append(self._pool.get_nowait()) except asyncio.QueueEmpty: break for conn in pooled_connections: self._pool.put_nowait(conn) all_conns = chain(pooled_connections, self._in_use_connections) for connection in all_conns: if connection is not None: connection.disconnect()