mirror of
https://gitee.com/mindspore/mindarmour.git
synced 2025-12-06 11:59:05 +08:00
@@ -355,32 +355,32 @@ def _crop(arr, crop_width):
|
||||
return cropped
|
||||
|
||||
|
||||
def compute_ssim(image1, image2):
|
||||
def compute_ssim(original_image, compared_image):
|
||||
"""
|
||||
compute structural similarity between two images.
|
||||
|
||||
Args:
|
||||
image1 (numpy.ndarray): The first image to be compared.
|
||||
image2 (numpy.ndarray): The second image to be compared.
|
||||
original_image (numpy.ndarray): The first image to be compared.
|
||||
compared_image (numpy.ndarray): The second image to be compared.
|
||||
|
||||
Returns:
|
||||
float, structural similarity.
|
||||
"""
|
||||
if not image1.shape == image2.shape:
|
||||
if not original_image.shape == compared_image.shape:
|
||||
msg = 'Input images must have the same dimensions, but got ' \
|
||||
'image1.shape: {} and image2.shape: {}' \
|
||||
.format(image1.shape, image2.shape)
|
||||
'original_image.shape: {} and compared_image.shape: {}' \
|
||||
.format(original_image.shape, compared_image.shape)
|
||||
LOGGER.error(TAG, msg)
|
||||
raise ValueError()
|
||||
if len(image1.shape) == 3: # rgb mode
|
||||
if image1.shape[0] in [1, 3]: # from nhw to hwn
|
||||
image1 = np.array(image1).transpose(1, 2, 0)
|
||||
image2 = np.array(image2).transpose(1, 2, 0)
|
||||
if len(original_image.shape) == 3: # rgb mode
|
||||
if original_image.shape[0] in [1, 3]: # from nhw to hwn
|
||||
original_image = np.array(original_image).transpose(1, 2, 0)
|
||||
compared_image = np.array(compared_image).transpose(1, 2, 0)
|
||||
# loop over channels
|
||||
n_channels = image1.shape[-1]
|
||||
n_channels = original_image.shape[-1]
|
||||
total_ssim = np.empty(n_channels)
|
||||
for ch in range(n_channels):
|
||||
ch_result = compute_ssim(image1[..., ch], image2[..., ch])
|
||||
ch_result = compute_ssim(original_image[..., ch], compared_image[..., ch])
|
||||
total_ssim[..., ch] = ch_result
|
||||
return total_ssim.mean()
|
||||
|
||||
@@ -388,28 +388,28 @@ def compute_ssim(image1, image2):
|
||||
k2 = 0.03
|
||||
win_size = 7
|
||||
|
||||
if np.any((np.asarray(image1.shape) - win_size) < 0):
|
||||
if np.any((np.asarray(original_image.shape) - win_size) < 0):
|
||||
msg = 'Size of each dimension must be larger win_size:7, ' \
|
||||
'but got image.shape:{}.' \
|
||||
.format(image1.shape)
|
||||
.format(original_image.shape)
|
||||
LOGGER.error(TAG, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
image1 = image1.astype(np.float64)
|
||||
image2 = image2.astype(np.float64)
|
||||
original_image = original_image.astype(np.float64)
|
||||
compared_image = compared_image.astype(np.float64)
|
||||
|
||||
ndim = image1.ndim
|
||||
ndim = original_image.ndim
|
||||
tmp = win_size ** ndim
|
||||
cov_norm = tmp / (tmp - 1)
|
||||
|
||||
# compute means
|
||||
ux = uniform_filter(image1, size=win_size)
|
||||
uy = uniform_filter(image2, size=win_size)
|
||||
ux = uniform_filter(original_image, size=win_size)
|
||||
uy = uniform_filter(compared_image, size=win_size)
|
||||
|
||||
# compute variances and covariances
|
||||
uxx = uniform_filter(image1*image1, size=win_size)
|
||||
uyy = uniform_filter(image2*image2, size=win_size)
|
||||
uxy = uniform_filter(image1*image2, size=win_size)
|
||||
uxx = uniform_filter(original_image*original_image, size=win_size)
|
||||
uyy = uniform_filter(compared_image*compared_image, size=win_size)
|
||||
uxy = uniform_filter(original_image*compared_image, size=win_size)
|
||||
|
||||
vx = cov_norm*(uxx - ux*ux)
|
||||
vy = cov_norm*(uyy - uy*uy)
|
||||
@@ -433,14 +433,14 @@ def compute_ssim(image1, image2):
|
||||
return mean_ssim
|
||||
|
||||
|
||||
def compute_psnr(image_true, image_test, data_range=None):
|
||||
def compute_psnr(original_image, compared_image, data_range=None):
|
||||
"""
|
||||
Compute the peak signal to noise ratio (PSNR) for an image.
|
||||
|
||||
Args:
|
||||
image_true : ndarray
|
||||
original_image : ndarray
|
||||
Ground-truth image, same shape as im_test.
|
||||
image_test : ndarray
|
||||
compared_image : ndarray
|
||||
Test image.
|
||||
data_range : int, optional
|
||||
The data range of the input image (distance between minimum and
|
||||
@@ -463,26 +463,26 @@ def compute_psnr(image_true, image_test, data_range=None):
|
||||
np.float32: (-1, 1),
|
||||
np.float64: (-1, 1),
|
||||
}
|
||||
image1 = image_true.astype(np.float64)
|
||||
image2 = image_test.astype(np.float64)
|
||||
if not image1.shape == image2.shape:
|
||||
original_image = original_image.astype(np.float64)
|
||||
compared_image = compared_image.astype(np.float64)
|
||||
if not original_image.shape == compared_image.shape:
|
||||
msg = 'Input images must have the same dimensions, but got ' \
|
||||
'image1.shape: {} and image2.shape: {}' \
|
||||
.format(image1.shape, image2.shape)
|
||||
'original_image.shape: {} and compared_image.shape: {}' \
|
||||
.format(original_image.shape, compared_image.shape)
|
||||
LOGGER.error(TAG, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if data_range is None:
|
||||
dmin, dmax = dtype_range.get(image1.dtype.type, [None, None])
|
||||
dmin, dmax = dtype_range.get(original_image.dtype.type, [None, None])
|
||||
if dmin is None or dmax is None:
|
||||
msg = 'Input image dtype error, the type should in {} ' \
|
||||
.format(list(dtype_range.keys()))
|
||||
LOGGER.error(TAG, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
true_min, true_max = np.min(image1), np.max(image2)
|
||||
true_min, true_max = np.min(original_image), np.max(compared_image)
|
||||
if true_max > dmax or true_min < dmin:
|
||||
msg = 'image_true has intensity values outside the range expected' \
|
||||
msg = 'original_image has intensity values outside the range expected' \
|
||||
'for its data type. Please manually specify the data_range.'
|
||||
LOGGER.error(TAG, msg)
|
||||
raise ValueError(msg)
|
||||
@@ -493,7 +493,7 @@ def compute_psnr(image_true, image_test, data_range=None):
|
||||
else:
|
||||
data_range = dmax - dmin
|
||||
|
||||
err = np.mean((image1 - image2)**2, dtype=np.float64)
|
||||
err = np.mean((original_image - compared_image)**2, dtype=np.float64)
|
||||
if err != 0:
|
||||
return 10 * np.log10((data_range**2) / err)
|
||||
return float('inf')
|
||||
|
||||
Reference in New Issue
Block a user