import tensorflow as tf
from tensorflow import keras
from keras_aug.datapoints import image as image_utils
from keras_aug.layers.base.vectorized_base_random_layer import (
VectorizedBaseRandomLayer,
)
[docs]@keras.utils.register_keras_serializable(package="keras_aug")
class Equalize(VectorizedBaseRandomLayer):
"""Performs histogram equalization on a channel-wise basis.
Args:
value_range (Sequence[int|float]): The range of values the incoming
images will have. This is typically either ``[0, 1]`` or
``[0, 255]`` depending on how your preprocessing pipeline is set up.
bins (int, optional): The number of bins to use in histogram
equalization. Should be in the range ``[0, 256]``. Defaults to
``256``.
References:
- `KerasCV <https://github.com/keras-team/keras-cv>`_
"""
def __init__(self, value_range, bins=256, **kwargs):
super().__init__(**kwargs)
self.value_range = value_range
self.bins = bins
def augment_ragged_image(self, image, transformation, **kwargs):
image = tf.expand_dims(image, axis=0)
image = self.augment_images(
images=image, transformations=transformation, **kwargs
)
return tf.squeeze(image, axis=0)
def augment_images(self, images, transformations=None, **kwargs):
original_shape = images.shape
images = image_utils.transform_value_range(
images, self.value_range, (0, 255), dtype=self.compute_dtype
)
images = tf.cast(images, dtype=tf.int32)
images = tf.map_fn(
self.equalize_single_image,
images,
)
images = tf.transpose(images, (0, 2, 3, 1))
images = tf.cast(images, dtype=self.compute_dtype)
images = image_utils.transform_value_range(
images, (0, 255), self.value_range, dtype=self.compute_dtype
)
images.set_shape(original_shape)
return tf.cast(images, dtype=self.compute_dtype)
def augment_labels(self, labels, transformations, **kwargs):
return labels
def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs):
return bounding_boxes
def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_masks
def augment_keypoints(self, keypoints, transformations, **kwargs):
return keypoints
def equalize_single_image(self, image):
return tf.map_fn(
lambda channel_index: self.equalize_single_channel(
image, channel_index
),
tf.range(tf.shape(image)[-1]),
)
def equalize_single_channel(self, image, channel_index):
image = image[..., channel_index]
# Compute the histogram of the image channel.
histogram = tf.histogram_fixed_width(image, [0, 255], nbins=self.bins)
# For the purposes of computing the step, filter out the non-zeros.
# Zeroes are replaced by a big number while calculating min to keep
# shape constant across input sizes for compatibility with
# vectorized_map
big_number = 1410065408
histogram_without_zeroes = tf.where(
tf.equal(histogram, 0),
tf.constant(big_number, dtype=histogram.dtype),
histogram,
)
step = (
tf.reduce_sum(histogram) - tf.reduce_min(histogram_without_zeroes)
) // (self.bins - 1)
def build_mapping(histogram, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lookup_table = (tf.cumsum(histogram) + (step // 2)) // step
# Shift lookup_table, prepending with 0.
lookup_table = tf.concat([[0], lookup_table[:-1]], 0)
# Clip the counts to be in range. This is done
# in the C code for image.point.
return tf.clip_by_value(lookup_table, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lookup table from the full histogram and step and then index from it.
image = tf.cond(
tf.equal(step, 0),
lambda: image,
lambda: tf.gather(build_mapping(histogram, step), image),
)
return image
def get_config(self):
config = super().get_config()
config.update({"value_range": self.value_range, "bins": self.bins})
return config