在 60,000(训练)和 26,000(测试)上使用以下 TF .9.0rc0 和 145 个编码列 (1,0) 来预测 1 或 0 以进行类识别。
classifier_TensorFlow = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],n_classes=2, steps=100)
classifier_TensorFlow.fit(X_train, y_train.ravel())
我得到:
WARNING:tensorflow:TensorFlowDNNClassifier class is deprecated. Please consider using DNNClassifier as an alternative.
Out[34]:TensorFlowDNNClassifier(steps=100, batch_size=32)
然后很快就得到了很好的结果:
score = metrics.accuracy_score(y_test, classifier_TensorFlow.predict(X_test))
print('Accuracy: {0:f}'.format(score))
Accuracy: 0.923121
和:
print (metrics.confusion_matrix(y_test, X_pred_class))
[[23996 103]
[ 1992 15]]
但是当我尝试使用新的建议方法时:
classifier_TensorFlow = learn.DNNClassifier(hidden_units=[10, 20, 10],n_classes=2)
它挂起没有完成?它不会采用“步骤”参数吗?我没有收到任何错误消息或输出,所以没有太多可继续的...有什么想法或提示吗?文档有点“轻?”