Source code for keras_aug.layers.augmentation.intensity.random_posterize

import tensorflow as tf
from tensorflow import keras

from keras_aug.datapoints import image as image_utils
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class RandomPosterize(VectorizedBaseRandomLayer): """Randomly reduces the number of bits for each color channel. Args: value_range (Sequence[int|float]): The range of values the incoming images will have. This is typically either ``[0, 1]`` or ``[0, 255]`` depending on how your preprocessing pipeline is set up. factor (int|Sequence[int]|keras_aug.FactorSampler): The number of bits to keep for each channel. Must be a value between ``[1, 8]``. ``factor=(5, 8)`` means RandomPosterize will randomly keep 5 to 8 bits for the image. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `AutoAugment <https://arxiv.org/abs/1805.09501>`_ - `RandAugment <https://arxiv.org/abs/1909.13719>`_ - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__(self, value_range, factor, seed=None, **kwargs): super().__init__(seed=seed, **kwargs) self.value_range = value_range if isinstance(factor, int): if not (0 < factor < 9): raise ValueError( "factor value must be between [1, 8]. " f"Received bits: {factor}." ) factor = (factor, 8 + 1) elif isinstance(factor, (tuple, list)): factor = (factor[0], factor[1] + 1) self.factor = augmentation_utils.parse_factor( factor, min_value=0, max_value=8 + 1 ) self.seed = seed def get_random_transformation_batch(self, batch_size, **kwargs): # cannot sample from tf.int32 due to self.factor might be # NormalFactorSampler factors = self.factor(shape=(batch_size, 1)) factors = tf.clip_by_value(factors, 0, 8) return tf.cast(factors, dtype=tf.int32) def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = tf.expand_dims(transformation, axis=0) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): images = image_utils.transform_value_range( images, original_range=self.value_range, target_range=(0, 255), ) images = tf.cast(images, tf.uint8) inputs_for_posterize_single_image = { augmentation_utils.IMAGES: images, "bits": transformations, } images = tf.vectorized_map( self.posterize_single_image, inputs_for_posterize_single_image ) images = tf.cast(images, self.compute_dtype) images = image_utils.transform_value_range( images, original_range=(0, 255), target_range=self.value_range, dtype=self.compute_dtype, ) return images def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): return bounding_boxes def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): return segmentation_masks def augment_keypoints(self, keypoints, transformations, **kwargs): return keypoints def posterize_single_image(self, inputs): image = inputs.get(augmentation_utils.IMAGES, None) shift = 8 - tf.cast(inputs.get("bits", None), dtype=tf.uint8) return tf.bitwise.left_shift( tf.bitwise.right_shift(image, shift), shift ) def get_config(self): config = super().get_config() config.update( { "value_range": self.value_range, "factor": self.factor, "seed": self.seed, } ) return config