Source code for keras_aug.utils.augmentation

import enum
from typing import Sequence

import tensorflow as tf
from tensorflow import keras

from keras_aug.core import ConstantFactorSampler
from keras_aug.core import FactorSampler
from keras_aug.core import NormalFactorSampler
from keras_aug.core import SignedNormalFactorSampler
from keras_aug.core import UniformFactorSampler

H_AXIS = -3
W_AXIS = -2

IMAGES = "images"
LABELS = "labels"
TARGETS = "targets"
BOUNDING_BOXES = "bounding_boxes"
KEYPOINTS = "keypoints"
SEGMENTATION_MASKS = "segmentation_masks"
CUSTOM_ANNOTATIONS = "custom_annotations"

BATCHED = "batched"


class PaddingPosition(enum.Enum):
    CENTER = "center"
    TOP_LEFT = "top_left"
    TOP_RIGHT = "top_right"
    BOTTOM_LEFT = "bottom_left"
    BOTTOM_RIGHT = "bottom_right"
    RANDOM = "random"


def get_padding_position(position):
    if isinstance(position, PaddingPosition):
        return position
    position = position.lower()
    if position not in PaddingPosition._value2member_map_.keys():
        raise NotImplementedError(
            f"Value not recognized for `position`: {position}. Supported "
            f"values are: {PaddingPosition._value2member_map_.keys()}"
        )
    return PaddingPosition._value2member_map_[position]


def get_position_params(
    tops, bottoms, lefts, rights, position, random_generator
):
    """This function supposes arguments are at `center` padding method."""
    tops = tf.convert_to_tensor(tops)
    bottoms = tf.convert_to_tensor(bottoms)
    lefts = tf.convert_to_tensor(lefts)
    rights = tf.convert_to_tensor(rights)

    if position == PaddingPosition.CENTER:
        # do nothing
        bottoms = bottoms
        rights = rights
        tops = tops
        lefts = lefts
    elif position == PaddingPosition.TOP_LEFT:
        bottoms += tops
        rights += lefts
        tops = tf.zeros_like(tops)
        lefts = tf.zeros_like(lefts)
    elif position == PaddingPosition.TOP_RIGHT:
        bottoms += tops
        lefts += rights
        tops = tf.zeros_like(tops)
        rights = tf.zeros_like(rights)
    elif position == PaddingPosition.BOTTOM_LEFT:
        tops += bottoms
        rights += lefts
        bottoms = tf.zeros_like(bottoms)
        lefts = tf.zeros_like(lefts)
    elif position == PaddingPosition.BOTTOM_RIGHT:
        tops += bottoms
        lefts += rights
        bottoms = tf.zeros_like(bottoms)
        rights = tf.zeros_like(rights)
    elif position == PaddingPosition.RANDOM:
        batch_size = tf.shape(tops)[0]
        original_dtype = tops.dtype
        h_pads = tf.cast(tops + bottoms, dtype=tf.float32)
        w_pads = tf.cast(lefts + rights, dtype=tf.float32)
        tops = random_generator.random_uniform(
            shape=(batch_size, 1), minval=0, maxval=1, dtype=tf.float32
        )
        tops = tf.cast(tf.round(tops * h_pads), dtype=original_dtype)
        bottoms = tf.cast(h_pads, dtype=tf.int32) - tops
        lefts = random_generator.random_uniform(
            shape=(batch_size, 1), minval=0, maxval=1, dtype=tf.float32
        )
        lefts = tf.cast(tf.round(lefts * w_pads), dtype=original_dtype)
        rights = tf.cast(w_pads, dtype=tf.int32) - lefts
    else:
        raise NotImplementedError(
            f"Value not recognized for `position`: {position}. Supported "
            f"values are: {PaddingPosition._value2member_map_.keys()}"
        )

    return tops, bottoms, lefts, rights


def parse_factor(
    param,
    min_value=0.0,
    max_value=1.0,
    center_value=0.5,
    param_name="factor",
    seed=None,
):
    if isinstance(param, FactorSampler):
        return param

    if isinstance(param, float) or isinstance(param, int):
        param = (center_value - param, center_value + param)

    if param[0] > param[1]:
        raise ValueError(
            f"`{param_name}[0] > {param_name}[1]`, `{param_name}[0]` must be "
            f"<= `{param_name}[1]`. Got `{param_name}={param}`"
        )
    if (min_value is not None and param[0] < min_value) or (
        max_value is not None and param[1] > max_value
    ):
        raise ValueError(
            f"`{param_name}` should be inside of range "
            f"[{min_value}, {max_value}]. Got {param_name}={param}"
        )

    if param[0] == param[1]:
        return ConstantFactorSampler(param[0])

    return UniformFactorSampler(param[0], param[1], seed=seed)


