From 457eae135383cde587dcfd55c014d2a2f8956e47 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 30 Sep 2024 17:05:50 +0200 Subject: [PATCH] Simplify and avoid knowing the provider values --- facefusion/execution.py | 37 ++++++++++++++++--------------------- tests/test_execution.py | 4 ++-- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/facefusion/execution.py b/facefusion/execution.py index 771923e4..6ce3a697 100644 --- a/facefusion/execution.py +++ b/facefusion/execution.py @@ -6,7 +6,7 @@ from typing import Any, List from onnxruntime import get_available_providers, set_default_logger_severity from facefusion.choices import execution_provider_set -from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ExecutionProviderValue, ValueAndUnit +from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ValueAndUnit set_default_logger_severity(3) @@ -29,23 +29,18 @@ def get_available_execution_provider_set() -> ExecutionProviderSet: return available_execution_provider_set -def extract_execution_providers(execution_provider_keys : List[ExecutionProviderKey]) -> List[ExecutionProviderValue]: - return [ execution_provider_set[execution_provider_key] for execution_provider_key in execution_provider_keys if execution_provider_key in execution_provider_set ] - - def create_execution_providers(execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> List[Any]: - execution_providers = extract_execution_providers(execution_provider_keys) - execution_providers_with_options : List[Any] = [] + execution_providers : List[Any] = [] - for execution_provider in execution_providers: - if execution_provider == 'CUDAExecutionProvider': - execution_providers_with_options.append((execution_provider, + for execution_provider_key in execution_provider_keys: + if execution_provider_key == 'cuda': + execution_providers.append((execution_provider_set.get(execution_provider_key), { 'device_id': execution_device_id, 'cudnn_conv_algo_search': 'EXHAUSTIVE' if use_exhaustive() else 'DEFAULT' })) - if execution_provider == 'TensorrtExecutionProvider': - execution_providers_with_options.append((execution_provider, + if execution_provider_key == 'tensorrt': + execution_providers.append((execution_provider_set.get(execution_provider_key), { 'device_id': execution_device_id, 'trt_engine_cache_enable': True, @@ -54,24 +49,24 @@ def create_execution_providers(execution_device_id : str, execution_provider_key 'trt_timing_cache_path': '.caches', 'trt_builder_optimization_level': 5 })) - if execution_provider == 'OpenVINOExecutionProvider': - execution_providers_with_options.append((execution_provider, + if execution_provider_key == 'openvino': + execution_providers.append((execution_provider_set.get(execution_provider_key), { 'device_type': 'GPU.' + execution_device_id, 'precision': 'FP32' })) - if execution_provider in [ 'DmlExecutionProvider', 'ROCMExecutionProvider' ]: - execution_providers_with_options.append((execution_provider, + if execution_provider_key in [ 'directml', 'rocm' ]: + execution_providers.append((execution_provider_set.get(execution_provider_key), { 'device_id': execution_device_id })) - if execution_provider == 'CoreMLExecutionProvider': - execution_providers_with_options.append(execution_provider) + if execution_provider_key == 'coreml': + execution_providers.append(execution_provider_set.get(execution_provider_key)) - if 'CPUExecutionProvider' in execution_providers: - execution_providers_with_options.append('CPUExecutionProvider') + if 'cpu' in execution_provider_keys: + execution_providers.append(execution_provider_set.get('cpu')) - return execution_providers_with_options + return execution_providers def use_exhaustive() -> bool: diff --git a/tests/test_execution.py b/tests/test_execution.py index 1823f1e8..790b7408 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -11,7 +11,7 @@ def test_has_execution_provider() -> None: def test_multiple_execution_providers() -> None: - execution_provider_with_options =\ + execution_providers =\ [ ('CUDAExecutionProvider', { @@ -21,4 +21,4 @@ def test_multiple_execution_providers() -> None: 'CPUExecutionProvider' ] - assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_provider_with_options + assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers