2

如何使用 Metal 计算具有 16 个通道的图像的均值和方差值?

我想分别计算不同通道的均值和方差值!

前任。:

kernel void meanandvariance(texture2d_array<float, access::read> in[[texture(0)]],
                          texture2d_array<float, access::write> out[[texture(1)]],

                          ushort3 gid[[thread_position_in_grid]],
                          ushort tid[[thread_index_in_threadgroup]],
                          ushort3 tg_size[[threads_per_threadgroup]]) {

                          }


4

2 回答 2

1

可能有一种方法可以做到这一点,即在输入纹理数组和输出纹理数组上创建一系列纹理视图MPSImageStatisticsMeanAndVariance,为每个切片编码内核调用。

但是让我们来看看如何自己做。有许多不同的可能方法,所以我选择了一种简单的方法,并使用了一些有趣的统计结果。

本质上,我们将执行以下操作:

  1. 编写一个可以为图像的单行生成子集均值和方差的内核。
  2. 编写一个内核,该内核可以从步骤 1 的部分结果中产生总体均值和方差。

以下是内核:

kernel void compute_row_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
                                            texture2d_array<float, access::write> outTexture [[texture(1)]],
                                            uint3 tpig [[thread_position_in_grid]])
{
    uint row = tpig.x;
    uint slice = tpig.y;
    uint width = inTexture.get_width();

    if (row >= inTexture.get_height() || slice >= inTexture.get_array_size()) { return; }

    float4 mean(0.0f);
    float4 var(0.0f);
    for (uint col = 0; col < width; ++col) {
        float4 rgba = inTexture.read(ushort2(col, row), slice);
        // http://datagenetics.com/blog/november22017/index.html
        float weight = 1.0f / (col + 1);
        float4 oldMean = mean;
        mean = mean + (rgba - mean) * weight;
        var = var + (rgba - oldMean) * (rgba - mean);
    }

    var = var / width;

    outTexture.write(mean, ushort2(row, 0), slice);
    outTexture.write(var, ushort2(row, 1), slice);
}

kernel void reduce_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
                                       texture2d_array<float, access::write> outTexture [[texture(1)]],
                                       uint3 tpig [[thread_position_in_grid]])
{
    uint width = inTexture.get_width();
    uint slice = tpig.x;

    // https://arxiv.org/pdf/1007.1012.pdf
    float4 mean(0.0f);
    float4 meanOfVar(0.0f);
    float4 varOfMean(0.0f);
    for (uint col = 0; col < width; ++col) {
        float weight = 1.0f / (col + 1);

        float4 oldMean = mean;
        float4 submean = inTexture.read(ushort2(col, 0), slice);
        mean = mean + (submean - mean) * weight;

        float4 subvar = inTexture.read(ushort2(col, 1), slice);
        meanOfVar = meanOfVar +  (subvar - meanOfVar) * weight;

        varOfMean = varOfMean + (submean - oldMean) * (submean - mean);
    }
    float4 var = meanOfVar + varOfMean / width;

    outTexture.write(mean, ushort2(0, 0), slice);
    outTexture.write(var, ushort2(1, 0), slice);
}

总之,为了实现第 1 步,我们使用“在线”(增量)算法来计算行的部分均值/方差,这种方式比仅添加所有像素值并除以宽度在数值上更稳定。我编写这个内核的参考是这篇文章。网格中的每个线程将其行的统计信息写入中间纹理数组的适当列和切片。

为了实现第 2 步,我们需要找到一种统计合理的方法来从部分结果中计算整体统计数据。这在求均值的情况下非常简单:总体均值是子集均值的均值(当每个子集的样本量相同时,这成立;在一般情况下,总体均值是子集均值的加权和)。方差更棘手,但事实 证明总体的方差是子集方差的均值和子集均值的方差之和(关于相同大小的子集的相同警告在这里适用)。这是一个方便的事实,我们可以结合上面的增量方法来生成每个切片的最终均值和方差,并将其写入输出纹理的相应切片。

为了完整起见,这是我用来驱动这些内核的 Swift 代码:

let library = device.makeDefaultLibrary()!

let meanVarKernelFunction = library.makeFunction(name: "compute_row_mean_variance_array")!
let meanVarComputePipelineState = try! device.makeComputePipelineState(function: meanVarKernelFunction)

let reduceKernelFunction = library.makeFunction(name: "reduce_mean_variance_array")!
let reduceComputePipelineState = try! device.makeComputePipelineState(function: reduceKernelFunction)

let width = sourceTexture.width
let height = sourceTexture.height
let arrayLength = sourceTexture.arrayLength

let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: width, height: height, mipmapped: false)
textureDescriptor.textureType = .type2DArray
textureDescriptor.arrayLength = arrayLength
textureDescriptor.width = height
textureDescriptor.height = 2
textureDescriptor.usage = [.shaderRead, .shaderWrite]

let partialResultsTexture = device.makeTexture(descriptor: textureDescriptor)!

textureDescriptor.width = 2
textureDescriptor.height = 1
textureDescriptor.usage = .shaderWrite

let destTexture = device.makeTexture(descriptor: textureDescriptor)!

let commandBuffer = commandQueue.makeCommandBuffer()!