def is_factor_working(factor, not_working_value=0.0):
    """Check whether ``factor`` is working or not.

    Args:
        factor (int|float|Sequence[int|float]|keras_aug.FactorSampler): The
            factor to check whether it is working or not.
        not_working_value (float, optional): The value indicating not working
            status. Defaults to ``0.0``.
    """
    if factor is None:
        return False
    if isinstance(factor, (int, float)):
        if factor == not_working_value:
            return False
    elif isinstance(factor, Sequence):
        if factor[0] == factor[1] and factor[0] == not_working_value:
            return False
    elif isinstance(factor, ConstantFactorSampler):
        if factor.value == not_working_value:
            return False
    elif isinstance(factor, UniformFactorSampler):
        if (
            factor.lower == not_working_value
            and factor.upper == not_working_value
        ):
            return False
    elif isinstance(factor, (NormalFactorSampler, SignedNormalFactorSampler)):
        if factor.stddev == 0 and factor.mean == not_working_value:
            return False
    else:
        raise ValueError(
            f"Cannot recognize factor type: {factor} with type {type(factor)}"
        )
    return True


def expand_dict_dims(dicts, axis):
    new_dicts = {}
    for key in dicts.keys():
        tensor = dicts[key]
        new_dicts[key] = tf.expand_dims(tensor, axis=axis)
    return new_dicts


def get_images_shape(images, dtype=tf.int32):
    """Get ``heights`` and ``widths`` of the input images.

    Input images can be ``tf.Tensor`` or ``tf.RaggedTensor`` with the shape of
    [B, H|None, W|None, C].

    Args:
        images (tf.Tensor|tf.RaggedTensor): The input images.
        dtype (tf.dtypes.DType, optional): The dtype of the outputs. Defaults to
            ``tf.int32``.
    """
    if isinstance(images, tf.RaggedTensor):
        heights = tf.reshape(images.row_lengths(), (-1, 1))
        widths = tf.reshape(
            tf.reduce_max(images.row_lengths(axis=2), 1), (-1, 1)
        )
    else:
        batch_size = tf.shape(images)[0]
        heights = tf.repeat(tf.shape(images)[H_AXIS], repeats=[batch_size])
        heights = tf.reshape(heights, shape=(-1, 1))
        widths = tf.repeat(tf.shape(images)[W_AXIS], repeats=[batch_size])
        widths = tf.reshape(widths, shape=(-1, 1))
    return tf.cast(heights, dtype=dtype), tf.cast(widths, dtype=dtype)


def cast_to(inputs, dtype):
    if IMAGES in inputs:
        inputs[IMAGES] = tf.cast(inputs[IMAGES], dtype)
    if LABELS in inputs:
        inputs[LABELS] = tf.cast(inputs[LABELS], dtype)
    if BOUNDING_BOXES in inputs:
        inputs[BOUNDING_BOXES]["boxes"] = tf.cast(
            inputs[BOUNDING_BOXES]["boxes"], dtype
        )
        inputs[BOUNDING_BOXES]["classes"] = tf.cast(
            inputs[BOUNDING_BOXES]["classes"], dtype
        )
    if SEGMENTATION_MASKS in inputs:
        inputs[SEGMENTATION_MASKS] = tf.cast(inputs[SEGMENTATION_MASKS], dtype)
    if KEYPOINTS in inputs:
        inputs[KEYPOINTS] = tf.cast(inputs[KEYPOINTS], dtype)
    if CUSTOM_ANNOTATIONS in inputs:
        raise NotImplementedError()
    return inputs


def compute_signature(inputs, dtype):
    fn_output_signature = {}
    if IMAGES in inputs:
        if isinstance(inputs[IMAGES], tf.Tensor):
            fn_output_signature[IMAGES] = tf.TensorSpec(
                inputs[IMAGES].shape[1:], dtype
            )
        else:
            fn_output_signature[IMAGES] = tf.RaggedTensorSpec(
                shape=inputs[IMAGES].shape[1:],
                ragged_rank=1,
                dtype=dtype,
            )
    if LABELS in inputs:
        fn_output_signature[LABELS] = tf.TensorSpec(
            inputs[LABELS].shape[1:], dtype
        )
    if BOUNDING_BOXES in inputs:
        fn_output_signature[BOUNDING_BOXES] = {
            "boxes": tf.RaggedTensorSpec(
                shape=[None, 4],
                ragged_rank=1,
                dtype=dtype,
            ),
            "classes": tf.RaggedTensorSpec(
                shape=[None], ragged_rank=0, dtype=dtype
            ),
        }
    if SEGMENTATION_MASKS in inputs:
        if isinstance(inputs[SEGMENTATION_MASKS], tf.Tensor):
            fn_output_signature[SEGMENTATION_MASKS] = tf.TensorSpec(
                inputs[SEGMENTATION_MASKS].shape[1:], dtype
            )
        else:
            fn_output_signature[SEGMENTATION_MASKS] = tf.RaggedTensorSpec(
                shape=inputs[SEGMENTATION_MASKS].shape[1:],
                ragged_rank=1,
                dtype=dtype,
            )
    if KEYPOINTS in inputs:
        if isinstance(inputs[KEYPOINTS], tf.Tensor):
            fn_output_signature[KEYPOINTS] = tf.TensorSpec(
                inputs[KEYPOINTS].shape[1:], dtype
            )
        else:
            fn_output_signature[KEYPOINTS] = tf.RaggedTensorSpec(
                shape=inputs[KEYPOINTS].shape[1:],
                ragged_rank=1,
                dtype=dtype,
            )
    if CUSTOM_ANNOTATIONS in inputs:
        raise NotImplementedError()
    return fn_output_signature


