有人告诉我,使用数据增强可以帮助我对从文档中提取的手写数字进行更准确的预测(所以它不在我正在使用的 MNIST 数据集中),所以我在我的模型中使用了它。但是,我很好奇我是否做得对,因为在使用数据增强之前训练集的大小是 60000,但是在添加数据增强之后它下降到每个 epoch 3750?我这样做正确吗?
在本教程的数据增强部分之后,我对其进行了调整,使其适用于我如何创建和训练我的模型。我还遗漏了目前我还不太了解的函数中的可选参数。
我正在使用的当前模型来自我从上一个问题中得到的答案之一,因为它比我为了尝试而拼凑起来的第二个模型表现更好。我认为我所做的唯一更改是让它sparse_categorical_crossentropy
代替丢失,因为我正在对数字进行分类,手写字符不能属于两类数字,嗯,对吧?
def createModel():
model = keras.models.Sequential()
# 1st conv & maxpool
model.add(keras.layers.Conv2D(40, (5, 5), padding="same", activation='relu', input_shape=(28, 28, 1)))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
# 2nd conv & maxpool
model.add(keras.layers.Conv2D(200, (3, 3), padding="same", activation='relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(1,1)))
# 3rd conv & maxpool
model.add(keras.layers.Conv2D(512, (3, 3), padding="valid", activation='relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(1,1)))
# reduces dims from 2d to 1d
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(units=100, activation='relu'))
# dropout for preventing overfitting
model.add(keras.layers.Dropout(0.5))
# final fully-connected layer
model.add(keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
训练由一个单独的函数完成,这就是我插入数据增强部分的地方:
def trainModel(file_model, epochs=5, create_new=False, show=False):
model = createModel()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1)) / 255.0
x_test = x_test.reshape((10000, 28, 28, 1)) /255.0
x_gen = np.array(x_train, copy=True)
y_gen = np.array(y_train, copy=True)
datagen = keras.preprocessing.image.ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True, rotation_range=20)
datagen.fit(x_gen)
x_train = np.concatenate((x_train, x_gen), axis=0)
y_train = np.concatenate((y_train, y_gen), axis=0)
# if prev model exists and new model isn't needed..
if (os.path.exists(file_model)==True and create_new==False):
model = keras.models.load_model(file_model)
else:
history = model.fit_generator(datagen.flow(x_train, y_train), epochs=epochs, validation_data=(x_test, y_test))
model.save(file_model)
if (show==True):
model.summary()
return model
我希望它会极大地帮助正确识别手写字符,假设它被正确使用。但我什至不确定我是否正确地做到了这一点,从而对模型的准确性做出了很大贡献。
编辑:它确实有助于识别一些提取的字符,但模型仍然没有正确提取大部分提取的字符,这让我怀疑我是否正确实现了它。