Source code for torch_simple_timing.timer

"""
This class enables timing of (PyTorch) code blocks. It internally
leverages the :class:`~torch_simple_timing.clock.Clock` class to measure
execution times.

When the constructor argument ``gpu`` is set to ``True``, the timer's clocks
will use :class:`torch.cuda.Event` to time GPU code. For timings to be meaningful,
:func:`torch.cuda.synchronize()` must be called before and after the code block.
In the case of distributed training, :func:`torch.distributed.barrier()` will also
be called.

This ^ is taken care of by the :class:`~torch_simple_timing.clock.Clock` class,
but be aware that this may slow-down your code.

.. note::

    Wait, what?? Timing slows code down?? Yes, it does. But it's not as bad as
    you might think. It mainly means that you should be careful when you define
    :class:`~torch_simple_timing.clock.Clock` and
    :class:`~torch_simple_timing.timer.Timer`
    objects. For example, if you want to time a ``forward``
    function of a model, the fact that the *overall* epoch is slower does not matter,
    you want to accurately measure the time spent in the ``forward`` function.

.. warning::

    Because of the :func:`torch.cuda.synchronize()` calls, the
    :class:`~torch_simple_timing.timer.Timer` class should
    be carefully used in the context of training. For instance, if you want to time
    **epochs** the synchronization overhead will be negligible. However, if you want
    to time training **iterations**, you should be careful to only do that for 1 (or
    a few epochs) and not for the whole training. Otherwise, the overhead may become
    significant. Use the ``ignore`` argument to disable timing for a specific clock.

Example:

.. code-block:: python

    import torch
    from torch.nn import Sequential, Linear, ReLU
    from torch_simple_timing import Timer

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gpu = device.type == "cuda"
    timer = Timer(gpu=gpu)

    # manual start
    timer.clock("init").start()

    batches = 32
    bs = 64
    n = batches * bs
    dim = 64
    labels = 10
    hidden = 1024
    epochs = 5

    t = torch.randn(n, dim, device=device)
    y = torch.randint(0, labels, (n,), device=device)

    model = Sequential(
        Linear(dim, hidden),
        ReLU(),
        Linear(hidden, hidden),
        ReLU(),
        Linear(hidden, labels),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters())
    loss_func = torch.nn.CrossEntropyLoss()

    timer.clock("init").stop()

    with timer.clock("train-loop"):
        for epoch in range(epochs):
            with timer.clock("train-epoch"):
                for batch in range(batches):
                    optimizer.zero_grad()
                    # only time the first 2 epochs
                    with timer.clock("train-batch", ignore=epoch > 2):
                        with timer.clock("forward"):
                            pred = model(t[batch * bs : (batch + 1) * bs])
                        with timer.clock("loss", ignore=epoch > 2):
                            loss = loss_func(pred, y[batch * bs : (batch + 1) * bs])
                        with timer.clock("backward", ignore=epoch > 2):
                            loss.backward()
                        optimizer.step()

                    if batch % 10 == 0:
                        print(f"Epoch {epoch}, batch {batch}, loss {loss.item():.3f}")

    # compute mean/std stats for each clock in the timer
    stats = timer.stats()

    # stats will be computed internally if not provided
    print(timer.display(stats=stats, precision=5))

.. code-block:: text

    init        : 0.01141           (n=  1)
    train-loop  : 1.97650           (n=  1)
    train-epoch : 0.39529 ± 0.07439 (n=  5)
    train-batch : 0.01087 ± 0.00189 (n= 96)
    forward     : 0.00327 ± 0.00703 (n=160)
    loss        : 0.00010 ± 0.00023 (n= 96)
    backward    : 0.00438 ± 0.00074 (n= 96)

"""

import torch
from torch_simple_timing.clock import Clock
from typing import Dict, List, Optional, Union, Callable