def blend(images_1, images_2, factors, value_range=None):
    """Blend image1 and image2 using 'factors'. Can be batched inputs.

    Factor can be above ``0.0``.  A value of ``0.0`` means only image1 is used.
    A value of ``1.0`` means only image2 is used.  A value between ``0.0`` and
    ``1.0`` means we linearly interpolate the pixel values between the two
    images.  A value greater than ``1.0`` "extrapolates" the difference
    between the two pixel values. If ``value_range`` is set, the results will be
    clipped into ``value_range``

    Args:
        image1 (tf.Tensor): First image(s).
        image2 (tf.Tensor): Second image(s).
        factor (float|tf.Tensor): The blend factor(s).
        value_range (Sequence[int|float], optional): The value range of the
            results. Defaults to ``None``.

    References:
        - `KerasCV <https://github.com/keras-team/keras-cv>`_
    """
    results = images_1 + factors * (images_2 - images_1)
    if value_range is not None:
        results = tf.clip_by_value(results, value_range[0], value_range[1])
    return results


def rgb_to_grayscale(images):
    """Converts images from RGB to Grayscale.

    Compared to ``tf.image.rgb_to_grayscale``, this function replaces
    ``tf.tensordot`` with ``tf.math.multiply`` and ``tf.math.add`` to reduce
    memory usage.

    Args:
        images (tf.Tensor): The RGB tensor to convert. The last dimension must
            have size 3 and should contain RGB values.

    References:
        - `torchvision <https://github.com/pytorch/vision>`_
    """
    return (
        images[..., 0:1] * 0.2989
        + images[..., 1:2] * 0.587
        + images[..., 2:3] * 0.114
    )


def get_rotation_matrix(
    angles, image_height, image_width, to_square=False, name=None
):
    """Returns projective transforms for the given angles.

    Args:
        angles (tf.Tensor): a vector with the angles to rotate each image
            in the batch.
        image_height (tf.Tensor): Height of the images to be transformed.
        image_width (tf.Tensor): Width of the images to be transformed.
        to_square (bool, optional): Whether to append ones to last dimension
            and reshape to ``(batch_size, 3, 3)``. Defaults to ``False``.
        name (str, optional): The name of the op. Defaults to ``None``.

    References:
        - `KerasCV <https://github.com/keras-team/keras-cv>`_
    """
    with keras.backend.name_scope(name or "rotation_matrix"):
        x_offset = (image_width - 1) - (
            tf.cos(angles) * (image_width - 1)
            - tf.sin(angles) * (image_height - 1)
        )
        x_offset /= 2.0
        y_offset = (image_height - 1) - (
            tf.sin(angles) * (image_width - 1)
            + tf.cos(angles) * (image_height - 1)
        )
        y_offset /= 2.0
        num_angles = tf.shape(angles)[0]
        matrix = tf.concat(
            [
                tf.cos(angles),
                -tf.sin(angles),
                x_offset,
                tf.sin(angles),
                tf.cos(angles),
                y_offset,
                tf.zeros((num_angles, 2), angles.dtype),
            ],
            axis=1,
        )
        if to_square:
            matrix = tf.concat(
                [matrix, tf.ones((num_angles, 1), angles.dtype)], axis=1
            )
            matrix = tf.reshape(matrix, (num_angles, 3, 3))
        return matrix


