diff --git a/tests/test_vision.py b/tests/test_vision.py index 364de536..418ce7af 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -1,5 +1,6 @@ import subprocess +import cv2 import pytest from facefusion.download import conditional_download @@ -117,5 +118,11 @@ def test_unpack_resolution() -> None: def test_match_frame_color() -> None: - assert match_frame_color(read_image(get_test_example_file('target-1080p.jpg')), read_image(get_test_example_file('target-240p.jpg'))).shape == (226, 426, 3) - assert match_frame_color(read_image(get_test_example_file('target-240p.jpg')), read_image(get_test_example_file('target-1080p.jpg'))).shape == (1080, 2048, 3) + source_vision_frame = read_image(get_test_example_file('target-1080p.jpg')) + target_vision_frame = cv2.cvtColor(cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR) + output_vision_frame = match_frame_color(source_vision_frame, target_vision_frame) + histogram_source = cv2.calcHist([ cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + histogram_output = cv2.calcHist([ cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + cv2.normalize(histogram_source, histogram_source, 0, 1, cv2.NORM_MINMAX) + cv2.normalize(histogram_output, histogram_output, 0, 1, cv2.NORM_MINMAX) + assert cv2.compareHist(histogram_source, histogram_output, cv2.HISTCMP_CORREL) > 0.5