Skip to content

Commit

Permalink
修复了部分bug 增加模型支持
Browse files Browse the repository at this point in the history
  • Loading branch information
zjkhahah committed Sep 17, 2024
1 parent 433b2ec commit 9a5f208
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
Binary file removed 1.png
Binary file not shown.
16 changes: 11 additions & 5 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ int main(int argc, char* argv[]) {
cmdline::parser p;
p.add<std::string>("input_image", 'i', "Enter image path", false, "./demo/images/test2.jpg");
p.add<std::string>("output_image", 'o', "Enter image path", false, ".");
p.add<int>("out_size_kb", 'k', "image size", false, 20);
p.add<std::string>("segment_model", 's', "segment model name", false, "mnn_hivision_modnet.mnn");
p.add<int>("out_size_kb", 'k', "image size", false, 0);
p.add<int>("thread_num", 't', "model use thread num", false, 4);
p.add<int>("background_color_r", 'r', "background red", false, 255);
p.add<int>("background_color_g", 'g', "background green", false, 0);
Expand All @@ -19,20 +20,23 @@ int main(int argc, char* argv[]) {
p.add<int>("out_images_height", 'h', "out images height", false, 413);
p.add<int>("face_model", 'f', "face_model type 5 and 8", false, 8);
std::string face_model_path = "./model";
const char* segment_modnet = "./model/mnn_hivision_modnet.mnn";
//const char* segment_modnet = "./model/mnn_hivision_modnet.mnn";
p.parse_check(argc, argv);

params.out_image_width=p.get<int>("out_images_width");
params.out_image_height=p.get<int>("out_images_height");
params.rgb_b=p.get<int>("background_color_b");
params.rgb_g=p.get<int>("background_color_g");
params.rgb_r=p.get<int>("background_color_r");


std::string modelFilename = p.get<std::string>("segment_model");
std::string modelPath = "./model/" + modelFilename;
const char* modelPathCStr = modelPath.c_str();
cv::Vec3b newBackgroundColor(p.get<int>("background_color_b"), p.get<int>("background_color_g"), p.get<int>("background_color_r"));
LFFD* face_detector = new LFFD(face_model_path, p.get<int>("face_model"), p.get<int>("thread_num"));

cv::Mat image = cv::imread(p.get<std::string>("input_image"), cv::IMREAD_COLOR);
cv::Mat bgra_img= Interference(segment_modnet, image,4);
cv::Mat bgra_img= Interference(modelPathCStr, image,4);

cv::Mat add_background_img = addBackground(bgra_img, newBackgroundColor);
cv::cvtColor(add_background_img, add_background_img, cv::COLOR_BGRA2BGR);
Expand All @@ -53,7 +57,9 @@ int main(int argc, char* argv[]) {
}
free(face_detector);
cv::Mat hd_result= photo_adjust(params, add_background_img);
resizeImageToKB(hd_result, p.get < std::string>("output_image")+"result_kb.png",p.get <int>("out_size_kb") );
if(p.get<int>("out_size_kb")>0){
resizeImageToKB(hd_result, p.get < std::string>("output_image")+"result_kb.png",p.get <int>("out_size_kb") );
}
cv::Mat standard_result;
cv::Size standard_size(params.out_image_width, params.out_image_height);
cv::resize(hd_result, standard_result, standard_size);
Expand Down
23 changes: 14 additions & 9 deletions src/human_matting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ cv::Mat Interference(const char* & mnn_path, cv::Mat input_BgrImg,int num_threa
auto input = net->getSessionInput(session, nullptr);
int size_w = 512; // 初始化宽度
int size_h = 512; // 初试化高度
int bpp =3;
auto shape = input->shape();
if (shape[0] != 1) {
shape[0] = 1;
net->resizeTensor(input, shape);
net->resizeSession(session);
}
{
int bpp = 0;
if(shape[2]!=-1||shape[3]!=-1){
if (shape[0] != 1) {
shape[0] = 1;
net->resizeTensor(input, shape);
net->resizeSession(session);
}
bpp = shape[1];
size_h = shape[2];
size_w = shape[3];
Expand All @@ -114,9 +114,14 @@ cv::Mat Interference(const char* & mnn_path, cv::Mat input_BgrImg,int num_threa
size_h = 1;
if (size_w == 0)
size_w = 1;
MNN_PRINT("input: w:%d , h:%d, bpp: %d\n", size_w, size_h, bpp);
}

else{
std::vector<int> shape = {1, bpp, size_h, size_w}; // 假设框架使用 NHWC 格式
net->resizeTensor(input, shape);
net->resizeSession(session);
}
MNN_PRINT("input: w:%d , h:%d, bpp: %d\n", size_w, size_h, bpp);

cv::Mat pre_img = seg_preprocess( matBgrImg,size_w, size_h);
std::vector<std::vector<cv::Mat>> nChannels;
std::vector<cv::Mat> rgbChannels(3);
Expand Down

0 comments on commit 9a5f208

Please sign in to comment.