From 1e353307980ea05b6cb29e7dd7de9af3ccda9e58 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 5 Sep 2024 17:16:28 +0200 Subject: [PATCH] Fix multi model context in inference pool --- facefusion/face_detector.py | 8 ++++++-- facefusion/face_landmarker.py | 8 ++++++-- facefusion/inference_manager.py | 2 +- facefusion/processors/modules/age_modifier.py | 8 ++++++-- facefusion/processors/modules/expression_restorer.py | 4 +++- facefusion/processors/modules/face_editor.py | 8 ++++++-- facefusion/processors/modules/face_enhancer.py | 8 ++++++-- facefusion/processors/modules/face_swapper.py | 8 ++++++-- facefusion/processors/modules/frame_colorizer.py | 8 ++++++-- facefusion/processors/modules/frame_enhancer.py | 8 ++++++-- facefusion/processors/modules/lip_syncer.py | 8 ++++++-- tests/test_face_analyser.py | 2 ++ 12 files changed, 60 insertions(+), 20 deletions(-) diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py index aed580c5..c6345f71 100644 --- a/facefusion/face_detector.py +++ b/facefusion/face_detector.py @@ -75,11 +75,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_sources) + face_detector_model = state_manager.get_item('face_detector_model') + model_context = __name__ + '.' + face_detector_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + face_detector_model = state_manager.get_item('face_detector_model') + model_context = __name__ + '.' + face_detector_model + inference_manager.clear_inference_pool(model_context) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py index 65d5b6ca..15949cf5 100644 --- a/facefusion/face_landmarker.py +++ b/facefusion/face_landmarker.py @@ -76,11 +76,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_sources) + face_landmarker_model = state_manager.get_item('face_landmarker_model') + model_context = __name__ + '.' + face_landmarker_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + face_landmarker_model = state_manager.get_item('face_landmarker_model') + model_context = __name__ + '.' + face_landmarker_model + inference_manager.clear_inference_pool(model_context) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index 80bba58a..e5385a40 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -68,7 +68,7 @@ def get_static_model_initializer(model_path : str) -> ModelInitializer: def resolve_execution_provider_keys(model_context : str) -> List[ExecutionProviderKey]: - if has_execution_provider('coreml') and model_context in [ 'facefusion.processors.modules.age_modifier', 'facefusion.processors.modules.frame_colorizer' ]: + if has_execution_provider('coreml') and (model_context.startswith('facefusion.processors.modules.age_modifier') or model_context.startswith('facefusion.processors.modules.frame_colorizer')): return [ 'cpu' ] return state_manager.get_item('execution_providers') diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py index be52cb50..1792c44a 100755 --- a/facefusion/processors/modules/age_modifier.py +++ b/facefusion/processors/modules/age_modifier.py @@ -53,11 +53,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + age_modifier_model = state_manager.get_item('age_modifier_model') + model_context = __name__ + '.' + age_modifier_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + age_modifier_model = state_manager.get_item('age_modifier_model') + model_context = __name__ + '.' + age_modifier_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py index d2c8ebd4..3fdffe03 100755 --- a/facefusion/processors/modules/expression_restorer.py +++ b/facefusion/processors/modules/expression_restorer.py @@ -73,7 +73,9 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + expression_restorer_model = state_manager.get_item('expression_restorer_model') + model_context = __name__ + '.' + expression_restorer_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py index 1301e748..b2d3c58a 100755 --- a/facefusion/processors/modules/face_editor.py +++ b/facefusion/processors/modules/face_editor.py @@ -92,11 +92,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + face_editor_model = state_manager.get_item('face_editor_model') + model_context = __name__ + '.' + face_editor_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + face_editor_model = state_manager.get_item('face_editor_model') + model_context = __name__ + '.' + face_editor_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index 56addc06..a0c2acdc 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -219,11 +219,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + face_enhancer_model = state_manager.get_item('face_enhancer_model') + model_context = __name__ + '.' + face_enhancer_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + face_enhancer_model = state_manager.get_item('face_enhancer_model') + model_context = __name__ + '.' + face_enhancer_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py index a0fa0e5d..a7358d2b 100755 --- a/facefusion/processors/modules/face_swapper.py +++ b/facefusion/processors/modules/face_swapper.py @@ -338,11 +338,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + face_swapper_model = state_manager.get_item('face_swapper_model') + model_context = __name__ + '.' + face_swapper_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + face_swapper_model = state_manager.get_item('face_swapper_model') + model_context = __name__ + '.' + face_swapper_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index 3fed36b9..04f65e90 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -125,11 +125,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + frame_colorizer_model = state_manager.get_item('frame_colorizer_model') + model_context = __name__ + '.' + frame_colorizer_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + frame_colorizer_model = state_manager.get_item('frame_colorizer_model') + model_context = __name__ + '.' + frame_colorizer_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index c532c2ca..d97adb14 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -277,11 +277,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + frame_enhancer_model = state_manager.get_item('frame_enhancer_model') + model_context = __name__ + '.' + frame_enhancer_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + frame_enhancer_model = state_manager.get_item('frame_enhancer_model') + model_context = __name__ + '.' + frame_enhancer_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py index cab56bd8..5edf56d0 100755 --- a/facefusion/processors/modules/lip_syncer.py +++ b/facefusion/processors/modules/lip_syncer.py @@ -71,11 +71,15 @@ MODEL_SET : ModelSet =\ def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + lip_syncer_model = state_manager.get_item('lip_syncer_model') + model_context = __name__ + '.' + lip_syncer_model + return inference_manager.get_inference_pool(model_context, model_sources) def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + lip_syncer_model = state_manager.get_item('lip_syncer_model') + model_context = __name__ + '.' + lip_syncer_model + inference_manager.clear_inference_pool(model_context) def get_model_options() -> ModelOptions: diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py index feecd6b2..7c351861 100644 --- a/tests/test_face_analyser.py +++ b/tests/test_face_analyser.py @@ -22,7 +22,9 @@ def before_all() -> None: state_manager.init_item('execution_device_id', 0) state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('face_detector_angles', [ 0 ]) + state_manager.init_item('face_detector_model', 'many') state_manager.init_item('face_detector_score', 0.5) + state_manager.init_item('face_landmarker_model', 'many') state_manager.init_item('face_landmarker_score', 0.5) face_classifier.pre_check() face_landmarker.pre_check()