Introduce download providers

This commit is contained in:
henryruhs 2024-11-07 23:12:53 +01:00
parent 191331c712
commit f9e5e7e2ce
10 changed files with 87 additions and 67 deletions

View File

@ -94,10 +94,13 @@ execution_providers =
execution_thread_count =
execution_queue_count =
[download]
download_providers =
skip_download =
[memory]
video_memory_strategy =
system_memory_limit =
[misc]
skip_download =
log_level =

View File

@ -106,11 +106,13 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
apply_state_item('execution_providers', args.get('execution_providers'))
apply_state_item('execution_thread_count', args.get('execution_thread_count'))
apply_state_item('execution_queue_count', args.get('execution_queue_count'))
# download
apply_state_item('download_providers', args.get('download_providers'))
apply_state_item('skip_download', args.get('skip_download'))
# memory
apply_state_item('video_memory_strategy', args.get('video_memory_strategy'))
apply_state_item('system_memory_limit', args.get('system_memory_limit'))
# misc
apply_state_item('skip_download', args.get('skip_download'))
apply_state_item('log_level', args.get('log_level'))
# jobs
apply_state_item('job_id', args.get('job_id'))

View File

@ -2,7 +2,7 @@ import logging
from typing import List, Sequence
from facefusion.common_helper import create_float_range, create_int_range
from facefusion.typing import Angle, ExecutionProviderSet, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
from facefusion.typing import Angle, DownloadProvider, ExecutionProviderSet, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
video_memory_strategies : List[VideoMemoryStrategy] = [ 'strict', 'moderate', 'tolerant' ]
@ -47,6 +47,8 @@ execution_provider_set : ExecutionProviderSet =\
'tensorrt': 'TensorrtExecutionProvider'
}
download_providers : List[DownloadProvider] = [ 'github', 'huggingface' ]
ui_workflows : List[UiWorkflow] = [ 'instant_runner', 'job_runner', 'job_manager' ]
job_statuses : List[JobStatus] = [ 'drafted', 'queued', 'completed', 'failed' ]

View File

@ -167,8 +167,8 @@ def force_download() -> ErrorCode:
processor_modules = get_processors_modules(available_processors)
for module in common_modules + processor_modules:
if hasattr(module, 'MODEL_SET'):
for model in module.MODEL_SET.values():
if hasattr(module, 'create_model_set'):
for model in module.create_model_set().values():
model_hashes = model.get('hashes')
model_sources = model.get('sources')

View File

@ -4,7 +4,7 @@ import ssl
import subprocess
import urllib.request
from functools import lru_cache
from typing import List, Tuple
from typing import List, Optional, Tuple
from urllib.parse import urlparse
from tqdm import tqdm
@ -30,8 +30,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ])
current_size = initial_size
progress.set_postfix(file = download_file_name)
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
while current_size < download_size:
if is_file(download_file_path):
current_size = get_file_size(download_file_path)
@ -129,3 +128,14 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
else:
invalid_source_paths.append(source_path)
return valid_source_paths, invalid_source_paths
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
download_providers = state_manager.get_item('download_providers')
for download_provider in download_providers:
if download_provider == 'github':
return 'https://github.com/facefusion/facefusion-assets/releases/download/' + base_name + '/' + file_name
if download_provider == 'huggingface':
return 'https://huggingface.co/facefusion/' + base_name + '/resolve/main/' + file_name
return None

View File

@ -62,12 +62,7 @@ def clear_processors_modules(processors : List[str]) -> None:
def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None:
queue_payloads = create_queue_payloads(temp_frame_paths)
with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
progress.set_postfix(
{
'execution_providers': state_manager.get_item('execution_providers'),
'execution_thread_count': state_manager.get_item('execution_thread_count'),
'execution_queue_count': state_manager.get_item('execution_queue_count')
})
progress.set_postfix(execution_providers = state_manager.get_item('execution_providers'))
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
queue : Queue[QueuePayload] = create_queue(queue_payloads)

View File

