Image Augmentation with Keras Preprocessing Layers and tf.image

Author: Adrian Tam

When we work on a machine learning problem related to images, not only we need to collect some images as training data, but also need to employ augmentation to create variations in the image. It is especially true for more complex object recognition problems.

There are many ways for image augmentation. You may use some external libraries or write your own functions for that. There are some modules in TensorFlow and Keras for augmentation, too. In this post you will discover how we can use the Keras preprocessing layer as well as tf.image module in TensorFlow for image augmentation.

After reading this post, you will know:

  • What are the Keras preprocessing layers and how to use them
  • What are the functions provided by tf.image module for image augmentation
  • How to use augmentation together with tf.data dataset

Let’s get started.

Image Augmentation with Keras Preprocessing Layers and tf.image.
Photo by Steven Kamenar. Some rights reserved.

Overview

This article is split into five sections; they are:

  • Getting Images
  • Visualizing the Images
  • Keras Preprocessing Layesr
  • Using tf.image API for Augmentation
  • Using Preprocessing Layers in Neural Networks

Getting Images

Before we see how we can do augmentation, we need to get the images. Ultimately, we need the images to be represented as arrays, for example, in HxWx3 in 8-bit integers for the RGB pixel value. There are many ways to get the images. Some can be downloaded as a ZIP file. If you’re using TensorFlow, you may get some image dataset from the tensorflow_datasets library.

In this tutorial, we are going to use the citrus leaves images, which is a small dataset in less than 100MB. It can be downloaded from tensorflow_datasets as follows:

import tensorflow_datasets as tfds
ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True)

Running this code the first time will download the image dataset into your computer with the following output:

Downloading and preparing dataset 63.87 MiB (download: 63.87 MiB, generated: 37.89 MiB, total: 101.76 MiB) to ~/tensorflow_datasets/citrus_leaves/0.1.2...
Extraction completed...: 100%|██████████████████████████████| 1/1 [00:06<00:00,  6.54s/ file]
Dl Size...: 100%|██████████████████████████████████████████| 63/63 [00:06<00:00,  9.63 MiB/s]
Dl Completed...: 100%|███████████████████████████████████████| 1/1 [00:06<00:00,  6.54s/ url]
Dataset citrus_leaves downloaded and prepared to ~/tensorflow_datasets/citrus_leaves/0.1.2. Subsequent calls will reuse this data.

The function above returns the images as a tf.data dataset object and the metadata. This is a classification dataset. We can print the training labels with the following:

...
for i in range(meta.features['label'].num_classes):
    print(meta.features['label'].int2str(i))

and this prints:

Black spot
canker
greening
healthy

If you run this code again at a later time, you will reuse the downloaded image. But the other way to load the downloaded images into a tf.data dataset is to the image_dataset_from_directory() function.

As we can see the screen output above, the dataset is downloaded into the directory ~/tensorflow_datasets. If you look at the directory, you see the directory structure as follows:

.../Citrus/Leaves
├── Black spot
├── Melanose
├── canker
├── greening
└── healthy

The directories are the labels and the images are files stored under their corresponding directory. We can let the function to read the directory recursively into a dataset:

import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory

# set to fixed image size 256x256
PATH = ".../Citrus/Leaves"
ds = image_dataset_from_directory(PATH,
                                  validation_split=0.2, subset="training",
                                  image_size=(256,256), interpolation="bilinear",
                                  crop_to_aspect_ratio=True,
                                  seed=42, shuffle=True, batch_size=32)

You may want to set batch_size=None if you do not want the dataset to be batched. Usually we would like the dataset to be batched for training a neural network model.

Visualizing the Images

It is important to visualize the augmentation result so we can verify the augmentation result is what we want it to be. We can use matplotlib for this.

In matplotlib, we have the imshow() function to display an image. However, for the image to be displayed correctly, the image should be presented as an array of 8-bit unsigned integer (uint8).

Given we have a dataset created using image_dataset_from_directory(), we can get the first batch (of 32 images) and display a few of them using imshow(), as follows:

...
import matplotlib.pyplot as plt

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for images, labels in ds.take(1):
    for i in range(3):
        for j in range(3):
            ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
            ax[i][j].set_title(ds.class_names[labels[i*3+j]])
plt.show()

Here we display 9 images in a grid, and label the images with their corresponding classification label, using ds.class_names. The images should be converted to NumPy array in uint8 for display. This code displays an image like the following:

The complete code from loading the image to display is as follows.

