Source code for keras_aug.layers.augmentation.regularization.random_erase

import tensorflow as tf
from keras_cv.utils import fill_utils
from tensorflow import keras

from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class RandomErase(VectorizedBaseRandomLayer): """Randomly erase rectangles from images and fill them. Args: area_factor (float|Sequence[float]|keras_aug.FactorSampler, optional): The range of the area factor that controls the area of the erasing. When represented as a single float, the factor will be picked between ``[0.0, 0.0 + upper]``. ``0.0`` means the rectangle will be of size 0% of the image area. ``0.1`` means the rectangle will have a size of 10% of the image area. Defaults to ``(0.02, 0.4)`` aspect_ratio_factor (float|Sequence[float]|keras_aug.FactorSampler, optional): The range of the aspect ratio factor that controls the aspect ratio of the erasing. When represented as a single float, the factor will be picked between ``[1.0 - lower, 1.0 + upper]``. ``1.0`` means the erasing will be square. Defaults to ``(0.3, 1.0 / 0.3)``. fill_mode (str, optional): Pixels inside the erasing are filled according to the given mode. Supported values: ``"constant", "gaussian_noise"``. Defaults to ``"constant"``. fill_value (tuple(float), optional): The values to be filled in the erasing when ``fill_mode="constant"``. Defaults to ``(125, 123, 114)`` which is the means of ImageNet. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `Random Erasing <https://arxiv.org/abs/1708.04896v2>`_ - `Random Erasing Official Repo <https://github.com/zhunzhong07/Random-Erasing>`_ - `Tensorflow Model augment <https://github.com/tensorflow/models/blob/v2.12.0/official/vision/ops/augment.py>`_ - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__( self, area_factor=(0.02, 0.4), aspect_ratio_factor=(0.3, 1.0 / 0.3), fill_mode="constant", fill_value=(125, 123, 114), seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) if isinstance(area_factor, (int, float)): area_factor = (0, area_factor) self.area_factor = augmentation_utils.parse_factor( area_factor, seed=seed ) if isinstance(aspect_ratio_factor, (int, float)): aspect_ratio_factor = ( 1.0 - aspect_ratio_factor, 1.0 + aspect_ratio_factor, ) self.aspect_ratio_factor = augmentation_utils.parse_factor( aspect_ratio_factor, max_value=None, seed=seed ) if fill_mode not in ["gaussian_noise", "constant"]: raise ValueError( '`fill_mode` should be "gaussian_noise" ' f'or "constant". Got `fill_mode`={fill_mode}' ) self.fill_mode = fill_mode self.fill_value = tf.convert_to_tensor( fill_value, dtype=self.compute_dtype ) self.seed = seed self.max_attemp = 10 def get_random_transformation_batch( self, batch_size, images=None, **kwargs ): heights, widths = augmentation_utils.get_images_shape( images, dtype=self.compute_dtype ) areas = heights * widths is_success = False for _ in range(self.max_attemp): if not is_success: erasing_areas = self.area_factor( shape=(batch_size, 1), dtype=self.compute_dtype ) erasing_areas = erasing_areas * areas erasing_aspect_ratios = self.aspect_ratio_factor( shape=(batch_size, 1), dtype=self.compute_dtype ) erasing_heights = tf.round( tf.sqrt(erasing_areas * erasing_aspect_ratios) ) erasing_widths = tf.round( tf.sqrt(erasing_areas / erasing_aspect_ratios) ) if tf.reduce_all( tf.less(erasing_widths, widths) ) and tf.reduce_all(tf.less(erasing_heights, heights)): is_success = True center_xs = self._random_generator.random_uniform( shape=(batch_size, 1), minval=0, maxval=1, dtype=self.compute_dtype ) center_ys = self._random_generator.random_uniform( shape=(batch_size, 1), minval=0, maxval=1, dtype=self.compute_dtype ) center_xs = tf.round(center_xs * (widths - erasing_widths)) center_xs = tf.cast(center_xs + erasing_widths / 2, dtype=tf.int32) center_ys = tf.round(center_ys * (heights - erasing_heights)) center_ys = tf.cast(center_ys + erasing_heights / 2, dtype=tf.int32) return { "center_xs": center_xs, "center_ys": center_ys, "erasing_heights": tf.cast(erasing_heights, dtype=tf.int32), "erasing_widths": tf.cast(erasing_widths, dtype=tf.int32), } def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): images = tf.cast(images, dtype=self.compute_dtype) center_xs = transformations["center_xs"] center_ys = transformations["center_ys"] erasing_heights = transformations["erasing_heights"] erasing_widths = transformations["erasing_widths"] rectangle_fills = self.compute_rectangle_fills(images) images = fill_utils.fill_rectangle( images, center_xs[..., 0], center_ys[..., 0], erasing_widths[..., 0], erasing_heights[..., 0], rectangle_fills, ) return images def augment_labels(self, labels, transformations, **kwargs): return labels def compute_rectangle_fills(self, inputs): if self.fill_mode == "constant": fills = tf.ones(tf.shape(inputs), dtype=self.compute_dtype) fills = ( fills * self.fill_value[tf.newaxis, tf.newaxis, tf.newaxis, ...] ) else: # gaussian noise fills = tf.random.normal(tf.shape(inputs), dtype=self.compute_dtype) return fills def get_config(self): config = super().get_config() config.update( { "area_factor": self.area_factor, "aspect_ratio_factor": self.aspect_ratio_factor, "fill_mode": self.fill_mode, "fill_value": self.fill_value, "seed": self.seed, } ) return config