
本文共 2457 字,大约阅读时间需要 8 分钟。
3.5 图像分类数据集(Fashion-MNIST)
在介绍Softmax回归的实现之前,我们引入一个多类图像分类数据集。这个数据集将在后续章节中被多次使用,以便我们比较不同算法在模型精度和计算效率上的表现。最常用的图像分类数据集是MNIST手写数字识别数据集,但大多数模型在MNIST上的分类精度通常超过95%。为了更直观地观察算法之间的差异,我们将使用一个内容更加复杂的数据集——Fashion-MNIST。
3.5.1 获取数据集
首先,我们需要导入相关的包或模块。以下是所需的代码:
import tensorflow as tffrom tensorflow.keras.datasets import fashion_mnistimport numpy as npimport timeimport sysimport matplotlib.pyplot as plt
接下来,我们通过Keras的数据集包下载这个数据集。通过参数train
指定获取训练数据集或测试数据集。测试数据集用于评价模型的表现,不用于模型训练。训练集中和测试集中每个类别的图像数分别为6000和1000,共10个类别,因此训练集和测试集的样本数分别为60000和10000。
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
我们可以通过方括号访问任意一个样本。以下是获取第一个样本的代码:
sample_feature, sample_label = x_train[0], y_train[0]
样本特征是一个28×28的二维numpy数组,每个像素的值在0到255之间(8位无符号整数)。图像标签是一个numpy标量。需要注意的是,Keras的Fashion-MNIST数据集与原书中MxNet提供的数据集有所不同。
sample_feature_shape, sample_feature_dtype = sample_feature.shape, sample_feature.dtypesample_label_type, _ = type(sample_label), sample_label.dtype
Fashion-MNIST共有10个类别,包括T恤、裤子、套衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和短靴。以下函数可以将数值标签转换为对应的文本标签:
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
以下函数可以在一行内绘制多张图像及其对应的标签:
def show_fashion_mnist(images, labels): _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.reshape((28, 28))) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
以下代码显示训练数据集中前10个样本的图像内容及其标签:
samples = []for i in range(10): samples.append(x_train[i]) samples.append(y_train[i])show_fashion_mnist(samples, get_fashion_mnist_labels(samples))
3.5.2 读取小批量数据
我们将在训练数据集上训练模型,并在测试数据集中评价模型的表现。为了提高读取速度,我们使用DataLoader
实例,它每次读取一个batch_size
的小批量数据。batch_size
是一个超参数。此外,ToTensor
实例将图像数据从uint8格式转换为32位浮点数格式,并将像素值归一化到0到1之间。ToTensor
还将图像通道从最后一维移到最前一维,以便后续介绍卷积神经网络时进行计算。transform_first
函数会将ToTensor
的变换应用到每个样本的图像部分。
batch_size = 256if sys.platform.startswith('win'): num_workers = 0 # 0表示不使用额外进程加速读取数据else: num_workers = 4train_iter = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
以下代码显示读取一遍训练数据所需的时间:
start = time.time()for X, y in train_iter: continueprint('%.2f sec' % (time.time() - start))
3.5.3 小结
- Fashion-MNIST是一个10类服饰分类数据集。
- 图像的高和宽分别为28像素,形状记为28×28。
- 图像标签是NumPy标量。
- 图像的通道顺序会影响后续操作。
发表评论
最新留言
关于作者
