!569 更新反演攻击评估指标PSNR/SSIM

Merge pull request !569 from 刘思铭/master
This commit is contained in:
刘思铭
2024-06-05 07:46:49 +00:00
committed by Gitee

View File

@@ -355,32 +355,32 @@ def _crop(arr, crop_width):
return cropped
def compute_ssim(image1, image2):
def compute_ssim(original_image, compared_image):
"""
compute structural similarity between two images.
Args:
image1 (numpy.ndarray): The first image to be compared.
image2 (numpy.ndarray): The second image to be compared.
original_image (numpy.ndarray): The first image to be compared.
compared_image (numpy.ndarray): The second image to be compared.
Returns:
float, structural similarity.
"""
if not image1.shape == image2.shape:
if not original_image.shape == compared_image.shape:
msg = 'Input images must have the same dimensions, but got ' \
'image1.shape: {} and image2.shape: {}' \
.format(image1.shape, image2.shape)
'original_image.shape: {} and compared_image.shape: {}' \
.format(original_image.shape, compared_image.shape)
LOGGER.error(TAG, msg)
raise ValueError()
if len(image1.shape) == 3: # rgb mode
if image1.shape[0] in [1, 3]: # from nhw to hwn
image1 = np.array(image1).transpose(1, 2, 0)
image2 = np.array(image2).transpose(1, 2, 0)
if len(original_image.shape) == 3: # rgb mode
if original_image.shape[0] in [1, 3]: # from nhw to hwn
original_image = np.array(original_image).transpose(1, 2, 0)
compared_image = np.array(compared_image).transpose(1, 2, 0)
# loop over channels
n_channels = image1.shape[-1]
n_channels = original_image.shape[-1]
total_ssim = np.empty(n_channels)
for ch in range(n_channels):
ch_result = compute_ssim(image1[..., ch], image2[..., ch])
ch_result = compute_ssim(original_image[..., ch], compared_image[..., ch])
total_ssim[..., ch] = ch_result
return total_ssim.mean()
@@ -388,28 +388,28 @@ def compute_ssim(image1, image2):
k2 = 0.03
win_size = 7
if np.any((np.asarray(image1.shape) - win_size) < 0):
if np.any((np.asarray(original_image.shape) - win_size) < 0):
msg = 'Size of each dimension must be larger win_size:7, ' \
'but got image.shape:{}.' \
.format(image1.shape)
.format(original_image.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
image1 = image1.astype(np.float64)
image2 = image2.astype(np.float64)
original_image = original_image.astype(np.float64)
compared_image = compared_image.astype(np.float64)
ndim = image1.ndim
ndim = original_image.ndim
tmp = win_size ** ndim
cov_norm = tmp / (tmp - 1)
# compute means
ux = uniform_filter(image1, size=win_size)
uy = uniform_filter(image2, size=win_size)
ux = uniform_filter(original_image, size=win_size)
uy = uniform_filter(compared_image, size=win_size)
# compute variances and covariances
uxx = uniform_filter(image1*image1, size=win_size)
uyy = uniform_filter(image2*image2, size=win_size)
uxy = uniform_filter(image1*image2, size=win_size)
uxx = uniform_filter(original_image*original_image, size=win_size)
uyy = uniform_filter(compared_image*compared_image, size=win_size)
uxy = uniform_filter(original_image*compared_image, size=win_size)
vx = cov_norm*(uxx - ux*ux)
vy = cov_norm*(uyy - uy*uy)
@@ -433,14 +433,14 @@ def compute_ssim(image1, image2):
return mean_ssim
def compute_psnr(image_true, image_test, data_range=None):
def compute_psnr(original_image, compared_image, data_range=None):
"""
Compute the peak signal to noise ratio (PSNR) for an image.
Args:
image_true : ndarray
original_image : ndarray
Ground-truth image, same shape as im_test.
image_test : ndarray
compared_image : ndarray
Test image.
data_range : int, optional
The data range of the input image (distance between minimum and
@@ -463,26 +463,26 @@ def compute_psnr(image_true, image_test, data_range=None):
np.float32: (-1, 1),
np.float64: (-1, 1),
}
image1 = image_true.astype(np.float64)
image2 = image_test.astype(np.float64)
if not image1.shape == image2.shape:
original_image = original_image.astype(np.float64)
compared_image = compared_image.astype(np.float64)
if not original_image.shape == compared_image.shape:
msg = 'Input images must have the same dimensions, but got ' \
'image1.shape: {} and image2.shape: {}' \
.format(image1.shape, image2.shape)
'original_image.shape: {} and compared_image.shape: {}' \
.format(original_image.shape, compared_image.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if data_range is None:
dmin, dmax = dtype_range.get(image1.dtype.type, [None, None])
dmin, dmax = dtype_range.get(original_image.dtype.type, [None, None])
if dmin is None or dmax is None:
msg = 'Input image dtype error, the type should in {} ' \
.format(list(dtype_range.keys()))
LOGGER.error(TAG, msg)
raise ValueError(msg)
true_min, true_max = np.min(image1), np.max(image2)
true_min, true_max = np.min(original_image), np.max(compared_image)
if true_max > dmax or true_min < dmin:
msg = 'image_true has intensity values outside the range expected' \
msg = 'original_image has intensity values outside the range expected' \
'for its data type. Please manually specify the data_range.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
@@ -493,7 +493,7 @@ def compute_psnr(image_true, image_test, data_range=None):
else:
data_range = dmax - dmin
err = np.mean((image1 - image2)**2, dtype=np.float64)
err = np.mean((original_image - compared_image)**2, dtype=np.float64)
if err != 0:
return 10 * np.log10((data_range**2) / err)
return float('inf')