From 447ca53d54d0a999599ba76fb4d29463ec2054fd Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 4 Nov 2024 17:26:50 +0530 Subject: [PATCH] adaptive color correction --- facefusion/processors/modules/deep_swapper.py | 29 +++++-------------- facefusion/vision.py | 18 ++++++++++++ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index 0f7c2370..4ebb4e5b 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -20,7 +20,7 @@ from facefusion.processors.typing import DeepSwapperInputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import thread_semaphore from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame -from facefusion.vision import read_image, read_static_image, write_image +from facefusion.vision import adaptive_match_frame_color, read_image, read_static_image, write_image MODEL_SET : ModelSet =\ { @@ -127,8 +127,11 @@ def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFram crop_vision_frame = prepare_crop_frame(crop_vision_frame) crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame) crop_vision_frame = normalize_crop_frame(crop_vision_frame) - crop_vision_frame = match_frame_color_with_mask(crop_vision_frame_raw, crop_vision_frame, crop_source_mask, crop_target_mask) - crop_masks.append(numpy.maximum.reduce([ feather_crop_mask(crop_source_mask), feather_crop_mask(crop_target_mask) ]).clip(0, 1)) + crop_vision_frame = adaptive_match_frame_color(crop_vision_frame_raw, crop_vision_frame) + crop_source_mask = feather_crop_mask(crop_source_mask) + crop_target_mask = feather_crop_mask(crop_target_mask) + crop_combine_mask = numpy.maximum.reduce([ crop_source_mask, crop_target_mask ]) + crop_masks.append(crop_combine_mask) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) return paste_vision_frame @@ -167,27 +170,11 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: def feather_crop_mask(crop_source_mask : Mask) -> Mask: model_size = get_model_options().get('size') crop_mask = crop_source_mask.reshape(model_size).clip(0, 1) - crop_mask = cv2.erode(crop_mask, numpy.ones((7, 7), numpy.uint8), iterations = 1) - crop_mask = cv2.GaussianBlur(crop_mask, (15, 15), 0) + crop_mask = cv2.erode(crop_mask, numpy.ones((5, 5), numpy.uint8), iterations = 1) + crop_mask = cv2.GaussianBlur(crop_mask, (7, 7), 0) return crop_mask -def match_frame_color_with_mask(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, source_mask : Mask, target_mask : Mask) -> VisionFrame: - target_lab_frame = cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255 - source_lab_frame = cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255 - source_mask = (source_mask > 0.5).astype(numpy.float32) - target_mask = (target_mask > 0.5).astype(numpy.float32) - target_lab_filter = target_lab_frame * cv2.cvtColor(source_mask, cv2.COLOR_GRAY2BGR) - source_lab_filter = source_lab_frame * cv2.cvtColor(target_mask, cv2.COLOR_GRAY2BGR) - target_lab_frame -= target_lab_filter.mean(axis = ( 0, 1 )) - target_lab_frame /= target_lab_filter.std(axis = ( 0, 1 )) + 1e-6 - target_lab_frame *= source_lab_filter.std(axis = ( 0, 1 )) - target_lab_frame += source_lab_filter.mean(axis = ( 0, 1 )) - target_lab_frame = numpy.multiply(target_lab_frame.clip(0, 1), 255).astype(numpy.uint8) - target_vision_frame = cv2.cvtColor(target_lab_frame, cv2.COLOR_LAB2BGR) - return target_vision_frame - - def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: return swap_face(target_face, temp_vision_frame) diff --git a/facefusion/vision.py b/facefusion/vision.py index 26021e32..0a1127dd 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -210,6 +210,12 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame: return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB) +def adaptive_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame) + target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor) + return target_vision_frame + + def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False) @@ -228,6 +234,18 @@ def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame return target_vision_frame +def calc_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float: + histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + histogram_differnce = float(numpy.interp(cv2.compareHist(histogram_source, histogram_target, cv2.HISTCMP_CORREL), [ -1, 1 ], [ 0, 1 ])) + return histogram_differnce + + +def blend_vision_frames(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, factor : float) -> VisionFrame: + blend_vision_frame = cv2.addWeighted(target_vision_frame, 1 - factor, source_vision_frame, factor, 0) + return blend_vision_frame + + def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]: vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) tile_width = size[0] - 2 * size[2]