1

我正在从不推荐使用的迁移我的神经网络:

init(device: MTLDevice, convolutionDescriptor: MPSCNNConvolutionDescriptor, kernelWeights: UnsafePointer<Float>, biasTerms: UnsafePointer<Float>?, flags: MPSCNNConvolutionFlags)

init(device: MTLDevice, weights: MPSCNNConvolutionDataSource)

我已经实现了一个MPSCNNConvolutionDataSource经过良好调试的并且适用于我的所有层,除了一个。仅出于测试目的,我在这里自己调用数据源函数以及不推荐使用的 init() ,MPSCNNFullyConnected以确保正确实现数据源。我知道这不是它的预期用途,但我希望相同的数据进入两个 MPSCNNFullyConnected() 构造函数。以下代码运行,NN 正常工作。

  /* This code runs as intended */
  let datasource = DataSource("test", 8, 8, 224, 1024, .reLU)
  _ = datasource.load()
    let layer = MPSCNNFullyConnected(device: device,
                                   convolutionDescriptor: datasource.descriptor(),
                                   kernelWeights: UnsafeMutablePointer<Float>(mutating: datasource.weights().assumingMemoryBound(to: Float.self)),
                                   biasTerms: datasource.biasTerms(),
                                   flags: .none)

当我用新的 init() 实例化全连接层时,网络会失败。以下代码运行但 NN 无法正常工作。

  /* This code does run, but the layer does NOT output the correct values */
  let datasource = DataSource("test", 8, 8, 224, 1024, .reLU)
  let layer = MPSCNNFullyConnected(device: device, weights: datasource)

有什么建议为什么两个电话不相同?

4

2 回答 2

3

最后我解决了。这两个调用之间的区别在于,如果您使用以下命令,则必须明确设置 layer.offset:

init(device: MTLDevice, weights: MPSCNNConvolutionDataSource) 

不推荐使用的调用:

init(device: MTLDevice, convolutionDescriptor: MPSCNNConvolutionDescriptor, kernelWeights: UnsafePointer<Float>, biasTerms: UnsafePointer<Float>?, flags: MPSCNNConvolutionFlags)

似乎隐含地做到了这一点。

此代码有效:

let datasource = DataSource("test", 8, 8, 224, 1024, .reLU)
let layer = MPSCNNFullyConnected(device: device, weights: datasource)
layer.offset = MPSOffset(x: 8/2, y: 8/2, z: 0)

我想这无处记录!感谢苹果三天的硬核调试。

于 2019-03-21T11:52:57.673 回答
0

大多数此类 Metal 内容的 Apple 文档似乎已经消失,但如果您查看头文件(或在 Xcode 中按住 Alt 键单击然后跳转到定义),它仍然存在。

权重的顺序没有改变。您仍然以与以前相同的方式加载它们。

有关如何编写此类数据源对象的示例,请查看此 repo:https ://github.com/hollance/YOLO-CoreML-MPSNNGraph/blob/2ba3435bfacb8d2f792b95887fc9df85d7048ae1/TinyYOLO-NNGraph/TinyYOLO-NNGraph/YOLO.swift# L254

于 2019-03-20T10:21:40.377 回答