let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder()!

computeCommandEncoder.setComputePipelineState(meanVarComputePipelineState)
computeCommandEncoder.setTexture(sourceTexture, index: 0)
computeCommandEncoder.setTexture(partialResultsTexture, index: 1)
let meanVarGridSize = MTLSize(width: sourceTexture.height, height: sourceTexture.arrayLength, depth: 1)
let meanVarThreadgroupSize = MTLSizeMake(meanVarComputePipelineState.threadExecutionWidth, 1, 1)
let meanVarThreadgroupCount = MTLSizeMake((meanVarGridSize.width + meanVarThreadgroupSize.width - 1) / meanVarThreadgroupSize.width,
                                          (meanVarGridSize.height + meanVarThreadgroupSize.height - 1) / meanVarThreadgroupSize.height,
                                          1)
computeCommandEncoder.dispatchThreadgroups(meanVarThreadgroupCount, threadsPerThreadgroup: meanVarThreadgroupSize)

computeCommandEncoder.setComputePipelineState(reduceComputePipelineState)
computeCommandEncoder.setTexture(partialResultsTexture, index: 0)
computeCommandEncoder.setTexture(destTexture, index: 1)
let reduceThreadgroupSize = MTLSizeMake(1, 1, 1)
let reduceThreadgroupCount = MTLSizeMake(arrayLength, 1, 1)
computeCommandEncoder.dispatchThreadgroups(reduceThreadgroupCount, threadsPerThreadgroup: reduceThreadgroupSize)

computeCommandEncoder.endEncoding()

let destTexture2DDesc = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: 2, height: 1, mipmapped: false)
destTexture2DDesc.usage = .shaderWrite
let destTexture2D = device.makeTexture(descriptor: destTexture2DDesc)!

meanVarKernel.encode(commandBuffer: commandBuffer, sourceTexture: sourceTexture2D, destinationTexture: destTexture2D)

#if os(macOS)
let blitCommandEncoder = commandBuffer.makeBlitCommandEncoder()!
blitCommandEncoder.synchronize(resource: destTexture)
blitCommandEncoder.synchronize(resource: destTexture2D)
blitCommandEncoder.endEncoding()
#endif

commandBuffer.commit()

commandBuffer.waitUntilCompleted()

在我的实验中,这个程序产生了相同的结果MPSImageStatisticsMeanAndVariance,在 1e-7 的数量级上给出或取了一些差异。它也比我的 Mac 上的 MPS慢 2.5 倍,部分原因可能是未能利用粒度并行性来利用延迟隐藏。

于 2020-01-20T00:09:20.200 回答
0
#include <metal_stdlib>
using namespace metal;

kernel void instance_norm(constant float4* scale[[buffer(0)]],
                          constant float4* shift[[buffer(1)]],
                          texture2d_array<float, access::read> in[[texture(0)]],
                          texture2d_array<float, access::write> out[[texture(1)]],

                          ushort3 gid[[thread_position_in_grid]],
                          ushort tid[[thread_index_in_threadgroup]],
                          ushort3 tg_size[[threads_per_threadgroup]]) {

    ushort width = in.get_width();
    ushort height = in.get_height();
    const ushort thread_count = tg_size.x * tg_size.y;

    threadgroup float4 shared_mem [256];

    float4 sum = 0;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            sum += in.read(ushort2(xIndex, yIndex), gid.z);
        }
    }
    shared_mem[tid] = sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduce to 32 values
    sum = 0;
    if (tid < 32) {
        for (ushort i = tid + 32; i < thread_count; i += 32) {
            sum += shared_mem[i];
        }
    }
    shared_mem[tid] += sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Calculate mean
    sum = 0;
    if (tid == 0) {
        ushort top = min(ushort(32), thread_count);
        for (ushort i = 0; i < top; i += 1) {
            sum += shared_mem[i];
        }
        shared_mem[0] = sum / (width * height);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    const float4 mean = shared_mem[0];

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Variance
    sum = 0;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            sum += pow(in.read(ushort2(xIndex, yIndex), gid.z) - mean, 2);
        }
    }

    shared_mem[tid] = sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduce to 32 values
    sum = 0;
    if (tid < 32) {
        for (ushort i = tid + 32; i < thread_count; i += 32) {
            sum += shared_mem[i];
        }
    }
    shared_mem[tid] += sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Calculate variance
    sum = 0;
    if (tid == 0) {
        ushort top = min(ushort(32), thread_count);
        for (ushort i = 0; i < top; i += 1) {
            sum += shared_mem[i];
        }
        shared_mem[0] = sum / (width * height);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    const float4 sigma = sqrt(shared_mem[0] + float4(1e-4));

    float4 multiplier = scale[gid.z] / sigma;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            float4 val = in.read(ushort2(xIndex, yIndex), gid.z);
            out.write(clamp((val - mean) * multiplier + shift[gid.z], -10.0, 10.0), ushort2(xIndex, yIndex), gid.z);
        }
    }

}

这就是 Blend 实现的方式,但我认为这不是真的,有人可以解决吗?

https://github.com/xmartlabs/Bender/blob/master/Sources/Metal/instanceNorm.metal

于 2020-01-21T09:36:02.860 回答