@ -10,7 +10,7 @@ import facefusion.processors.core as processors
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, voice_extractor, wording
from facefusion.audio import create_empty_audio_frame, get_voice_frame, read_static_voice
from facefusion.common_helper import get_first
from facefusion.download import conditional_download_hashes, conditional_download_sources
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.face_analyser import get_many_faces, get_one_face
from facefusion.face_helper import create_bounding_box, paste_back, warp_face_by_bounding_box, warp_face_by_face_landmark_5
from facefusion.face_masker import create_mouth_mask, create_occlusion_mask, create_static_box_mask
@ -24,49 +24,51 @@ from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.typing import ApplyStateItem, Args, AudioFrame, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import read_image, read_static_image, restrict_video_fps, write_image
MODEL_SET : ModelSet =\
{
'wav2lip_96':
def create_model_set() -> ModelSet:
return\
{
'hashes':
'wav2lip_96':
{
'lip_syncer':
'hashes':
{
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_96.hash',
'path': resolve_relative_path('../.assets/models/wav2lip_96.hash')
}
'lip_syncer':
{
'url': resolve_download_url('models-3.0.0', 'wav2lip_96.hash'),
'path': resolve_relative_path('../.assets/models/wav2lip_96.hash')
}
},
'sources':
{
'lip_syncer':
{
'url': resolve_download_url('models-3.0.0', 'wav2lip_96.onnx'),
'path': resolve_relative_path('../.assets/models/wav2lip_96.onnx')
}
},
'size': (96, 96)
},
'sources':
'wav2lip_gan_96':
{
'lip_syncer':
'hashes':
{
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_96.onnx',
'path': resolve_relative_path('../.assets/models/wav2lip_96.onnx')
}
},
'size': (96, 96)
},
'wav2lip_gan_96':
{
'hashes':
{
'lip_syncer':
'lip_syncer':
{
'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.hash'),
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.hash')
}
},
'sources':
{
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_gan_96.hash',
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.hash')
}
},
'sources':
{
'lip_syncer':
{
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_gan_96.onnx',
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.onnx')
}
},
'size': (96, 96)
'lip_syncer':
{
'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.onnx'),
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.onnx')
}
},
'size': (96, 96)
}
}
}
def get_inference_pool() -> InferencePool:
@ -82,7 +84,7 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
lip_syncer_model = state_manager.get_item('lip_syncer_model')
return MODEL_SET.get(lip_syncer_model)
return create_model_set().get(lip_syncer_model)
def register_args(program : ArgumentParser) -> None:

View File

@ -185,11 +185,12 @@ def create_memory_program() -> ArgumentParser:
return program
def create_skip_download_program() -> ArgumentParser:
def create_download_program() -> ArgumentParser:
program = ArgumentParser(add_help = False)
group_misc = program.add_argument_group('misc')
group_misc.add_argument('--skip-download', help = wording.get('help.skip_download'), action = 'store_true', default = config.get_bool_value('misc.skip_download'))
job_store.register_job_keys([ 'skip_download' ])
group_download = program.add_argument_group('download')
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
group_download.add_argument('--skip-download', help = wording.get('help.skip_download'), action = 'store_true', default = config.get_bool_value('misc.skip_download'))
job_store.register_job_keys([ 'download_providers', 'skip_download' ])
return program
@ -225,7 +226,7 @@ def collect_step_program() -> ArgumentParser:
def collect_job_program() -> ArgumentParser:
return ArgumentParser(parents= [ create_execution_program(), create_memory_program(), create_skip_download_program(), create_log_level_program() ], add_help = False)
return ArgumentParser(parents= [ create_execution_program(), create_download_program(), create_memory_program(), create_log_level_program() ], add_help = False)
def create_program() -> ArgumentParser:

View File

@ -109,13 +109,6 @@ OutputAudioEncoder = Literal['aac', 'libmp3lame', 'libopus', 'libvorbis']
OutputVideoEncoder = Literal['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf','h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox']
OutputVideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow']
Download = TypedDict('Download',
{
'url' : str,
'path' : str
})
DownloadSet = Dict[str, Download]
ModelOptions = Dict[str, Any]
ModelSet = Dict[str, ModelOptions]
ModelInitializer = NDArray[Any]
@ -124,6 +117,14 @@ ExecutionProviderKey = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino',
ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider']
ExecutionProviderSet = Dict[ExecutionProviderKey, ExecutionProviderValue]
Download = TypedDict('Download',
{
'url' : str,
'path' : str
})
DownloadSet = Dict[str, Download]
DownloadProvider = Literal['github', 'huggingface']
ValueAndUnit = TypedDict('ValueAndUnit',
{
'value' : int,
@ -237,9 +238,10 @@ StateKey = Literal\
'execution_providers',
'execution_thread_count',
'execution_queue_count',
'download_providers',
'skip_download',
'video_memory_strategy',
'system_memory_limit',
'skip_download',
'log_level',
'job_id',
'job_status',
@ -294,9 +296,10 @@ State = TypedDict('State',
'execution_providers' : List[ExecutionProviderKey],
'execution_thread_count' : int,
'execution_queue_count' : int,
'download_providers' : List[DownloadProvider],
'skip_download': bool,
'video_memory_strategy' : VideoMemoryStrategy,
'system_memory_limit' : int,
'skip_download' : bool,
'log_level' : LogLevel,
'job_id' : str,
'job_status' : JobStatus,

View File

@ -178,14 +178,16 @@ WORDING : Dict[str, Any] =\
'ui_workflow': 'choose the ui workflow',
# execution
'execution_device_id': 'specify the device used for processing',
'execution_providers': 'accelerate the model inference using different providers (choices: {choices}, ...)',
'execution_providers': 'inference using different providers (choices: {choices}, ...)',
'execution_thread_count': 'specify the amount of parallel threads while processing',
'execution_queue_count': 'specify the amount of frames each thread is processing',
# download
'download_providers': 'download using different providers (choices: {choices}, ...)',
'skip_download': 'omit downloads and remote lookups',
# memory
'video_memory_strategy': 'balance fast processing and low VRAM usage',
'system_memory_limit': 'limit the available RAM that can be used while processing',
# misc
'skip_download': 'omit downloads and remote lookups',
'log_level': 'adjust the message severity displayed in the terminal',
# run
'run': 'run the program',