Fix multi model context in inference pool

This commit is contained in:
henryruhs 2024-09-05 17:16:28 +02:00
parent 14fd6c6a96
commit 1e35330798
12 changed files with 60 additions and 20 deletions

View File

@ -75,11 +75,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
_, model_sources = collect_model_downloads() _, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_sources) face_detector_model = state_manager.get_item('face_detector_model')
model_context = __name__ + '.' + face_detector_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) face_detector_model = state_manager.get_item('face_detector_model')
model_context = __name__ + '.' + face_detector_model
inference_manager.clear_inference_pool(model_context)
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:

View File

@ -76,11 +76,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
_, model_sources = collect_model_downloads() _, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_sources) face_landmarker_model = state_manager.get_item('face_landmarker_model')
model_context = __name__ + '.' + face_landmarker_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) face_landmarker_model = state_manager.get_item('face_landmarker_model')
model_context = __name__ + '.' + face_landmarker_model
inference_manager.clear_inference_pool(model_context)
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:

View File

@ -68,7 +68,7 @@ def get_static_model_initializer(model_path : str) -> ModelInitializer:
def resolve_execution_provider_keys(model_context : str) -> List[ExecutionProviderKey]: def resolve_execution_provider_keys(model_context : str) -> List[ExecutionProviderKey]:
if has_execution_provider('coreml') and model_context in [ 'facefusion.processors.modules.age_modifier', 'facefusion.processors.modules.frame_colorizer' ]: if has_execution_provider('coreml') and (model_context.startswith('facefusion.processors.modules.age_modifier') or model_context.startswith('facefusion.processors.modules.frame_colorizer')):
return [ 'cpu' ] return [ 'cpu' ]
return state_manager.get_item('execution_providers') return state_manager.get_item('execution_providers')

View File

@ -53,11 +53,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) age_modifier_model = state_manager.get_item('age_modifier_model')
model_context = __name__ + '.' + age_modifier_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) age_modifier_model = state_manager.get_item('age_modifier_model')
model_context = __name__ + '.' + age_modifier_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -73,7 +73,9 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) expression_restorer_model = state_manager.get_item('expression_restorer_model')
model_context = __name__ + '.' + expression_restorer_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:

View File

@ -92,11 +92,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) face_editor_model = state_manager.get_item('face_editor_model')
model_context = __name__ + '.' + face_editor_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) face_editor_model = state_manager.get_item('face_editor_model')
model_context = __name__ + '.' + face_editor_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -219,11 +219,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) face_enhancer_model = state_manager.get_item('face_enhancer_model')
model_context = __name__ + '.' + face_enhancer_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) face_enhancer_model = state_manager.get_item('face_enhancer_model')
model_context = __name__ + '.' + face_enhancer_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -338,11 +338,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) face_swapper_model = state_manager.get_item('face_swapper_model')
model_context = __name__ + '.' + face_swapper_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) face_swapper_model = state_manager.get_item('face_swapper_model')
model_context = __name__ + '.' + face_swapper_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -125,11 +125,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
model_context = __name__ + '.' + frame_colorizer_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
model_context = __name__ + '.' + frame_colorizer_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -277,11 +277,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
model_context = __name__ + '.' + frame_enhancer_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
model_context = __name__ + '.' + frame_enhancer_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -71,11 +71,15 @@ MODEL_SET : ModelSet =\
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources') model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources) lip_syncer_model = state_manager.get_item('lip_syncer_model')
model_context = __name__ + '.' + lip_syncer_model
return inference_manager.get_inference_pool(model_context, model_sources)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) lip_syncer_model = state_manager.get_item('lip_syncer_model')
model_context = __name__ + '.' + lip_syncer_model
inference_manager.clear_inference_pool(model_context)
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:

View File

@ -22,7 +22,9 @@ def before_all() -> None:
state_manager.init_item('execution_device_id', 0) state_manager.init_item('execution_device_id', 0)
state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('execution_providers', [ 'cpu' ])
state_manager.init_item('face_detector_angles', [ 0 ]) state_manager.init_item('face_detector_angles', [ 0 ])
state_manager.init_item('face_detector_model', 'many')
state_manager.init_item('face_detector_score', 0.5) state_manager.init_item('face_detector_score', 0.5)
state_manager.init_item('face_landmarker_model', 'many')
state_manager.init_item('face_landmarker_score', 0.5) state_manager.init_item('face_landmarker_score', 0.5)
face_classifier.pre_check() face_classifier.pre_check()
face_landmarker.pre_check() face_landmarker.pre_check()