Source code for keras_aug.layers.augmentation.mix.mix_up

import tensorflow as tf
from keras_cv import bounding_box
from tensorflow import keras

from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils.augmentation import BATCHED
from keras_aug.utils.distribution import stateless_random_beta


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class MixUp(VectorizedBaseRandomLayer): """The MixUp data augmentation technique. The MixUp data augmentation technique involves taking 2 images from a given batch and fusing them together using a ratio sampled from a beta distribution. Labels are applied by same ratio ratio. Bounding boxes are concated according to the position of the 2 images. Args: alpha (float, optional): The inverse scale parameter between 0 to +inf for the gamma distribution. This controls the shape of the distribution from which the smoothing values are sampled. Defaults to ``0.2``, which is a recommended value when training an ImageNet classification model. For object detection, it is recommended to use a larger value. For example YOLOV8 uses ``32.0``. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `MixUp <https://arxiv.org/abs/1710.09412>`_ - `Bag of Freebies for Training Object Detection Neural Networks <https://arxiv.org/abs/1902.04103>`_ - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ # noqa: E501 def __init__(self, alpha=0.2, seed=None, **kwargs): super().__init__(seed=seed, **kwargs) self.alpha = alpha self.seed = seed # set force_no_unwrap_ragged_image_call=True because MixUp needs # to process images in batch. self.force_no_unwrap_ragged_image_call = True def get_random_transformation_batch(self, batch_size, **kwargs): permutation_order = tf.argsort( self._random_generator.random_uniform((batch_size,)), axis=-1 ) lambda_samples = stateless_random_beta( (batch_size, 1, 1, 1), seed_alpha=self._random_generator.make_seed_for_stateless_op(), seed_beta=self._random_generator.make_seed_for_stateless_op(), alpha=self.alpha, beta=self.alpha, dtype=self.compute_dtype, ) return { "permutation_order": permutation_order, "lambda_samples": lambda_samples, } def augment_images(self, images, transformations, **kwargs): if isinstance(images, tf.RaggedTensor): raise ValueError( "MixUp expects dense images. Received: images type: " f"{type(images)}" ) permutation_order = transformations["permutation_order"] lambda_samples = transformations["lambda_samples"] mixup_images = tf.gather(images, permutation_order) mixup_images = tf.cast(mixup_images, dtype=self.compute_dtype) images = lambda_samples * images + (1.0 - lambda_samples) * mixup_images return images def augment_labels(self, labels, transformations, **kwargs): permutation_order = transformations["permutation_order"] lambda_samples = transformations["lambda_samples"] labels = tf.cast(labels, dtype=self.compute_dtype) mixup_labels = tf.gather(labels, permutation_order) lambda_samples = tf.reshape(lambda_samples, [-1, 1]) labels = lambda_samples * labels + (1.0 - lambda_samples) * mixup_labels return labels def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): permutation_order = transformations["permutation_order"] bounding_boxes = bounding_box.to_dense(bounding_boxes) boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"] mixup_boxes = tf.gather(boxes, permutation_order) mixup_classes = tf.gather(classes, permutation_order) boxes = tf.concat([boxes, mixup_boxes], axis=1) classes = tf.concat([classes, mixup_classes], axis=1) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = boxes bounding_boxes["classes"] = classes return bounding_boxes def _batch_augment(self, inputs): self._validate_inputs(inputs) return super()._batch_augment(inputs) def call(self, inputs): _, metadata = self._format_inputs(inputs) if metadata[BATCHED] is not True: raise ValueError( "MixUp received a single image to `call`. The " "layer relies on combining multiple examples, and as such " "will not behave as expected. Please call the layer with 2 " "or more samples." ) return super().call(inputs=inputs) def _validate_inputs(self, inputs): images = inputs.get("images", None) labels = inputs.get("labels", None) bounding_boxes = inputs.get("bounding_boxes", None) if images is None or (labels is None and bounding_boxes is None): raise ValueError( "MixUp expects inputs in a dictionary with format " '{"images": images, "labels": labels}. or' '{"images": images, "bounding_boxes": bounding_boxes}' f"Got: inputs = {inputs}." ) if bounding_boxes is not None: _ = bounding_box.validate_format(bounding_boxes) def sample_from_beta(self, alpha, beta, shape): sample_alpha = tf.random.stateless_gamma( shape, alpha=alpha, seed=self._random_generator.make_seed_for_stateless_op(), ) sample_beta = tf.random.stateless_gamma( shape, alpha=beta, seed=self._random_generator.make_seed_for_stateless_op(), ) return sample_alpha / (sample_alpha + sample_beta) def get_config(self): config = super().get_config() config.update({"alpha": self.alpha, "seed": self.seed}) return config