diff --git a/facefusion/choices.py b/facefusion/choices.py index 2775155b..6f128651 100755 --- a/facefusion/choices.py +++ b/facefusion/choices.py @@ -55,8 +55,16 @@ execution_provider_set : ExecutionProviderSet =\ execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys()) download_provider_set : DownloadProviderSet =\ { - 'github': 'https://github.com/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}', - 'huggingface': 'https://huggingface.co/facefusion/{base_name}/resolve/main/{file_name}' + 'github': + { + 'url': 'https://github.com', + 'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}' + }, + 'huggingface': + { + 'url': 'https://huggingface.co', + 'path': '/facefusion/{base_name}/resolve/main/{file_name}' + } } download_providers : List[DownloadProvider] = list(download_provider_set.keys()) download_scopes : List[DownloadScope] = [ 'lite', 'full' ] diff --git a/facefusion/download.py b/facefusion/download.py index 818e3ef3..43452c83 100644 --- a/facefusion/download.py +++ b/facefusion/download.py @@ -1,8 +1,6 @@ import os import shutil -import ssl import subprocess -import urllib.request from functools import lru_cache from typing import List, Optional, Tuple from urllib.parse import urlparse @@ -11,13 +9,15 @@ from tqdm import tqdm import facefusion.choices from facefusion import logger, process_manager, state_manager, wording -from facefusion.common_helper import is_macos from facefusion.filesystem import get_file_size, is_file, remove_file from facefusion.hash_helper import validate_hash from facefusion.typing import DownloadProvider, DownloadSet -if is_macos(): - ssl._create_default_https_context = ssl._create_unverified_context + +def open_curl(args : List[str]) -> subprocess.Popen[bytes]: + commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ] + commands.extend(args) + return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE) def conditional_download(download_directory_path : str, urls : List[str]) -> None: @@ -25,13 +25,15 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non download_file_name = os.path.basename(urlparse(url).path) download_file_path = os.path.join(download_directory_path, download_file_name) initial_size = get_file_size(download_file_path) - download_size = get_download_size(url) + download_size = get_static_download_size(url) if initial_size < download_size: with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: - subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ]) + commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ] + open_curl(commands) current_size = initial_size progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name) + while current_size < download_size: if is_file(download_file_path): current_size = get_file_size(download_file_path) @@ -39,13 +41,27 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non @lru_cache(maxsize = None) -def get_download_size(url : str) -> int: - try: - response = urllib.request.urlopen(url, timeout = 10) - content_length = response.headers.get('Content-Length') - return int(content_length) - except (OSError, TypeError, ValueError): - return 0 +def get_static_download_size(url : str) -> int: + commands = [ '-I', url ] + process = open_curl(commands) + lines = reversed(process.stdout.readlines()) + + for line in lines: + __line__ = line.decode().lower() + + if 'content-length:' in __line__: + _, content_length = __line__.split('content-length:') + return int(content_length) + + return 0 + + +@lru_cache(maxsize = None) +def ping_static_url(url : str) -> bool: + commands = [ '-I', url ] + process = open_curl(commands) + process.wait() + return process.returncode == 0 def conditional_download_hashes(hashes : DownloadSet) -> bool: @@ -61,6 +77,7 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool: conditional_download(download_directory_path, [ invalid_hash_url ]) valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths) + for valid_hash_path in valid_hash_paths: valid_hash_file_name, _ = os.path.splitext(os.path.basename(valid_hash_path)) logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__) @@ -86,6 +103,7 @@ def conditional_download_sources(sources : DownloadSet) -> bool: conditional_download(download_directory_path, [ invalid_source_url ]) valid_source_paths, invalid_source_paths = validate_source_paths(source_paths) + for valid_source_path in valid_source_paths: valid_source_file_name, _ = os.path.splitext(os.path.basename(valid_source_path)) logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__) @@ -128,11 +146,17 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str def resolve_download_url(base_name : str, file_name : str) -> Optional[str]: download_providers = state_manager.get_item('download_providers') - for download_provider in facefusion.choices.download_provider_set: - if download_provider in download_providers: + for download_provider in download_providers: + if ping_download_provider(download_provider): return resolve_download_url_by_provider(download_provider, base_name, file_name) return None +def ping_download_provider(download_provider : DownloadProvider) -> bool: + download_provider_value = facefusion.choices.download_provider_set.get(download_provider) + return ping_static_url(download_provider_value.get('url')) + + def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]: - return facefusion.choices.download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name) + download_provider_value = facefusion.choices.download_provider_set.get(download_provider) + return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name) diff --git a/facefusion/ffmpeg.py b/facefusion/ffmpeg.py index 48530282..05fe0fa7 100644 --- a/facefusion/ffmpeg.py +++ b/facefusion/ffmpeg.py @@ -22,10 +22,15 @@ def run_ffmpeg_with_progress(args: List[str], update_progress : UpdateProgress) while process_manager.is_processing(): try: - while line := process.stdout.readline().decode(): - if 'frame=' in line: - _, frame_number = line.split('frame=') + lines = process.stdout.readlines() + + for line in lines: + __line__ = line.decode().lower() + + if 'frame=' in __line__: + _, frame_number = __line__.split('frame=') update_progress(int(frame_number)) + if log_level == 'debug': log_debug(process) process.wait(timeout = 0.5) diff --git a/facefusion/program.py b/facefusion/program.py index 2f45d8e3..1a25aab2 100755 --- a/facefusion/program.py +++ b/facefusion/program.py @@ -206,7 +206,7 @@ def create_download_providers_program() -> ArgumentParser: program = ArgumentParser(add_help = False) download_providers = list(facefusion.choices.download_provider_set.keys()) group_download = program.add_argument_group('download') - group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') + group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') job_store.register_job_keys([ 'download_providers' ]) return program diff --git a/facefusion/typing.py b/facefusion/typing.py index 0e006db4..4b2f441a 100755 --- a/facefusion/typing.py +++ b/facefusion/typing.py @@ -159,7 +159,12 @@ ExecutionDevice = TypedDict('ExecutionDevice', }) DownloadProvider = Literal['github', 'huggingface'] -DownloadProviderSet = Dict[DownloadProvider, str] +DownloadProviderValue = TypedDict('DownloadProviderValue', +{ + 'url' : str, + 'path' : str +}) +DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue] DownloadScope = Literal['lite', 'full'] Download = TypedDict('Download', { diff --git a/tests/test_download.py b/tests/test_download.py index 3129b345..49865a6d 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,18 +1,13 @@ -import pytest - -from facefusion.download import conditional_download, get_download_size -from .helper import get_test_examples_directory +from facefusion.download import get_static_download_size, ping_static_url -@pytest.fixture(scope = 'module', autouse = True) -def before_all() -> None: - conditional_download(get_test_examples_directory(), - [ - 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' - ]) +def test_get_static_download_size() -> None: + assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx') == 85170772 + assert get_static_download_size('https://huggingface.co/facefusion/models-3.0.0/resolve/main/fairface.onnx') == 85170772 + assert get_static_download_size('invalid') == 0 -def test_get_download_size() -> None: - assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4') == 191675 - assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-360p.mp4') == 370732 - assert get_download_size('invalid') == 0 +def test_static_ping_url() -> None: + assert ping_static_url('https://github.com') is True + assert ping_static_url('https://huggingface.co') is True + assert ping_static_url('invalid') is False