Source code for keras_aug.layers.preprocessing.geometry.pad_if_needed

import tensorflow as tf
from keras_cv import bounding_box
from tensorflow import keras

from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class PadIfNeeded(VectorizedBaseRandomLayer): """Pads the images if needed. PadIfNeeded can be configured by specifying the height/width or the divisors. The images will be padded to ``(min_height, min_width)`` or the size of the both sides to be divisible by ``height_divisor`` and ``width_divisor``. PadIfNeeded is required to specify ``min_height`` or ``height_divisor`` and ``min_width`` or ``width_divisor``, respectively. Args: min_height (int, optional): The height of result image. min_width (int, optional): The width of result image. height_divisor (int, optional): The divisor that ensures image height is divisible by. width_divisor (int, optional): The divisor that ensures image width is divisible by. position (str, optional): The padding method. Supported values: ``"center", "top_left", "top_right", "bottom_left", "bottom_right", "random"``. Defaults to ``"center"``. padding_value (int|float, optional): The padding value. Defaults to ``0``. bounding_box_format (str, optional): The format of bounding boxes of input dataset. Refer https://github.com/keras-team/keras-cv/blob/master/keras_cv/bounding_box/converters.py for more details on supported bounding box formats. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `Albumentations <https://github.com/albumentations-team/albumentations>`_ """ # noqa: E501 def __init__( self, min_height=None, min_width=None, pad_height_divisor=None, pad_width_divisor=None, position="center", padding_value=0, bounding_box_format=None, seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) if (min_height is None or min_width is None) and ( pad_height_divisor is None or pad_width_divisor is None ): raise ValueError( "`PadIfNeeded()` expects at most one of " "(`min_height`, `min_width`) or " "(`pad_height_divisor`, `pad_width_divisor`) to be set." ) self.min_height = min_height self.min_width = min_width self.pad_height_divisor = pad_height_divisor self.pad_width_divisor = pad_width_divisor self.position = augmentation_utils.get_padding_position(position) self.padding_value = padding_value self.bounding_box_format = bounding_box_format self.seed = seed def get_random_transformation_batch( self, batch_size, images=None, **kwargs ): heights, widths = augmentation_utils.get_images_shape(images) if self.min_height is not None: tops = tf.where( heights < self.min_height, tf.cast((self.min_height - heights) / 2, heights.dtype), tf.zeros_like(heights, dtype=heights.dtype), ) bottoms = tf.where( heights < self.min_height, self.min_height - heights - tops, tf.zeros_like(heights, dtype=heights.dtype), ) else: pad_remaineds = heights % self.pad_height_divisor pad_rows = tf.where( tf.math.greater(pad_remaineds, 0), self.pad_height_divisor - pad_remaineds, tf.zeros_like(pad_remaineds, dtype=pad_remaineds.dtype), ) pad_rows = tf.cast(pad_rows, dtype=tf.float32) tops = tf.round(pad_rows / 2.0) bottoms = tf.cast(pad_rows - tops, dtype=tf.int32) tops = tf.cast(tops, dtype=tf.int32) if self.min_width is not None: lefts = tf.where( widths < self.min_width, tf.cast((self.min_width - widths) / 2, widths.dtype), tf.zeros_like(widths, dtype=widths.dtype), ) rights = tf.where( widths < self.min_width, self.min_width - widths - lefts, tf.zeros_like(widths, dtype=widths.dtype), ) else: pad_remaineds = widths % self.pad_width_divisor pad_cols = tf.where( tf.math.greater(pad_remaineds, 0), self.pad_width_divisor - pad_remaineds, tf.zeros_like(pad_remaineds, dtype=pad_remaineds.dtype), ) pad_cols = tf.cast(pad_cols, dtype=tf.float32) lefts = tf.round(pad_cols / 2.0) rights = tf.cast(pad_cols - lefts, dtype=tf.int32) lefts = tf.cast(lefts, dtype=tf.int32) (tops, bottoms, lefts, rights) = augmentation_utils.get_position_params( tops, bottoms, lefts, rights, self.position, self._random_generator ) return { "pad_tops": tops, "pad_bottoms": bottoms, "pad_lefts": lefts, "pad_rights": rights, } def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): pad_top = transformations["pad_tops"][0][0] pad_bottom = transformations["pad_bottoms"][0][0] pad_left = transformations["pad_lefts"][0][0] pad_right = transformations["pad_rights"][0][0] paddings = tf.stack( ( tf.zeros(shape=(2,), dtype=pad_top.dtype), tf.stack((pad_top, pad_bottom)), tf.stack((pad_left, pad_right)), tf.zeros(shape=(2,), dtype=pad_top.dtype), ) ) images = tf.pad( images, paddings=paddings, constant_values=self.padding_value ) return images def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes( self, bounding_boxes, transformations, images=None, raw_images=None, **kwargs, ): if self.bounding_box_format is None: raise ValueError( "`PadIfNeeded()` was called with bounding boxes," "but no `bounding_box_format` was specified in the constructor." "Please specify a bounding box format in the constructor. i.e." "`PadIfNeeded(..., bounding_box_format='xyxy')`" ) bounding_boxes = bounding_box.to_dense(bounding_boxes) bounding_boxes = bounding_box.convert_format( bounding_boxes, source=self.bounding_box_format, target="xyxy", images=raw_images, ) pad_tops = tf.cast(transformations["pad_tops"], dtype=tf.float32) pad_lefts = tf.cast(transformations["pad_lefts"], dtype=tf.float32) x1s, y1s, x2s, y2s = tf.split(bounding_boxes["boxes"], 4, axis=-1) x1s += tf.expand_dims(pad_lefts, axis=1) y1s += tf.expand_dims(pad_tops, axis=1) x2s += tf.expand_dims(pad_lefts, axis=1) y2s += tf.expand_dims(pad_tops, axis=1) outputs = tf.concat([x1s, y1s, x2s, y2s], axis=-1) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = outputs bounding_boxes = bounding_box.convert_format( bounding_boxes, source="xyxy", target=self.bounding_box_format, dtype=self.compute_dtype, images=images, ) return bounding_boxes def augment_ragged_segmentation_mask( self, segmentation_mask, transformation, **kwargs ): segmentation_mask = tf.expand_dims(segmentation_mask, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) segmentation_mask = self.augment_segmentation_masks( segmentation_masks=segmentation_mask, transformations=transformation, **kwargs, ) return tf.squeeze(segmentation_mask, axis=0) def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): pad_top = transformations["pad_tops"][0][0] pad_bottom = transformations["pad_bottoms"][0][0] pad_left = transformations["pad_lefts"][0][0] pad_right = transformations["pad_rights"][0][0] paddings = tf.stack( ( tf.zeros(shape=(2,), dtype=pad_top.dtype), tf.stack((pad_top, pad_bottom)), tf.stack((pad_left, pad_right)), tf.zeros(shape=(2,), dtype=pad_top.dtype), ) ) segmentation_masks = tf.pad( segmentation_masks, paddings=paddings, constant_values=0 ) return segmentation_masks def augment_segmentation_mask_single(self, inputs): segmentation_mask = inputs.get( augmentation_utils.SEGMENTATION_MASKS, None ) transformation = inputs.get("transformation", None) pad_top = transformation["pad_tops"][0] pad_bottom = transformation["pad_bottoms"][0] pad_left = transformation["pad_lefts"][0] pad_right = transformation["pad_rights"][0] paddings = tf.stack( ( tf.stack((pad_top, pad_bottom)), tf.stack((pad_left, pad_right)), tf.zeros(shape=(2,), dtype=pad_top.dtype), ) ) segmentation_mask = tf.pad( segmentation_mask, paddings=paddings, constant_values=0 ) return segmentation_mask def get_config(self): config = super().get_config() config.update( { "min_height": self.min_height, "min_width": self.min_width, "pad_height_divisor": self.pad_height_divisor, "pad_width_divisor": self.pad_width_divisor, "position": self.position, "padding_value": self.padding_value, "bounding_box_format": self.bounding_box_format, "seed": self.seed, } ) return config