Source code for keras_aug.layers.augmentation.auto.trivial_augment_wide

from functools import partial

import tensorflow as tf
from keras_cv import bounding_box
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras

from keras_aug import layers
from keras_aug.core import UniformFactorSampler
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)
from keras_aug.utils import augmentation as augmentation_utils
from keras_aug.utils.augmentation import BOUNDING_BOXES
from keras_aug.utils.augmentation import IMAGES


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class TrivialAugmentWide(VectorizedBaseRandomLayer): """TrivialAugmentWide performs the Wide version of Trivial Augment operation on input images. TrivialAugmentWide can be thought of as an all-in-one image augmentation layer. The policy implemented by TrivialAugmentWide has been benchmarked extensively and is effective on a wide variety of datasets. The input images will be converted to the range [0, 255], performed TrivialAugment and then converted back to the original value range. For object detection tasks, you should set ``fill_mode="constant"`` and ``fill_value=128`` to avoid artifacts. Moreover, you can set ``use_geometry=False`` to turn off all geometric augmentations if the distortion of the bounding boxes is too large. Args: value_range (Sequence[int|float]): The range of values the incoming images will have. This is typically either ``[0, 1]`` or ``[0, 255]`` depending on how your preprocessing pipeline is set up. use_geometry (bool, optional): whether to include geometric augmentations. This should be set to ``False`` when performing object detection. Defaults to ``True``. interpolation (str, optional): The interpolation mode. Supported values: ``"nearest", "bilinear"``. Defaults to `"nearest"`. fill_mode (str, optional): The fill mode. Supported values: ``"constant", "reflect", "wrap", "nearest"``. Defaults to ``"reflect"``. fill_value (int|float, optional): The value to be filled outside the boundaries when ``fill_mode="constant"``. Defaults to ``0``. exclude_ops (list(str), optional): Exclude selected operations. Defaults to ``None``. bounding_box_format (str, optional): The format of bounding boxes of input dataset. Refer https://github.com/keras-team/keras-cv/blob/master/keras_cv/bounding_box/converters.py for more details on supported bounding box formats. seed (int|float, optional): The random seed. Defaults to ``None``. References: - `TrivialAugment <https://arxiv.org/abs/2103.10158>`_ - `TrivialAugment Official Repo <https://github.com/automl/trivialaugment>`_ - `torchvision <https://github.com/pytorch/vision>`_ """ # noqa: E501 def __init__( self, value_range, use_geometry=True, interpolation="nearest", fill_mode="reflect", fill_value=0, exclude_ops=None, bounding_box_format=None, seed=None, **kwargs, ): super().__init__(seed=seed, **kwargs) self.value_range = value_range self.use_geometry = use_geometry self.interpolation = interpolation self.fill_mode = fill_mode self.fill_value = fill_value self.exclude_ops = exclude_ops self.bounding_box_format = bounding_box_format self.seed = seed self.aug_layers = self.get_standard_policy( use_geometry, exclude_ops, bounding_box_format, seed, **kwargs, ) self.num_layers = len(self.aug_layers) def get_standard_policy( self, use_geometry, exclude_ops, bounding_box_format=None, seed=None, **kwargs, ): policy = create_trivial_augment_policy(seed) aug_layers = [] if exclude_ops is not None: for op in exclude_ops: policy.pop(op) for key in policy.keys(): if key == "identity": aug_layers.append(layers.Identity(**policy[key], **kwargs)) elif key == "auto_contrast": aug_layers.append( layers.AutoContrast( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "equalize": aug_layers.append( layers.Equalize( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "posterize": aug_layers.append( layers.RandomPosterize( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "solarize": aug_layers.append( layers.RandomSolarize( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "color": aug_layers.append( layers.RandomColorJitter( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "contrast": aug_layers.append( layers.RandomColorJitter( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "brightness": aug_layers.append( layers.RandomColorJitter( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "sharpness": aug_layers.append( layers.RandomSharpness( **policy[key], value_range=(0, 255), seed=seed, **kwargs, ) ) elif key == "rotate": if use_geometry: aug_layers.append( layers.RandomAffine( **policy[key], interpolation=self.interpolation, fill_mode=self.fill_mode, fill_value=self.fill_value, bounding_box_format=bounding_box_format, seed=seed, **kwargs, ) ) elif key == "shear_x": if use_geometry: aug_layers.append( layers.RandomAffine( **policy[key], interpolation=self.interpolation, fill_mode=self.fill_mode, fill_value=self.fill_value, bounding_box_format=bounding_box_format, seed=seed, **kwargs, ) ) elif key == "shear_y": if use_geometry: aug_layers.append( layers.RandomAffine( **policy[key], interpolation=self.interpolation, fill_mode=self.fill_mode, fill_value=self.fill_value, bounding_box_format=bounding_box_format, seed=seed, **kwargs, ) ) elif key == "translate_x": if use_geometry: aug_layers.append( layers.RandomAffine( **policy[key], interpolation=self.interpolation, fill_mode=self.fill_mode, fill_value=self.fill_value, bounding_box_format=bounding_box_format, seed=seed, **kwargs, ) ) elif key == "translate_y": if use_geometry: aug_layers.append( layers.RandomAffine( **policy[key], interpolation=self.interpolation, fill_mode=self.fill_mode, fill_value=self.fill_value, bounding_box_format=bounding_box_format, seed=seed, **kwargs, ) ) else: raise ValueError(f"Not recognized policy key: {key}") return aug_layers def get_random_transformation_batch(self, batch_size): random_indices = self._random_generator.random_uniform( shape=(batch_size,), minval=0, maxval=self.num_layers, dtype=tf.int32, ) return random_indices def _batch_augment(self, inputs): images = inputs.get(augmentation_utils.IMAGES, None) batch_size = tf.shape(images)[0] transformations = self.get_random_transformation_batch(batch_size) # images value_range transform to [0, 255] images = preprocessing_utils.transform_value_range( images, self.value_range, (0, 255), dtype=self.compute_dtype ) inputs[IMAGES] = images bounding_boxes = inputs.get(BOUNDING_BOXES, None) # make bounding_boxes to dense first if bounding_boxes is not None: ori_bbox_info = bounding_box.validate_format(bounding_boxes) inputs[BOUNDING_BOXES] = bounding_box.to_dense(bounding_boxes) inputs_for_trivial_augment_single_input = { "inputs": inputs, "transformations": transformations, } result = tf.map_fn( self.trivial_augment_single_input, inputs_for_trivial_augment_single_input, fn_output_signature=augmentation_utils.compute_signature( inputs, self.compute_dtype ), ) bounding_boxes = result.get(BOUNDING_BOXES, None) if bounding_boxes is not None: if ori_bbox_info["ragged"]: bounding_boxes = bounding_box.to_ragged(bounding_boxes) else: bounding_boxes = bounding_box.to_dense(bounding_boxes) result[BOUNDING_BOXES] = bounding_boxes # recover value_range images = result.get(IMAGES, None) images = preprocessing_utils.transform_value_range( images, (0, 255), self.value_range, self.compute_dtype ) result[IMAGES] = images return result def trivial_augment_single_input(self, inputs): input = inputs.get("inputs") random_indice = inputs.get("transformations") # construct branch_fns branch_fns = {} for idx, layer in enumerate(self.aug_layers): branch_fns[idx] = partial(layer, input) # TODO: # figure out why tf will make tf.float32 instead of tf.float16 # keras.mixed_precision.set_global_policy("mixed_float16") result = tf.switch_case(random_indice, branch_fns=branch_fns) if BOUNDING_BOXES in result: result[BOUNDING_BOXES] = bounding_box.to_ragged( result[BOUNDING_BOXES], dtype=self.compute_dtype ) return result def get_config(self): config = super().get_config() config.update( { "value_range": self.value_range, "use_geometry": self.use_geometry, "interpolation": self.interpolation, "fill_mode": self.fill_mode, "fill_value": self.fill_value, "exclude_ops": self.exclude_ops, "bounding_box_format": self.bounding_box_format, "seed": self.seed, } ) return config
def create_trivial_augment_policy(seed): """Create TrivialAugment Policy. References: https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_auto_augment.py """ # noqa: E501 policy = {} policy["identity"] = {} policy["auto_contrast"] = {} policy["equalize"] = {} policy["rotate"] = { "rotation_factor": UniformFactorSampler( lower=-135, upper=135, seed=seed ), } policy["posterize"] = { "factor": UniformFactorSampler(lower=2, upper=8 + 1, seed=seed) } policy["solarize"] = { "threshold_factor": UniformFactorSampler(lower=0, upper=255, seed=seed), "addition_factor": 0, } policy["color"] = { "saturation_factor": UniformFactorSampler( lower=1, upper=1.99, seed=seed ), } policy["contrast"] = { "contrast_factor": UniformFactorSampler(lower=1, upper=1.99, seed=seed), } policy["brightness"] = { "brightness_factor": UniformFactorSampler( lower=1, upper=1.99, seed=seed ), } policy["sharpness"] = { "factor": UniformFactorSampler(lower=1, upper=1.99, seed=seed), } policy["shear_x"] = { "shear_height_factor": 0, "shear_width_factor": UniformFactorSampler( lower=-0.99, upper=0.99, seed=seed ), } policy["shear_y"] = { "shear_height_factor": UniformFactorSampler( lower=-0.99, upper=0.99, seed=seed ), "shear_width_factor": 0, } policy["translate_x"] = { "translation_height_factor": 0, "translation_width_factor": UniformFactorSampler( lower=-1, upper=1, seed=seed ), } policy["translate_y"] = { "translation_height_factor": UniformFactorSampler( lower=-1, upper=1, seed=seed ), "translation_width_factor": 0, } return policy