Source code for keras_aug.layers.augmentation.geometry.random_affine

import math

import tensorflow as tf
from keras_cv import bounding_box
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras

from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils
from keras_aug.utils import bounding_box as bounding_box_utils


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class RandomAffine(VectorizedBaseRandomLayer): """Randomly affines transformation of the images keeping center invariant. Randomly affines by rotation, translation, zoom and shear. RandomAffine processes the images by combined transformation matrix, so it is fast. Args: rotation_factor (float|Sequence[float]|keras_aug.FactorSampler): The range of the degree for random rotation. When represented as a single float, the factor will be picked between ``[0.0 - lower, 0.0 + upper]``. A positive value means rotating counter clock-wise, while a negative value means clock-wise. Defaults to ``None``. translation_height_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random vertical translation. When represented as a single float, the factor will be picked between ``[0.0 - lower, 0.0 + upper]``. A negative value means shifting image up, while a positive value means shifting image down. Defaults to ``None``. translation_width_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random horizontal translation. When represented as a single float, the factor will be picked between ``[0.0 - lower, 0.0 + upper]``. A negative value means shifting image left, while a positive value means shifting image right. Defaults to ``None``. zoom_height_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random vertical zoom. When represented as a single float, the factor will be picked between ``[1.0 - lower, 1.0 + upper]``. A negative value means zooming in while a positive value means zooming out. Defaults to ``None``. zoom_width_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random horizontal zoom. When represented as a single float, the factor will be picked between ``[1.0 - lower, 1.0 + upper]``. A negative value means zooming in while a positive value means zooming out. Defaults to ``None``. shear_height_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random vertical shear. When represented as a single float, the factor will be picked between ``[0.0 - lower, 0.0 + upper]``. Defaults to ``None``. shear_width_factor (float|Sequence[float]|keras_aug.FactorSampler): The range for random horizontal shear. When represented as a single float, the factor will be picked between ``[0.0 - lower, 0.0 + upper]``. Defaults to ``None``. same_zoom_factor (bool, optional): If True, the zoom factor sampled from ``zoom_height_factor`` will be applied to both height and width. It is useful to keep aspect ratio. Defaults to ``False``. interpolation (str, optional): The interpolation mode. Supported values: ``"nearest", "bilinear"``. Defaults to `"bilinear"`. fill_mode (str, optional): The fill mode. Supported values: ``"constant", "reflect", "wrap", "nearest"``. Defaults to ``"constant"``. fill_value (int|float, optional): The value to be filled outside the boundaries when ``fill_mode="constant"``. Defaults to ``0``. bounding_box_format (str, optional): The format of bounding boxes of input dataset. Refer https://github.com/keras-team/keras-cv/blob/master/keras_cv/bounding_box/converters.py for more details on supported bounding box formats. bounding_box_min_area_ratio (float, optional): The threshold to apply sanitize_bounding_boxes. Defaults to ``None``. bounding_box_max_aspect_ratio (float, optional): The threshold to apply sanitize_bounding_boxes. Defaults to ``None``. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__( self, rotation_factor=None, translation_height_factor=None, translation_width_factor=None, zoom_height_factor=None, zoom_width_factor=None, shear_height_factor=None, shear_width_factor=None, same_zoom_factor=False, interpolation="bilinear", fill_mode="constant", fill_value=0, bounding_box_format=None, bounding_box_min_area_ratio=None, bounding_box_max_aspect_ratio=None, seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) self.rotation_factor = None self.translation_height_factor = None self.translation_width_factor = None self.zoom_height_factor = None self.zoom_width_factor = None self.shear_height_factor = None self.shear_width_factor = None # rotation if rotation_factor is not None: self.rotation_factor = augmentation_utils.parse_factor( rotation_factor, min_value=-180, max_value=180, center_value=0, seed=seed, ) # translation if translation_height_factor is not None: self.translation_height_factor = augmentation_utils.parse_factor( translation_height_factor, min_value=-1, max_value=1, center_value=0.0, seed=seed, ) if translation_width_factor is not None: self.translation_width_factor = augmentation_utils.parse_factor( translation_width_factor, min_value=-1, max_value=1, center_value=0.0, seed=seed, ) # zoom if zoom_height_factor is not None: self.zoom_height_factor = augmentation_utils.parse_factor( zoom_height_factor, min_value=0, max_value=None, center_value=1.0, seed=seed, ) if zoom_width_factor is not None: self.zoom_width_factor = augmentation_utils.parse_factor( zoom_width_factor, min_value=0, max_value=None, center_value=1.0, seed=seed, ) # shear if shear_height_factor is not None: self.shear_height_factor = augmentation_utils.parse_factor( shear_height_factor, min_value=-1, max_value=1, center_value=0.0, seed=seed, ) if shear_width_factor is not None: self.shear_width_factor = augmentation_utils.parse_factor( shear_width_factor, min_value=-1, max_value=1, center_value=0.0, seed=seed, ) self.same_zoom_factor = same_zoom_factor preprocessing_utils.check_fill_mode_and_interpolation( fill_mode, interpolation ) self.interpolation = interpolation self.fill_mode = fill_mode self.fill_value = fill_value self.bounding_box_format = bounding_box_format self.bounding_box_min_area_ratio = bounding_box_min_area_ratio self.bounding_box_max_aspect_ratio = bounding_box_max_aspect_ratio self.seed = seed # decide whether to enable the augmentation self._enable_rotation = augmentation_utils.is_factor_working( self.rotation_factor, not_working_value=0.0 ) _enable_translation_height = augmentation_utils.is_factor_working( self.translation_height_factor, not_working_value=0.0 ) _enable_translation_width = augmentation_utils.is_factor_working( self.translation_width_factor, not_working_value=0.0 ) self._enable_translation = ( _enable_translation_height or _enable_translation_width ) _enable_zoom_height = augmentation_utils.is_factor_working( self.zoom_height_factor, not_working_value=0.0 ) _enable_zoom_width = augmentation_utils.is_factor_working( self.zoom_width_factor, not_working_value=0.0 ) self._enable_zoom = _enable_zoom_height or _enable_zoom_width _enable_shear_height = augmentation_utils.is_factor_working( self.shear_height_factor, not_working_value=0.0 ) _enable_shear_width = augmentation_utils.is_factor_working( self.shear_width_factor, not_working_value=0.0 ) self._enable_shear = _enable_shear_height or _enable_shear_width def get_random_transformation_batch( self, batch_size, images=None, **kwargs ): heights, widths = augmentation_utils.get_images_shape( images, dtype=self.compute_dtype ) factor_shape = (batch_size, 1) # dummy angles = tf.zeros(factor_shape, dtype=self.compute_dtype) translation_heights = tf.zeros(factor_shape, dtype=self.compute_dtype) translation_widths = tf.zeros(factor_shape, dtype=self.compute_dtype) zoom_heights = tf.zeros(factor_shape, dtype=self.compute_dtype) zoom_widths = tf.zeros(factor_shape, dtype=self.compute_dtype) shear_heights = tf.zeros(factor_shape, dtype=self.compute_dtype) shear_widths = tf.zeros(factor_shape, dtype=self.compute_dtype) if self._enable_rotation: angles = self.rotation_factor( factor_shape, dtype=self.compute_dtype ) if self._enable_translation: translation_heights = self.translation_height_factor( factor_shape, dtype=self.compute_dtype ) translation_widths = self.translation_width_factor( factor_shape, dtype=self.compute_dtype ) if self._enable_zoom: zoom_heights = self.zoom_height_factor( factor_shape, dtype=self.compute_dtype ) if self.same_zoom_factor: zoom_widths = zoom_heights else: zoom_widths = self.zoom_width_factor( factor_shape, dtype=self.compute_dtype ) if self._enable_shear: shear_heights = self.shear_height_factor( factor_shape, dtype=self.compute_dtype ) shear_widths = self.shear_width_factor( factor_shape, dtype=self.compute_dtype ) angles = angles / 360.0 * 2.0 * math.pi translations = tf.concat( [translation_widths, translation_heights], axis=1 ) zooms = tf.concat([zoom_widths, zoom_heights], axis=1) shears = tf.concat([shear_widths, shear_heights], axis=1) # start from identity matrixes: # [[1 0 0] # [0 1 0] # [0 0 1]] identity_matrixes = tf.concat( [ tf.ones((batch_size, 1), dtype=self.compute_dtype), tf.zeros((batch_size, 3), dtype=self.compute_dtype), tf.ones((batch_size, 1), dtype=self.compute_dtype), tf.zeros((batch_size, 3), dtype=self.compute_dtype), tf.ones((batch_size, 1), dtype=self.compute_dtype), ], axis=1, ) combined_matrixes = tf.reshape(identity_matrixes, (batch_size, 3, 3)) # process zoom if self._enable_zoom: zoom_matrixes = augmentation_utils.get_zoom_matrix( zooms, heights, widths, to_square=True ) combined_matrixes = zoom_matrixes @ combined_matrixes # process rotations if self._enable_rotation: rotation_matrixes = augmentation_utils.get_rotation_matrix( angles, heights, widths, to_square=True ) combined_matrixes = rotation_matrixes @ combined_matrixes # process shear if self._enable_shear: shear_matrixes = augmentation_utils.get_shear_matrix( shears, to_square=True ) combined_matrixes = shear_matrixes @ combined_matrixes # process translations if self._enable_translation: translation_matrixes = augmentation_utils.get_translation_matrix( translations, heights, widths, to_square=True ) combined_matrixes = translation_matrixes @ combined_matrixes return { "angles": angles, "translations": translations, "zooms": zooms, "shears": shears, "combined_matrixes": combined_matrixes, # (batch_size, 3, 3) } def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): original_shape = images.shape batch_size = tf.shape(images)[0] combined_matrixes = transformations["combined_matrixes"] combined_matrixes = tf.reshape( combined_matrixes, shape=(batch_size, -1) ) combined_matrixes = combined_matrixes[:, :-1] images = preprocessing_utils.transform( images, tf.cast(combined_matrixes, dtype=tf.float32), # must be float32 fill_mode=self.fill_mode, fill_value=self.fill_value, interpolation=self.interpolation, ) images = tf.ensure_shape(images, shape=original_shape) return images def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes( self, bounding_boxes, transformations, raw_images=None, **kwargs ): if self.bounding_box_format is None: raise ValueError( "`RandomAffine()` was called with bounding boxes," "but no `bounding_box_format` was specified in the constructor." "Please specify a bounding box format in the constructor. i.e." "`RandomAffine(bounding_box_format='xyxy')`" ) heights, widths = augmentation_utils.get_images_shape( raw_images, dtype=self.compute_dtype ) bounding_boxes = bounding_box.to_dense(bounding_boxes) bounding_boxes = bounding_box.convert_format( bounding_boxes, source=self.bounding_box_format, target="xyxy", images=raw_images, dtype=self.compute_dtype, ) boxes = bounding_boxes["boxes"] original_bounding_boxes = bounding_boxes.copy() # process rotations if self._enable_rotation: origin_x = widths / 2 origin_y = heights / 2 angles = -transformations["angles"] angles = angles[:, tf.newaxis, tf.newaxis] # points: (batch_size, max_num_boxes, 4, 2) points = tf.stack( [ tf.stack([boxes[..., 0], boxes[..., 1]], axis=2), tf.stack([boxes[..., 2], boxes[..., 1]], axis=2), tf.stack([boxes[..., 2], boxes[..., 3]], axis=2), tf.stack([boxes[..., 0], boxes[..., 3]], axis=2), ], axis=2, ) point_x_offsets = points[..., 0] - origin_x[..., tf.newaxis] point_y_offsets = points[..., 1] - origin_y[..., tf.newaxis] new_x = ( origin_x[..., tf.newaxis, tf.newaxis] + tf.multiply(tf.cos(angles), point_x_offsets[..., tf.newaxis]) - tf.multiply(tf.sin(angles), point_y_offsets[..., tf.newaxis]) ) new_y = ( origin_y[..., tf.newaxis, tf.newaxis] + tf.multiply(tf.sin(angles), point_x_offsets[..., tf.newaxis]) + tf.multiply(tf.cos(angles), point_y_offsets[..., tf.newaxis]) ) out = tf.concat([new_x, new_y], axis=3) min_cordinates = tf.math.reduce_min(out, axis=2) max_cordinates = tf.math.reduce_max(out, axis=2) boxes = tf.concat([min_cordinates, max_cordinates], axis=2) # process translations if self._enable_translation: translations = transformations["translations"] translation_widths = tf.expand_dims( translations[:, 0:1] * widths, axis=-1 ) translation_heights = tf.expand_dims( translations[:, 1:2] * heights, axis=-1 ) x1s, y1s, x2s, y2s = tf.split(boxes, 4, axis=-1) x1s = x1s + translation_widths y1s = y1s + translation_heights x2s = x2s + translation_widths y2s = y2s + translation_heights boxes = tf.concat([x1s, y1s, x2s, y2s], axis=-1) # process shear if self._enable_shear: shears = transformations["shears"] shear_widths = tf.expand_dims(shears[:, 0:1], axis=-1) shear_heights = tf.expand_dims(shears[:, 1:2], axis=-1) _x1s, _y1s, _x2s, _y2s = tf.split(boxes, 4, axis=-1) # x1, x2 x1_tops = _x1s - (shear_widths * _y1s) x1_bottoms = _x1s - (shear_widths * _y2s) x1s = tf.where(shear_widths < 0, x1_tops, x1_bottoms) x2_tops = _x2s - (shear_widths * _y1s) x2_bottoms = _x2s - (shear_widths * _y2s) x2s = tf.where(shear_widths < 0, x2_bottoms, x2_tops) # y1, y2 y1_lefts = _y1s - (shear_heights * _x1s) y1_rights = _y1s - (shear_heights * _x2s) y1s = tf.where(shear_heights > 0, y1_rights, y1_lefts) y2_lefts = _y2s - (shear_heights * _x1s) y2_rights = _y2s - (shear_heights * _x2s) y2s = tf.where(shear_heights > 0, y2_lefts, y2_rights) boxes = tf.concat([x1s, y1s, x2s, y2s], axis=-1) # process zoom if self._enable_zoom: zooms = transformations["zooms"] zoom_widths = tf.expand_dims(zooms[:, 0:1], axis=-1) zoom_heights = tf.expand_dims(zooms[:, 1:2], axis=-1) x1s, y1s, x2s, y2s = tf.split(boxes, 4, axis=-1) x_offsets = ((tf.expand_dims(widths, axis=-1) - 1.0) / 2.0) * ( 1.0 - zoom_widths ) y_offsets = ((tf.expand_dims(heights, axis=-1) - 1.0) / 2.0) * ( 1.0 - zoom_heights ) x1s = (x1s - x_offsets) / zoom_widths x2s = (x2s - x_offsets) / zoom_widths y1s = (y1s - y_offsets) / zoom_heights y2s = (y2s - y_offsets) / zoom_heights boxes = tf.concat([x1s, y1s, x2s, y2s], axis=-1) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = boxes bounding_boxes = bounding_box_utils.clip_to_image( bounding_boxes, bounding_box_format="xyxy", images=raw_images, ) bounding_boxes = bounding_box_utils.sanitize_bounding_boxes( bounding_boxes, min_area_ratio=self.bounding_box_min_area_ratio, max_aspect_ratio=self.bounding_box_max_aspect_ratio, bounding_box_format="xyxy", reference_bounding_boxes=original_bounding_boxes, images=raw_images, reference_images=raw_images, ) bounding_boxes = bounding_box.convert_format( bounding_boxes, source="xyxy", target=self.bounding_box_format, dtype=self.compute_dtype, images=raw_images, ) return bounding_boxes def augment_ragged_segmentation_mask( self, segmentation_mask, transformation, **kwargs ): segmentation_mask = tf.expand_dims(segmentation_mask, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) segmentation_mask = self.augment_segmentation_masks( segmentation_masks=segmentation_mask, transformations=transformation, **kwargs, ) return tf.squeeze(segmentation_mask, axis=0) def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): original_shape = segmentation_masks.shape batch_size = tf.shape(segmentation_masks)[0] combined_matrixes = transformations["combined_matrixes"] combined_matrixes = tf.reshape( combined_matrixes, shape=(batch_size, -1) ) combined_matrixes = combined_matrixes[:, :-1] segmentation_masks = preprocessing_utils.transform( segmentation_masks, tf.cast(combined_matrixes, dtype=tf.float32), # must be float32 fill_mode=self.fill_mode, fill_value=0, interpolation="nearest", ) segmentation_masks = tf.ensure_shape( segmentation_masks, shape=original_shape ) return segmentation_masks def get_config(self): config = super().get_config() config.update( { "rotation_factor": self.rotation_factor, "translation_height_factor": self.translation_height_factor, "translation_width_factor": self.translation_width_factor, "zoom_height_factor": self.zoom_height_factor, "zoom_width_factor": self.zoom_width_factor, "shear_height_factor": self.shear_height_factor, "shear_width_factor": self.shear_width_factor, "same_zoom_factor": self.same_zoom_factor, "fill_mode": self.fill_mode, "fill_value": self.fill_value, "interpolation": self.interpolation, "bounding_box_format": self.bounding_box_format, "bounding_box_min_area_ratio": self.bounding_box_min_area_ratio, # noqa: E501 "bounding_box_max_aspect_ratio": self.bounding_box_max_aspect_ratio, # noqa: E501 "seed": self.seed, } ) return config