Fix multi model context in inference pool (#721)

* Fix multi model context in inference pool

* Fix multi model context in inference pool part2
This commit is contained in:
Henry Ruhs 2024-09-05 18:16:01 +02:00 committed by GitHub
parent 14fd6c6a96
commit 5e725a9c7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 41 additions and 20 deletions

View File

@ -75,11 +75,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
_, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('face_detector_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('face_detector_model')
inference_manager.clear_inference_pool(model_context)
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:

View File

@ -76,11 +76,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
_, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('face_landmarker_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('face_landmarker_model')
inference_manager.clear_inference_pool(model_context)
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:

View File

@ -68,7 +68,7 @@ def get_static_model_initializer(model_path : str) -> ModelInitializer:
def resolve_execution_provider_keys(model_context : str) -> List[ExecutionProviderKey]:
if has_execution_provider('coreml') and model_context in [ 'facefusion.processors.modules.age_modifier', 'facefusion.processors.modules.frame_colorizer' ]:
if has_execution_provider('coreml') and (model_context.startswith('facefusion.processors.modules.age_modifier') or model_context.startswith('facefusion.processors.modules.frame_colorizer')):
return [ 'cpu' ]
return state_manager.get_item('execution_providers')

View File

@ -53,11 +53,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('age_modifier_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('age_modifier_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -73,7 +73,8 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('expression_restorer_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:

View File

@ -92,11 +92,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('face_editor_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('face_editor_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -219,11 +219,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('face_enhancer_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('face_enhancer_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -338,11 +338,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('face_swapper_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('face_swapper_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -125,11 +125,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('frame_colorizer_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('frame_colorizer_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -277,11 +277,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('frame_enhancer_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('frame_enhancer_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -71,11 +71,13 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
model_context = __name__ + '.' + state_manager.get_item('lip_syncer_model')
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__)
model_context = __name__ + '.' + state_manager.get_item('lip_syncer_model')
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions:

View File

@ -22,7 +22,9 @@ def before_all() -> None:
state_manager.init_item('execution_device_id', 0)
state_manager.init_item('execution_providers', [ 'cpu' ])
state_manager.init_item('face_detector_angles', [ 0 ])
state_manager.init_item('face_detector_model', 'many')
state_manager.init_item('face_detector_score', 0.5)
state_manager.init_item('face_landmarker_model', 'many')
state_manager.init_item('face_landmarker_score', 0.5)
face_classifier.pre_check()
face_landmarker.pre_check()