pytorch SNPE推理输出未对齐的图像

bbmckpt7  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(83)

我正在开发一个采用pix 2 pix模型(类似于UNET)的去天气化Android应用程序。该应用程序主要使用手机摄像头(一加7)来捕捉图像,去风化并在主界面中显示结果。深度学习推理接口采用Qualcomm的SNPE框架。目前,我们遇到了一个问题,模型到Bitmap的输出不对齐,如图所示。x1c 0d1x下面是推理代码:

final List<String> result = new LinkedList<>();

    final FloatTensor tensor = mNeuralNetwork.createFloatTensor(
            mNeuralNetwork.getInputTensorsShapes().get(mInputLayer));

    Log.e("[MODEL]", "create tensor");
    Bitmap smImage = Bitmap.createScaledBitmap(mImage, 1080, 720, true);

    final int[] dimensions = tensor.getShape();
    final boolean isGrayScale = (dimensions[dimensions.length -1] == 1);
    float[] rgbBitmapAsFloat;
    if (!isGrayScale) {
        rgbBitmapAsFloat = loadRgbBitmapAsFloat(smImage);
    } else {
        rgbBitmapAsFloat = loadGrayScaleBitmapAsFloat(smImage);
    }
    tensor.write(rgbBitmapAsFloat, 0, rgbBitmapAsFloat.length);

    Log.e("[MODEL]", "create tensor done!");
    final Map<String, FloatTensor> inputs = new HashMap<>();
    inputs.put(mInputLayer, tensor);
    Log.e("[MODEL]", "create input tensor done!");

    final long javaExecuteStart = SystemClock.elapsedRealtime();
    final Map<String, FloatTensor> outputs = mNeuralNetwork.execute(inputs);
    Log.e("[MODEL]", "model execute!");
    final long javaExecuteEnd = SystemClock.elapsedRealtime();
    mJavaExecuteTime = javaExecuteEnd - javaExecuteStart;
    FloatTensor outputTensor = new FloatTensor() {
        @Override
        public void write(float[] floats, int i, int i1, int... ints) {

        }

        @Override
        public void write(float v, int... ints) {

        }

        @Override
        public int read(float[] floats, int i, int i1, int... ints) {
            return 0;
        }

        @Override
        public float read(int... ints) {
            return 0;
        }

        @Override
        public void release() {

        }
    };
    for (Map.Entry<String, FloatTensor> output : outputs.entrySet()) {
        Log.e("[MODEL]", "output_layer: " + output.getKey());
        if (output.getKey().equals(mOutputLayer)) {
            outputTensor = output.getValue();
            Log.e("[MODEL]", "output_layer: " + output.getKey() + ", shape: " +
                    String.valueOf(outputTensor.getShape()[0]) + " " +
                    String.valueOf(outputTensor.getShape()[1]) + " " +
                    String.valueOf(outputTensor.getShape()[2]) + " " +
                    String.valueOf(outputTensor.getShape()[3]) + " " );
            }
        }
    return outputTensor;`

下面是将SNPE Floattensor转换为JAVA Bitmap的代码:

final float[] pixelsBatched = new float[tensor.getSize()];
    tensor.read(pixelsBatched, 0, tensor.getSize());
    Log.i("[IMAGE]", "size: " + String.valueOf(tensor.getSize()));
    int w = 1080;
    int h = 720;
    Bitmap img = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            float r = pixelsBatched[y * w * 3 + x * 3 + 0] * 255;
            float g = pixelsBatched[y * w * 3 + x * 3 + 1] * 255;
            float b = pixelsBatched[y * w * 3 + x * 3 + 2] * 255;
            int color = ((int)r << 16) | ((int)g << 8) | (int)b | 0xFF000000;
            img.setPixel(x, y, color);
        }
    }

    return img;

为了进一步分析这个问题,我将输入Tensor直接输出,而不是推断它。
return tensor;
将输入Tensor转换为Bitmap后,我发现图像是正确的。因此,我猜如果推理步骤是错误的。我使用Pytorch框架进行训练,并将训练好的模型导出到ONNX。我在pytorch框架中测试了模型,模型输出了正确的图像。然后,通过onnx-sim对模型进行简化,并通过SNPE的转换工具将其转换为dlc模型。onnx网络的结构如下所示。

我想问一下发生这种不对准的可能原因是什么。非常感谢您!

更新!##################

int channelSize = w * h;
    float r = pixelsBatched[y * w + x] * 255;
    float g = pixelsBatched[y * w + x + channelSize] * 255;
    float b = pixelsBatched[y * w + x + 2 * channelSize] * 255;

结果如下:

更新!##############################

snpe-dlv-viewer的结果:

输入层信息:

输出层信息:

更新!##########################

float[] loadRgbBitmapAsFloat(Bitmap image) {
    final int[] pixels = new int[image.getWidth() * image.getHeight()];
    image.getPixels(pixels, 0, image.getWidth(), 0, 0,
            image.getWidth(), image.getHeight());

    final float[] pixelsBatched = new float[pixels.length * 3];
    for (int y = 0; y < image.getHeight(); y++) {
        for (int x = 0; x < image.getWidth(); x++) {
            final int idx = y * image.getWidth() + x;
            final int batchIdx = idx * 3;

            final float[] rgb = extractColorChannels(pixels[idx]);
            pixelsBatched[batchIdx]     = rgb[0];
            pixelsBatched[batchIdx + 1] = rgb[1];
            pixelsBatched[batchIdx + 2] = rgb[2];
        }
    }
    return pixelsBatched;
}
ars1skjm

ars1skjm1#

我想你可能把输出Tensor的布局搞错了。当你像这样覆盖输出Tensor时:

float r = pixelsBatched[y * w * 3 + x * 3 + 0] * 255;
float g = pixelsBatched[y * w * 3 + x * 3 + 1] * 255;
float b = pixelsBatched[y * w * 3 + x * 3 + 2] * 255;

顺序读取RGB值。然而,输出Tensor布局是1x3x1080x720,这意味着所有R值都按顺序存储,然后是所有B值,然后是所有G值。
所以,你需要定义

int channelSize = w * h;

然后你这样读,

float r = pixelsBatched[y * w + x] * 255;
float g = pixelsBatched[y * w + x + channelSize] * 255;
float b = pixelsBatched[y * w + x + 2 * channelSize] * 255;

相关问题