全部开发者教程

TensorFlow 入门教程

首页 慕课教程 TensorFlow 入门教程 TensorFlow 入门教程 在 TensorFlow 之中进行图像分割

在 TensorFlow 之中进行图像分割

在之前的学习之中,对于图像数据,我们进行过分类等一些常见的任务;这节课我们便来学习一下对于图像数据的另外一种任务:图像分割。

1. 什么是图像分割

图像分割,顾名思义,就是对图像数据进行分割,而分类的物体一般是我们认为进行指定的。比如物品分割、人脸分割、医学病灶分割等。

举个例子,如下图所示,原来的图像是一个马路的图片,通过图像分割,我们会按照不同的物体进行不同的分割,比如车分为一类、人分为一类、建筑分为一类、马路分为一类等。

图片描述

图像分割是很多任务的前提,有很多的任务只有进行了有效的分割之后才能进行有效的处理,比如:

  • 医学病灶识别;
  • 人脸情绪识别;
  • 路况检测;
  • 自动驾驶;
  • 等等。

2. 如何进行图像分割

图像分割看上去是一个很复杂的任务,但是实现起来的原理却是非常简单,具体来说分为以下几步:

  • 确定要分类的类别,比如,我们可以将图片中所有的物体分割为 10 类,包括车、人等;
  • 对于每个像素点进行数字分类,数字分类的类别数量对应于上述的类别,这里是 10 ;
  • 将每个数字类别对应于分类的类别,比如 0 代表车、1 代表人。

可以看出,图像分割任务其实就是一个分类任务,只不过是对于每个像素点进行分类,也就是确定每个像素点所对应的类别。

在这节课之中,我们会使用图像分割的基础数据集:oxford_iiit_pet 图像分割数据集来进行演示。与此同时,我们也会采用之前学习到的迁移学习的方式来进行模型的构建,从而完成图像分割的任务。

3. 使用 TensorFlow 进行图像分割的程序示例

在 oxford_iiit_pet 之中,所有的图片都是宠物,我们的任务是将图片中的宠物分割出来,所有的像素点都被分为三类

  • 1: 对应于宠物的一部分;
  • 2: 对应于宠物的边界;
  • 3: 不属于宠物的一部分。

在这里,我们使用代码有一部分来自 TensorFlow 官方的一个例子,这个例子非常的简单易懂,作为图像分割任务的入门是再适合不过的了。

我们会逐步进行代码的解释与理解,从而帮助大家学习图像分割的任务的特点。

1. 首先我们获取数据集

import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

这里会下载数据集,因为是图片数据,因此数据集相对比较大。

2. 定义归一化处理函数

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  return input_image, input_mask

它接收两个参数,第一个参数是图片,我们会将其归一化到 [0, 1] ,第二个参数是图像的标签。

3. 构建数据集

def load_image_train(data):
  input_image = tf.image.resize(data['image'], (128, 128))
  input_mask = tf.image.resize(data['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

def load_image_test(data):
  input_image = tf.image.resize(data['image'], (128, 128))
  input_mask = tf.image.resize(data['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

num_examples = info.splits['train'].num_examples
BATCH = 64
step_per_epch = num_examples // BATCH

train = dataset['train'].map(load_image_train)
test = dataset['test'].map(load_image_test)

train_dataset = train.cache().shuffle(1000).batch(BATCH).repeat()
test_dataset = test.batch(BATCH)

在构建数据集函数之中,我们做了两件事情:

  • 将图像与标签重新调整大小到 [128, 128] ;
  • 将数据归一化。

然后我们进行了分批的处理,这里取批次的大小为 64 ,大家可以根据自己的内存或现存大小灵活调整。

4. 构建网络模型


output_channels = 3

# 获取基础模型
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# 定义要使用其输出的基础模型网络层
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# 创建特征提取模型
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

# 进行降频采样
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

# 定义UNet网络模型
def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])
  x = inputs

  # 在模型中降频取样
  skips = down_stack(x)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # 升频取样然后建立跳跃连接
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # 这是模型的最后一层
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(output_channels)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

在这里,我们首先得到了一个预训练的 MobileNetV2 用于特征提取,在这里我们并没有包含它的输出层,因为我们要根据自己的任务灵活调节。

然后定义了我们要使用的 MobileNetV2 的网络层的输出,我们使用这些输出来作为我们提取的特征。

然后我们定义了我们的网络模型,这个模型的理解有些困难,大家可能不用详细了解网络的具体原理。大家只需要知道,这个网络大致经过的步骤包括:

  • 先将数据压缩(便于数据的处理)
  • 然后进行数据的处理
  • 最后将数据解压返回到原来的大小,从而完成网络的任务

最后我们编译该模型,我们使用 adam 优化器,交叉熵损失函数(因为图像分割是个分类任务)。

5. 模型的训练

epoch = 20
valid_steps = info.splits['test'].num_examples//BATCH

model_history = model.fit(train_dataset, epochs=epoch,
                          steps_per_epoch=step_per_epch,
                          validation_steps=valid_steps,
                          validation_data=test_dataset)

loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

这边就是一个简单的训练过程,我们可以得到如下输出:

Epoch 1/20
57/57 [==============================] - 296s 5s/step - loss: 0.4928 - accuracy: 0.7995 - val_loss: 0.6747 - val_accuracy: 0.7758
......
Epoch 20/20
57/57 [==============================] - 276s 5s/step - loss: 0.2586 - accuracy: 0.9218 - val_loss: 0.2821 - val_accuracy: 0.9148

我们可以看到我们最后达到了 91% 的准确率,还是一个可以接受的结果。

感兴趣的同学可以尝试一下进行结果的可视化,从而更加直观的查看到结果。

4. 小结

在这节课之中,我们学习了什么是图像分割,同时了解了图像分割的简单的实现方式,最终我们通过一个示例来了解了如何在 TensorFlow 之中进行图像分割。