infer_from_trt_engine¶
inference_models.models.common.trt.infer_from_trt_engine
¶
infer_from_trt_engine(pre_processed_images, trt_config, engine, context, device, input_name, outputs, stream=None, trt_cuda_graph_cache=None)
Run inference using a TensorRT engine, optionally with CUDA graph acceleration.
Executes inference on preprocessed images using a TensorRT engine. Handles both static and dynamic batch sizes, automatically splitting large batches if needed.
When trt_cuda_graph_cache is provided, CUDA graphs are captured and replayed
for improved performance on repeated inference with the same input shape. Each
graph is keyed by (shape, dtype, device) and stored in the cache. The cache
itself must be created by the caller (typically in the model class).
When trt_cuda_graph_cache is None, inference runs through the standard
TRT execution path using the provided context.
Parameters:
-
(pre_processed_images¶Tensor) –Preprocessed input tensor on CUDA device. Shape: (batch_size, channels, height, width).
-
(trt_config¶TRTConfig) –TensorRT configuration object containing batch size settings and other engine-specific parameters.
-
(engine¶ICudaEngine) –TensorRT CUDA engine (ICudaEngine) to use for inference.
-
(device¶device) –PyTorch CUDA device to use for inference.
-
(input_name¶str) –Name of the input tensor in the TensorRT engine.
-
(outputs¶List[str]) –List of output tensor names to retrieve from the engine.
-
(context¶IExecutionContext) –TensorRT execution context (IExecutionContext) for running inference. Required when
trt_cuda_graph_cacheisNone. Ignored when using CUDA graphs (each cached graph owns its own execution context). -
(trt_cuda_graph_cache¶Optional[TRTCudaGraphCache], default:None) –Optional CUDA graph cache. When provided, CUDA graphs are used for inference. When
None, standard TRT execution is used. -
(stream¶Optional[Stream], default:None) –CUDA stream to use for inference. Defaults to the current stream for the given device.
Returns:
-
List[Tensor]–List of output tensors from the TensorRT engine, in the order specified
-
List[Tensor]–by the outputs parameter.
Examples:
Run TensorRT inference (standard path):
>>> from inference_models.developer_tools import (
... load_trt_model,
... get_trt_engine_inputs_and_outputs,
... infer_from_trt_engine
... )
>>> from inference_models.models.common.roboflow.model_packages import (
... parse_trt_config
... )
>>> import torch
>>>
>>> # Load engine and config
>>> engine = load_trt_model("model.plan")
>>> trt_config = parse_trt_config("trt_config.json")
>>> context = engine.create_execution_context()
>>>
>>> # Get input/output names
>>> inputs, outputs = get_trt_engine_inputs_and_outputs(engine)
>>>
>>> # Prepare input
>>> images = torch.randn(1, 3, 640, 640, device="cuda:0")
>>>
>>> # Run inference
>>> results = infer_from_trt_engine(
... pre_processed_images=images,
... trt_config=trt_config,
... engine=engine,
... context=context,
... device=torch.device("cuda:0"),
... input_name=inputs[0],
... outputs=outputs,
... )
Handle large batches:
>>> # Large batch will be automatically split
>>> large_batch = torch.randn(100, 3, 640, 640, device="cuda:0")
>>>
>>> results = infer_from_trt_engine(
... pre_processed_images=large_batch,
... trt_config=trt_config,
... engine=engine,
... context=context,
... device=torch.device("cuda:0"),
... input_name=inputs[0],
... outputs=outputs,
... )
>>> # Results are automatically concatenated
Run with CUDA graph acceleration:
>>> from inference_models.models.common.trt import TRTCudaGraphCache
>>> cache = TRTCudaGraphCache(capacity=16)
>>>
>>> results = infer_from_trt_engine(
... pre_processed_images=images,
... trt_config=trt_config,
... engine=engine,
... device=torch.device("cuda:0"),
... input_name=inputs[0],
... outputs=outputs,
... trt_cuda_graph_cache=cache,
... )
Note
- Requires TensorRT and PyCUDA to be installed
- Input must be on CUDA device
- Automatically handles batch size constraints from trt_config
- Uses asynchronous execution with CUDA streams
Raises:
-
ModelRuntimeError–If inference execution fails.
See Also
load_trt_model(): Load TensorRT engine from fileget_trt_engine_inputs_and_outputs(): Get engine tensor names