Source code for dask_cuda.explicit_comms.comms

# Copyright (c) 2021-2025 NVIDIA CORPORATION.
import asyncio
import concurrent.futures
import contextlib
import time
import uuid
import weakref
from typing import Any, Dict, Hashable, Iterable, List, Optional

import distributed.comm
from dask.tokenize import tokenize
from distributed import Client, Worker, default_client, get_worker
from distributed.comm.addressing import parse_address, parse_host_port, unparse_address

# Mapping tokenize(client ID, [worker addresses]) to CommsContext
_comms_cache: weakref.WeakValueDictionary[
    str, "CommsContext"
] = weakref.WeakValueDictionary()


def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
    """Return either a MultiLock or a NULL context

    Parameters
    ----------
    multi_lock_context: bool
        If True return MultiLock context else return a NULL context that
        doesn't do anything

    *args, **kwargs:
        Arguments parsed to the MultiLock creation

    Returns
    -------
    context: context
        Either `MultiLock(*args, **kwargs)` or a NULL context
    """
    if multi_lock_context:
        from distributed import MultiLock

        return MultiLock(*args, **kwargs)
    else:
        return contextlib.nullcontext()


def default_comms(client: Optional[Client] = None) -> "CommsContext":
    """Return the default comms object for ``client``.

    Creates a new default comms object if one does not already exist
    for ``client``.

    Parameters
    ----------
    client: Client, optional
        If no default comm object exists, create the new comm on `client`
        are returned.

    Returns
    -------
    comms: CommsContext
        The default comms object

    Notes
    -----
    There are some subtle points around explicit-comms and the lifecycle
    of a Dask Cluster.

    A :class:`CommsContext` establishes explicit communication channels
    between the workers *at the time it's created*. If workers are added
    or removed, they will not be included in the communication channels
    with the other workers.

    If you need to refresh the explicit communications channels, then
    create a new :class:`CommsContext` object or call ``default_comms``
    again after workers have been added to or removed from the cluster.
    """
    # Comms are unique to a {client, [workers]} pair, so we key our
    # cache by the token of that.
    client = client or default_client()
    token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
    maybe_comms = _comms_cache.get(token)
    if maybe_comms is None:
        maybe_comms = CommsContext(client=client)
        _comms_cache[token] = maybe_comms

    return maybe_comms


def worker_state(sessionId: Optional[int] = None) -> dict:
    """Retrieve the state(s) of the current worker

    Parameters
    ----------
    sessionId: int, optional
        Worker session state ID. If None, all states of the worker
        are returned.

    Returns
    -------
    state: dict
        Either a single state dict or a dict of state dict
    """
    worker: Any = get_worker()
    if not hasattr(worker, "_explicit_comm_state"):
        worker._explicit_comm_state = {}
    if sessionId is not None:
        if sessionId not in worker._explicit_comm_state:
            worker._explicit_comm_state[sessionId] = {
                "ts": time.time(),
                "eps": {},
                "loop": worker.loop.asyncio_loop,
                "worker": worker,
            }
        return worker._explicit_comm_state[sessionId]
    return worker._explicit_comm_state


def _run_coroutine_on_worker(sessionId, coroutine, args):
    session_state = worker_state(sessionId)

    def _run():
        future = asyncio.run_coroutine_threadsafe(
            coroutine(session_state, *args), session_state["loop"]
        )
        return future.result()

    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        return executor.submit(_run).result()


async def _create_listeners(session_state, nworkers, rank):
    assert session_state["loop"] is asyncio.get_event_loop()
    assert "nworkers" not in session_state
    session_state["nworkers"] = nworkers
    assert "rank" not in session_state
    session_state["rank"] = rank

    async def server_handler(ep):
        peer_rank = await ep.read()
        session_state["eps"][peer_rank] = ep

    # We listen on the same protocol and address as the worker address
    protocol, address = parse_address(session_state["worker"].address)
    address = parse_host_port(address)[0]
    address = unparse_address(protocol, address)

    session_state["lf"] = distributed.comm.listen(address, server_handler)
    await session_state["lf"].start()
    return session_state["lf"].listen_address


async def _create_endpoints(session_state, peers):
    """Each worker creates a UCX endpoint to all workers with greater rank"""
    assert session_state["loop"] is asyncio.get_event_loop()

    myrank = session_state["rank"]
    peers = list(enumerate(peers))

    # Create endpoints to workers with a greater rank than my rank
    for rank, address in peers[myrank + 1 :]:
        ep = await distributed.comm.connect(address)
        await ep.write(session_state["rank"])
        session_state["eps"][rank] = ep

    # Block until all endpoints has been created
    while len(session_state["eps"]) < session_state["nworkers"] - 1:
        await asyncio.sleep(0.1)


