Overview

本篇我们来记录一下怎么使用TensorFlow2.0当中的Keras模块来进行CNN图片分类。

1.加载数据

我们用经典的猫狗分类数据集来做这次图片分类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os
import numpy as np
import matplotlib.pyplot as plt
 
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
 
# 下载数据
 
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
 
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

统计下载到本地的数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
 
train_cats_dir = os.path.join(train_dir, 'cats'# 训练集猫图
train_dogs_dir = os.path.join(train_dir, 'dogs'# 训练集狗图
validation_cats_dir = os.path.join(validation_dir, 'cats'# 验证集猫图
validation_dogs_dir = os.path.join(validation_dir, 'dogs'# 验证集狗图
 
num_cats_train = len(os.listdir(train_cats_dir))
num_dogs_train = len(os.listdir(train_dogs_dir))
 
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
 
total_train = num_cats_train + num_dogs_train
total_val = num_cats_val + num_dogs_val

2.处理数据

如果直接拿训练集去训练,会遇到过拟合问题,所以我们需要对训练集进行数据增强。图片的数据增强是很简单的,有水平或垂直翻转,旋转,缩放等等方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
 
train_image_generator = ImageDataGenerator(
                    rescale=1./255,
                    rotation_range=45,
                    width_shift_range=.15,
                    height_shift_range=.15,
                    horizontal_flip=True,
                    zoom_range=0.5
                    )
 
# train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)

读取图片:

1
2
3
4
5
6
7
8
9
10
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')
 
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode='binary')

3.构建模型

我们用tensorflow.keras.models.Sequential来顺序堆叠卷积神经网络,同时加上Dropout层来防止过拟合。

1
2
3
4
5
6
7
8
9
10
11
12
model = Sequential([
    Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
    MaxPooling2D(),
    Conv2D(32, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(64, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Dropout(0.2), # dropout层是用来防止过拟合的
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1)
])

查看一下网络结构:

1
model.summary()

显示如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #  
=================================================================
conv2d (Conv2D)              (None, 150, 150, 16)      448      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 75, 75, 16)        0        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 75, 75, 32)        4640     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 37, 37, 32)        0        
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 37, 37, 64)        18496    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 18, 18, 64)        0        
_________________________________________________________________
dropout (Dropout)            (None, 18, 18, 64)        0        
_________________________________________________________________
flatten (Flatten)            (None, 20736)             0        
_________________________________________________________________
dense (Dense)                (None, 512)               10617344 
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 513      
=================================================================
Total params: 10,641,441
Trainable params: 10,641,441
Non-trainable params: 0

4.编译并训练模型

1
2
3
4
5
6
7
8
9
10
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['acc'])
history = model.fit(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=total_val // batch_size
)

这里,选择优化算法时,一般情况下,无脑选择adam即可。
如果运行训练时报如下错误:

1
ImportError: Could not import PIL.Image. The use of `load_img` requires PIL.

那么安装pillow包即可:

1
pip install pillow

如果运行正常,显示如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
Epoch 1/15
15/15 [==============================] - 31s 2s/step - loss: 0.9137 - acc: 0.4995 - val_loss: 0.6909 - val_acc: 0.5011
Epoch 2/15
15/15 [==============================] - 30s 2s/step - loss: 0.6883 - acc: 0.5016 - val_loss: 0.6954 - val_acc: 0.6205
Epoch 3/15
15/15 [==============================] - 30s 2s/step - loss: 0.6933 - acc: 0.5123 - val_loss: 0.6849 - val_acc: 0.4911
Epoch 4/15
15/15 [==============================] - 30s 2s/step - loss: 0.6825 - acc: 0.5037 - val_loss: 0.6637 - val_acc: 0.6105
Epoch 5/15
15/15 [==============================] - 30s 2s/step - loss: 0.6613 - acc: 0.5716 - val_loss: 0.6391 - val_acc: 0.6529
Epoch 6/15
15/15 [==============================] - 30s 2s/step - loss: 0.6439 - acc: 0.5946 - val_loss: 0.6150 - val_acc: 0.6641
Epoch 7/15
15/15 [==============================] - 31s 2s/step - loss: 0.6196 - acc: 0.6167 - val_loss: 0.6006 - val_acc: 0.6641
Epoch 8/15
15/15 [==============================] - 31s 2s/step - loss: 0.6096 - acc: 0.6282 - val_loss: 0.6135 - val_acc: 0.6708
Epoch 9/15
15/15 [==============================] - 31s 2s/step - loss: 0.6063 - acc: 0.6325 - val_loss: 0.5719 - val_acc: 0.6908
Epoch 10/15
15/15 [==============================] - 30s 2s/step - loss: 0.6040 - acc: 0.6485 - val_loss: 0.6070 - val_acc: 0.6920
Epoch 11/15
15/15 [==============================] - 30s 2s/step - loss: 0.5804 - acc: 0.6688 - val_loss: 0.5853 - val_acc: 0.6998
Epoch 12/15
15/15 [==============================] - 30s 2s/step - loss: 0.5908 - acc: 0.6683 - val_loss: 0.5668 - val_acc: 0.6362
Epoch 13/15
15/15 [==============================] - 30s 2s/step - loss: 0.5828 - acc: 0.6667 - val_loss: 0.5698 - val_acc: 0.7054
Epoch 14/15
15/15 [==============================] - 31s 2s/step - loss: 0.5673 - acc: 0.6976 - val_loss: 0.5514 - val_acc: 0.7054
Epoch 15/15
15/15 [==============================] - 30s 2s/step - loss: 0.5627 - acc: 0.6971 - val_loss: 0.5685 - val_acc: 0.7154

查看模型训练过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
acc = history.history['acc']
val_acc = history.history['val_acc']
 
loss=history.history['loss']
val_loss=history.history['val_loss']
 
epochs_range = range(epochs)
 
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

下载.png
可以看到并没有过拟合,训练还是比较满意的。

本篇文章主要参考了TensorFlow2.0官方文档,Image classification