Introduce download providers
This commit is contained in:
parent
191331c712
commit
f9e5e7e2ce
@ -94,10 +94,13 @@ execution_providers =
|
||||
execution_thread_count =
|
||||
execution_queue_count =
|
||||
|
||||
[download]
|
||||
download_providers =
|
||||
skip_download =
|
||||
|
||||
[memory]
|
||||
video_memory_strategy =
|
||||
system_memory_limit =
|
||||
|
||||
[misc]
|
||||
skip_download =
|
||||
log_level =
|
||||
|
@ -106,11 +106,13 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
|
||||
apply_state_item('execution_providers', args.get('execution_providers'))
|
||||
apply_state_item('execution_thread_count', args.get('execution_thread_count'))
|
||||
apply_state_item('execution_queue_count', args.get('execution_queue_count'))
|
||||
# download
|
||||
apply_state_item('download_providers', args.get('download_providers'))
|
||||
apply_state_item('skip_download', args.get('skip_download'))
|
||||
# memory
|
||||
apply_state_item('video_memory_strategy', args.get('video_memory_strategy'))
|
||||
apply_state_item('system_memory_limit', args.get('system_memory_limit'))
|
||||
# misc
|
||||
apply_state_item('skip_download', args.get('skip_download'))
|
||||
apply_state_item('log_level', args.get('log_level'))
|
||||
# jobs
|
||||
apply_state_item('job_id', args.get('job_id'))
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import List, Sequence
|
||||
|
||||
from facefusion.common_helper import create_float_range, create_int_range
|
||||
from facefusion.typing import Angle, ExecutionProviderSet, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
|
||||
from facefusion.typing import Angle, DownloadProvider, ExecutionProviderSet, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
|
||||
|
||||
video_memory_strategies : List[VideoMemoryStrategy] = [ 'strict', 'moderate', 'tolerant' ]
|
||||
|
||||
@ -47,6 +47,8 @@ execution_provider_set : ExecutionProviderSet =\
|
||||
'tensorrt': 'TensorrtExecutionProvider'
|
||||
}
|
||||
|
||||
download_providers : List[DownloadProvider] = [ 'github', 'huggingface' ]
|
||||
|
||||
ui_workflows : List[UiWorkflow] = [ 'instant_runner', 'job_runner', 'job_manager' ]
|
||||
job_statuses : List[JobStatus] = [ 'drafted', 'queued', 'completed', 'failed' ]
|
||||
|
||||
|
@ -167,8 +167,8 @@ def force_download() -> ErrorCode:
|
||||
processor_modules = get_processors_modules(available_processors)
|
||||
|
||||
for module in common_modules + processor_modules:
|
||||
if hasattr(module, 'MODEL_SET'):
|
||||
for model in module.MODEL_SET.values():
|
||||
if hasattr(module, 'create_model_set'):
|
||||
for model in module.create_model_set().values():
|
||||
model_hashes = model.get('hashes')
|
||||
model_sources = model.get('sources')
|
||||
|
||||
|
@ -4,7 +4,7 @@ import ssl
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tqdm import tqdm
|
||||
@ -30,8 +30,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
||||
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||
subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ])
|
||||
current_size = initial_size
|
||||
|
||||
progress.set_postfix(file = download_file_name)
|
||||
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
||||
while current_size < download_size:
|
||||
if is_file(download_file_path):
|
||||
current_size = get_file_size(download_file_path)
|
||||
@ -129,3 +128,14 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
||||
else:
|
||||
invalid_source_paths.append(source_path)
|
||||
return valid_source_paths, invalid_source_paths
|
||||
|
||||
|
||||
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
||||
download_providers = state_manager.get_item('download_providers')
|
||||
|
||||
for download_provider in download_providers:
|
||||
if download_provider == 'github':
|
||||
return 'https://github.com/facefusion/facefusion-assets/releases/download/' + base_name + '/' + file_name
|
||||
if download_provider == 'huggingface':
|
||||
return 'https://huggingface.co/facefusion/' + base_name + '/resolve/main/' + file_name
|
||||
return None
|
||||
|
@ -62,12 +62,7 @@ def clear_processors_modules(processors : List[str]) -> None:
|
||||
def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None:
|
||||
queue_payloads = create_queue_payloads(temp_frame_paths)
|
||||
with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||
progress.set_postfix(
|
||||
{
|
||||
'execution_providers': state_manager.get_item('execution_providers'),
|
||||
'execution_thread_count': state_manager.get_item('execution_thread_count'),
|
||||
'execution_queue_count': state_manager.get_item('execution_queue_count')
|
||||
})
|
||||
progress.set_postfix(execution_providers = state_manager.get_item('execution_providers'))
|
||||
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
|
||||
futures = []
|
||||
queue : Queue[QueuePayload] = create_queue(queue_payloads)
|
||||
|
@ -10,7 +10,7 @@ import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, voice_extractor, wording
|
||||
from facefusion.audio import create_empty_audio_frame, get_voice_frame, read_static_voice
|
||||
from facefusion.common_helper import get_first
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.face_analyser import get_many_faces, get_one_face
|
||||
from facefusion.face_helper import create_bounding_box, paste_back, warp_face_by_bounding_box, warp_face_by_face_landmark_5
|
||||
from facefusion.face_masker import create_mouth_mask, create_occlusion_mask, create_static_box_mask
|
||||
@ -24,49 +24,51 @@ from facefusion.thread_helper import conditional_thread_semaphore
|
||||
from facefusion.typing import ApplyStateItem, Args, AudioFrame, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
||||
from facefusion.vision import read_image, read_static_image, restrict_video_fps, write_image
|
||||
|
||||
MODEL_SET : ModelSet =\
|
||||
{
|
||||
'wav2lip_96':
|
||||
|
||||
def create_model_set() -> ModelSet:
|
||||
return\
|
||||
{
|
||||
'hashes':
|
||||
'wav2lip_96':
|
||||
{
|
||||
'lip_syncer':
|
||||
'hashes':
|
||||
{
|
||||
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_96.hash',
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_96.hash')
|
||||
}
|
||||
'lip_syncer':
|
||||
{
|
||||
'url': resolve_download_url('models-3.0.0', 'wav2lip_96.hash'),
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_96.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'lip_syncer':
|
||||
{
|
||||
'url': resolve_download_url('models-3.0.0', 'wav2lip_96.onnx'),
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_96.onnx')
|
||||
}
|
||||
},
|
||||
'size': (96, 96)
|
||||
},
|
||||
'sources':
|
||||
'wav2lip_gan_96':
|
||||
{
|
||||
'lip_syncer':
|
||||
'hashes':
|
||||
{
|
||||
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_96.onnx',
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_96.onnx')
|
||||
}
|
||||
},
|
||||
'size': (96, 96)
|
||||
},
|
||||
'wav2lip_gan_96':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
'lip_syncer':
|
||||
'lip_syncer':
|
||||
{
|
||||
'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.hash'),
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_gan_96.hash',
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'lip_syncer':
|
||||
{
|
||||
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/wav2lip_gan_96.onnx',
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.onnx')
|
||||
}
|
||||
},
|
||||
'size': (96, 96)
|
||||
'lip_syncer':
|
||||
{
|
||||
'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.onnx'),
|
||||
'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.onnx')
|
||||
}
|
||||
},
|
||||
'size': (96, 96)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
@ -82,7 +84,7 @@ def clear_inference_pool() -> None:
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
lip_syncer_model = state_manager.get_item('lip_syncer_model')
|
||||
return MODEL_SET.get(lip_syncer_model)
|
||||
return create_model_set().get(lip_syncer_model)
|
||||
|
||||
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
|
@ -185,11 +185,12 @@ def create_memory_program() -> ArgumentParser:
|
||||
return program
|
||||
|
||||
|
||||
def create_skip_download_program() -> ArgumentParser:
|
||||
def create_download_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
group_misc = program.add_argument_group('misc')
|
||||
group_misc.add_argument('--skip-download', help = wording.get('help.skip_download'), action = 'store_true', default = config.get_bool_value('misc.skip_download'))
|
||||
job_store.register_job_keys([ 'skip_download' ])
|
||||
group_download = program.add_argument_group('download')
|
||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||
group_download.add_argument('--skip-download', help = wording.get('help.skip_download'), action = 'store_true', default = config.get_bool_value('misc.skip_download'))
|
||||
job_store.register_job_keys([ 'download_providers', 'skip_download' ])
|
||||
return program
|
||||
|
||||
|
||||
@ -225,7 +226,7 @@ def collect_step_program() -> ArgumentParser:
|
||||
|
||||
|
||||
def collect_job_program() -> ArgumentParser:
|
||||
return ArgumentParser(parents= [ create_execution_program(), create_memory_program(), create_skip_download_program(), create_log_level_program() ], add_help = False)
|
||||
return ArgumentParser(parents= [ create_execution_program(), create_download_program(), create_memory_program(), create_log_level_program() ], add_help = False)
|
||||
|
||||
|
||||
def create_program() -> ArgumentParser:
|
||||
|
@ -109,13 +109,6 @@ OutputAudioEncoder = Literal['aac', 'libmp3lame', 'libopus', 'libvorbis']
|
||||
OutputVideoEncoder = Literal['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf','h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox']
|
||||
OutputVideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow']
|
||||
|
||||
Download = TypedDict('Download',
|
||||
{
|
||||
'url' : str,
|
||||
'path' : str
|
||||
})
|
||||
DownloadSet = Dict[str, Download]
|
||||
|
||||
ModelOptions = Dict[str, Any]
|
||||
ModelSet = Dict[str, ModelOptions]
|
||||
ModelInitializer = NDArray[Any]
|
||||
@ -124,6 +117,14 @@ ExecutionProviderKey = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino',
|
||||
ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider']
|
||||
ExecutionProviderSet = Dict[ExecutionProviderKey, ExecutionProviderValue]
|
||||
|
||||
Download = TypedDict('Download',
|
||||
{
|
||||
'url' : str,
|
||||
'path' : str
|
||||
})
|
||||
DownloadSet = Dict[str, Download]
|
||||
DownloadProvider = Literal['github', 'huggingface']
|
||||
|
||||
ValueAndUnit = TypedDict('ValueAndUnit',
|
||||
{
|
||||
'value' : int,
|
||||
@ -237,9 +238,10 @@ StateKey = Literal\
|
||||
'execution_providers',
|
||||
'execution_thread_count',
|
||||
'execution_queue_count',
|
||||
'download_providers',
|
||||
'skip_download',
|
||||
'video_memory_strategy',
|
||||
'system_memory_limit',
|
||||
'skip_download',
|
||||
'log_level',
|
||||
'job_id',
|
||||
'job_status',
|
||||
@ -294,9 +296,10 @@ State = TypedDict('State',
|
||||
'execution_providers' : List[ExecutionProviderKey],
|
||||
'execution_thread_count' : int,
|
||||
'execution_queue_count' : int,
|
||||
'download_providers' : List[DownloadProvider],
|
||||
'skip_download': bool,
|
||||
'video_memory_strategy' : VideoMemoryStrategy,
|
||||
'system_memory_limit' : int,
|
||||
'skip_download' : bool,
|
||||
'log_level' : LogLevel,
|
||||
'job_id' : str,
|
||||
'job_status' : JobStatus,
|
||||
|
@ -178,14 +178,16 @@ WORDING : Dict[str, Any] =\
|
||||
'ui_workflow': 'choose the ui workflow',
|
||||
# execution
|
||||
'execution_device_id': 'specify the device used for processing',
|
||||
'execution_providers': 'accelerate the model inference using different providers (choices: {choices}, ...)',
|
||||
'execution_providers': 'inference using different providers (choices: {choices}, ...)',
|
||||
'execution_thread_count': 'specify the amount of parallel threads while processing',
|
||||
'execution_queue_count': 'specify the amount of frames each thread is processing',
|
||||
# download
|
||||
'download_providers': 'download using different providers (choices: {choices}, ...)',
|
||||
'skip_download': 'omit downloads and remote lookups',
|
||||
# memory
|
||||
'video_memory_strategy': 'balance fast processing and low VRAM usage',
|
||||
'system_memory_limit': 'limit the available RAM that can be used while processing',
|
||||
# misc
|
||||
'skip_download': 'omit downloads and remote lookups',
|
||||
'log_level': 'adjust the message severity displayed in the terminal',
|
||||
# run
|
||||
'run': 'run the program',
|
||||
|
Loading…
Reference in New Issue
Block a user