Source code for keras_aug.layers.preprocessing.geometry.center_crop

import tensorflow as tf
from tensorflow import keras

from keras_aug.datapoints import bounding_box
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class CenterCrop(VectorizedBaseRandomLayer): """Center crops the images. CenterCrop crops the central portion of the images to a specified ``(height, width)``. If an image is smaller than the target size, it will be padded and then cropped. Args: height (int): The height of result image. width (int): The width of result image. padding_value (int|float, optional): The padding value. Defaults to ``0``. bounding_box_format (str, optional): The format of bounding boxes of input dataset. Refer https://github.com/james77777778/keras-aug/blob/main/keras_aug/datapoints/bounding_box/converter.py for more details on supported bounding box formats. bounding_box_min_area_ratio (float, optional): The threshold to apply sanitize_bounding_boxes. Defaults to ``None``. bounding_box_max_aspect_ratio (float, optional): The threshold to apply sanitize_bounding_boxes. Defaults to ``None``. seed (int|float, optional): The random seed. Defaults to ``None``. """ # noqa: E501 def __init__( self, height, width, padding_value=0, bounding_box_format=None, bounding_box_min_area_ratio=None, bounding_box_max_aspect_ratio=None, seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) self.height = height self.width = width self.position = augmentation_utils.get_padding_position("center") self.padding_value = padding_value self.bounding_box_format = bounding_box_format self.bounding_box_min_area_ratio = bounding_box_min_area_ratio self.bounding_box_max_aspect_ratio = bounding_box_max_aspect_ratio self.seed = seed # set force_output_dense_images=True because the output images must # have same shape (B, height, width, C) self.force_output_dense_images = True def get_random_transformation_batch( self, batch_size, images=None, **kwargs ): heights, widths = augmentation_utils.get_images_shape(images) tops = tf.where( heights < self.height, tf.cast((self.height - heights) / 2, heights.dtype), tf.zeros_like(heights, dtype=heights.dtype), ) bottoms = tf.where( heights < self.height, self.height - heights - tops, tf.zeros_like(heights, dtype=heights.dtype), ) lefts = tf.where( widths < self.width, tf.cast((self.width - widths) / 2, widths.dtype), tf.zeros_like(widths, dtype=widths.dtype), ) rights = tf.where( widths < self.width, self.width - widths - lefts, tf.zeros_like(widths, dtype=widths.dtype), ) (tops, bottoms, lefts, rights) = augmentation_utils.get_position_params( tops, bottoms, lefts, rights, self.position, self._random_generator ) return { "pad_tops": tops, "pad_bottoms": bottoms, "pad_lefts": lefts, "pad_rights": rights, } def compute_ragged_image_signature(self, images): return tf.RaggedTensorSpec( shape=(self.height, self.width, images.shape[-1]), ragged_rank=1, dtype=self.compute_dtype, ) def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): ori_height = tf.shape(images)[augmentation_utils.H_AXIS] ori_width = tf.shape(images)[augmentation_utils.W_AXIS] pad_top = transformations["pad_tops"][0][0] pad_bottom = transformations["pad_bottoms"][0][0] pad_left = transformations["pad_lefts"][0][0] pad_right = transformations["pad_rights"][0][0] paddings = tf.stack( ( tf.zeros(shape=(2,), dtype=pad_top.dtype), tf.stack((pad_top, pad_bottom)), tf.stack((pad_left, pad_right)), tf.zeros(shape=(2,), dtype=pad_top.dtype), ) ) images = tf.pad( images, paddings=paddings, constant_values=self.padding_value ) # center crop offset_height = (ori_height + pad_top + pad_bottom - self.height) // 2 offset_width = (ori_width + pad_left + pad_right - self.width) // 2 images = tf.image.crop_to_bounding_box( images, offset_height, offset_width, self.height, self.width, ) images = tf.ensure_shape( images, shape=(None, self.height, self.width, None) ) return tf.cast(images, dtype=self.compute_dtype) def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes( self, bounding_boxes, transformations, images=None, raw_images=None, **kwargs, ): if self.bounding_box_format is None: raise ValueError( "`CenterCrop()` was called with bounding boxes," "but no `bounding_box_format` was specified in the constructor." "Please specify a bounding box format in the constructor. i.e." "`CenterCrop(..., bounding_box_format='xyxy')`" ) bounding_boxes = bounding_box.to_dense(bounding_boxes) bounding_boxes = bounding_box.convert_format( bounding_boxes, source=self.bounding_box_format, target="xyxy", images=raw_images, ) original_bounding_boxes = bounding_boxes.copy() x1s, y1s, x2s, y2s = tf.split(bounding_boxes["boxes"], 4, axis=-1) pad_tops = tf.cast(transformations["pad_tops"], dtype=tf.float32) pad_lefts = tf.cast(transformations["pad_lefts"], dtype=tf.float32) pad_bottoms = tf.cast(transformations["pad_bottoms"], dtype=tf.float32) pad_rights = tf.cast(transformations["pad_rights"], dtype=tf.float32) heights, widths = augmentation_utils.get_images_shape( raw_images, dtype=tf.float32 ) offset_heights = (heights + pad_tops + pad_bottoms - self.height) // 2 offset_widths = (widths + pad_lefts + pad_rights - self.width) // 2 x1s += tf.expand_dims(pad_lefts - offset_widths, axis=1) y1s += tf.expand_dims(pad_tops - offset_heights, axis=1) x2s += tf.expand_dims(pad_lefts - offset_widths, axis=1) y2s += tf.expand_dims(pad_tops - offset_heights, axis=1) outputs = tf.concat([x1s, y1s, x2s, y2s], axis=-1) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = outputs bounding_boxes = bounding_box.clip_to_image( bounding_boxes, bounding_box_format="xyxy", images=images, ) bounding_boxes = bounding_box.sanitize_bounding_boxes( bounding_boxes, min_area_ratio=self.bounding_box_min_area_ratio, max_aspect_ratio=self.bounding_box_max_aspect_ratio, bounding_box_format="xyxy", reference_bounding_boxes=original_bounding_boxes, images=images, reference_images=raw_images, ) bounding_boxes = bounding_box.convert_format( bounding_boxes, source="xyxy", target=self.bounding_box_format, dtype=self.compute_dtype, images=images, ) return bounding_boxes def compute_ragged_segmentation_mask_signature(self, segmentation_masks): return tf.RaggedTensorSpec( shape=(self.height, self.width, segmentation_masks.shape[-1]), ragged_rank=1, dtype=self.compute_dtype, ) def augment_ragged_segmentation_mask( self, segmentation_mask, transformation, **kwargs ): segmentation_mask = tf.expand_dims(segmentation_mask, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) segmentation_mask = self.augment_segmentation_masks( segmentation_masks=segmentation_mask, transformations=transformation, **kwargs, ) return tf.squeeze(segmentation_mask, axis=0) def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): ori_height = tf.shape(segmentation_masks)[augmentation_utils.H_AXIS] ori_width = tf.shape(segmentation_masks)[augmentation_utils.W_AXIS] pad_top = transformations["pad_tops"][0][0] pad_bottom = transformations["pad_bottoms"][0][0] pad_left = transformations["pad_lefts"][0][0] pad_right = transformations["pad_rights"][0][0] paddings = tf.stack( ( tf.zeros(shape=(2,), dtype=pad_top.dtype), tf.stack((pad_top, pad_bottom)), tf.stack((pad_left, pad_right)), tf.zeros(shape=(2,), dtype=pad_top.dtype), ) ) segmentation_masks = tf.pad( segmentation_masks, paddings=paddings, constant_values=0 ) # center crop offset_height = (ori_height + pad_top + pad_bottom - self.height) // 2 offset_width = (ori_width + pad_left + pad_right - self.width) // 2 segmentation_masks = tf.image.crop_to_bounding_box( segmentation_masks, offset_height, offset_width, self.height, self.width, ) segmentation_masks = tf.ensure_shape( segmentation_masks, shape=(None, self.height, self.width, None) ) return tf.cast(segmentation_masks, dtype=self.compute_dtype) def get_config(self): config = super().get_config() config.update( { "height": self.height, "width": self.width, "padding_value": self.padding_value, "bounding_box_format": self.bounding_box_format, "bounding_box_min_area_ratio": self.bounding_box_min_area_ratio, # noqa: E501 "bounding_box_max_aspect_ratio": self.bounding_box_max_aspect_ratio, # noqa: E501 "seed": self.seed, } ) return config