本文共 2155 字,大约阅读时间需要 7 分钟。
1,基本概念
MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例。而TensorFlow的封装让使用MNIST数据集变得更加方便。MNIST数据集是NIST数据集的一个子集,MNIST 数据集可在 获取, 它包含了四个部分: (1)Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本) (2)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签) (3)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本) (4)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签) 2,代码解读import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_data#下载mnist数据集并查看大小mnist = input_data.read_data_sets('data',one_hot = True)print("type of mnist is %s" %(type(mnist)))print("number of trian data is %d" %(mnist.train.num_examples))#训练集样本个数print("number of test data is %d" %(mnist.test.num_examples))#测试集样本个数#mnist数据集具体信息(样本,标签)trainimg = mnist.train.imagestrainlabel = mnist.train.labelstestimg = mnist.test.imagestestlabel = mnist.test.labelsprint("type of trainimg is %s" %(type(trainimg))) #训练集样本的类型print("type of trainlabel is %s" %(type(trainlabel)))#训练标签的类型print("type of testimg is %s" %(type(testimg)))print("type of testlabel is %s" %(type(testlabel)))print("shape of trainimg is %s" %(trainimg.shape,))#训练样本的个数,单个样本像素点的个数(28*28)print("shape of trainlabel is %s" %(trainlabel.shape,))#训练标签的个数,单个样本可能类别的个数(10)print("shape of testimg is %s" %(testimg.shape,))print("shape of testlabel is %s" %(testlabel.shape,))
运行结果:
#展示训练集样本的实例nsample=3randidx = np.random.randint(trainimg.shape[0],size=nsample)for i in randidx: cur_img = np.reshape(trainimg[i,:],(28,28)) cur_label = np.argmax(trainlabel[i,:]) plt.matshow(cur_img,cmap = plt.get_cmap('gray')) plt.title(""+str(i)+"th Training Data"+"Label is"+str(cur_label)) plt.show()
运行结果:
#Batch Learningbatch_size = 128 #batch大小batch_xs, batch_ys = mnist.train.next_batch(batch_size) #训练集batch中样本的个数(batch_xs),相应标签的个数(batch_ys)print("type of batch_xs is %s" %(type(batch_xs)))print("type of batch_ys is %s" %(type(batch_ys)))print("shape of batch_xs is %s" %(batch_xs.shape,))print("shape of batch_ys is %s" %(batch_ys.shape,))
运行结果:
转载地址:http://bohwi.baihongyu.com/