import enum
from typing import Sequence
import tensorflow as tf
from keras_aug.core import ConstantFactorSampler
from keras_aug.core import FactorSampler
from keras_aug.core import NormalFactorSampler
from keras_aug.core import SignedNormalFactorSampler
from keras_aug.core import UniformFactorSampler
H_AXIS = -3
W_AXIS = -2
IMAGES = "images"
LABELS = "labels"
TARGETS = "targets"
BOUNDING_BOXES = "bounding_boxes"
KEYPOINTS = "keypoints"
SEGMENTATION_MASKS = "segmentation_masks"
CUSTOM_ANNOTATIONS = "custom_annotations"
BATCHED = "batched"
class PaddingPosition(enum.Enum):
CENTER = "center"
TOP_LEFT = "top_left"
TOP_RIGHT = "top_right"
BOTTOM_LEFT = "bottom_left"
BOTTOM_RIGHT = "bottom_right"
RANDOM = "random"
def get_padding_position(position):
if isinstance(position, PaddingPosition):
return position
position = position.lower()
if position not in PaddingPosition._value2member_map_.keys():
raise NotImplementedError(
f"Value not recognized for `position`: {position}. Supported "
f"values are: {PaddingPosition._value2member_map_.keys()}"
)
return PaddingPosition._value2member_map_[position]
def get_position_params(
tops, bottoms, lefts, rights, position, random_generator
):
"""This function supposes arguments are at `center` padding method."""
tops = tf.convert_to_tensor(tops)
bottoms = tf.convert_to_tensor(bottoms)
lefts = tf.convert_to_tensor(lefts)
rights = tf.convert_to_tensor(rights)
if position == PaddingPosition.CENTER:
# do nothing
bottoms = bottoms
rights = rights
tops = tops
lefts = lefts
elif position == PaddingPosition.TOP_LEFT:
bottoms += tops
rights += lefts
tops = tf.zeros_like(tops)
lefts = tf.zeros_like(lefts)
elif position == PaddingPosition.TOP_RIGHT:
bottoms += tops
lefts += rights
tops = tf.zeros_like(tops)
rights = tf.zeros_like(rights)
elif position == PaddingPosition.BOTTOM_LEFT:
tops += bottoms
rights += lefts
bottoms = tf.zeros_like(bottoms)
lefts = tf.zeros_like(lefts)
elif position == PaddingPosition.BOTTOM_RIGHT:
tops += bottoms
lefts += rights
bottoms = tf.zeros_like(bottoms)
rights = tf.zeros_like(rights)
elif position == PaddingPosition.RANDOM:
batch_size = tf.shape(tops)[0]
original_dtype = tops.dtype
h_pads = tf.cast(tops + bottoms, dtype=tf.float32)
w_pads = tf.cast(lefts + rights, dtype=tf.float32)
tops = random_generator.random_uniform(
shape=(batch_size, 1), minval=0, maxval=1, dtype=tf.float32
)
tops = tf.cast(tf.round(tops * h_pads), dtype=original_dtype)
bottoms = tf.cast(h_pads, dtype=tf.int32) - tops
lefts = random_generator.random_uniform(
shape=(batch_size, 1), minval=0, maxval=1, dtype=tf.float32
)
lefts = tf.cast(tf.round(lefts * w_pads), dtype=original_dtype)
rights = tf.cast(w_pads, dtype=tf.int32) - lefts
else:
raise NotImplementedError(
f"Value not recognized for `position`: {position}. Supported "
f"values are: {PaddingPosition._value2member_map_.keys()}"
)
return tops, bottoms, lefts, rights
def parse_factor(
param,
min_value=0.0,
max_value=1.0,
center_value=0.5,
param_name="factor",
seed=None,
):
if isinstance(param, FactorSampler):
return param
if isinstance(param, float) or isinstance(param, int):
param = (center_value - param, center_value + param)
if param[0] > param[1]:
raise ValueError(
f"`{param_name}[0] > {param_name}[1]`, `{param_name}[0]` must be "
f"<= `{param_name}[1]`. Got `{param_name}={param}`"
)
if (min_value is not None and param[0] < min_value) or (
max_value is not None and param[1] > max_value
):
raise ValueError(
f"`{param_name}` should be inside of range "
f"[{min_value}, {max_value}]. Got {param_name}={param}"
)
if param[0] == param[1]:
return ConstantFactorSampler(param[0])
return UniformFactorSampler(param[0], param[1], seed=seed)
def is_factor_working(factor, not_working_value=0.0):
"""Check whether ``factor`` is working or not.
Args:
factor (int|float|Sequence[int|float]|keras_aug.FactorSampler): The
factor to check whether it is working or not.
not_working_value (float, optional): The value indicating not working
status. Defaults to ``0.0``.
"""
if factor is None:
return False
if isinstance(factor, (int, float)):
if factor == not_working_value:
return False
elif isinstance(factor, Sequence):
if factor[0] == factor[1] and factor[0] == not_working_value:
return False
elif isinstance(factor, ConstantFactorSampler):
if factor.value == not_working_value:
return False
elif isinstance(factor, UniformFactorSampler):
if (
factor.lower == not_working_value
and factor.upper == not_working_value
):
return False
elif isinstance(factor, (NormalFactorSampler, SignedNormalFactorSampler)):
if factor.stddev == 0 and factor.mean == not_working_value:
return False
else:
raise ValueError(
f"Cannot recognize factor type: {factor} with type {type(factor)}"
)
return True
def expand_dict_dims(dicts, axis):
new_dicts = {}
for key in dicts.keys():
tensor = dicts[key]
new_dicts[key] = tf.expand_dims(tensor, axis=axis)
return new_dicts
def get_images_shape(images, dtype=tf.int32):
"""Get ``heights`` and ``widths`` of the input images.
Input images can be ``tf.Tensor`` or ``tf.RaggedTensor`` with the shape of
[B, H|None, W|None, C].
Args:
images (tf.Tensor|tf.RaggedTensor): The input images.
dtype (tf.dtypes.DType, optional): The dtype of the outputs. Defaults to
``tf.int32``.
"""
if isinstance(images, tf.RaggedTensor):
heights = tf.reshape(images.row_lengths(), (-1, 1))
widths = tf.reshape(
tf.reduce_max(images.row_lengths(axis=2), 1), (-1, 1)
)
else:
batch_size = tf.shape(images)[0]
heights = tf.repeat(tf.shape(images)[H_AXIS], repeats=[batch_size])
heights = tf.reshape(heights, shape=(-1, 1))
widths = tf.repeat(tf.shape(images)[W_AXIS], repeats=[batch_size])
widths = tf.reshape(widths, shape=(-1, 1))
return tf.cast(heights, dtype=dtype), tf.cast(widths, dtype=dtype)
def cast_to(inputs, dtype):
if IMAGES in inputs:
inputs[IMAGES] = tf.cast(inputs[IMAGES], dtype)
if LABELS in inputs:
inputs[LABELS] = tf.cast(inputs[LABELS], dtype)
if BOUNDING_BOXES in inputs:
inputs[BOUNDING_BOXES]["boxes"] = tf.cast(
inputs[BOUNDING_BOXES]["boxes"], dtype
)
inputs[BOUNDING_BOXES]["classes"] = tf.cast(
inputs[BOUNDING_BOXES]["classes"], dtype
)
if SEGMENTATION_MASKS in inputs:
inputs[SEGMENTATION_MASKS] = tf.cast(inputs[SEGMENTATION_MASKS], dtype)
if KEYPOINTS in inputs:
inputs[KEYPOINTS] = tf.cast(inputs[KEYPOINTS], dtype)
if CUSTOM_ANNOTATIONS in inputs:
raise NotImplementedError()
return inputs
def compute_signature(inputs, dtype):
fn_output_signature = {}
if IMAGES in inputs:
if isinstance(inputs[IMAGES], tf.Tensor):
fn_output_signature[IMAGES] = tf.TensorSpec(
inputs[IMAGES].shape[1:], dtype
)
else:
fn_output_signature[IMAGES] = tf.RaggedTensorSpec(
shape=inputs[IMAGES].shape[1:],
ragged_rank=1,
dtype=dtype,
)
if LABELS in inputs:
fn_output_signature[LABELS] = tf.TensorSpec(
inputs[LABELS].shape[1:], dtype
)
if BOUNDING_BOXES in inputs:
fn_output_signature[BOUNDING_BOXES] = {
"boxes": tf.RaggedTensorSpec(
shape=[None, 4],
ragged_rank=1,
dtype=dtype,
),
"classes": tf.RaggedTensorSpec(
shape=[None], ragged_rank=0, dtype=dtype
),
}
if SEGMENTATION_MASKS in inputs:
if isinstance(inputs[SEGMENTATION_MASKS], tf.Tensor):
fn_output_signature[SEGMENTATION_MASKS] = tf.TensorSpec(
inputs[SEGMENTATION_MASKS].shape[1:], dtype
)
else:
fn_output_signature[SEGMENTATION_MASKS] = tf.RaggedTensorSpec(
shape=inputs[SEGMENTATION_MASKS].shape[1:],
ragged_rank=1,
dtype=dtype,
)
if KEYPOINTS in inputs:
if isinstance(inputs[KEYPOINTS], tf.Tensor):
fn_output_signature[KEYPOINTS] = tf.TensorSpec(
inputs[KEYPOINTS].shape[1:], dtype
)
else:
fn_output_signature[KEYPOINTS] = tf.RaggedTensorSpec(
shape=inputs[KEYPOINTS].shape[1:],
ragged_rank=1,
dtype=dtype,
)
if CUSTOM_ANNOTATIONS in inputs:
raise NotImplementedError()
return fn_output_signature
def random_inversion(rng):
negate = rng.random_uniform((), 0, 1, dtype=tf.float32) > 0.5
negate = tf.cond(negate, lambda: -1.0, lambda: 1.0)
return negate
def ensure_tensor(inputs, dtype=None):
"""Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
inputs = tf.convert_to_tensor(inputs, dtype)
if dtype is not None and inputs.dtype != dtype:
inputs = tf.cast(inputs, dtype)
return inputs
_TF_INTERPOLATION_METHODS = {
"bilinear": tf.image.ResizeMethod.BILINEAR,
"nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
"bicubic": tf.image.ResizeMethod.BICUBIC,
"area": tf.image.ResizeMethod.AREA,
"lanczos3": tf.image.ResizeMethod.LANCZOS3,
"lanczos5": tf.image.ResizeMethod.LANCZOS5,
"gaussian": tf.image.ResizeMethod.GAUSSIAN,
"mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC,
}
def get_interpolation(interpolation):
interpolation = interpolation.lower()
if interpolation not in _TF_INTERPOLATION_METHODS:
raise NotImplementedError(
"Value not recognized for `interpolation`: {}. Supported values "
"are: {}".format(interpolation, _TF_INTERPOLATION_METHODS.keys())
)
return _TF_INTERPOLATION_METHODS[interpolation]
def check_fill_mode_and_interpolation(fill_mode, interpolation):
if fill_mode not in {"reflect", "wrap", "constant", "nearest"}:
raise NotImplementedError(
" Want fillmode to be one of `reflect`, `wrap`, "
"`constant` or `nearest`. Got `fill_mode` {}. ".format(fill_mode)
)
if interpolation not in {"nearest", "bilinear"}:
raise NotImplementedError(
"Unknown `interpolation` {}. Only `nearest` and "
"`bilinear` are supported.".format(interpolation)
)