from tensorflow.keras.utils import image_dataset_from_directory
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves'  # modify to your path
ds = image_dataset_from_directory(PATH,
                                  validation_split=0.2, subset="training",
                                  image_size=(256,256), interpolation="mitchellcubic",
                                  crop_to_aspect_ratio=True,
                                  seed=42, shuffle=True, batch_size=32)

# Take one batch from dataset and display the images
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for images, labels in ds.take(1):
    for i in range(3):
        for j in range(3):
            ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
            ax[i][j].set_title(ds.class_names[labels[i*3+j]])
plt.show()

Note that, if you’re using tensorflow_datasets to get the image, the samples are presented as a dictionary instead of a tuple of (image,label). You should change your code slightly into the following:

import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# use tfds.load() or image_dataset_from_directory() to load images
ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True)
ds = ds.batch(32)

# Take one batch from dataset and display the images
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for sample in ds.take(1):
    images, labels = sample["image"], sample["label"]
    for i in range(3):
        for j in range(3):
            ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
            ax[i][j].set_title(meta.features['label'].int2str(labels[i*3+j]))
plt.show()

In the rest of this post, we assume the dataset is created using image_dataset_from_directory(). You may need to tweak the code slightly if your dataset is created differently.

Keras Preprocessing Layers

Keras comes with many neural network layers such as convolution layers that we need to train. There are also layers with no parameters to train, such as flatten layers to convert an array such as an image into a vector.

The preprocessing layers in Keras are specifically designed to use in early stages in a neural network. We can use them for image preprocessing, such as to resize or rotate the image or to adjust the brightness and contrast. While the preprocessing layers are supposed to be part of a larger neural network, we can also use them as functions. Below is how we can use the resizing layer as a function to transform some images and display them side-by-side with the original:

...

# create a resizing layer
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)

# show original vs resized
fig, ax = plt.subplots(2, 3, figsize=(6,4))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # resize
        ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("resize")
plt.show()

Our images are in 256×256 pixels and the resizing layer will make them into 256×128 pixels. The output of the above code is as follows:

Since the resizing layer is a function itself, we can chain them to the dataset itself. For example,

...
def augment(image, label):
    return resize(image), label

resized_ds = ds.map(augment)

for image, label in resized_ds:
   ...

The dataset ds has samples in the form of (image, label). Hence we created a function that takes in such tuple and preprocess the image with the resizing layer. We assigned this function as an argument for map() in the dataset. When we draw a sample from the new dataset created with the map() function, the image will be a transformed one.

There are more preprocessing layers available. In below, we demonstrate some.

As we saw above, we can resize the image. We can also randomly enlarge or shrink the height or width of an image. Similarly, we can zoom in or zoom out on an image. Below is an example to manipulate the image size in various ways for a maximum of 30% increase or decrease:

...

# Create preprocessing layers
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
height = tf.keras.layers.RandomHeight(0.3)
width = tf.keras.layers.RandomWidth(0.3)
zoom = tf.keras.layers.RandomZoom(0.3)

