使用keras对mnist数据集做分类

来源:互联网 发布:计算机三级数据库 编辑:程序博客网 时间:2024/06/02 15:01

只贴代码:

原始数据集下载:http://yann.lecun.com/exdb/mnist/代码说明:http://keras.io/getting-started/sequential-model-guide/#examples
"""@version:@author: vinsin@license: Apache Licence@software: PyCharm@file: test_keras.py@time: 16-7-19 下午4:53"""def load_mnist(path, kind='train'):    """Load MNIST data from `path`"""    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)    with open(labels_path, 'rb') as lbpath:        magic, n = struct.unpack('>II', lbpath.read(8))        labels = np.fromfile(lbpath, dtype=np.uint8)    with open(images_path, 'rb') as imgpath:        magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)    return images, labelsX_train, y_train = load_mnist('../data', kind='train')print('Rows: %d, columns: %d' % (X_train.shape[0], X_train.shape[1]))X_test, y_test = load_mnist('../data', kind='t10k')print('Rows: %d, columns: %d' % (X_test.shape[0], X_test.shape[1]))import theanotheano.config.floatX = 'float32'X_train = X_train.astype(theano.config.floatX)X_test = X_test.astype(theano.config.floatX)from keras.utils import np_utilsprint('First 3 data: ', X_train[:3])print('First 3 labels: ', y_train[:3])y_train_ohe = np_utils.to_categorical(y_train)print('First 3 labels (one-hot):', y_train_ohe[:3])from keras.models import Sequentialfrom keras.layers.core import Densefrom keras.optimizers import SGDnp.random.seed(1)model = Sequential()model.add(Dense(input_dim=X_train.shape[1],                output_dim=50,                init='uniform',                activation='tanh'))model.add(Dense(input_dim=50,                output_dim=50,                init='uniform',                activation='tanh'))model.add(Dense(input_dim=50,                output_dim=y_train_ohe.shape[1],                init='uniform',                activation='softmax'))sgd = SGD(lr=0.001, decay=1e-7, momentum=.9)model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])model.fit(X_train,          y_train_ohe,          nb_epoch=50,          batch_size=300,          verbose=1,          validation_split=0.1)y_train_pred = model.predict_classes(X_train, verbose=0)print('First 3 predictions: ', y_train_pred[:3])train_acc = np.sum(y_train == y_train_pred, axis=0) / X_train.shape[0]print('Training accuracy: %.2f%%' % (train_acc * 100))y_test_pred = model.predict_classes(X_test, verbose=0)test_acc = np.sum(y_test == y_test_pred, axis=0) / X_test.shape[0]print('Test accuracy: %.2f%%' % (test_acc * 100))

