To be able to extract a good vector representation from an image, the filters of our convolution layers need to learn good patterns. To reach this our model need to see many different images, but once we have a good vector representation, training dense layers is relatively easy. Our dense layers have the last decisions while our convolution layers only extract the patterns from an image.
Following this idea, we can take a neural network pretrained on a very large dataset, like ImageNet, keep the convolutional layers with their well-learned filters add our own dense layers, and use it to solve our problem.
First, we need to load the base model:
from tensorflow.keras.applications.xception import Xception
base_model = Xception(
weights='imagenet',
include_top=False,
input_shape=(200, 200, 3),
)
Remembering the first part, we can visualize the output of the pretrained model:
img = load_img(images[0], target_size=(200, 200))
feature_maps = base_model.predict(tf.cast(tf.expand_dims(img, axis=0), tf.float32))
plt.subplots(4, 8, figsize=(10, 6), dpi=300)
for i in range(32):
plt.subplot(4, 8, i+1)
plt.imshow(feature_maps[0][:, :, i], cmap='gray')
plt.axis('off')
plt.suptitle("Feature Maps")
plt.tight_layout()
The argument include_top=False
exclude the last dense layers, the layers that take the last decision. In Keras terminology, the “top” is the set of final layers of the network.
If include_top=True
, Keras includes the dense layer in the pretrained network.
We can see that the layer named 'predictions' is a Dense layer with 1000 units. This network can classify 1000 different classes, but we need to classify 2 classes. Therefore, we need to replace this last layer with our own Dense layer with one unit.
Before to continue, we need to freeze the weights of the pretrained network, if we don’t freeze this weight this technique receive the name of “Fine-Tuning”.
base_model.trainable = False
Now, we can continue:
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras import Model, Input
inputs = Input(shape=(200, 200, 3))
base = base_model(inputs, training=False)
conv = GlobalAveragePooling2D()(base)
outputs = Dense(1, activation='sigmoid')(conv)
model = Model(inputs=inputs, outputs=outputs)