Source code for keras_aug.layers.augmentation.geometry.random_flip

import tensorflow as tf
from tensorflow import keras

from keras_aug.datapoints import bounding_box
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils

# Defining modes for random flipping
HORIZONTAL = "horizontal"
VERTICAL = "vertical"
HORIZONTAL_AND_VERTICAL = "horizontal_and_vertical"


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class RandomFlip(VectorizedBaseRandomLayer): """Randomly flips the input images. This layer will flip the images horizontally and or vertically based on the ``mode`` attribute. Args: mode (str, optional): The flip mode to use. Supported values: ``"horizontal", "vertical", "horizontal_and_vertical"``. Defaults to ``"horizontal"``. ``"horizontal"`` is a left-right flip and ``"vertical"`` is a top-bottom flip. rate (float, optional): The frequency of flipping. ``1.0`` indicates that images are always flipped. ``0.0`` indicates no flipping. Defaults to ``0.5``. bounding_box_format (str, optional): The format of bounding boxes of input dataset. Refer https://github.com/james77777778/keras-aug/blob/main/keras_aug/datapoints/bounding_box/converter.py for more details on supported bounding box formats. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__( self, mode=HORIZONTAL, rate=0.5, bounding_box_format=None, seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) if mode == HORIZONTAL: self.horizontal = True self.vertical = False elif mode == VERTICAL: self.horizontal = False self.vertical = True elif mode == HORIZONTAL_AND_VERTICAL: self.horizontal = True self.vertical = True else: raise ValueError( "RandomFlip layer {name} received an unknown mode=" "{arg}".format(name=self.name, arg=mode) ) self.mode = mode if rate < 0.0 or rate > 1.0: raise ValueError( f"`rate` should be inside of range [0, 1]. Got rate={rate}" ) self.rate = rate self.bounding_box_format = bounding_box_format self.seed = seed def get_random_transformation_batch(self, batch_size, **kwargs): flip_horizontals = tf.zeros(shape=(batch_size, 1)) flip_verticals = tf.zeros(shape=(batch_size, 1)) if self.horizontal: flip_horizontals = self._random_generator.random_uniform( shape=(batch_size, 1) ) if self.vertical: flip_verticals = self._random_generator.random_uniform( shape=(batch_size, 1) ) return { "flip_horizontals": flip_horizontals, "flip_verticals": flip_verticals, } def augment_ragged_image(self, image, transformation, **kwargs): image = tf.expand_dims(image, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) image = self.augment_images( images=image, transformations=transformation, **kwargs ) return tf.squeeze(image, axis=0) def augment_images(self, images, transformations, **kwargs): return self.flip_images(images, transformations) def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes( self, bounding_boxes, transformations=None, raw_images=None, **kwargs ): if self.bounding_box_format is None: raise ValueError( "`RandomFlip()` was called with bounding boxes," "but no `bounding_box_format` was specified in the constructor." "Please specify a bounding box format in the constructor. i.e." "`RandomFlip(bounding_box_format='xyxy')`" ) bounding_boxes = bounding_box.to_dense(bounding_boxes) bounding_boxes = bounding_box.convert_format( bounding_boxes, source=self.bounding_box_format, target="rel_xyxy", images=raw_images, ) boxes = bounding_boxes["boxes"] flip_horizontals = transformations["flip_horizontals"] flip_verticals = transformations["flip_verticals"] # broadcast flip_horizontals = flip_horizontals[:, tf.newaxis, :] flip_verticals = flip_verticals[:, tf.newaxis, :] if self.horizontal: boxes = tf.where( flip_horizontals > (1.0 - self.rate), self.flip_boxes_horizontal(boxes), boxes, ) if self.vertical: boxes = tf.where( flip_verticals > (1.0 - self.rate), self.flip_boxes_vertical(boxes), boxes, ) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = boxes bounding_boxes = bounding_box.clip_to_image( bounding_boxes, bounding_box_format="rel_xyxy", images=raw_images, ) bounding_boxes = bounding_box.convert_format( bounding_boxes, source="rel_xyxy", target=self.bounding_box_format, dtype=self.compute_dtype, images=raw_images, ) return bounding_boxes def augment_ragged_segmentation_mask( self, segmentation_mask, transformation, **kwargs ): segmentation_mask = tf.expand_dims(segmentation_mask, axis=0) transformation = augmentation_utils.expand_dict_dims( transformation, axis=0 ) segmentation_mask = self.augment_segmentation_masks( segmentation_masks=segmentation_mask, transformations=transformation, **kwargs, ) return tf.squeeze(segmentation_mask, axis=0) def augment_segmentation_masks( self, segmentation_masks, transformations=None, **kwargs ): return self.flip_images(segmentation_masks, transformations) def flip_images(self, images, transformations): original_shape = images.shape flip_horizontals = transformations["flip_horizontals"] flip_verticals = transformations["flip_verticals"] # broadcast flip_horizontals = flip_horizontals[:, tf.newaxis, tf.newaxis, :] flip_verticals = flip_verticals[:, tf.newaxis, tf.newaxis, :] if self.horizontal: images = tf.where( flip_horizontals > (1.0 - self.rate), tf.image.flip_left_right(images), images, ) if self.vertical: images = tf.where( flip_verticals > (1.0 - self.rate), tf.image.flip_up_down(images), images, ) images.set_shape(original_shape) return images def flip_boxes_horizontal(self, boxes): x1, x2, x3, x4 = tf.split(boxes, 4, axis=-1) outputs = tf.concat([1 - x3, x2, 1 - x1, x4], axis=-1) return outputs def flip_boxes_vertical(self, boxes): x1, x2, x3, x4 = tf.split(boxes, 4, axis=-1) outputs = tf.concat([x1, 1 - x4, x3, 1 - x2], axis=-1) return outputs def get_config(self): config = super().get_config() config.update( { "mode": self.mode, "rate": self.rate, "bounding_box_format": self.bounding_box_format, "seed": self.seed, } ) return config