def get_translation_matrix(
    translations, image_height, image_width, to_square=False, name=None
):
    """Returns projective transforms for the given translations.

    Args:
        translations (tf.Tensor): A matrix of 2-element lists representing
            ``[dx, dy]`` to translate for a batch of images.
        image_height (tf.Tensor): Height of the images to be transformed.
        image_width (tf.Tensor): Width of the images to be transformed.
        to_square (bool, optional): Whether to append ones to last dimension
            and reshape to ``(batch_size, 3, 3)``. Defaults to ``False``.
        name (str, optional): The name of the op. Defaults to ``None``.

    References:
        - `KerasCV <https://github.com/keras-team/keras-cv>`_
    """
    with keras.backend.name_scope(name or "translation_matrix"):
        num_translations = tf.shape(translations)[0]
        # The translation matrix looks like:
        #     [[1 0 -dx]
        #      [0 1 -dy]
        #      [0 0 1]]
        # where the last entry is implicit.
        # Translation matrices are always float32.
        matrix = tf.concat(
            values=[
                tf.ones((num_translations, 1), translations.dtype),
                tf.zeros((num_translations, 1), translations.dtype),
                -translations[:, 0, tf.newaxis] * image_width,
                tf.zeros((num_translations, 1), translations.dtype),
                tf.ones((num_translations, 1), translations.dtype),
                -translations[:, 1, tf.newaxis] * image_height,
                tf.zeros((num_translations, 2), translations.dtype),
            ],
            axis=1,
        )
        if to_square:
            matrix = tf.concat(
                [matrix, tf.ones((num_translations, 1), translations.dtype)],
                axis=1,
            )
            matrix = tf.reshape(matrix, (num_translations, 3, 3))
        return matrix


def get_zoom_matrix(
    zooms, image_height, image_width, to_square=False, name=None
):
    """Returns projective transforms for the given zooms.

    Args:
        zooms (tf.Tensor): A matrix of 2-element lists representing
            ``[zx, zy]`` to zoom for a batch of images.
        image_height (tf.Tensor): Height of the images to be transformed.
        image_width (tf.Tensor): Width of the images to be transformed.
        to_square (bool, optional): Whether to append ones to last dimension
            and reshape to ``(batch_size, 3, 3)``. Defaults to ``False``.
        name (str, optional): The name of the op. Defaults to ``None``.

    References:
        - `KerasCV <https://github.com/keras-team/keras-cv>`_
    """
    with keras.backend.name_scope(name or "zoom_matrix"):
        num_zooms = tf.shape(zooms)[0]
        # The zoom matrix looks like:
        #     [[zx 0 0]
        #      [0 zy 0]
        #      [0 0 1]]
        # where the last entry is implicit.
        # Zoom matrices are always float32.
        x_offset = ((image_width - 1.0) / 2.0) * (1.0 - zooms[:, 0, tf.newaxis])
        y_offset = ((image_height - 1.0) / 2.0) * (
            1.0 - zooms[:, 1, tf.newaxis]
        )
        matrix = tf.concat(
            values=[
                zooms[:, 0, tf.newaxis],
                tf.zeros((num_zooms, 1), zooms.dtype),
                x_offset,
                tf.zeros((num_zooms, 1), zooms.dtype),
                zooms[:, 1, tf.newaxis],
                y_offset,
                tf.zeros((num_zooms, 2), zooms.dtype),
            ],
            axis=1,
        )
        if to_square:
            matrix = tf.concat(
                [matrix, tf.ones((num_zooms, 1), zooms.dtype)], axis=1
            )
            matrix = tf.reshape(matrix, (num_zooms, 3, 3))
        return matrix


def get_shear_matrix(shears, to_square=False, name=None):
    """Returns projective transforms for the given shears.

    Args:
        shears (tf.Tensor): A matrix of 2-element lists representing `[sx, sy]`
            to shear for a batch of images.
        to_square (bool, optional): Whether to append ones to last dimension
            and reshape to ``(batch_size, 3, 3)``. Defaults to ``False``.
        name (str, optional): The name of the op. Defaults to ``None``.

    References:
        - `KerasCV <https://github.com/keras-team/keras-cv>`_
    """
    with keras.backend.name_scope(name or "shear_matrix"):
        num_shears = tf.shape(shears)[0]
        # The transform matrix looks like:
        # (1, x, 0)
        # (y, 1, 0)
        # (0, 0, 1)
        # where the last entry is implicit.
        matrix = tf.concat(
            values=[
                tf.ones((num_shears, 1), shears.dtype),
                shears[:, 0, tf.newaxis],
                tf.zeros((num_shears, 1), shears.dtype),
                shears[:, 1, tf.newaxis],
                tf.ones((num_shears, 1), shears.dtype),
                tf.zeros((num_shears, 3), shears.dtype),
            ],
            axis=1,
        )
        if to_square:
            matrix = tf.concat(
                [matrix, tf.ones((num_shears, 1), shears.dtype)], axis=1
            )
            matrix = tf.reshape(matrix, (num_shears, 3, 3))
        return matrix