From a9fbf54693d58cd4174015c9cf8b2904a1e7c83b Mon Sep 17 00:00:00 2001 From: Green Sky Date: Sun, 6 Oct 2024 14:52:34 +0200 Subject: [PATCH] add sdcpp stduhpf webapi https://github.com/leejet/stable-diffusion.cpp/pull/367 commit 1c599839800ed5984e72562968db7e4df5d052bd --- src/sd_bot.cpp | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/sd_bot.cpp b/src/sd_bot.cpp index 3a9fc6a..cd15849 100644 --- a/src/sd_bot.cpp +++ b/src/sd_bot.cpp @@ -129,6 +129,77 @@ struct SDcpp_wip1_Endpoint : public SDBot::EndpointI { } }; +struct SDcpp_stduhpf_wip1_Endpoint : public SDBot::EndpointI { + SDcpp_stduhpf_wip1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} + + bool handleResponse(Contact3 contact, ByteSpan data) override { + bool succ = true; + + try { + // extract json result + const auto j = nlohmann::json::parse( + std::string_view{reinterpret_cast(data.ptr), data.size} + ); + + if (!j.is_array()) { + std::cerr << "SDB: json response was not an array\n"; + return false; + } + + for (const auto& j_entry : j) { + if (!j_entry.is_object()) { + std::cerr << "SDB warning: non object entry, skipping\n"; + continue; + } + + // for each returned image + // "channel": 3, // rgb? + // "data": base64 encoded image file + // "encoding": "png", + // "height": 512, + // "width": 512 + + if (j_entry.contains("encoding")) { + if (!j_entry["encoding"].is_string() || j_entry["encoding"] != "png") { + std::cerr << "SDB warning: unknown encoding '" << j_entry["encoding"] << "'\n"; + } + } + + if (!j_entry.contains("data") || !j_entry.at("data").is_string()) { + std::cerr << "SDB warning: non data entry, skipping\n"; + continue; + } + + const auto& img_data_str = j_entry.at("data").get(); + // decode data (base64) + std::vector png_data(img_data_str.size()); // just init to upper bound + size_t decoded_size {0}; + sodium_base642bin( + png_data.data(), png_data.size(), + img_data_str.data(), img_data_str.size(), + " \n\t", + &decoded_size, + nullptr, + sodium_base64_VARIANT_ORIGINAL + ); + png_data.resize(decoded_size); + + std::filesystem::create_directories("sdbot_img_send"); + //const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; + const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; + const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; + + std::ofstream(tmp_img_file_path).write(reinterpret_cast(png_data.data()), png_data.size()); + succ = succ && _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); + } + } catch (...) { + return false; + } + + return succ; + } +}; + SDBot::SDBot( Contact3Registry& cr, RegistryMessageModelI& rmm, @@ -148,6 +219,8 @@ SDBot::SDBot( _endpoint = std::make_unique(_rmm, _rng); } else if (endpoint_type == "sdcpp_wip1") { _endpoint = std::make_unique(_rmm, _rng); + } else if (endpoint_type == "sdcpp_stduhpf_wip1") { + _endpoint = std::make_unique(_rmm, _rng); } else { std::cerr << "SDB error: unknown endpoint type '" << endpoint_type << "'\n"; // TODO: throw?