使用keras对mnist数据集做分类
来源:互联网 发布:计算机三级数据库 编辑:程序博客网 时间:2024/06/02 15:01
只贴代码:
原始数据集下载:http://yann.lecun.com/exdb/mnist/代码说明:http://keras.io/getting-started/sequential-model-guide/#examples
"""@version:@author: vinsin@license: Apache Licence@software: PyCharm@file: test_keras.py@time: 16-7-19 下午4:53"""def load_mnist(path, kind='train'): """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) return images, labelsX_train, y_train = load_mnist('../data', kind='train')print('Rows: %d, columns: %d' % (X_train.shape[0], X_train.shape[1]))X_test, y_test = load_mnist('../data', kind='t10k')print('Rows: %d, columns: %d' % (X_test.shape[0], X_test.shape[1]))import theanotheano.config.floatX = 'float32'X_train = X_train.astype(theano.config.floatX)X_test = X_test.astype(theano.config.floatX)from keras.utils import np_utilsprint('First 3 data: ', X_train[:3])print('First 3 labels: ', y_train[:3])y_train_ohe = np_utils.to_categorical(y_train)print('First 3 labels (one-hot):', y_train_ohe[:3])from keras.models import Sequentialfrom keras.layers.core import Densefrom keras.optimizers import SGDnp.random.seed(1)model = Sequential()model.add(Dense(input_dim=X_train.shape[1], output_dim=50, init='uniform', activation='tanh'))model.add(Dense(input_dim=50, output_dim=50, init='uniform', activation='tanh'))model.add(Dense(input_dim=50, output_dim=y_train_ohe.shape[1], init='uniform', activation='softmax'))sgd = SGD(lr=0.001, decay=1e-7, momentum=.9)model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])model.fit(X_train, y_train_ohe, nb_epoch=50, batch_size=300, verbose=1, validation_split=0.1)y_train_pred = model.predict_classes(X_train, verbose=0)print('First 3 predictions: ', y_train_pred[:3])train_acc = np.sum(y_train == y_train_pred, axis=0) / X_train.shape[0]print('Training accuracy: %.2f%%' % (train_acc * 100))y_test_pred = model.predict_classes(X_test, verbose=0)test_acc = np.sum(y_test == y_test_pred, axis=0) / X_test.shape[0]print('Test accuracy: %.2f%%' % (test_acc * 100))
输出的日志信息如下:
Rows: 60000, columns: 784Rows: 10000, columns: 784Using Theano backend.First 3 data: [[ 0. 0. 0. ..., 0. 0. 0.] [ 0. 0. 0. ..., 0. 0. 0.] [ 0. 0. 0. ..., 0. 0. 0.]]First 3 labels: [5 0 4]First 3 labels (one-hot): [[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]Train on 54000 samples, validate on 6000 samplesEpoch 1/5054000/54000 [==============================] - 1s - loss: 2.2290 - acc: 0.3594 - val_loss: 2.1092 - val_acc: 0.5343Epoch 2/5054000/54000 [==============================] - 1s - loss: 1.8848 - acc: 0.5272 - val_loss: 1.6081 - val_acc: 0.5518Epoch 3/5054000/54000 [==============================] - 1s - loss: 1.3900 - acc: 0.5886 - val_loss: 1.1653 - val_acc: 0.6645Epoch 4/5054000/54000 [==============================] - 1s - loss: 1.0614 - acc: 0.6958 - val_loss: 0.9047 - val_acc: 0.7703Epoch 5/5054000/54000 [==============================] - 1s - loss: 0.8566 - acc: 0.7734 - val_loss: 0.7286 - val_acc: 0.8297Epoch 6/5054000/54000 [==============================] - 1s - loss: 0.7208 - acc: 0.8192 - val_loss: 0.6151 - val_acc: 0.8647Epoch 7/5054000/54000 [==============================] - 1s - loss: 0.6218 - acc: 0.8495 - val_loss: 0.5358 - val_acc: 0.8845Epoch 8/5054000/54000 [==============================] - 1s - loss: 0.5518 - acc: 0.8664 - val_loss: 0.4613 - val_acc: 0.8970Epoch 9/5054000/54000 [==============================] - 1s - loss: 0.4973 - acc: 0.8773 - val_loss: 0.4200 - val_acc: 0.9015Epoch 10/5054000/54000 [==============================] - 1s - loss: 0.4537 - acc: 0.8867 - val_loss: 0.3881 - val_acc: 0.9102Epoch 11/5054000/54000 [==============================] - 1s - loss: 0.4244 - acc: 0.8926 - val_loss: 0.3479 - val_acc: 0.9160Epoch 12/5054000/54000 [==============================] - 1s - loss: 0.3913 - acc: 0.8990 - val_loss: 0.3307 - val_acc: 0.9163Epoch 13/5054000/54000 [==============================] - 1s - loss: 0.3747 - acc: 0.9008 - val_loss: 0.3198 - val_acc: 0.9200Epoch 14/5054000/54000 [==============================] - 1s - loss: 0.3559 - acc: 0.9042 - val_loss: 0.3076 - val_acc: 0.9178Epoch 15/5054000/54000 [==============================] - 1s - loss: 0.3416 - acc: 0.9092 - val_loss: 0.2982 - val_acc: 0.9228Epoch 16/5054000/54000 [==============================] - 1s - loss: 0.3351 - acc: 0.9092 - val_loss: 0.2883 - val_acc: 0.9222Epoch 17/5054000/54000 [==============================] - 1s - loss: 0.3254 - acc: 0.9100 - val_loss: 0.2785 - val_acc: 0.9257Epoch 18/5054000/54000 [==============================] - 1s - loss: 0.3134 - acc: 0.9136 - val_loss: 0.2789 - val_acc: 0.9270Epoch 19/5054000/54000 [==============================] - 1s - loss: 0.3087 - acc: 0.9141 - val_loss: 0.2524 - val_acc: 0.9295Epoch 20/5054000/54000 [==============================] - 1s - loss: 0.3067 - acc: 0.9135 - val_loss: 0.2622 - val_acc: 0.9247Epoch 21/5054000/54000 [==============================] - 1s - loss: 0.2891 - acc: 0.9188 - val_loss: 0.2459 - val_acc: 0.9338Epoch 22/5054000/54000 [==============================] - 1s - loss: 0.2813 - acc: 0.9193 - val_loss: 0.2438 - val_acc: 0.9332Epoch 23/5054000/54000 [==============================] - 1s - loss: 0.2826 - acc: 0.9204 - val_loss: 0.2444 - val_acc: 0.9335Epoch 24/5054000/54000 [==============================] - 1s - loss: 0.2713 - acc: 0.9228 - val_loss: 0.2208 - val_acc: 0.9400Epoch 25/5054000/54000 [==============================] - 1s - loss: 0.2680 - acc: 0.9232 - val_loss: 0.2231 - val_acc: 0.9395Epoch 26/5054000/54000 [==============================] - 1s - loss: 0.2630 - acc: 0.9247 - val_loss: 0.2298 - val_acc: 0.9347Epoch 27/5054000/54000 [==============================] - 1s - loss: 0.2635 - acc: 0.9242 - val_loss: 0.2280 - val_acc: 0.9345Epoch 28/5054000/54000 [==============================] - 1s - loss: 0.2527 - acc: 0.9286 - val_loss: 0.2105 - val_acc: 0.9417Epoch 29/5054000/54000 [==============================] - 1s - loss: 0.2460 - acc: 0.9296 - val_loss: 0.2020 - val_acc: 0.9465Epoch 30/5054000/54000 [==============================] - 1s - loss: 0.2466 - acc: 0.9278 - val_loss: 0.2037 - val_acc: 0.9425Epoch 31/5054000/54000 [==============================] - 1s - loss: 0.2373 - acc: 0.9307 - val_loss: 0.2058 - val_acc: 0.9425Epoch 32/5054000/54000 [==============================] - 1s - loss: 0.2483 - acc: 0.9278 - val_loss: 0.2157 - val_acc: 0.9347Epoch 33/5054000/54000 [==============================] - 1s - loss: 0.2503 - acc: 0.9291 - val_loss: 0.2073 - val_acc: 0.9452Epoch 34/5054000/54000 [==============================] - 1s - loss: 0.2394 - acc: 0.9305 - val_loss: 0.2097 - val_acc: 0.9413Epoch 35/5054000/54000 [==============================] - 1s - loss: 0.2341 - acc: 0.9310 - val_loss: 0.1993 - val_acc: 0.9423Epoch 36/5054000/54000 [==============================] - 1s - loss: 0.2316 - acc: 0.9323 - val_loss: 0.2063 - val_acc: 0.9412Epoch 37/5054000/54000 [==============================] - 1s - loss: 0.2300 - acc: 0.9330 - val_loss: 0.1971 - val_acc: 0.9432Epoch 38/5054000/54000 [==============================] - 1s - loss: 0.2259 - acc: 0.9338 - val_loss: 0.1969 - val_acc: 0.9440Epoch 39/5054000/54000 [==============================] - 1s - loss: 0.2211 - acc: 0.9356 - val_loss: 0.1933 - val_acc: 0.9442Epoch 40/5054000/54000 [==============================] - 1s - loss: 0.2233 - acc: 0.9345 - val_loss: 0.2032 - val_acc: 0.9432Epoch 41/5054000/54000 [==============================] - 1s - loss: 0.2238 - acc: 0.9331 - val_loss: 0.1878 - val_acc: 0.9468Epoch 42/5054000/54000 [==============================] - 1s - loss: 0.2141 - acc: 0.9383 - val_loss: 0.1871 - val_acc: 0.9463Epoch 43/5054000/54000 [==============================] - 1s - loss: 0.2135 - acc: 0.9380 - val_loss: 0.1871 - val_acc: 0.9475Epoch 44/5054000/54000 [==============================] - 1s - loss: 0.2135 - acc: 0.9384 - val_loss: 0.1826 - val_acc: 0.9488Epoch 45/5054000/54000 [==============================] - 1s - loss: 0.2073 - acc: 0.9402 - val_loss: 0.1847 - val_acc: 0.9468Epoch 46/5054000/54000 [==============================] - 1s - loss: 0.2120 - acc: 0.9371 - val_loss: 0.1765 - val_acc: 0.9498Epoch 47/5054000/54000 [==============================] - 1s - loss: 0.2059 - acc: 0.9408 - val_loss: 0.1786 - val_acc: 0.9502Epoch 48/5054000/54000 [==============================] - 1s - loss: 0.1995 - acc: 0.9419 - val_loss: 0.1793 - val_acc: 0.9483Epoch 49/5054000/54000 [==============================] - 1s - loss: 0.2044 - acc: 0.9414 - val_loss: 0.1781 - val_acc: 0.9450Epoch 50/5054000/54000 [==============================] - 1s - loss: 0.2047 - acc: 0.9403 - val_loss: 0.1987 - val_acc: 0.9403First 3 predictions: [5 0 4]Training accuracy: 93.31%Test accuracy: 92.47%
0 0
- 使用keras对mnist数据集做分类
- keras加载MNIST数据集方法
- 使用Keras构建神经网络进行Mnist手写字体分类
- [深度学习框架] Keras上使用神经网络进行mnist分类
- [深度学习框架] Keras上使用CNN进行mnist分类
- [深度学习框架] Keras上使用RNN进行mnist分类
- 使用Keras面向小数据集进行图像分类
- 使用逻辑回归对MNIST数字分类
- 使用Keras搭建一个CNN处理MNIST数据
- 用keras实验mnist数据
- keras下基于mnist数据集的cnn
- 使用libsvm对MNIST数据集进行实验
- 使用KNN对MNIST数据集进行实验
- 使用Decision Tree对MNIST数据集进行实验
- 使用tensorflow对Mnist数据集进行字体识别
- Keras(2):使用Keras构建神经网络进行Mnist手写字体分类,并定性分析各种超参数的影响
- keras IMDB数据集 LSTM分类
- 用RNN做MNIST分类
- 根视图隐藏导航栏,子视图返回时没有出现导航栏
- Android系统启动时间(不是系统当前时间)的获取
- Unity图片加载器
- JdbcTemplate 简介
- 单个Tomcat配置多个域并配置多个证书
- 使用keras对mnist数据集做分类
- ScrollView嵌套ListView,listItem.measure(0,0);报空指针异常NullPointerException
- Blobs, Layers, and Nets: anatomy of a Caffe model
- Android 跳转硬件公众号
- 在windows、linux中开启nginx的Gzip压缩大大提高页面、图片加载速度
- 30s倒计时
- JavaScript中的setter和getter方法!!
- /proc/devices:NO entry for device-mapper found grub 语法错误解决方案
- NSURLSession