Locks

coredis.patterns.lock

coredis ships with an implementation (Lock) of distributed locking that can be used with a single redis instance or a cluster. The lock object can either be used as an async context manager or acquired explicitly using the acquire() method ( and subsequently released using release()).

For convenience, factory methods coredis.Redis.lock() & coredis.RedisCluster.lock() are available in the clients.

As an example, let’s try to implement an atomic increment operation using a distributed lock

import asyncio
import coredis

async def increment(client: coredis.Redis, key: str) -> int:
    async with client.lock(f"increment:{key}") as lock:
        value = int(await client.get(key) or 0)
        await client.set(key, int(value)+1)

async def test():
    async with coredis.Redis() as client:
        await client.delete(["fubar"])
        await asyncio.gather(
            *(increment(client, "fubar") for _ in range(64))
        )
        assert int(await client.get("fubar")) ==  64

asyncio.run(test())

Implementation

The implementation is based on the distributed locking pattern described in redis docs

When used with a RedisCluster instance, acquiring the lock includes ensuring that the token set by the acquire() method is replicated to atleast n/2 replicas using the ensure_replication() context manager.

The implementation uses the following LUA scripts:

  1. Release the lock

    -- KEYS[1] - lock name
    -- ARGS[1] - token
    -- return 1 if the lock was released, otherwise 0
    
    local token = redis.call('get', KEYS[1])
    if not token or token ~= ARGV[1] then
        return 0
    end
    redis.call('del', KEYS[1])
    return 1
    
  2. Extend the lock

    -- KEYS[1] - lock name
    -- ARGS[1] - token
    -- ARGS[2] - additional milliseconds
    -- return 1 if the locks time was extended, otherwise 0
    
    local token = redis.call('get', KEYS[1])
    if not token or token ~= ARGV[1] then
        return 0
    end
    local expiration = redis.call('pttl', KEYS[1])
    if not expiration then
        expiration = 0
    end
    if expiration < 0 then
        return 0
    end
    redis.call('pexpire', KEYS[1], expiration + ARGV[2])
    return 1