Source code for keras_aug.layers.preprocessing.intensity.grayscale

import tensorflow as tf
from tensorflow import keras

from keras_aug.datapoints import image as image_utils
from keras_aug.layers.base.vectorized_base_random_layer import (
    VectorizedBaseRandomLayer,
)


[docs]@keras.utils.register_keras_serializable(package="keras_aug") class Grayscale(VectorizedBaseRandomLayer): """Grayscale transforms RGB images to grayscale images. Args: output_channels (int, optional): The number of the color channels of the outputs. Defaults to ``3``. References: - `KerasCV <https://github.com/keras-team/keras-cv>`_ """ def __init__(self, output_channels=3, **kwargs): super().__init__(**kwargs) if output_channels not in (1, 3): raise ValueError( "Received invalid argument output_channels. " f"output_channels must be in 1 or 3. Got {output_channels}" ) self.output_channels = output_channels def compute_ragged_image_signature(self, images): ragged_spec = tf.RaggedTensorSpec( shape=images.shape[1:3] + (self.output_channels,), ragged_rank=1, dtype=self.compute_dtype, ) return ragged_spec def augment_ragged_image(self, image, transformation, **kwargs): return self.augment_images( image, transformations=transformation, **kwargs ) def augment_images(self, images, transformations=None, **kwargs): grayscales = image_utils.rgb_to_grayscale(images) if self.output_channels == 1: return grayscales elif self.output_channels == 3: return tf.image.grayscale_to_rgb(grayscales) else: raise ValueError("Unsupported value for `output_channels`.") def augment_labels(self, labels, transformations, **kwargs): return labels def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): return bounding_boxes def augment_segmentation_masks( self, segmentation_masks, transformations, **kwargs ): return segmentation_masks def augment_keypoints(self, keypoints, transformations, **kwargs): return keypoints def get_config(self): config = super().get_config() config.update({"output_channels": self.output_channels}) return config