> 文档中心 > BP Network mnist手写数据集 基于sklearn

BP Network mnist手写数据集 基于sklearn

bp神经网络利用sklearn自带mnist手写数据集实现对数字的识别

首先导入数据

from sklearn.neural_network import MLPClassifierfrom sklearn.datasets import load_digitsfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_reportimport matplotlib.pyplot as pltdigits = load_digits()x_data = digits.datay_data = digits.target##print(x_data.shape)##print(y_data.shape)y=digits.data[1].reshape([8, 8])

拆分训练集和测试

x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)

构建模型

#构建模型,2个隐藏层,第一个隐藏层有100个神经元,第2隐藏层50个神经元,训练500周期

mlp = MLPClassifier(hidden_layer_sizes=(100,50), max_iter=500)

训练fit

mlp.fit(x_train,y_train)

测试集展示

#测试集准确率的评估predictions = mlp.predict(x_test)   for i in range(1, 21): plt.subplot(4,5, i)  #划分成2行5列 plt.imshow(x_test[i - 1].reshape([8, 8]), cmap=plt.cm.gray_r) plt.text(60, 3, str(predictions[i-1])) #在图片的任意位置添加文本 plt.xticks([]) #认为设置坐标轴显示的刻度值 plt.yticks([]) plt.rcParams['savefig.dpi'] = 128 #图片像素 plt.rcParams['figure.dpi'] = 128 #分辨率 plt.subplots_adjust(bottom=0.10,top=1.5)plt.show()

 

 效果如下:

 

完整代码:

from sklearn.neural_network import MLPClassifierfrom sklearn.datasets import load_digitsfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_reportimport matplotlib.pyplot as pltdigits = load_digits()x_data = digits.datay_data = digits.target##print(x_data.shape)##print(y_data.shape)y=digits.data[1].reshape([8, 8])#数据拆分x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)#构建模型,2个隐藏层,第一个隐藏层有100个神经元,第2隐藏层50个神经元,训练500周期mlp = MLPClassifier(hidden_layer_sizes=(100,50), max_iter=500)mlp.fit(x_train,y_train)#测试集准确率的评估predictions = mlp.predict(x_test)   for i in range(1, 21): plt.subplot(4,5, i)  #划分成2行5列 plt.imshow(x_test[i - 1].reshape([8, 8]), cmap=plt.cm.gray_r) plt.text(60, 3, str(predictions[i-1])) #在图片的任意位置添加文本 plt.xticks([]) #认为设置坐标轴显示的刻度值 plt.yticks([]) plt.rcParams['savefig.dpi'] = 128 #图片像素 plt.rcParams['figure.dpi'] = 128 #分辨率 plt.subplots_adjust(bottom=0.10,top=1.5)plt.show()##print(classification_report(y_test, predictions))

 

 

 

唱吧电脑版