Source code for keras_aug.layers.augmentation.utility.random_choice

from functools import partial

import tensorflow as tf
from keras_cv import bounding_box
from tensorflow import keras

from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils
from keras_aug.utils.augmentation import BOUNDING_BOXES
from keras_aug.utils.augmentation import CUSTOM_ANNOTATIONS
from keras_aug.utils.augmentation import IMAGES
from keras_aug.utils.augmentation import KEYPOINTS
from keras_aug.utils.augmentation import LABELS
from keras_aug.utils.augmentation import SEGMENTATION_MASKS


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class RandomChoice(VectorizedBaseRandomLayer): """RandomChoice constructs a pipeline based on provided arguments. The implemented policy does the following: for each input provided in `call`(), the policy selects a random layer from the provided list of `layers`. It then calls the `layer()` on the inputs. Notes: The shape and type of the outputs must be the same of the inputs. Args: layers (list(VectorizedBaseRandomLayer|keras.Layer|keras.Sequential)): The list of the layers that will be picked randomly for the pipeline. batchwise (bool, optional): Whether to pass entire batches to the underlying layer. When set to ``True``, each batch is passed to a single layer, instead of each sample to an independent layer. This is useful when using ``MixUp()``, ``CutMix()``, ``Mosaic()``, etc. Defaults to ``False``. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__(self, layers, batchwise=False, seed=None, **kwargs): super().__init__(seed=seed, **kwargs) self.layers = layers self.batchwise = batchwise self.seed = seed def get_random_transformation_batch(self, batch_size, **kwargs): selected_op_idx = self._random_generator.random_uniform( shape=(batch_size,), minval=0, maxval=len(self.layers), dtype=tf.int32, ) return selected_op_idx def compute_inputs_signature(self, inputs): fn_output_signature = {} if IMAGES in inputs: if isinstance(inputs[IMAGES], tf.Tensor): fn_output_signature[IMAGES] = tf.TensorSpec( inputs[IMAGES].shape[1:], self.compute_dtype ) else: fn_output_signature[IMAGES] = tf.RaggedTensorSpec( shape=inputs[IMAGES].shape[1:], ragged_rank=1, dtype=self.compute_dtype, ) if LABELS in inputs: fn_output_signature[LABELS] = tf.TensorSpec( inputs[LABELS].shape[1:], self.compute_dtype ) if BOUNDING_BOXES in inputs: fn_output_signature[BOUNDING_BOXES] = { "boxes": tf.RaggedTensorSpec( shape=[None, 4], ragged_rank=1, dtype=self.compute_dtype, ), "classes": tf.RaggedTensorSpec( shape=[None], ragged_rank=0, dtype=self.compute_dtype ), } if SEGMENTATION_MASKS in inputs: if isinstance(inputs[SEGMENTATION_MASKS], tf.Tensor): fn_output_signature[SEGMENTATION_MASKS] = tf.TensorSpec( inputs[SEGMENTATION_MASKS].shape[1:], self.compute_dtype ) else: fn_output_signature[SEGMENTATION_MASKS] = tf.RaggedTensorSpec( shape=inputs[SEGMENTATION_MASKS].shape[1:], ragged_rank=1, dtype=self.compute_dtype, ) if KEYPOINTS in inputs: if isinstance(inputs[KEYPOINTS], tf.Tensor): fn_output_signature[KEYPOINTS] = tf.TensorSpec( inputs[KEYPOINTS].shape[1:], self.compute_dtype ) else: fn_output_signature[KEYPOINTS] = tf.RaggedTensorSpec( shape=inputs[KEYPOINTS].shape[1:], ragged_rank=1, dtype=self.compute_dtype, ) if CUSTOM_ANNOTATIONS in inputs: raise NotImplementedError() return fn_output_signature def _batch_augment(self, inputs): images = inputs.get(IMAGES, None) batch_size = tf.shape(images)[0] transformations = self.get_random_transformation_batch(batch_size) if self.batchwise: selected_op_idx = transformations[0] result = self.augment( {"inputs": inputs, "transformations": selected_op_idx} ) else: bounding_boxes = inputs.get(BOUNDING_BOXES, None) # make bounding_boxes to dense first if bounding_boxes is not None: ori_bbox_info = bounding_box.validate_format(bounding_boxes) inputs[BOUNDING_BOXES] = bounding_box.to_dense(bounding_boxes) inputs_for_random_choice_single_input = { "inputs": inputs, "transformations": transformations, } result = tf.map_fn( self.augment, inputs_for_random_choice_single_input, fn_output_signature=self.compute_inputs_signature(inputs), ) bounding_boxes = result.get(BOUNDING_BOXES, None) if bounding_boxes is not None: if ori_bbox_info["ragged"]: bounding_boxes = bounding_box.to_ragged(bounding_boxes) else: bounding_boxes = bounding_box.to_dense(bounding_boxes) result[BOUNDING_BOXES] = bounding_boxes return result def augment(self, inputs): input = inputs.get("inputs") selected_op_idx = inputs.get("transformations") # construct branch_fns branch_fns = {} for idx, layer in enumerate(self.layers): branch_fns[idx] = partial(layer, input) # augment result = tf.switch_case(selected_op_idx, branch_fns=branch_fns) if BOUNDING_BOXES in result: result[BOUNDING_BOXES] = bounding_box.to_ragged( result[BOUNDING_BOXES], dtype=self.compute_dtype ) result = augmentation_utils.cast_to(result, dtype=self.compute_dtype) return result def get_config(self): config = super().get_config() config.update( { "layers": self.layers, "batchwise": self.batchwise, "seed": self.seed, } ) return config