Keras API provides ImageDataGenerator class to augment image data. In this tutorial, we'll briefly learn how to create augmented data with ImageDataGenerator in Python. The tutorial covers:
- Loading the image
- Defining the ImageDataGenerator
- Generating images
- Source code listing
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
Loading the image
First,
we'll load sample image to use in this tutorial. You can use any image you have to do this example.
path = "/Pictures/rabbit1.jpg"
image = load_img(path, target_size=(200, 250))
plt.imshow(image)
plt.show()
Then, convert it into the array type by using img_to_array() function and reshape its dimensions.
img_arr = img_to_array(image)
print(img_arr.shape)
(200, 250, 3)
img_arr = img_arr.reshape((1,)+img_arr.shape)
print(img_arr.shape)
(1, 200, 250, 3)
Defining the ImageDataGenerator
Next,
we'll define the image generator by using ImageDataGenerator class. Here, we can set the options we want to apply the image. You can check the Keras documentation to get more info about each option.
- rotation_range defines the rotaion degree
- width_shift and height_shift translates the image vertically or horizontally
- shear_range applies shearing transform
- zoom_range zooms the picture
- fil_mode fills newly created pixels
- horizontal_flip flips horizontally
datagen = ImageDataGenerator(rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True)
Generating images
Next, we'll generate image by using the datagen object. Here, we'll create 9 augmented images from the original image.
Finally, we'll plot the augmented images
n = 9
imgs = []
for i in datagen.flow(img_arr, batch_size=1):
imgs.append(array_to_img(i[0], scale=True))
if(len(imgs) == n):
break
Finally, we'll plot the augmented images
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
for i in range(0, n):
plt.subplot(3, 3, i + 1)
plt.tick_params(labelbottom=False)
plt.tick_params(labelleft=False)
plt.imshow(imgs[i])
plt.show()
By changing the parameter values of the ImageDataGenerator class, you can get different outputs and increase the number of training data.
In this tutorial, we've briefly learned how to generated augmented data by using Keras ImageDataGenerator class in Python.
Source code listing
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
path = "/Pictures/rabbit1.jpg"
image = load_img(path, target_size=(200, 250))
plt.imshow(image)
plt.show()
img_arr = img_to_array(image)
print(img_arr.shape)
img_arr = img_arr.reshape((1,)+img_arr.shape)
print(img_arr.shape)
datagen = ImageDataGenerator(rotation_range=20,
width_shift_range=0,
height_shift_range=0,
shear_range=0,
zoom_range=0,
horizontal_flip=True,
)
n = 9
imgs = []
for i in datagen.flow(img_arr, batch_size=1):
imgs.append(array_to_img(i[0], scale=True))
if(len(imgs) == n):
break
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
for i in range(0, n):
plt.subplot(3, 3, i + 1)
plt.tick_params(labelbottom=False)
plt.tick_params(labelleft=False)
plt.imshow(imgs[i])
plt.show()
References:
No comments:
Post a Comment