Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 12 additions & 42 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2866,6 +2866,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
trt_runtime_config = std::unique_ptr<nvinfer1::IRuntimeConfig>(trt_engine->createRuntimeConfig());
if (trt_runtime_config && cuda_graph_enable_) {
trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER);
#if TRT_MAJOR_RTX > 1 || (TRT_MAJOR_RTX == 1 && TRT_MINOR_RTX >= 3)
auto cuda_strategy_flag = trt_runtime_config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE);
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] CUDA graph strategy with RTX Graph capture enabled : " << cuda_strategy_flag;
#else
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] CUDA graph is enabled but RTX Graph capture is not available. "
<< "The current TRT RTX version does not support RTX Graph. "
<< "Please upgrade to TRT RTX >= 1.3 to use RTX Graph capture feature for optimal CUDA graph performance.";
#endif
}
trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED);
if (!runtime_cache_.empty()) {
Expand Down Expand Up @@ -3148,22 +3156,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
}
}

// Start CUDA graph capture with the correct stream
// Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream
// Get the graph annotation ID that was stored during OnRunStart
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId();
bool graph_replay_on_this_run = false;
bool should_start_capture = false;

HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id,
graph_replay_on_this_run, should_start_capture);

if (!graph_replay_on_this_run) {
if (!trt_context->enqueueV3(stream)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed.");
}
} else {
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_));
if (!trt_context->enqueueV3(stream)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed.");
}

/*
Expand All @@ -3181,11 +3175,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/

if (cuda_graph_enable_ && should_start_capture) {
GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_));
}

if (sync_stream_after_enqueue_) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
}
Expand Down Expand Up @@ -3474,22 +3463,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
trt_context->setAuxStreams(aux_streams_, (int32_t)auxiliary_streams_);
}

// Start CUDA graph capture with the correct stream
// Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream
// Get the graph annotation ID that was stored during OnRunStart
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId();
bool graph_replay_on_this_run = false;
bool should_start_capture = false;

HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id,
graph_replay_on_this_run, should_start_capture);

if (!graph_replay_on_this_run) {
if (!trt_context->enqueueV3(stream)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed.");
}
} else {
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_));
if (!trt_context->enqueueV3(stream)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed.");
}

/*
Expand All @@ -3507,11 +3482,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/

if (cuda_graph_enable_ && should_start_capture) {
GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_));
}

if (sync_stream_after_enqueue_) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct NvExecutionProviderInfo {
std::string profile_min_shapes{""};
std::string profile_max_shapes{""};
std::string profile_opt_shapes{""};
bool cuda_graph_enable{false};
bool cuda_graph_enable{true};
bool multi_profile_enable{false};
bool dump_ep_context_model{false};
std::string ep_context_file_path{""};
Expand Down
Loading