我使用在 tensorflow 中实现的随机森林估计器来预测文本是否为英文。我使用以下代码(train_input_fn 函数返回特征和类标签)保存了我的模型(具有 2k 个样本和 2 个类标签 0/1(非英语/英语)的数据集):
model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)
运行上述代码后,graph.pbtxt 和 checkpoints 保存在模型文件夹中。现在我想在 Android 上使用它。我有两个问题:
作为第一步,我需要将图形和检查点冻结为 .pb 文件,以便在 Android 上使用它。我尝试了 freeze_graph(我在这里使用了代码:https ://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py )。当我在我的模式下调用 freeze_graph 时,我收到以下错误并且代码无法创建最终的 .pb 图:
文件“/Users/XXXXXXX/freeze_graph.py”,第 105 行,在 freeze_graph _ = tf.import_graph_def(input_graph_def, name="") 文件“/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow /python/framework/importer.py",第 258 行,在 import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'
这就是我所说的 freeze_graph:
def save_model_android():
checkpoint_state_name = "model.ckpt-1"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)
input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = None
input_binary = False
output_node_names = "output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
我还尝试在“tf.contrib.learn.datasets.load_iris”中冻结 iris 数据集。我犯了同样的错误。所以我相信它与数据集无关。
- 第二步,我需要使用手机上的 .pb 文件来预测文本。我通过谷歌找到了相机演示示例,其中包含大量代码。我想知道是否有分步教程如何通过传递特征向量并获取类标签在Android上使用Tensorflow模型。
提前致谢!
更新
通过使用最新版本的 tensorflow(0.12),问题得到解决。但是,现在的问题是我应该传递给 output_node_names 什么???我怎样才能得到图中的输出节点是什么?