Simplify and avoid knowing the provider values
This commit is contained in:
parent
c44dd275c9
commit
457eae1353
@ -6,7 +6,7 @@ from typing import Any, List
|
||||
from onnxruntime import get_available_providers, set_default_logger_severity
|
||||
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ExecutionProviderValue, ValueAndUnit
|
||||
from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ValueAndUnit
|
||||
|
||||
set_default_logger_severity(3)
|
||||
|
||||
@ -29,23 +29,18 @@ def get_available_execution_provider_set() -> ExecutionProviderSet:
|
||||
return available_execution_provider_set
|
||||
|
||||
|
||||
def extract_execution_providers(execution_provider_keys : List[ExecutionProviderKey]) -> List[ExecutionProviderValue]:
|
||||
return [ execution_provider_set[execution_provider_key] for execution_provider_key in execution_provider_keys if execution_provider_key in execution_provider_set ]
|
||||
|
||||
|
||||
def create_execution_providers(execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> List[Any]:
|
||||
execution_providers = extract_execution_providers(execution_provider_keys)
|
||||
execution_providers_with_options : List[Any] = []
|
||||
execution_providers : List[Any] = []
|
||||
|
||||
for execution_provider in execution_providers:
|
||||
if execution_provider == 'CUDAExecutionProvider':
|
||||
execution_providers_with_options.append((execution_provider,
|
||||
for execution_provider_key in execution_provider_keys:
|
||||
if execution_provider_key == 'cuda':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
{
|
||||
'device_id': execution_device_id,
|
||||
'cudnn_conv_algo_search': 'EXHAUSTIVE' if use_exhaustive() else 'DEFAULT'
|
||||
}))
|
||||
if execution_provider == 'TensorrtExecutionProvider':
|
||||
execution_providers_with_options.append((execution_provider,
|
||||
if execution_provider_key == 'tensorrt':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
{
|
||||
'device_id': execution_device_id,
|
||||
'trt_engine_cache_enable': True,
|
||||
@ -54,24 +49,24 @@ def create_execution_providers(execution_device_id : str, execution_provider_key
|
||||
'trt_timing_cache_path': '.caches',
|
||||
'trt_builder_optimization_level': 5
|
||||
}))
|
||||
if execution_provider == 'OpenVINOExecutionProvider':
|
||||
execution_providers_with_options.append((execution_provider,
|
||||
if execution_provider_key == 'openvino':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
{
|
||||
'device_type': 'GPU.' + execution_device_id,
|
||||
'precision': 'FP32'
|
||||
}))
|
||||
if execution_provider in [ 'DmlExecutionProvider', 'ROCMExecutionProvider' ]:
|
||||
execution_providers_with_options.append((execution_provider,
|
||||
if execution_provider_key in [ 'directml', 'rocm' ]:
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
{
|
||||
'device_id': execution_device_id
|
||||
}))
|
||||
if execution_provider == 'CoreMLExecutionProvider':
|
||||
execution_providers_with_options.append(execution_provider)
|
||||
if execution_provider_key == 'coreml':
|
||||
execution_providers.append(execution_provider_set.get(execution_provider_key))
|
||||
|
||||
if 'CPUExecutionProvider' in execution_providers:
|
||||
execution_providers_with_options.append('CPUExecutionProvider')
|
||||
if 'cpu' in execution_provider_keys:
|
||||
execution_providers.append(execution_provider_set.get('cpu'))
|
||||
|
||||
return execution_providers_with_options
|
||||
return execution_providers
|
||||
|
||||
|
||||
def use_exhaustive() -> bool:
|
||||
|
@ -11,7 +11,7 @@ def test_has_execution_provider() -> None:
|
||||
|
||||
|
||||
def test_multiple_execution_providers() -> None:
|
||||
execution_provider_with_options =\
|
||||
execution_providers =\
|
||||
[
|
||||
('CUDAExecutionProvider',
|
||||
{
|
||||
@ -21,4 +21,4 @@ def test_multiple_execution_providers() -> None:
|
||||
'CPUExecutionProvider'
|
||||
]
|
||||
|
||||
assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_provider_with_options
|
||||
assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers
|
||||
|
Loading…
Reference in New Issue
Block a user