Skip to content

TRTCudaGraphCache

inference_models.models.common.trt.TRTCudaGraphCache

LRU cache for captured CUDA graphs used in TensorRT inference.

Stores captured torch.cuda.CUDAGraph objects keyed by input (shape, dtype, device) tuples. When the cache exceeds its capacity, the least recently used entry is evicted and its GPU resources are released.

The cache is thread-safe — all mutating operations acquire an internal threading.RLock.

Parameters:

  • capacity

    (int) –

    Maximum number of CUDA graphs to store. Each entry holds a dedicated TensorRT execution context and GPU memory buffers, so higher values increase VRAM usage.

Examples:

Create a cache and pass it to a model:

>>> from inference_models.developer_tools import TRTCudaGraphCache
>>> from inference_models import AutoModel
>>> import torch
>>>
>>> cache = TRTCudaGraphCache(capacity=16)
>>> model = AutoModel.from_pretrained(
...     model_id_or_path="rfdetr-nano",
...     device=torch.device("cuda:0"),
...     backend="trt",
...     trt_cuda_graph_cache=cache,
... )
See Also
  • establish_trt_cuda_graph_cache(): Factory that creates a cache based on environment configuration
  • infer_from_trt_engine(): Uses the cache during TRT inference

Functions

get_current_size

get_current_size()

Return the number of CUDA graphs currently stored in the cache.

Returns:

  • int

    Number of cached entries.

Examples:

>>> cache = TRTCudaGraphCache(capacity=16)
>>> cache.get_current_size()
0

list_keys

list_keys()

Return a list of all keys currently in the cache.

Each key is a (shape, dtype, device) tuple representing a cached CUDA graph. Keys are returned in insertion order (oldest first), which reflects eviction priority.

Returns:

  • List[Tuple[Tuple[int, ...], dtype, device]]

    List of (shape, dtype, device) tuples for all cached entries.

Examples:

>>> cache = TRTCudaGraphCache(capacity=16)
>>> # ... after some forward passes ...
>>> for shape, dtype, device in cache.list_keys():
...     print(f"Cached: shape={shape}, dtype={dtype}")

purge

purge(n_oldest=None)

Remove entries from the cache, starting with the least recently used.

When called without arguments, clears the entire cache. When n_oldest is specified, only that many entries are evicted (or all entries if the cache contains fewer).

GPU memory cleanup (torch.cuda.empty_cache()) is called once after all evictions, making this more efficient than calling safe_remove() in a loop.

Parameters:

  • n_oldest
    (Optional[int], default: None ) –

    Number of least recently used entries to evict. When None (default), all entries are removed.

Examples:

Evict the 4 oldest entries:

>>> cache.purge(n_oldest=4)

Clear the entire cache:

>>> cache.purge()
>>> cache.get_current_size()
0
Note
  • Eviction order follows LRU policy — entries that haven't been accessed recently are removed first
  • Each evicted entry's CUDA graph, execution context, and GPU buffers are released
See Also
  • safe_remove(): Remove a single entry by key

safe_remove

safe_remove(key)

Remove a single entry from the cache by its key.

If the key exists, the associated CUDA graph, execution context, and GPU buffers are released and torch.cuda.empty_cache() is called. If the key does not exist, this method is a no-op.

Parameters:

  • key
    (Tuple[Tuple[int, ...], dtype, device]) –

    A (shape, dtype, device) tuple identifying the entry to remove.

Examples:

Remove a cached graph for a specific input shape:

>>> import torch
>>> key = ((1, 3, 384, 384), torch.float16, torch.device("cuda:0"))
>>> cache.safe_remove(key)

Safe to call with a non-existent key:

>>> cache.safe_remove(((99, 99), torch.float32, torch.device("cuda:0")))
>>> # no error raised
See Also
  • purge(): Remove multiple entries at once with batched GPU memory cleanup