Skip to content

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
Signed-off-by: jagadeesh <[email protected]>
  • Loading branch information
jagadeesh committed Aug 11, 2023
1 parent ec7e7f6 commit 327ad4d
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

namespace resnet {

constexpr int kTargetImageSize = 224;
constexpr double kImageNormalizationMeanR = 0.485;
constexpr double kImageNormalizationMeanG = 0.456;
constexpr double kImageNormalizationMeanB = 0.406;
constexpr double kImageNormalizationStdR = 0.229;
constexpr double kImageNormalizationStdG = 0.224;
constexpr double kImageNormalizationStdB = 0.225;
constexpr int kTopKClasses = 5;

std::vector<torch::jit::IValue> ResnetHandler::Preprocess(
std::shared_ptr<torch::Device>& device,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
Expand Down Expand Up @@ -88,24 +97,24 @@ std::vector<torch::jit::IValue> ResnetHandler::Preprocess(
image = image(roi);

// Resize
cv::resize(image, image, cv::Size(224, 224));
cv::resize(image, image, cv::Size(kTargetImageSize, kTargetImageSize));

// Convert BGR to RGB format
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);

image.convertTo(image, CV_32FC3, 1 / 255.0);

// Convert the OpenCV image to a torch tensor
torch::Tensor tensorImage = torch::from_blob(image.data, {image.rows, image.cols, 3}, c10::kFloat);
torch::Tensor tensorImage = torch::from_blob(
image.data, {image.rows, image.cols, 3}, c10::kFloat);
tensorImage = tensorImage.permute({2, 0, 1});
tensorImage.unsqueeze_(0);

// Normalize
std::vector<double> norm_mean = {0.485, 0.456, 0.406};
std::vector<double> norm_std = {0.229, 0.224, 0.225};
std::vector<double> norm_mean = {kImageNormalizationMeanR, kImageNormalizationMeanG, kImageNormalizationMeanB};
std::vector<double> norm_std = {kImageNormalizationStdR, kImageNormalizationStdG, kImageNormalizationStdB};

tensorImage =
torch::data::transforms::Normalize<>(norm_mean, norm_std)(tensorImage);
tensorImage = torch::data::transforms::Normalize<>(
norm_mean, norm_std)(tensorImage);

tensorImage.clone();
batch_tensors.emplace_back(tensorImage.to(*device));
Expand Down Expand Up @@ -144,12 +153,14 @@ void ResnetHandler::Postprocess(
try {
auto response = (*response_batch)[kv.second];
namespace F = torch::nn::functional;

// Perform softmax and top-k operations
torch::Tensor ps = F::softmax(data, F::SoftmaxFuncOptions(1));
std::tuple<torch::Tensor, torch::Tensor> result =
torch::topk(ps, 5, 1, true, true);
torch::topk(ps, kTopKClasses, 1, true, true);
auto [probs, classes] = result;
// tensor([[0.4097, 0.3467, 0.1300, 0.0239, 0.0115]]) tensor([[281, 282,
// 285, 287, 463]])

// Serialize and set the response
response->SetResponse(200, "data_tpye",
torchserve::PayloadType::kDATA_TYPE_BYTES,
torch::pickle_save(probs[kv.first]));
Expand All @@ -172,6 +183,7 @@ void ResnetHandler::Postprocess(
}
}
}

} // namespace resnet

#if defined(__linux__) || defined(__APPLE__)
Expand Down

0 comments on commit 327ad4d

Please sign in to comment.