Source code for coredis.connection._tcp
from __future__ import annotations
import dataclasses
import socket
from anyio import connect_tcp, fail_after
from anyio.abc import ByteStream, SocketAttribute
from coredis._telemetry import TelemetryProvider
from coredis.typing import Unpack
from ._base import BaseConnection, BaseConnectionParams, Location
[docs]
@dataclasses.dataclass(unsafe_hash=True)
class TCPLocation(Location):
"""Location of a redis instance listening on a tcp port"""
#: hostname of the server
host: str
#: the port the server is listening on
port: int
def __repr__(self) -> str:
return f"{self.host}:{self.port}"
[docs]
async def check(self) -> bool:
try:
async with await connect_tcp(self.host, self.port):
return True
except OSError:
return False
return False
[docs]
class TCPConnection(BaseConnection):
location: TCPLocation
def __init__(
self,
location: TCPLocation,
*,
socket_keepalive: bool | None = None,
socket_keepalive_options: dict[int, int | bytes] | None = None,
**kwargs: Unpack[BaseConnectionParams],
):
super().__init__(location, **kwargs)
self._socket_keepalive = socket_keepalive
self._socket_keepalive_options: dict[int, int | bytes] = socket_keepalive_options or {}
# FIXME: this is only for backward compatibility as 6.0 still had
# host/port in TCP connections
self.host = self.location.host
self.port = self.location.port
async def _connect(self) -> ByteStream:
with fail_after(self._connect_timeout):
if self._ssl_context:
connection: ByteStream = await connect_tcp(
self.location.host,
self.location.port,
tls=True,
ssl_context=self._ssl_context,
tls_standard_compatible=False,
)
else:
connection = await connect_tcp(self.location.host, self.location.port)
sock = connection.extra(SocketAttribute.raw_socket, default=None)
if sock is not None:
if self._socket_keepalive: # TCP_KEEPALIVE
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in self._socket_keepalive_options.items():
sock.setsockopt(socket.SOL_TCP, k, v)
return connection
def describe(self) -> str:
return f"Connection<host={self.location.host},port={self.location.port},db={self._db}>"
def telemetry_attributes(self, provider: TelemetryProvider) -> dict[str, str | int]:
return {
**super().telemetry_attributes(provider),
**{
"network.peer.hostname": self.location.host,
"network.peer.port": self.location.port,
"server.address": self.location.host,
"server.port": self.location.port,
},
}