输出的日志信息如下:
Rows: 60000, columns: 784Rows: 10000, columns: 784Using Theano backend.First 3 data:  [[ 0.  0.  0. ...,  0.  0.  0.] [ 0.  0.  0. ...,  0.  0.  0.] [ 0.  0.  0. ...,  0.  0.  0.]]First 3 labels:  [5 0 4]First 3 labels (one-hot): [[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.] [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.] [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]]Train on 54000 samples, validate on 6000 samplesEpoch 1/5054000/54000 [==============================] - 1s - loss: 2.2290 - acc: 0.3594 - val_loss: 2.1092 - val_acc: 0.5343Epoch 2/5054000/54000 [==============================] - 1s - loss: 1.8848 - acc: 0.5272 - val_loss: 1.6081 - val_acc: 0.5518Epoch 3/5054000/54000 [==============================] - 1s - loss: 1.3900 - acc: 0.5886 - val_loss: 1.1653 - val_acc: 0.6645Epoch 4/5054000/54000 [==============================] - 1s - loss: 1.0614 - acc: 0.6958 - val_loss: 0.9047 - val_acc: 0.7703Epoch 5/5054000/54000 [==============================] - 1s - loss: 0.8566 - acc: 0.7734 - val_loss: 0.7286 - val_acc: 0.8297Epoch 6/5054000/54000 [==============================] - 1s - loss: 0.7208 - acc: 0.8192 - val_loss: 0.6151 - val_acc: 0.8647Epoch 7/5054000/54000 [==============================] - 1s - loss: 0.6218 - acc: 0.8495 - val_loss: 0.5358 - val_acc: 0.8845Epoch 8/5054000/54000 [==============================] - 1s - loss: 0.5518 - acc: 0.8664 - val_loss: 0.4613 - val_acc: 0.8970Epoch 9/5054000/54000 [==============================] - 1s - loss: 0.4973 - acc: 0.8773 - val_loss: 0.4200 - val_acc: 0.9015Epoch 10/5054000/54000 [==============================] - 1s - loss: 0.4537 - acc: 0.8867 - val_loss: 0.3881 - val_acc: 0.9102Epoch 11/5054000/54000 [==============================] - 1s - loss: 0.4244 - acc: 0.8926 - val_loss: 0.3479 - val_acc: 0.9160Epoch 12/5054000/54000 [==============================] - 1s - loss: 0.3913 - acc: 0.8990 - val_loss: 0.3307 - val_acc: 0.9163Epoch 13/5054000/54000 [==============================] - 1s - loss: 0.3747 - acc: 0.9008 - val_loss: 0.3198 - val_acc: 0.9200Epoch 14/5054000/54000 [==============================] - 1s - loss: 0.3559 - acc: 0.9042 - val_loss: 0.3076 - val_acc: 0.9178Epoch 15/5054000/54000 [==============================] - 1s - loss: 0.3416 - acc: 0.9092 - val_loss: 0.2982 - val_acc: 0.9228Epoch 16/5054000/54000 [==============================] - 1s - loss: 0.3351 - acc: 0.9092 - val_loss: 0.2883 - val_acc: 0.9222Epoch 17/5054000/54000 [==============================] - 1s - loss: 0.3254 - acc: 0.9100 - val_loss: 0.2785 - val_acc: 0.9257Epoch 18/5054000/54000 [==============================] - 1s - loss: 0.3134 - acc: 0.9136 - val_loss: 0.2789 - val_acc: 0.9270Epoch 19/5054000/54000 [==============================] - 1s - loss: 0.3087 - acc: 0.9141 - val_loss: 0.2524 - val_acc: 0.9295Epoch 20/5054000/54000 [==============================] - 1s - loss: 0.3067 - acc: 0.9135 - val_loss: 0.2622 - val_acc: 0.9247Epoch 21/5054000/54000 [==============================] - 1s - loss: 0.2891 - acc: 0.9188 - val_loss: 0.2459 - val_acc: 0.9338Epoch 22/5054000/54000 [==============================] - 1s - loss: 0.2813 - acc: 0.9193 - val_loss: 0.2438 - val_acc: 0.9332Epoch 23/5054000/54000 [==============================] - 1s - loss: 0.2826 - acc: 0.9204 - val_loss: 0.2444 - val_acc: 0.9335Epoch 24/5054000/54000 [==============================] - 1s - loss: 0.2713 - acc: 0.9228 - val_loss: 0.2208 - val_acc: 0.9400Epoch 25/5054000/54000 [==============================] - 1s - loss: 0.2680 - acc: 0.9232 - val_loss: 0.2231 - val_acc: 0.9395Epoch 26/5054000/54000 [==============================] - 1s - loss: 0.2630 - acc: 0.9247 - val_loss: 0.2298 - val_acc: 0.9347Epoch 27/5054000/54000 [==============================] - 1s - loss: 0.2635 - acc: 0.9242 - val_loss: 0.2280 - val_acc: 0.9345Epoch 28/5054000/54000 [==============================] - 1s - loss: 0.2527 - acc: 0.9286 - val_loss: 0.2105 - val_acc: 0.9417Epoch 29/5054000/54000 [==============================] - 1s - loss: 0.2460 - acc: 0.9296 - val_loss: 0.2020 - val_acc: 0.9465Epoch 30/5054000/54000 [==============================] - 1s - loss: 0.2466 - acc: 0.9278 - val_loss: 0.2037 - val_acc: 0.9425Epoch 31/5054000/54000 [==============================] - 1s - loss: 0.2373 - acc: 0.9307 - val_loss: 0.2058 - val_acc: 0.9425Epoch 32/5054000/54000 [==============================] - 1s - loss: 0.2483 - acc: 0.9278 - val_loss: 0.2157 - val_acc: 0.9347Epoch 33/5054000/54000 [==============================] - 1s - loss: 0.2503 - acc: 0.9291 - val_loss: 0.2073 - val_acc: 0.9452Epoch 34/5054000/54000 [==============================] - 1s - loss: 0.2394 - acc: 0.9305 - val_loss: 0.2097 - val_acc: 0.9413Epoch 35/5054000/54000 [==============================] - 1s - loss: 0.2341 - acc: 0.9310 - val_loss: 0.1993 - val_acc: 0.9423Epoch 36/5054000/54000 [==============================] - 1s - loss: 0.2316 - acc: 0.9323 - val_loss: 0.2063 - val_acc: 0.9412Epoch 37/5054000/54000 [==============================] - 1s - loss: 0.2300 - acc: 0.9330 - val_loss: 0.1971 - val_acc: 0.9432Epoch 38/5054000/54000 [==============================] - 1s - loss: 0.2259 - acc: 0.9338 - val_loss: 0.1969 - val_acc: 0.9440Epoch 39/5054000/54000 [==============================] - 1s - loss: 0.2211 - acc: 0.9356 - val_loss: 0.1933 - val_acc: 0.9442Epoch 40/5054000/54000 [==============================] - 1s - loss: 0.2233 - acc: 0.9345 - val_loss: 0.2032 - val_acc: 0.9432Epoch 41/5054000/54000 [==============================] - 1s - loss: 0.2238 - acc: 0.9331 - val_loss: 0.1878 - val_acc: 0.9468Epoch 42/5054000/54000 [==============================] - 1s - loss: 0.2141 - acc: 0.9383 - val_loss: 0.1871 - val_acc: 0.9463Epoch 43/5054000/54000 [==============================] - 1s - loss: 0.2135 - acc: 0.9380 - val_loss: 0.1871 - val_acc: 0.9475Epoch 44/5054000/54000 [==============================] - 1s - loss: 0.2135 - acc: 0.9384 - val_loss: 0.1826 - val_acc: 0.9488Epoch 45/5054000/54000 [==============================] - 1s - loss: 0.2073 - acc: 0.9402 - val_loss: 0.1847 - val_acc: 0.9468Epoch 46/5054000/54000 [==============================] - 1s - loss: 0.2120 - acc: 0.9371 - val_loss: 0.1765 - val_acc: 0.9498Epoch 47/5054000/54000 [==============================] - 1s - loss: 0.2059 - acc: 0.9408 - val_loss: 0.1786 - val_acc: 0.9502Epoch 48/5054000/54000 [==============================] - 1s - loss: 0.1995 - acc: 0.9419 - val_loss: 0.1793 - val_acc: 0.9483Epoch 49/5054000/54000 [==============================] - 1s - loss: 0.2044 - acc: 0.9414 - val_loss: 0.1781 - val_acc: 0.9450Epoch 50/5054000/54000 [==============================] - 1s - loss: 0.2047 - acc: 0.9403 - val_loss: 0.1987 - val_acc: 0.9403First 3 predictions:  [5 0 4]Training accuracy: 93.31%Test accuracy: 92.47%




0 0
原创粉丝点击