From 21bbdda13f26a28c5d6b9de5d2b4c538cb9eb937 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 5 Nov 2024 21:17:33 +0530 Subject: [PATCH] changes --- facefusion/processors/modules/deep_swapper.py | 4 ++-- facefusion/uis/components/deep_swapper_options.py | 12 ++++++------ facefusion/vision.py | 2 +- tests/test_vision.py | 14 ++++++++++---- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index b6c8de89..8b7ec1f1 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -21,7 +21,7 @@ from facefusion.processors.typing import DeepSwapperInputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import thread_semaphore from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame -from facefusion.vision import adaptive_match_frame_color, read_image, read_static_image, write_image +from facefusion.vision import conditional_match_frame_color, read_image, read_static_image, write_image MODEL_SET : ModelSet =\ { @@ -133,7 +133,7 @@ def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFram crop_vision_frame = prepare_crop_frame(crop_vision_frame) crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame) crop_vision_frame = normalize_crop_frame(crop_vision_frame) - crop_vision_frame = adaptive_match_frame_color(crop_vision_frame_raw, crop_vision_frame) + crop_vision_frame = conditional_match_frame_color(crop_vision_frame_raw, crop_vision_frame) crop_masks.append(prepare_crop_mask(crop_source_mask, crop_target_mask)) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) diff --git a/facefusion/uis/components/deep_swapper_options.py b/facefusion/uis/components/deep_swapper_options.py index 77b27572..6b2136b2 100755 --- a/facefusion/uis/components/deep_swapper_options.py +++ b/facefusion/uis/components/deep_swapper_options.py @@ -5,7 +5,7 @@ import gradio from facefusion import state_manager, wording from facefusion.processors import choices as processors_choices from facefusion.processors.core import load_processor_module -from facefusion.processors.typing import FaceEnhancerModel +from facefusion.processors.typing import DeepSwapperModel from facefusion.uis.core import get_ui_component, register_ui_component DEEP_SWAPPER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None @@ -24,7 +24,7 @@ def render() -> None: def listen() -> None: - DEEP_SWAPPER_MODEL_DROPDOWN.change(update_face_enhancer_model, inputs = DEEP_SWAPPER_MODEL_DROPDOWN, outputs = DEEP_SWAPPER_MODEL_DROPDOWN) + DEEP_SWAPPER_MODEL_DROPDOWN.change(update_deep_swapper_model, inputs = DEEP_SWAPPER_MODEL_DROPDOWN, outputs = DEEP_SWAPPER_MODEL_DROPDOWN) processors_checkbox_group = get_ui_component('processors_checkbox_group') if processors_checkbox_group: @@ -32,14 +32,14 @@ def listen() -> None: def remote_update(processors : List[str]) -> gradio.Dropdown: - has_face_enhancer = 'deep_swapper' in processors - return gradio.Dropdown(visible = has_face_enhancer) + has_deep_swapper = 'deep_swapper' in processors + return gradio.Dropdown(visible = has_deep_swapper) -def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> gradio.Dropdown: +def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> gradio.Dropdown: deep_swapper_module = load_processor_module('deep_swapper') deep_swapper_module.clear_inference_pool() - state_manager.set_item('deep_swapper_model', face_enhancer_model) + state_manager.set_item('deep_swapper_model', deep_swapper_model) if deep_swapper_module.pre_check(): return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')) diff --git a/facefusion/vision.py b/facefusion/vision.py index 337b9b9e..8d1fac93 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -210,7 +210,7 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame: return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB) -def adaptive_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: +def conditional_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame) target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor) return target_vision_frame diff --git a/tests/test_vision.py b/tests/test_vision.py index 9f5e770c..8922fe21 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -4,7 +4,7 @@ import cv2 import pytest from facefusion.download import conditional_download -from facefusion.vision import count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, match_frame_color, normalize_resolution, pack_resolution, read_image, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution +from facefusion.vision import calc_histogram_difference, count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, match_frame_color, normalize_resolution, pack_resolution, read_image, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution from .helper import get_test_example_file, get_test_examples_directory @@ -117,11 +117,17 @@ def test_unpack_resolution() -> None: assert unpack_resolution('2x2') == (2, 2) +def test_calc_histogram_difference() -> None: + source_vision_frame = read_image(get_test_example_file('target-1080p.jpg')) + target_vision_frame = cv2.cvtColor(cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR) + + assert calc_histogram_difference(source_vision_frame, source_vision_frame) == 1.0 + assert calc_histogram_difference(source_vision_frame, target_vision_frame) < 0.5 + + def test_match_frame_color() -> None: source_vision_frame = read_image(get_test_example_file('target-1080p.jpg')) target_vision_frame = cv2.cvtColor(cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR) output_vision_frame = match_frame_color(source_vision_frame, target_vision_frame) - histogram_source = cv2.calcHist([ cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) - histogram_output = cv2.calcHist([ cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) - assert cv2.compareHist(histogram_source, histogram_output, cv2.HISTCMP_CORREL) > 0.5 + assert calc_histogram_difference(source_vision_frame, output_vision_frame) > 0.5