Skip to content

MXNet embedding extraction issue #912

@bootleg-dev

Description

@bootleg-dev

Hello everyone. I've slightly modified the example which is located in the samples folder in order to extract embeddings from an image. I'm using MXNet models which can be derived from here https://github.com/deepinsight/insightface/wiki/Model-Zoo. The problem is running the model inference always outputs the same embedding for different images. So for aligned faces of completely different people the model outputs the same embedding. Could you please give me some guidance, how to fix this?

face1
face2

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;

import org.apache.mxnet.javaapi.*;
import org.bytedeco.javacpp.*;

import org.bytedeco.mxnet.*;
import static org.bytedeco.mxnet.global.mxnet.*;

import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_imgproc.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;

public class ImageClassificationPredict {

    static final float DEFAULT_MEAN = 117.0f;

    // Read file to buffer
    static class BufferFile implements Closeable {
        public String file_path_;
        public int length_ = 0;
        public BytePointer buffer_;

        public BufferFile(String file_path) {
            file_path_ = file_path;
            try {
                byte[] bytes = Files.readAllBytes(Paths.get(file_path));
                length_ = bytes.length;
                System.out.println(file_path + " ... " + length_ + " bytes");
                buffer_ = new BytePointer(bytes);
            } catch (IOException e) {
                System.err.println("Can't open the file: " + e + ". Please check " + file_path + ".");
                assert false;
            }
        }

        public int GetLength() {
            return length_;
        }

        public BytePointer GetBuffer() {
            return buffer_;
        }

        public void close() throws IOException {
            buffer_.deallocate();
            buffer_ = null;
        }
    }

    static void GetImageFile(String image_file, FloatPointer image_data,
                             int channels, Size resize_size, FloatPointer mean_data) {
        // Read all kinds of file into a BGR color 3 channels image
        Mat im_ori = imread(image_file, IMREAD_COLOR);

        if (im_ori.empty()) {
            System.err.println("Can't open the image. Please check " + image_file + ".");
            assert false;
        }

        Mat im = new Mat();

        resize(im_ori, im, resize_size);

        int rows = im.rows();
        int cols = im.cols();
        int size = rows * cols * channels;

        FloatBuffer ptr_image_r = image_data.position(0).asBuffer();
        FloatBuffer ptr_image_g = image_data.position(size / 3).asBuffer();
        FloatBuffer ptr_image_b = image_data.position(size / 3 * 2).asBuffer();

        FloatBuffer ptr_mean_r, ptr_mean_g, ptr_mean_b;
        ptr_mean_r = ptr_mean_g = ptr_mean_b = null;
        if (mean_data != null && !mean_data.isNull()) {
            ptr_mean_r = mean_data.position(0).asBuffer();
            ptr_mean_g = mean_data.position(size / 3).asBuffer();
            ptr_mean_b = mean_data.position(size / 3 * 2).asBuffer();
        }

        float mean_b, mean_g, mean_r;
        mean_b = mean_g = mean_r = DEFAULT_MEAN;

        for (int i = 0; i < rows; i++) {
            ByteBuffer data = im.ptr(i).capacity(3 * cols).asBuffer();

            for (int j = 0; j < cols; j++) {
                if (mean_data != null && !mean_data.isNull()) {
                    mean_r = ptr_mean_r.get();
                    if (channels > 1) {
                        mean_g = ptr_mean_g.get();
                        mean_b = ptr_mean_b.get();
                    }
                }
                if (channels > 1) {
                    ptr_image_b.put((float)(data.get() & 0xFF) - mean_b);
                    ptr_image_g.put((float)(data.get() & 0xFF) - mean_g);
                }

                ptr_image_r.put((float)(data.get() & 0xFF) - mean_r);
            }
        }
    }


    static void PrintOutputResult(FloatPointer data) {

        for (int i = 0; i < data.limit(); i++) {
            System.out.print(data.get(i) + " ");
        }


    }

    static void predict(PredictorHandle pred_hnd, FloatPointer image_data,
                        NDListHandle nd_hnd, int n) {
        int image_size = (int)image_data.limit();
        // Set Input Image
        MXPredSetInput(pred_hnd, "data", image_data.position(0), image_size);
        // Do Predict Forward
        MXPredForward(pred_hnd);

        int output_index = 0;

        IntPointer shape = new IntPointer((IntPointer)null);
        IntPointer shape_len = new IntPointer(1);

        // Get Output Result
        MXPredGetOutputShape(pred_hnd, output_index, shape, shape_len);

        int size = 1;
        for (int i = 0; i < shape_len.get(0); i++) { size *= shape.get(i); }

        FloatPointer data = new FloatPointer(size);

        MXPredGetOutput(pred_hnd, output_index, data.position(0), size);

        // Release NDList
        if (nd_hnd != null) {
            MXNDListFree(nd_hnd);
        }

        // Release Predictor
        MXPredFree(pred_hnd);

        // Synset path for your model, you have to modify it
        // Print Output Data
        PrintOutputResult(data.position(0));
    }

    public static void main(String[] args) throws Exception {
        // Preload required by JavaCPP
        Loader.load(org.bytedeco.mxnet.global.mxnet.class);

        if (args.length < 1) {
            System.out.println("No test image here.");
            System.out.println("Usage: java ImageClassificationPredict apple.jpg [num_threads]");
            return;
        }

        final String test_file = args[0];
        int num_threads = 1;
        if (args.length == 2) {
            num_threads = Integer.parseInt(args[1]);
        }

        // Models path for your model, you have to modify it
        final String json_file = "/home/adileg/JavaProjects/javacpp-presets-release/mxnet/samples/feature_model/model-r100-ii/model-symbol.json";
        final String param_file = "/home/adileg/JavaProjects/javacpp-presets-release/mxnet/samples/feature_model/model-r100-ii/model-0000.params";

        BufferFile json_data = new BufferFile(json_file);
        BufferFile param_data = new BufferFile(param_file);

        // Parameters
        int dev_type = 1;  // 1: cpu, 2: gpu
        int dev_id = 0;  // arbitrary.
        int num_input_nodes = 1;  // 1 for feedforward
        String[] input_keys = { "data" };

        // Image size and channels
        int width = 112;
        int height = 112;
        int channels = 3;

        int[] input_shape_indptr = { 0, 4 };
        int[] input_shape_data = { 1, channels, height, width };


        if (json_data.GetLength() == 0 || param_data.GetLength() == 0) {
            System.exit(1 /* EXIT_FAILURE */);
        }


        final int image_size = width * height * channels;

        // Read Mean Data
        final FloatPointer nd_data = new FloatPointer((Pointer)null);
        final NDListHandle nd_hnd = new NDListHandle((Pointer)null);

        // Read Image Data
        final FloatPointer image_data = new FloatPointer(image_size);

        GetImageFile(test_file, image_data, channels, new Size(width, height), nd_data);


        if (num_threads == 1) {
            // Create Predictor
            final PointerPointer<PredictorHandle> pred_hnd = new PointerPointer<PredictorHandle>(1);
            MXPredCreate(json_data.GetBuffer(),
                    param_data.GetBuffer(),
                    param_data.GetLength(),
                    dev_type,
                    dev_id,
                    num_input_nodes,
                    new PointerPointer(input_keys),
                    new IntPointer(input_shape_indptr),
                    new IntPointer(input_shape_data),
                    pred_hnd);

            assert !pred_hnd.get().isNull();
            predict(pred_hnd.get(PredictorHandle.class), image_data, nd_hnd, 0);

        }

        System.out.println("run successfully");

        System.exit(0 /* EXIT_SUCCESS */);
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions