Source code for keras_aug.layers.preprocessing.intensity.auto_contrast

import tensorflow as tf
from tensorflow import keras

from keras_aug.datapoints import image as image_utils
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class AutoContrast(VectorizedBaseRandomLayer): """Performs the AutoContrast operation on the input images. Auto contrast stretches the values of an image across the entire available ``value_range``. This makes differences between pixels more obvious. An example of this is if an image only has values ``[0, 1]`` out of the range ``[0, 255]``, auto contrast will change the ``1`` values to be ``255``. Args: value_range (Sequence[int|float]): The range of values the incoming images will have. This is typically either ``[0, 1]`` or ``[0, 255]`` depending on how your preprocessing pipeline is set up. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ def __init__(self, value_range, **kwargs): super().__init__(**kwargs) self.value_range = value_range def augment_ragged_image(self, image, transformation, **kwargs): images = tf.expand_dims(image, axis=0) images = self.augment_images( images=images, transformations=transformation, **kwargs ) return tf.squeeze(images, axis=0) def augment_images(self, images, transformations, **kwargs): images = image_utils.transform_value_range( images, original_range=self.value_range, target_range=(0, 255), dtype=self.compute_dtype, ) lows = tf.reduce_min(images, axis=(1, 2), keepdims=True) highs = tf.reduce_max(images, axis=(1, 2), keepdims=True) scales = 255.0 / (highs - lows) eq_idxs = tf.math.is_inf(scales) lows = tf.where(eq_idxs, 0.0, lows) scales = tf.where(eq_idxs, 1.0, scales) images = tf.clip_by_value((images - lows) * scales, 0, 255) images = image_utils.transform_value_range( images, original_range=(0, 255), target_range=self.value_range, dtype=self.compute_dtype, ) return images def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): return bounding_boxes def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): return segmentation_masks def augment_keypoints(self, keypoints, transformations, **kwargs): return keypoints def get_config(self): config = super().get_config() config.update({"value_range": self.value_range}) return config