# Visualize images and augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # resize
        ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("resize")
        # height
        ax[2][i].imshow(height(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("height")
        # width
        ax[3][i].imshow(width(images[i]).numpy().astype("uint8"))
        ax[3][i].set_title("width")
        # zoom
        ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8"))
        ax[4][i].set_title("zoom")
plt.show()

This code shows images as follows:

While we specified a fixed dimension in resize, we have a random amount of manipulation in other augmentations.

We can also do flipping, rotation, cropping, and geometric translation using preprocessing layers:

...
# Create preprocessing layers
flip = tf.keras.layers.RandomFlip("horizontal_and_vertical") # or "horizontal", "vertical"
rotate = tf.keras.layers.RandomRotation(0.2)
crop = tf.keras.layers.RandomCrop(out_height, out_width)
translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2)

# Visualize augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # flip
        ax[1][i].imshow(flip(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("flip")
        # crop
        ax[2][i].imshow(crop(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("crop")
        # translation
        ax[3][i].imshow(translation(images[i]).numpy().astype("uint8"))
        ax[3][i].set_title("translation")
        # rotate
        ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8"))
        ax[4][i].set_title("rotate")
plt.show()

This code shows the following images:

And finally, we can do augmentations on color adjustments as well:

...
brightness = tf.keras.layers.RandomBrightness([-0.8,0.8])
contrast = tf.keras.layers.RandomContrast(0.2)

# Visualize augmentation
fig, ax = plt.subplots(3, 3, figsize=(6,7))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # brightness
        ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("brightness")
        # contrast
        ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("contrast")
plt.show()

This shows the images as follows:

For completeness, below is the code to display the result of various augmentations:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves'  # modify to your path
ds = image_dataset_from_directory(PATH,
                                  validation_split=0.2, subset="training",
                                  image_size=(256,256), interpolation="mitchellcubic",
                                  crop_to_aspect_ratio=True,
                                  seed=42, shuffle=True, batch_size=32)

# Create preprocessing layers
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
height = tf.keras.layers.RandomHeight(0.3)
width = tf.keras.layers.RandomWidth(0.3)
zoom = tf.keras.layers.RandomZoom(0.3)

flip = tf.keras.layers.RandomFlip("horizontal_and_vertical")
rotate = tf.keras.layers.RandomRotation(0.2)
crop = tf.keras.layers.RandomCrop(out_height, out_width)
translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2)

brightness = tf.keras.layers.RandomBrightness([-0.8,0.8])
contrast = tf.keras.layers.RandomContrast(0.2)

# Visualize images and augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # resize
        ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("resize")
        # height
        ax[2][i].imshow(height(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("height")
        # width
        ax[3][i].imshow(width(images[i]).numpy().astype("uint8"))
        ax[3][i].set_title("width")
        # zoom
        ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8"))
        ax[4][i].set_title("zoom")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # flip
        ax[1][i].imshow(flip(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("flip")
        # crop
        ax[2][i].imshow(crop(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("crop")
        # translation
        ax[3][i].imshow(translation(images[i]).numpy().astype("uint8"))
        ax[3][i].set_title("translation")
        # rotate
        ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8"))
        ax[4][i].set_title("rotate")
plt.show()

fig, ax = plt.subplots(3, 3, figsize=(6,7))
for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # brightness
        ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8"))
        ax[1][i].set_title("brightness")
        # contrast
        ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8"))
        ax[2][i].set_title("contrast")
plt.show()

Finally, it is important to point out that most neural network model can work better if the input images are scaled. While we usually use 8-bit unsigned integer for the pixel values in an image (e.g., for display using imshow() as above), neural network prefers the pixel values to be between 0 and 1, or between -1 and +1. This can be done with a preprocessing layers, too. Below is how we can update one of our example above to add the scaling layer into the augmentation:

...
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
rescale = tf.keras.layers.Rescaling(1/127.5, offset=-1)  # rescale pixel values to [-1,1]

def augment(image, label):
    return rescale(resize(image)), label

rescaled_resized_ds = ds.map(augment)

for image, label in rescaled_resized_ds:
   ...

Using tf.image API for Augmentation

Besides the preprocessing layer, the tf.image module also provided some functions for augmentation. Unlike the preprocessing layer, these functions are intended to be used in a user-defined function and assigned to a dataset using map() as we saw above.

The functions provided by tf.image are not duplicates of the preprocessing layers, although there are some overlap. Below is an example of using the tf.image functions to resize and crop images:

...

fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
    for i in range(3):
        # original
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # resize
        h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
        w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
        ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8"))
        ax[1][i].set_title("resize")
        # crop
        y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8")
        ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8"))
        ax[2][i].set_title("crop")
        # central crop
        x = tf.random.uniform([], minval=0.4, maxval=1.0)
        ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8"))
        ax[3][i].set_title("central crop")
        # crop to (h,w) at random offset
        h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8")
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8"))
        ax[4][i].set_title("random crop")
plt.show()

Below is the output of the above code:

While the display of images match what we would expect from the code, the use of tf.image functions is quite different from that of the preprocessing layers. Every tf.image function is different. Therefore, we can see the crop_to_bounding_box() function takes pixel coordinates but the central_crop() function assumes a fraction ratio as argument.

These functions are also different in the way randomness is handled. Some of these function does not assume random behavior. Therefore, the random resize should have the exact output size generated using a random number generator separately before calling the resize function. Some other function, such as stateless_random_crop(), can do augmentation randomly but a pair of random seed in int32 needs to be specified explicitly.

To continue the example, there are the functions for flipping an image and extracting the Sobel edges:

...
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # flip
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8"))
        ax[1][i].set_title("flip left-right")
        # flip
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8"))
        ax[2][i].set_title("flip up-down")
        # sobel edge
        sobel = tf.image.sobel_edges(images[i:i+1])
        ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8"))
        ax[3][i].set_title("sobel y")
        # sobel edge
        ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8"))
        ax[4][i].set_title("sobel x")
plt.show()

which shows the following:

And the following are the functions to manipulate the brightness, contrast, and colors:

...
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # brightness
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8"))
        ax[1][i].set_title("brightness")
        # contrast
        ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
        ax[2][i].set_title("contrast")
        # saturation
        ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
        ax[3][i].set_title("saturation")
        # hue
        ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8"))
        ax[4][i].set_title("hue")
plt.show()

This code shows the following:

Below is the complete code to display all of the above:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves'  # modify to your path
ds = image_dataset_from_directory(PATH,
                                  validation_split=0.2, subset="training",
                                  image_size=(256,256), interpolation="mitchellcubic",
                                  crop_to_aspect_ratio=True,
                                  seed=42, shuffle=True, batch_size=32)

# Visualize tf.image augmentations

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
    for i in range(3):
        # original
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # resize
        h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
        w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
        ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8"))
        ax[1][i].set_title("resize")
        # crop
        y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8")
        ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8"))
        ax[2][i].set_title("crop")
        # central crop
        x = tf.random.uniform([], minval=0.4, maxval=1.0)
        ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8"))
        ax[3][i].set_title("central crop")
        # crop to (h,w) at random offset
        h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8")
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8"))
        ax[4][i].set_title("random crop")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # flip
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8"))
        ax[1][i].set_title("flip left-right")
        # flip
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8"))
        ax[2][i].set_title("flip up-down")
        # sobel edge
        sobel = tf.image.sobel_edges(images[i:i+1])
        ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8"))
        ax[3][i].set_title("sobel y")
        # sobel edge
        ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8"))
        ax[4][i].set_title("sobel x")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
    for i in range(3):
        ax[0][i].imshow(images[i].numpy().astype("uint8"))
        ax[0][i].set_title("original")
        # brightness
        seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
        ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8"))
        ax[1][i].set_title("brightness")
        # contrast
        ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
        ax[2][i].set_title("contrast")
        # saturation
        ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
        ax[3][i].set_title("saturation")
        # hue
        ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8"))
        ax[4][i].set_title("hue")
plt.show()

These augmentation functions should be enough for most use. But if you have some specific idea on augmentation, probably you would need a better image processing library. OpenCV and Pillow are common but powerful libraries that allows you to transform images better.

Using Preprocessing Layers in Neural Networks

We used the Keras preprocessing layers as functions in the examples above. But they can also be used as layers in a neural network. It is trivial to use. Below is an example on how we can incorporate a preprocessing layer into a classification network and train it using a dataset:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves'  # modify to your path
ds = image_dataset_from_directory(PATH,
                                  validation_split=0.2, subset="training",
                                  image_size=(256,256), interpolation="mitchellcubic",
                                  crop_to_aspect_ratio=True,
                                  seed=42, shuffle=True, batch_size=32)

AUTOTUNE = tf.data.AUTOTUNE
ds = ds.cache().prefetch(buffer_size=AUTOTUNE)

num_classes = 5
model = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.Rescaling(1/127.0, offset=-1),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
  
model.fit(ds, epochs=3)

Running this code gives the following output:

Found 609 files belonging to 5 classes.
Using 488 files for training.
Epoch 1/3
16/16 [==============================] - 5s 253ms/step - loss: 1.4114 - accuracy: 0.4283
Epoch 2/3
16/16 [==============================] - 4s 259ms/step - loss: 0.8101 - accuracy: 0.6475
Epoch 3/3
16/16 [==============================] - 4s 267ms/step - loss: 0.7015 - accuracy: 0.7111

In the code above, we created the dataset with cache() and prefetch(). This is a performance technique to allow the dataset to prepare data asynchronously while the neural network is trained. This would be significant if the dataset has some other augmentation assigned using the map() function.

You will see some improvement in accuracy if you removed the RandomFlip and RandomRotation layers because you make the problem easier. However, as we want the network to predict well on a wide variations of image quality and properties, using augmentation can help our resulting network more powerful.

Further Reading

Below are documentations from TensorFlow that are related to the examples above:

Summary

In this post, you have seen how we can use the tf.data dataset with image augmentation functions from Keras and TensorFlow.

Specifically, you learned:

  • How to use the preprocessing layers from Keras, both as a function and as part of a neural network
  • How to create your own image augmentation function and apply it to the dataset using the map() function
  • How to use the functions provided by the tf.image module for image augmentation

The post Image Augmentation with Keras Preprocessing Layers and tf.image appeared first on Machine Learning Mastery.

Go to Source