async def _stop_ucp_listeners(session_state):
    assert len(session_state["eps"]) == session_state["nworkers"] - 1
    assert session_state["loop"] is asyncio.get_event_loop()
    session_state["lf"].stop()
    del session_state["lf"]


async def _stage_keys(session_state: dict, name: str, keys: set):
    worker: Worker = session_state["worker"]
    data = worker.data
    my_keys = keys.intersection(data)

    stages = session_state.get("stages", {})
    stage = stages.get(name, {})
    for k in my_keys:
        stage[k] = data[k]
    stages[name] = stage
    session_state["stages"] = stages
    return (session_state["rank"], my_keys)


[docs] class CommsContext: """Communication handler for explicit communication Parameters ---------- client: Client, optional Specify client to use for communication. If None, use the default client. """ client: Client sessionId: int worker_addresses: List[str] def __init__(self, client: Optional[Client] = None): self.client = client if client is not None else default_client() self.sessionId = uuid.uuid4().int # Get address of all workers (not Nanny addresses) self.worker_addresses = list(self.client.scheduler_info()["workers"].keys()) # Make all workers listen and get all listen addresses self.worker_direct_addresses = [] for rank, address in enumerate(self.worker_addresses): self.worker_direct_addresses.append( self.submit( address, _create_listeners, len(self.worker_addresses), rank, wait=True, ) ) # Each worker creates an endpoint to all workers with greater rank self.run(_create_endpoints, self.worker_direct_addresses) # At this point all workers should have a rank and endpoints to # all other workers thus we can now stop the listening. self.run(_stop_ucp_listeners)
[docs] def submit(self, worker, coroutine, *args, wait=False): """Run a coroutine on a single worker The coroutine is given the worker's state dict as the first argument and ``*args`` as the following arguments. Parameters ---------- worker: str Worker to run the ``coroutine`` coroutine: coroutine The function to run on the worker *args: Arguments for ``coroutine`` wait: boolean, optional If True, waits for the coroutine to finished before returning. Returns ------- ret: object or Future If wait=True, the result of `coroutine` If wait=False, Future that can be waited on later. """ ret = self.client.submit( _run_coroutine_on_worker, self.sessionId, coroutine, args, workers=[worker], pure=False, ) return ret.result() if wait else ret
[docs] def run(self, coroutine, *args, workers=None, lock_workers=False): """Run a coroutine on multiple workers The coroutine is given the worker's state dict as the first argument and ``*args`` as the following arguments. Parameters ---------- coroutine: coroutine The function to run on each worker *args: Arguments for ``coroutine`` workers: list, optional List of workers. Default is all workers lock_workers: bool, optional Use distributed.MultiLock to get exclusive access to the workers. Use this flag to support parallel runs. Returns ------- ret: list List of the output from each worker """ if workers is None: workers = self.worker_addresses with get_multi_lock_or_null_context(lock_workers, workers): ret = [] for worker in workers: ret.append( self.client.submit( _run_coroutine_on_worker, self.sessionId, coroutine, args, workers=[worker], pure=False, ) ) return self.client.gather(ret)
[docs] def stage_keys(self, name: str, keys: Iterable[Hashable]) -> Dict[int, set]: """Staging keys on workers under the given name In an explicit-comms task, use `pop_staging_area(..., name)` to access the staged keys and the associated data. Notes ----- In the context of explicit-comms, staging is the act of duplicating the responsibility of Dask keys. When staging a key, the worker owning the key (as assigned by the Dask scheduler) save a reference to the key and the associated data to its local staging area. From this point on, if the scheduler cancels the key, the worker (and the task running on the worker) now has exclusive access to the key and the associated data. This way, staging makes it possible for long running explicit-comms tasks to free input data ASAP. Parameters ---------- name: str Name for the staging area keys: iterable The keys to stage Returns ------- dict dict that maps each worker-rank to the workers set of staged keys """ return dict(self.run(_stage_keys, name, set(keys)))
def pop_staging_area(session_state: dict, name: str) -> Dict[str, Any]: """Pop the staging area called `name` This function must be called within a running explicit-comms task. Parameters ---------- session_state: dict Worker session state name: str Name for the staging area Returns ------- dict The staging area, which is a dict that maps keys to their data. """ return session_state["stages"].pop(name)