-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Describe the issue
I seem to have stumbled upon a bug in the _create_inference_session method that causes execution provider fallback to be incorrectly overwritten, defaulting to CPU even when GPU acceleration is available. Additionally there also seems to be a slight typo in the code which can affect using the NvTensorRTRTXExecutionProvider provider when passed as a tuple.
onnxruntime/onnxruntime/python/onnxruntime_inference_collection.py
Lines 488 to 537 in 986b66a
| def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): | |
| available_providers = C.get_available_providers() | |
| # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. | |
| if "TensorrtExecutionProvider" in available_providers: | |
| if ( | |
| providers | |
| and any( | |
| provider == "CUDAExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") | |
| for provider in providers | |
| ) | |
| and any( | |
| provider == "TensorrtExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") | |
| for provider in providers | |
| ) | |
| ): | |
| self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| self._fallback_providers = ["CPUExecutionProvider"] | |
| if "NvTensorRTRTXExecutionProvider" in available_providers: | |
| if ( | |
| providers | |
| and any( | |
| provider == "CUDAExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") | |
| for provider in providers | |
| ) | |
| and any( | |
| provider == "NvTensorRTRTXExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider") | |
| for provider in providers | |
| ) | |
| ): | |
| self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| self._fallback_providers = ["CPUExecutionProvider"] | |
| # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. | |
| elif "MIGraphXExecutionProvider" in available_providers: | |
| if providers and any( | |
| provider == "ROCMExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") | |
| for provider in providers | |
| ): | |
| self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| self._fallback_providers = ["CPUExecutionProvider"] | |
| else: | |
| self._fallback_providers = ["CPUExecutionProvider"] |
here I'm referring to this structure:
if "TensorrtExecutionProvider" in available_providers:
# Sets fallback to ["CUDAExecutionProvider", "CPUExecutionProvider"] when both TensorRT and CUDA requested
if "NvTensorRTRTXExecutionProvider" in available_providers: # Should be elif!
...
elif "MIGraphXExecutionProvider" in available_providers:
...
else:
self._fallback_providers = ["CPUExecutionProvider"]where any run of the if-elif-else statements can overwrite the fallback providers set where TensorRT is available.
Also what appears to be the typo I mentioned earlier:
onnxruntime/onnxruntime/python/onnxruntime_inference_collection.py
Lines 517 to 520 in 986b66a
| and any( | |
| provider == "NvTensorRTRTXExecutionProvider" | |
| or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider") | |
| for provider in providers |
I'll be happy to contribute a fix if this is confirmed!
To reproduce
from onnxruntime import InferenceSession, get_available_providers
print("Available providers:", get_available_providers())
providers = ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
session = InferenceSession("model.onnx", providers=providers)
print(f"Actual fallback: {session._fallback_providers}")(where available providers: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
When trying to reproduce locally on my laptop I hit the scenario where the TensorRT files were not installed and I was wondering if this scenario could be a valuable one in itself since even now I should be able to fall back to the available CUDA instead of losing GPU acceleration entirely? So I've attached outputs from that run itself.
-
Expected behavior: When TensorRT fails but CUDA is explicitly requested and available, should fall back to:
["CUDAExecutionProvider", "CPUExecutionProvider"] -
Actual behavior: Falls back to:
["CPUExecutionProvider"]
Output:
['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
2025-06-24 04:43:15.201212241 [E:onnxruntime:Default, provider_bridge_ort.cc:2167 TryGetProviderInfo_TensorRT] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1778 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_tensorrt.so with error: libcublas.so.12: cannot open shared object file: No such file or directory
** EP Error **
EP Error /onnxruntime_src/onnxruntime/python/onnxruntime_pybind_state.cc:505 void onnxruntime::python::RegisterTensorRTPluginsAsCustomOps(PySessionOptions&, const onnxruntime::ProviderOptions&) Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.
when using ['TensorrtExecutionProvider', 'CUDAExecutionProvider']
Falling back to ['CPUExecutionProvider'] and retrying.
**
Actual fallback: ['CPUExecutionProvider']
Urgency
No response
Platform
Linux
OS Version
6.13.8-arch1-1
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.22.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
No response