[docs]class Timer: def __init__(self, gpu: bool = False, ignore: bool = False): """ ``Clock`` manager. Store and display timing statistics. .. warning:: In order to accurately measure GPU timings, :func:`torch.cuda.synchronize()` will be called before and after each clock's ``start`` and ``stop`` Args: gpu (bool, optional): Whether or not to use GPU timing using CUDA events. Defaults to ``False``. ignore (bool, optional): Whether to disable this timer. Can be useful when the same piece of code is used in various contexts, for instance in training or validation modes you may want to disable timing. Defaults to ``False``. """ self.times = {} self.clocks = {} self.gpu = gpu self.ignore = ignore
[docs] def __repr__(self) -> str: t = {k: len(v) for k, v in self.times.items()} r = f"Timer(gpu={self.gpu}, ignore={self.ignore}, times={t})" return r
[docs] def reset(self, keys: Optional[Union[str, List[str]]] = None) -> None: """ Deletes specified ``keys``. If ``keys`` is None, resets all timers. Args: keys (Union[str, List[str]], optional): Specific named timers to reset, or all of them if ``keys`` is ``None`` . Defaults to ``None``. """ if isinstance(keys, str): keys = [keys] if keys is None: self.times = {} self.clocks = {} else: for k in keys: self.times.pop(k, None) self.clocks.pop(k, None)
[docs] def clock( self, name: str, ignore: Optional[bool] = None, gpu: Optional[bool] = None, ) -> Clock: """ Create a new ``Clock`` object with name ``name`` and add it to the ``Timer``. If the ``Clock`` already exists, it will be returned. .. note:: If ``ignore`` is ``None``, the ``Timer``'s ``ignore`` attribute will be used. .. note:: If ``ignore`` is not ``None``, the ``Clock`` 's ``ignore`` attribute will be updated. .. warning:: Don't forget to call ``.start()`` and ``.stop()`` on the returned ``Clock`` if you're not using ``timer.clock()`` as a context manager. Args: name (str): A name for the requested clock. ignore (Optional[bool], optional): Whether to ignore this clock and don't time anything. This is useful in case timing slows you down (because of :func:`torch.cuda.synchronize()` and :func:`torch.distributed.barrier()`) and you only want to time the first epoch for instance. Defaults to ``None``. gpu (Optional[bool], optional): Whether to enable GPU timing with CUDA events. Defaults to ``None``. Returns: Clock: The requested ``Clock`` object. """ if name not in self.clocks: if ignore is None: ignore = self.ignore self.clocks[name] = Clock( name, self.times, self.gpu if gpu is None else gpu, ignore ) if ignore != self.clocks[name].ignore: self.clocks[name].ignore = ignore return self.clocks[name]
[docs] def disable(self, clock_names: Optional[List[str]] = None) -> None: """ Disable the specified clocks based on their names. Args: clock_names (Optional[List[str]], optional): The list of clock names to disable. If ``None``, all clocks in this timer are disabled. Defaults to ``None``. """ if clock_names is None: clock_names = self.clocks.keys() for k in clock_names: if k in self.clocks: self.clocks[k].ignore = True
[docs] def stats( self, clock_names: Optional[List[str]] = None, map_funcs: Optional[Dict[str, callable]] = None, ) -> Dict[str, Dict[str, Union[int, float]]]: """ Computes the mean and standard deviation of the times for each clock. Returns a dictionary of dictionaries with the following structure: .. code-block:: python { "clock_name": { "mean": float, "std": float, "n": int } } Optionally, you can provide a dictionary of functions to apply to the list of times for each clock. If a clock name is not in the dictionary, no function will be applied (equivalent to ``lambda t: t``). .. code-block:: python throughput = timer.stats( map_funcs={"forward": lambda t: batch_size / t} ) This method will be called internally by ``timer.display()`` or you can provide it there if you want to do something else with the stats (log them for instance). Args: clock_names (List[str], optional): List of clock names to compute the stats for, or all of them if ``None`` . Defaults to ``None``. map_funcs (Dict[str, callable], optional): Dictionary of functions to pre-process the list of times for each clock. Defaults to ``None``. Returns: Dict[str, Dict[str, Union[int, float]]]: A dictionary of dictionaries, mapping clock names to a dictionary of statistics. """ if clock_names is None: clock_names = self.times.keys() if map_funcs is None: map_funcs = {} clock_names = set(clock_names) stats = {} for k, v in self.times.items(): if k in clock_names: t = torch.tensor(list(map(map_funcs.get(k, lambda t: t), v))).float() m = torch.mean(t).item() s = torch.std(t).item() n = len(v) stats[k] = {"mean": m, "std": s, "n": n} return stats
[docs] def display( self, clock_names: Optional[List[str]] = None, precision: int = 3, sort_keys_func: Callable = None, stats: Dict[str, Dict[str, Union[int, float]]] = None, ): """ Display the mean, standard deviation and support of the times for each clock. :meth:`Timer.stats` is called internally to compute the stats. You can pre-compute stats independently and pass them to this method with the ``stats=`` argument. Optionally, you can provide a function ``sort_keys_func`` to sort the clocks by a specific key. For instance, you can sort them alphabetically with ``sort_keys_func=lambda k: k``. By default, they will be displayed according to their creation order. .. code-block:: python >>> print(timer.display()) epoch : 0.251 ± 0.027 (n=10) forward : 0.002 ± 0.005 (n=50) backward : 0.002 ± 0.004 (n=50) Args: clock_names (Optional[List[str]], optional): The list of clock names to display. If ``None``, all clocks in this timer are displayed. Defaults to ``None``. precision (int, optional): The number of digits to display after the decimal point. Defaults to ``3``. sort_keys_func (Callable, optional): A function to use to sort the displayed clocks. Defaults to ``None``, *i.e.* creation order. stats (Dict[str, Dict[str, Union[int, float]]], optional): The stats to display. If ``None``, the stats will be computed internally. Defaults to ``None``. Returns: str: A string representation of the stats. """ if stats is None: stats = self.stats(clock_names) if sort_keys_func is not None: keys = sorted(stats.keys(), key=sort_keys_func) else: keys = list(stats.keys()) max_key_len = max(len(k) for k in keys) mean_strs = [f"{stats[k]['mean']:.{precision}f}" for k in keys] std_strs = [f"{stats[k]['std']:.{precision}f}" for k in keys] n_strs = [f"{stats[k]['n']}" for k in keys] max_std_len = max(len(s) for s in std_strs) max_mean_len = max(len(s) for s in mean_strs) max_n_len = max(len(s) for s in n_strs) outs = [] for i, k in enumerate(keys): v = stats[k] mean = mean_strs[i] mean_s = f"{mean:>{max_mean_len}}" n = n_strs[i] n_str = f"(n={n:>{max_n_len}})" std = std_strs[i] std_s = ( f" ± {std:>{max_std_len}}" if v["n"] > 1 else " " * (max_std_len + 3) ) outs.append(f"{k:<{max_key_len+1}}: {mean_s}{std_s} {n_str}") return "\n".join(outs)