Source code for keras_aug.layers.base.vectorized_base_random_layer

import tensorflow as tf
from keras_cv import bounding_box
from keras_cv.utils import preprocessing
from tensorflow import keras

from keras_aug.layers.__internal__.base_layer import BaseRandomLayer

H_AXIS = -3
W_AXIS = -2

IMAGES = "images"
LABELS = "labels"
TARGETS = "targets"
BOUNDING_BOXES = "bounding_boxes"
KEYPOINTS = "keypoints"
SEGMENTATION_MASKS = "segmentation_masks"
CUSTOM_ANNOTATIONS = "custom_annotations"

IS_DICT = "is_dict"
BATCHED = "batched"
USE_TARGETS = "use_targets"


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class VectorizedBaseRandomLayer(BaseRandomLayer): """Abstract base layer for vectorized image augmentation. This layer contains base functionalities for preprocessing layers which augment image related data, e.g. image and in the future, label and bounding boxes. The subclasses could avoid making certain mistakes and reduce code duplications. This layer requires you to implement one method ``augment_images()``, which augments one single image during the training. There are a few additional methods that you can implement for added functionality on the layer. ``augment_ragged_image()`` and ``compute_ragged_image_signature()``, which handles ragged images augmentation if the layer supports that. ``augment_labels()``, which handles label augmentation if the layer supports that. ``augment_bounding_boxes()``, which handles the bounding box augmentation, if the layer supports that. ``augment_keypoints()``, which handles the keypoints augmentation, if the layer supports that. ``augment_segmentation_masks()``, which handles the segmentation masks augmentation, if the layer supports that. ``augment_custom_annotations()``, which handles the custom annotations augmentation, if the layer supports that. This is useful to implement augmentation for special annotatinos. ``get_random_transformations()``, which should produce a batch of random transformation settings. The transformation object, which must be a batched Tensor or a dictionary where each input is a batched Tensor, will be passed to ``augment_images``, ``augment_labels`` and `augment_bounding_boxes`, to coordinate the randomness behavior, eg, in the RandomFlip layer, the image and bounding_boxes should be changed in the same way. The ``call()`` method support two formats of inputs:: 1. Single image tensor with 3D (HWC) or 4D (NHWC) format. 2. A dict of tensors with stable keys. The supported keys are ``"images"``, ``"labels"``, ``"bounding_boxes"``, ``segmentation_masks``, ``keypoints`` and ``custom_annotations`` at the moment. We might add more keys in future when we support more types of augmentation. The output of the ``call()`` will be in two formats, which will be the same structure as the inputs. The ``call()`` will unpack the inputs, forward to the correct function, and pack the output back to the same structure as the inputs. By default, the dense or ragged status of the output will be preserved. However, you can override this behavior by setting ``self.force_output_dense_images = True`` in your ``__init__()`` method. When enabled, images and segmentation masks will be converted to dense tensor by ``to_tensor()`` if ragged. .. code-block:: python class SubclassLayer(VectorizedBaseImageAugmentationLayer): def __init__(self): super().__init__() self.force_output_dense_images = True Note that since the randomness is also a common functionality, this layer also includes a keras.backend.RandomGenerator, which can be used to produce the random numbers. The random number generator is stored in the `self._random_generator` attribute. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ def __init__( self, seed=None, force_generator=True, rng_type="stateless", **kwargs ): # Defaults to stateless random operations super().__init__( seed=seed, force_generator=force_generator, rng_type=rng_type, **kwargs, ) @property def force_no_unwrap_ragged_image_call(self): """Control whether to force not to unwrap ragged image call.""" return getattr(self, "_force_no_unwrap_ragged_image_call", False) @force_no_unwrap_ragged_image_call.setter def force_no_unwrap_ragged_image_call( self, force_no_unwrap_ragged_image_call ): self._force_no_unwrap_ragged_image_call = ( force_no_unwrap_ragged_image_call ) @property def force_output_dense_images(self): """Control whether to force outputting of dense images.""" return getattr(self, "_force_output_dense_images", False) @force_output_dense_images.setter def force_output_dense_images(self, force_output_dense_images): self._force_output_dense_images = force_output_dense_images def get_random_transformation_batch( self, batch_size, images=None, labels=None, bounding_boxes=None, keypoints=None, segmentation_masks=None, custom_annotations=None, ): """Produce random transformations config for a batch of inputs. This is used to produce same randomness between image / label / bounding_box. Args: batch_size: the batch size of transformations configuration to sample. images: 3D image tensor from inputs. labels: optional 1D label tensor from inputs. bounding_boxes: optional 2D bounding boxes tensor from inputs. segmentation_masks: optional 3D segmentation mask tensor from inputs. Returns: Any type of object, which will be forwarded to `augment_images`, `augment_labels` and `augment_bounding_boxes` as the `transformations` parameter. """ # Required to work with map_fn in the ragged cast. return tf.zeros((batch_size)) def compute_ragged_image_signature(self, images): """Computes the output image signature for the `_unwrap_ragged_image_call()` function. Must be overridden to return tensors with different shapes than the input images. By default, returns either a `tf.RaggedTensorSpec` matching the input image spec, or a `tf.TensorSpec` matching the input image spec. """ return tf.RaggedTensorSpec( shape=images.shape[1:], ragged_rank=1, dtype=self.compute_dtype, ) def compute_ragged_segmentation_mask_signature(self, segmentation_maks): """Computes the output segmentation_mask signature for the `_unwrap_ragged_segmentation_call()` function. Must be overridden to return tensors with different shapes than the input images. By default, returns either a `tf.RaggedTensorSpec` matching the input image spec, or a `tf.TensorSpec` matching the input image spec. """ return tf.RaggedTensorSpec( shape=segmentation_maks.shape[1:], ragged_rank=1, dtype=self.compute_dtype, ) def augment_ragged_image(self, image, transformation, **kwargs): """Augment an image from a ragged image batch. This method accepts a single Dense image Tensor, and returns a Dense image. The resulting images are then stacked back into a ragged image batch. The behavior of this method should be identical to that of `augment_images()` but is to operate on a batch-wise basis. Args: image: a single image from the batch transformation: a single transformation sampled from `get_random_transformations()`. kwargs: all the other call arguments (i.e. bounding_boxes, labels, etc.). Returns: Augmented image. """ raise NotImplementedError( "A ragged image batch was passed to layer of type " f"`{type(self).__name__}`. This layer does not implement " "`augment_ragged_image()`. If this is a `keras_aug`, open a GitHub " "issue requesting Ragged functionality on the layer titled: " f"'`{type(self).__name__}`: ragged image support'. If this is a " "custom layer, implement the `augment_ragged_image()` method." ) def augment_images(self, images, transformations, **kwargs): """Augment a batch of images. Args: images: 4D image input tensor to the layer. Forwarded from `layer.call()`. This should generally have the shape [B, H, W, C]. Forwarded from `layer.call()`. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 4D tensor, which will be forward to `layer.call()`. """ raise NotImplementedError() def augment_labels(self, labels, transformations, **kwargs): """Augment a batch of labels. Args: labels: 2D label to the layer. Forwarded from `layer.call()`. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 2D tensor, which will be forward to `layer.call()`. """ raise NotImplementedError() def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): """Augment bounding boxes for one image. Args: bounding_boxes: 3D bounding boxes to the layer. Forwarded from `call()`. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 3D tensor, which will be forward to `layer.call()`. """ raise NotImplementedError() def augment_keypoints(self, keypoints, transformations, **kwargs): """Augment a batch of keypoints for one image. Args: keypoints: 3D keypoints input tensor to the layer. Forwarded from `layer.call()`. Shape should be [batch, num_keypoints, 2] in the specified keypoint format. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 3D tensor, which will be forward to `layer.call()`. """ raise NotImplementedError() def augment_ragged_segmentation_mask( self, segmentation_mask, transformation, **kwargs ): """Augment an image from a ragged segmentation mask batch. This method accepts a single Dense image Tensor, and returns a Dense image. The resulting images are then stacked back into a ragged image batch. The behavior of this method should be identical to that of `augment_segmentation_masks()` but is to operate on a batch-wise basis. Args: segmentation_mask: a single image from the batch transformation: a single transformation sampled from `get_random_transformations()`. kwargs: all the other call arguments (i.e. bounding_boxes, labels, etc.). Returns: Augmented segmentation mask. """ return segmentation_mask def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): """Augment a batch of images' segmentation masks. Args: segmentation_masks: 4D segmentation mask input tensor to the layer. This should generally have the shape [B, H, W, 1], or in some cases [B, H, W, C] for multilabeled data. Forwarded from `layer.call()`. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 4D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`. """ raise NotImplementedError() def augment_custom_annotations( self, custom_annotations, transformations, **kwargs ): """Augment a batch of images' custom_annotations. Args: custom_annotations: 4D custom annotations input tensor to the layer. transformations: The transformations object produced by `get_random_transformations`. Used to coordinate the randomness between image, label, bounding box, keypoints, and segmentation mask. Returns: output 4D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`. """ raise NotImplementedError() def _unwrap_ragged_image_call(self, inputs): images = inputs.get(IMAGES, None) labels = inputs.get(LABELS, None) bounding_boxes = inputs.get(BOUNDING_BOXES, None) keypoints = inputs.get(KEYPOINTS, None) segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) transformations = inputs.get("transformations") images = images.to_tensor() images = self.augment_ragged_image( image=images, label=labels, bounding_boxes=bounding_boxes, keypoints=keypoints, segmentation_mask=segmentation_masks, transformation=transformations, ) return tf.RaggedTensor.from_tensor(images) def _unwrap_ragged_segmentation_mask_call(self, inputs): segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) transformations = inputs.get("transformations") labels = inputs.get(LABELS, None) bounding_boxes = inputs.get(BOUNDING_BOXES, None) images = inputs.get(IMAGES, None) raw_images = inputs.get("raw_images", None) segmentation_masks = segmentation_masks.to_tensor() segmentation_masks = self.augment_ragged_segmentation_mask( segmentation_mask=segmentation_masks, transformation=transformations, label=labels, bounding_boxes=bounding_boxes, image=images, raw_image=raw_images, ) return tf.RaggedTensor.from_tensor(segmentation_masks) def _batch_augment(self, inputs): images = inputs.get(IMAGES, None) raw_images = images labels = inputs.get(LABELS, None) bounding_boxes = inputs.get(BOUNDING_BOXES, None) keypoints = inputs.get(KEYPOINTS, None) segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) custom_annotations = inputs.get(CUSTOM_ANNOTATIONS, None) batch_size = tf.shape(images)[0] transformations = self.get_random_transformation_batch( batch_size, images=images, labels=labels, bounding_boxes=bounding_boxes, keypoints=keypoints, segmentation_masks=segmentation_masks, custom_annotations=custom_annotations, ) if ( isinstance(images, tf.RaggedTensor) and not self.force_no_unwrap_ragged_image_call ): inputs_for_raggeds = {"transformations": transformations, **inputs} images = tf.map_fn( self._unwrap_ragged_image_call, inputs_for_raggeds, fn_output_signature=self.compute_ragged_image_signature(images), ) else: images = self.augment_images( images, transformations=transformations, bounding_boxes=bounding_boxes, labels=labels, ) if ( isinstance(images, tf.RaggedTensor) and self.force_output_dense_images ): images = images.to_tensor() result = {IMAGES: images} if labels is not None: labels = self.augment_labels( labels, transformations=transformations, bounding_boxes=bounding_boxes, images=images, raw_images=raw_images, ) result[LABELS] = labels if bounding_boxes is not None: ori_bounding_boxes_info = bounding_box.validate_format( bounding_boxes ) bounding_boxes = self.augment_bounding_boxes( bounding_boxes, transformations=transformations, labels=labels, images=images, raw_images=raw_images, ) if ori_bounding_boxes_info["ragged"]: bounding_boxes = bounding_box.to_ragged( bounding_boxes, dtype=self.compute_dtype ) result[BOUNDING_BOXES] = bounding_boxes if keypoints is not None: keypoints = self.augment_keypoints( keypoints, transformations=transformations, labels=labels, bounding_boxes=bounding_boxes, images=images, raw_images=raw_images, ) result[KEYPOINTS] = keypoints if segmentation_masks is not None: if ( isinstance(segmentation_masks, tf.RaggedTensor) and not self.force_no_unwrap_ragged_image_call ): inputs_for_raggeds = { "transformations": transformations, "raw_images": raw_images, **inputs, } segmentation_masks = tf.map_fn( self._unwrap_ragged_segmentation_mask_call, inputs_for_raggeds, fn_output_signature=self.compute_ragged_segmentation_mask_signature( # noqa: E501 segmentation_masks ), ) else: segmentation_masks = self.augment_segmentation_masks( segmentation_masks, transformations=transformations, labels=labels, bounding_boxes=bounding_boxes, images=images, raw_images=raw_images, ) if ( isinstance(segmentation_masks, tf.RaggedTensor) and self.force_output_dense_images ): segmentation_masks = segmentation_masks.to_tensor() result[SEGMENTATION_MASKS] = segmentation_masks if custom_annotations is not None: custom_annotations = self.augment_custom_annotations( custom_annotations, transformations=transformations, labels=labels, bounding_boxes=bounding_boxes, images=images, raw_images=raw_images, ) result[CUSTOM_ANNOTATIONS] = custom_annotations # preserve any additional inputs unmodified by this layer. for key in inputs.keys() - result.keys(): result[key] = inputs[key] return result def _ensure_inputs_are_compute_dtype(self, inputs): if not isinstance(inputs, dict): return preprocessing.ensure_tensor( inputs, self.compute_dtype, ) inputs[IMAGES] = preprocessing.ensure_tensor( inputs[IMAGES], self.compute_dtype, ) if BOUNDING_BOXES in inputs: inputs[BOUNDING_BOXES]["boxes"] = preprocessing.ensure_tensor( inputs[BOUNDING_BOXES]["boxes"], self.compute_dtype, ) inputs[BOUNDING_BOXES]["classes"] = preprocessing.ensure_tensor( inputs[BOUNDING_BOXES]["classes"], self.compute_dtype, ) return inputs def _format_inputs(self, inputs): metadata = {IS_DICT: True, USE_TARGETS: False} if tf.is_tensor(inputs): metadata[IS_DICT] = False inputs = {IMAGES: inputs} metadata[BATCHED] = inputs["images"].shape.rank == 4 if inputs["images"].shape.rank == 3: for key in list(inputs.keys()): if key == BOUNDING_BOXES: inputs[BOUNDING_BOXES]["boxes"] = tf.expand_dims( inputs[BOUNDING_BOXES]["boxes"], axis=0 ) inputs[BOUNDING_BOXES]["classes"] = tf.expand_dims( inputs[BOUNDING_BOXES]["classes"], axis=0 ) else: inputs[key] = tf.expand_dims(inputs[key], axis=0) if not isinstance(inputs, dict): raise ValueError( "Expect the inputs to be image tensor or dict. Got " f"inputs={inputs}" ) if BOUNDING_BOXES in inputs: inputs[BOUNDING_BOXES] = self._format_bounding_boxes( inputs[BOUNDING_BOXES] ) if isinstance(inputs, dict) and TARGETS in inputs: inputs[LABELS] = inputs[TARGETS] del inputs[TARGETS] metadata[USE_TARGETS] = True return inputs, metadata return inputs, metadata def _format_bounding_boxes(self, bounding_boxes): # We can't catch the case where this is None, sometimes RaggedTensor # drops this dimension. if "classes" not in bounding_boxes: raise ValueError( "Bounding boxes are missing class_id. If you would like to pad " "the bounding boxes with class_id, use: " "`bounding_boxes['classes'] = " "tf.ones_like(bounding_boxes['boxes'])`." ) return bounding_boxes def _format_output(self, output, metadata): if not metadata[BATCHED]: for key in list(output.keys()): if key == BOUNDING_BOXES: output[BOUNDING_BOXES]["boxes"] = tf.squeeze( output[BOUNDING_BOXES]["boxes"], axis=0 ) output[BOUNDING_BOXES]["classes"] = tf.squeeze( output[BOUNDING_BOXES]["classes"], axis=0 ) else: output[key] = tf.squeeze(output[key], axis=0) if not metadata[IS_DICT]: return output[IMAGES] elif metadata[USE_TARGETS]: output[TARGETS] = output[LABELS] del output[LABELS] return output def call(self, inputs): inputs = self._ensure_inputs_are_compute_dtype(inputs) inputs, metadata = self._format_inputs(inputs) images = inputs[IMAGES] if images.shape.rank == 3 or images.shape.rank == 4: return self._format_output(self._batch_augment(inputs), metadata) else: raise ValueError( "Image augmentation layers are expecting inputs to be " "rank 3 (HWC) or 4D (NHWC) tensors. Got shape: " f"{images.shape}" ) @classmethod def from_config(cls, config): return cls(**config)