Simplify and avoid knowing the provider values (#782)

This commit is contained in:
Henry Ruhs 2024-10-01 11:26:31 +02:00 committed by henryruhs
parent d1309a06e1
commit ab73908305
2 changed files with 18 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, List
from onnxruntime import get_available_providers, set_default_logger_severity
from facefusion.choices import execution_provider_set
from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ExecutionProviderValue, ValueAndUnit
from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ValueAndUnit
set_default_logger_severity(3)
@ -29,23 +29,18 @@ def get_available_execution_provider_set() -> ExecutionProviderSet:
return available_execution_provider_set
def extract_execution_providers(execution_provider_keys : List[ExecutionProviderKey]) -> List[ExecutionProviderValue]:
return [ execution_provider_set[execution_provider_key] for execution_provider_key in execution_provider_keys if execution_provider_key in execution_provider_set ]
def create_execution_providers(execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> List[Any]:
execution_providers = extract_execution_providers(execution_provider_keys)
execution_providers_with_options : List[Any] = []
execution_providers : List[Any] = []
for execution_provider in execution_providers:
if execution_provider == 'CUDAExecutionProvider':
execution_providers_with_options.append((execution_provider,
for execution_provider_key in execution_provider_keys:
if execution_provider_key == 'cuda':
execution_providers.append((execution_provider_set.get(execution_provider_key),
{
'device_id': execution_device_id,
'cudnn_conv_algo_search': 'EXHAUSTIVE' if use_exhaustive() else 'DEFAULT'
}))
if execution_provider == 'TensorrtExecutionProvider':
execution_providers_with_options.append((execution_provider,
if execution_provider_key == 'tensorrt':
execution_providers.append((execution_provider_set.get(execution_provider_key),
{
'device_id': execution_device_id,
'trt_engine_cache_enable': True,
@ -54,24 +49,24 @@ def create_execution_providers(execution_device_id : str, execution_provider_key
'trt_timing_cache_path': '.caches',
'trt_builder_optimization_level': 5
}))
if execution_provider == 'OpenVINOExecutionProvider':
execution_providers_with_options.append((execution_provider,
if execution_provider_key == 'openvino':
execution_providers.append((execution_provider_set.get(execution_provider_key),
{
'device_type': 'GPU.' + execution_device_id,
'precision': 'FP32'
}))
if execution_provider in [ 'DmlExecutionProvider', 'ROCMExecutionProvider' ]:
execution_providers_with_options.append((execution_provider,
if execution_provider_key in [ 'directml', 'rocm' ]:
execution_providers.append((execution_provider_set.get(execution_provider_key),
{
'device_id': execution_device_id
}))
if execution_provider == 'CoreMLExecutionProvider':
execution_providers_with_options.append(execution_provider)
if execution_provider_key == 'coreml':
execution_providers.append(execution_provider_set.get(execution_provider_key))
if 'CPUExecutionProvider' in execution_providers:
execution_providers_with_options.append('CPUExecutionProvider')
if 'cpu' in execution_provider_keys:
execution_providers.append(execution_provider_set.get('cpu'))
return execution_providers_with_options
return execution_providers
def use_exhaustive() -> bool:

View File

@ -11,7 +11,7 @@ def test_has_execution_provider() -> None:
def test_multiple_execution_providers() -> None:
execution_provider_with_options =\
execution_providers =\
[
('CUDAExecutionProvider',
{
@ -21,4 +21,4 @@ def test_multiple_execution_providers() -> None:
'CPUExecutionProvider'
]
assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_provider_with_options
assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers