1

我正在使用 TF v0.9 构建基于 skflow 的 DNN 预测(0 或 1)模型。我的代码TensorFlowDNNClassifier是这样的。我训练了大约 26,000 条记录并测试了 6,500 条记录。

classifier = learn.TensorFlowDNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
test_pred = classifier.predict(test_features)
print(classification_report(test_labels, test_pred))

大约需要 1 分钟并得到结果。

             precision    recall  f1-score   support
          0       0.77      0.92      0.84      4265
          1       0.75      0.47      0.58      2231
avg / total       0.76      0.76      0.75      6496

但我得到了

WARNING:tensorflow:TensorFlowDNNClassifier class is deprecated. 
Please consider using DNNClassifier as an alternative.

所以我简单地更新了我的代码DNNClassifier

classifier = learn.DNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)

它也很好用。但结果不一样。

             precision    recall  f1-score   support
          0       0.77      0.96      0.86      4265
          1       0.86      0.45      0.59      2231
avg / total       0.80      0.79      0.76      6496

1的精度提高了。当然这对我来说是一件好事,但为什么会有所改进呢?大约需要2个小时。 这比前面的例子慢了大约 120 倍。

我有什么问题吗?或错过一些参数?或者DNNClassifierTF v0.9 不稳定?

4

1 回答 1

0

我给出与这里相同的答案。您可能会遇到这种情况,因为您使用了 steps 参数而不是max_steps。它只是 TensorFlowDNNClassifier 上的步骤,实际上做了 max_steps。现在,您可以决定是否真的需要 50000 步或提前自动中止。

于 2016-10-23T21:54:36.110 回答