我正在使用 TF v0.9 构建基于 skflow 的 DNN 预测(0 或 1)模型。我的代码TensorFlowDNNClassifier
是这样的。我训练了大约 26,000 条记录并测试了 6,500 条记录。
classifier = learn.TensorFlowDNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
test_pred = classifier.predict(test_features)
print(classification_report(test_labels, test_pred))
大约需要 1 分钟并得到结果。
precision recall f1-score support
0 0.77 0.92 0.84 4265
1 0.75 0.47 0.58 2231
avg / total 0.76 0.76 0.75 6496
但我得到了
WARNING:tensorflow:TensorFlowDNNClassifier class is deprecated.
Please consider using DNNClassifier as an alternative.
所以我简单地更新了我的代码DNNClassifier
。
classifier = learn.DNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
它也很好用。但结果不一样。
precision recall f1-score support
0 0.77 0.96 0.86 4265
1 0.86 0.45 0.59 2231
avg / total 0.80 0.79 0.76 6496
1
的精度提高了。当然这对我来说是一件好事,但为什么会有所改进呢?大约需要2个小时。
这比前面的例子慢了大约 120 倍。
我有什么问题吗?或错过一些参数?或者DNNClassifier
TF v0.9 不稳定?