1

我使用在 tensorflow 中实现的随机森林估计器来预测文本是否为英文。我使用以下代码(train_input_fn 函数返回特征和类标签)保存了我的模型(具有 2k 个样本和 2 个类标签 0/1(非英语/英语)的数据集):

model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)

运行上述代码后,graph.pbtxt 和 checkpoints 保存在模型文件夹中。现在我想在 Android 上使用它。我有两个问题:

  1. 作为第一步,我需要将图形和检查点冻结为 .pb 文件,以便在 Android 上使用它。我尝试了 freeze_graph(我在这里使用了代码:https ://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py )。当我在我的模式下调用 freeze_graph 时,我收到以下错误并且代码无法创建最终的 .pb 图:

    文件“/Users/XXXXXXX/freeze_graph.py”,第 105 行,在 freeze_graph _ = tf.import_graph_def(input_graph_def, name="") 文件“/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow /python/framework/importer.py",第 258 行,在 import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'

这就是我所说的 freeze_graph:

def save_model_android():
    checkpoint_state_name = "model.ckpt-1"
    input_graph_name = "graph.pbtxt"
    output_graph_name = "output_graph.pb"
    checkpoint_path = os.path.join(model_path, checkpoint_state_name)

    input_graph_path = os.path.join(model_path, input_graph_name)
    input_saver_def_path = None
    input_binary = False
    output_node_names = "output"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(model_path, output_graph_name)
    clear_devices = True

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path,
                              output_node_names, restore_op_name,
                              filename_tensor_name, output_graph_path,
                              clear_devices, "")

我还尝试在“tf.contrib.learn.datasets.load_iris”中冻结 iris 数据集。我犯了同样的错误。所以我相信它与数据集无关。

  1. 第二步,我需要使用手机上的 .pb 文件来预测文本。我通过谷歌找到了相机演示示例,其中包含大量代码。我想知道是否有分步教程如何通过传递特征向量并获取类标签在Android上使用Tensorflow模型。

提前致谢!

更新

通过使用最新版本的 tensorflow(0.12),问题得到解决。但是,现在的问题是我应该传递给 output_node_names 什么???我怎样才能得到图中的输出节点是什么?

4

2 回答 2

1

Re (1) 看起来您正在一个无法访问 contrib ops 的 tensorflow 构建上运行 freeze_graph。也许在调用 freeze_graph 之前尝试显式导入 tensorforest?

Re (2) 我不知道一个更简单的例子。

于 2016-11-28T17:30:51.230 回答
0

CountExtremelyRandomStats 是 TensorForest 的自定义操作之一,存在于 tensorflow/contrib 中。正如所指出的,TF 在某些时候默认切换到包含 contrib ops。我认为在以前的版本中没有一种简单的方法可以将 contrib 自定义操作包含在全局注册表中,因为 TensorForest 使用构建 .so 文件的方法,该文件包含在运行时加载的数据文件中(一种方法这是创建 TensorForest 时的标准,但可能不再是)。因此,没有容易包含的 python 构建规则可以正确链接到 C++ 自定义操作中。您可以尝试在构建规则中包含 tensorflow/contrib/tensor_forest:ops_lib 作为一个 dep,但我认为它不会起作用。

在任何情况下,您都可以尝试安装 tensorflow 的夜间版本。替代方案包括修改 tensorforest 自定义操作的构建方式,这非常讨厌。

于 2016-11-28T19:22:54.703 回答