Feat/download provider fallback (#837)
* Introduce download providers fallback, Use CURL everywhre * Fix CI * Use readlines() over readline() to avoid while * Use readlines() over readline() to avoid while * Use readlines() over readline() to avoid while
This commit is contained in:
parent
e26381753c
commit
034d029a41
@ -55,8 +55,16 @@ execution_provider_set : ExecutionProviderSet =\
|
|||||||
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
||||||
download_provider_set : DownloadProviderSet =\
|
download_provider_set : DownloadProviderSet =\
|
||||||
{
|
{
|
||||||
'github': 'https://github.com/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}',
|
'github':
|
||||||
'huggingface': 'https://huggingface.co/facefusion/{base_name}/resolve/main/{file_name}'
|
{
|
||||||
|
'url': 'https://github.com',
|
||||||
|
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
|
||||||
|
},
|
||||||
|
'huggingface':
|
||||||
|
{
|
||||||
|
'url': 'https://huggingface.co',
|
||||||
|
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
download_providers : List[DownloadProvider] = list(download_provider_set.keys())
|
download_providers : List[DownloadProvider] = list(download_provider_set.keys())
|
||||||
download_scopes : List[DownloadScope] = [ 'lite', 'full' ]
|
download_scopes : List[DownloadScope] = [ 'lite', 'full' ]
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import ssl
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import urllib.request
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -11,13 +9,15 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import facefusion.choices
|
import facefusion.choices
|
||||||
from facefusion import logger, process_manager, state_manager, wording
|
from facefusion import logger, process_manager, state_manager, wording
|
||||||
from facefusion.common_helper import is_macos
|
|
||||||
from facefusion.filesystem import get_file_size, is_file, remove_file
|
from facefusion.filesystem import get_file_size, is_file, remove_file
|
||||||
from facefusion.hash_helper import validate_hash
|
from facefusion.hash_helper import validate_hash
|
||||||
from facefusion.typing import DownloadProvider, DownloadSet
|
from facefusion.typing import DownloadProvider, DownloadSet
|
||||||
|
|
||||||
if is_macos():
|
|
||||||
ssl._create_default_https_context = ssl._create_unverified_context
|
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
|
||||||
|
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ]
|
||||||
|
commands.extend(args)
|
||||||
|
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
||||||
|
|
||||||
|
|
||||||
def conditional_download(download_directory_path : str, urls : List[str]) -> None:
|
def conditional_download(download_directory_path : str, urls : List[str]) -> None:
|
||||||
@ -25,13 +25,15 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
|||||||
download_file_name = os.path.basename(urlparse(url).path)
|
download_file_name = os.path.basename(urlparse(url).path)
|
||||||
download_file_path = os.path.join(download_directory_path, download_file_name)
|
download_file_path = os.path.join(download_directory_path, download_file_name)
|
||||||
initial_size = get_file_size(download_file_path)
|
initial_size = get_file_size(download_file_path)
|
||||||
download_size = get_download_size(url)
|
download_size = get_static_download_size(url)
|
||||||
|
|
||||||
if initial_size < download_size:
|
if initial_size < download_size:
|
||||||
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||||
subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ])
|
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ]
|
||||||
|
open_curl(commands)
|
||||||
current_size = initial_size
|
current_size = initial_size
|
||||||
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
||||||
|
|
||||||
while current_size < download_size:
|
while current_size < download_size:
|
||||||
if is_file(download_file_path):
|
if is_file(download_file_path):
|
||||||
current_size = get_file_size(download_file_path)
|
current_size = get_file_size(download_file_path)
|
||||||
@ -39,13 +41,27 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize = None)
|
@lru_cache(maxsize = None)
|
||||||
def get_download_size(url : str) -> int:
|
def get_static_download_size(url : str) -> int:
|
||||||
try:
|
commands = [ '-I', url ]
|
||||||
response = urllib.request.urlopen(url, timeout = 10)
|
process = open_curl(commands)
|
||||||
content_length = response.headers.get('Content-Length')
|
lines = reversed(process.stdout.readlines())
|
||||||
return int(content_length)
|
|
||||||
except (OSError, TypeError, ValueError):
|
for line in lines:
|
||||||
return 0
|
__line__ = line.decode().lower()
|
||||||
|
|
||||||
|
if 'content-length:' in __line__:
|
||||||
|
_, content_length = __line__.split('content-length:')
|
||||||
|
return int(content_length)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize = None)
|
||||||
|
def ping_static_url(url : str) -> bool:
|
||||||
|
commands = [ '-I', url ]
|
||||||
|
process = open_curl(commands)
|
||||||
|
process.wait()
|
||||||
|
return process.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
def conditional_download_hashes(hashes : DownloadSet) -> bool:
|
def conditional_download_hashes(hashes : DownloadSet) -> bool:
|
||||||
@ -61,6 +77,7 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool:
|
|||||||
conditional_download(download_directory_path, [ invalid_hash_url ])
|
conditional_download(download_directory_path, [ invalid_hash_url ])
|
||||||
|
|
||||||
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
|
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
|
||||||
|
|
||||||
for valid_hash_path in valid_hash_paths:
|
for valid_hash_path in valid_hash_paths:
|
||||||
valid_hash_file_name, _ = os.path.splitext(os.path.basename(valid_hash_path))
|
valid_hash_file_name, _ = os.path.splitext(os.path.basename(valid_hash_path))
|
||||||
logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__)
|
logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__)
|
||||||
@ -86,6 +103,7 @@ def conditional_download_sources(sources : DownloadSet) -> bool:
|
|||||||
conditional_download(download_directory_path, [ invalid_source_url ])
|
conditional_download(download_directory_path, [ invalid_source_url ])
|
||||||
|
|
||||||
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)
|
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)
|
||||||
|
|
||||||
for valid_source_path in valid_source_paths:
|
for valid_source_path in valid_source_paths:
|
||||||
valid_source_file_name, _ = os.path.splitext(os.path.basename(valid_source_path))
|
valid_source_file_name, _ = os.path.splitext(os.path.basename(valid_source_path))
|
||||||
logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__)
|
logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__)
|
||||||
@ -128,11 +146,17 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
|||||||
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
||||||
download_providers = state_manager.get_item('download_providers')
|
download_providers = state_manager.get_item('download_providers')
|
||||||
|
|
||||||
for download_provider in facefusion.choices.download_provider_set:
|
for download_provider in download_providers:
|
||||||
if download_provider in download_providers:
|
if ping_download_provider(download_provider):
|
||||||
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def ping_download_provider(download_provider : DownloadProvider) -> bool:
|
||||||
|
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||||
|
return ping_static_url(download_provider_value.get('url'))
|
||||||
|
|
||||||
|
|
||||||
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
||||||
return facefusion.choices.download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name)
|
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||||
|
return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
||||||
|
@ -22,10 +22,15 @@ def run_ffmpeg_with_progress(args: List[str], update_progress : UpdateProgress)
|
|||||||
|
|
||||||
while process_manager.is_processing():
|
while process_manager.is_processing():
|
||||||
try:
|
try:
|
||||||
while line := process.stdout.readline().decode():
|
lines = process.stdout.readlines()
|
||||||
if 'frame=' in line:
|
|
||||||
_, frame_number = line.split('frame=')
|
for line in lines:
|
||||||
|
__line__ = line.decode().lower()
|
||||||
|
|
||||||
|
if 'frame=' in __line__:
|
||||||
|
_, frame_number = __line__.split('frame=')
|
||||||
update_progress(int(frame_number))
|
update_progress(int(frame_number))
|
||||||
|
|
||||||
if log_level == 'debug':
|
if log_level == 'debug':
|
||||||
log_debug(process)
|
log_debug(process)
|
||||||
process.wait(timeout = 0.5)
|
process.wait(timeout = 0.5)
|
||||||
|
@ -206,7 +206,7 @@ def create_download_providers_program() -> ArgumentParser:
|
|||||||
program = ArgumentParser(add_help = False)
|
program = ArgumentParser(add_help = False)
|
||||||
download_providers = list(facefusion.choices.download_provider_set.keys())
|
download_providers = list(facefusion.choices.download_provider_set.keys())
|
||||||
group_download = program.add_argument_group('download')
|
group_download = program.add_argument_group('download')
|
||||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||||
job_store.register_job_keys([ 'download_providers' ])
|
job_store.register_job_keys([ 'download_providers' ])
|
||||||
return program
|
return program
|
||||||
|
|
||||||
|
@ -159,7 +159,12 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
|||||||
})
|
})
|
||||||
|
|
||||||
DownloadProvider = Literal['github', 'huggingface']
|
DownloadProvider = Literal['github', 'huggingface']
|
||||||
DownloadProviderSet = Dict[DownloadProvider, str]
|
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
||||||
|
{
|
||||||
|
'url' : str,
|
||||||
|
'path' : str
|
||||||
|
})
|
||||||
|
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
|
||||||
DownloadScope = Literal['lite', 'full']
|
DownloadScope = Literal['lite', 'full']
|
||||||
Download = TypedDict('Download',
|
Download = TypedDict('Download',
|
||||||
{
|
{
|
||||||
|
@ -1,18 +1,13 @@
|
|||||||
import pytest
|
from facefusion.download import get_static_download_size, ping_static_url
|
||||||
|
|
||||||
from facefusion.download import conditional_download, get_download_size
|
|
||||||
from .helper import get_test_examples_directory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope = 'module', autouse = True)
|
def test_get_static_download_size() -> None:
|
||||||
def before_all() -> None:
|
assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx') == 85170772
|
||||||
conditional_download(get_test_examples_directory(),
|
assert get_static_download_size('https://huggingface.co/facefusion/models-3.0.0/resolve/main/fairface.onnx') == 85170772
|
||||||
[
|
assert get_static_download_size('invalid') == 0
|
||||||
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4'
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_download_size() -> None:
|
def test_static_ping_url() -> None:
|
||||||
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4') == 191675
|
assert ping_static_url('https://github.com') is True
|
||||||
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-360p.mp4') == 370732
|
assert ping_static_url('https://huggingface.co') is True
|
||||||
assert get_download_size('invalid') == 0
|
assert ping_static_url('invalid') is False
|
||||||
|
Loading…
Reference in New Issue
Block a user