Skip to content

Commit

Permalink
Refine Lite API (VeriSilicon#221)
Browse files Browse the repository at this point in the history
Signed-off-by: Zongwu Yang <[email protected]>
  • Loading branch information
ZongwuYang authored Nov 19, 2021
1 parent 0ca4970 commit c90efe7
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 240 deletions.
26 changes: 18 additions & 8 deletions include/tim/lite/execution.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,24 @@ namespace tim {
namespace lite {

class Execution {
public:
static std::shared_ptr<Execution> Create(
const void* executable, size_t executable_size);
virtual Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual bool Trigger() = 0;
public:
static std::shared_ptr<Execution> Create(const void* executable,
size_t executable_size);
virtual std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx,
uint8_t* buffer,
size_t size) = 0;
virtual std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx,
uint8_t* buffer,
size_t size) = 0;
virtual Execution& BindInputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& BindOutputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) = 0;
virtual Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) = 0;
virtual bool Trigger() = 0;
};

}
}
} // namespace lite
} // namespace tim
#endif
15 changes: 2 additions & 13 deletions include/tim/lite/handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,10 @@
namespace tim {
namespace lite {

class HandleImpl;

class Handle {
public:
std::unique_ptr<HandleImpl>& impl() { return impl_; }
bool Flush();
bool Invalidate();
protected:
std::unique_ptr<HandleImpl> impl_;
};

class UserHandle : public Handle {
public:
UserHandle(void* buffer, size_t size);
~UserHandle();
virtual bool Flush() = 0;
virtual bool Invalidate() = 0;
};

}
Expand Down
20 changes: 16 additions & 4 deletions samples/lenet_lite/lenet_lite_asymu8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,28 @@ int main() {
assert(input);
assert(output);
memset(output, 0, output_sz);

auto input_handle = exec->CreateInputHandle(0, input, input_sz);
auto output_handle = exec->CreateOutputHandle(0, (uint8_t*)output, output_sz);

exec->BindInputs({input_handle});
exec->BindOutputs({output_handle});
memcpy(input, input_data.data(), input_data.size());
input_handle->Flush();
exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5);

auto input_handle = std::make_shared<tim::lite::UserHandle>(
input, input_data.size());
auto output_handle = std::make_shared<tim::lite::UserHandle>(
output, lenet_output_size * sizeof(float));
// rebind input and output
exec->UnBindInput(input_handle);
exec->UnBindOutput(output_handle);
exec->BindInputs({input_handle});
exec->BindOutputs({output_handle});
input_handle->Flush();
exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5);

free(output);
free(input);
} else {
Expand Down
168 changes: 57 additions & 111 deletions src/tim/lite/execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,74 +28,16 @@
#include <cstring>
#include <vector>
#include <memory>
#include <algorithm>
#include <iostream>
#include <cassert>
#include "handle_private.h"

#include "vip_lite.h"

namespace tim {
namespace lite {

namespace {
bool QueryInputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
default:
break;
}
return true;
}

bool QueryOutputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
break;
default:
break;
}
return true;
}
}

ExecutionImpl::ExecutionImpl(const void* executable, size_t executable_size) {
vip_status_e status = VIP_SUCCESS;
vip_network network = nullptr;
Expand Down Expand Up @@ -130,41 +72,44 @@ ExecutionImpl::~ExecutionImpl() {
vip_finish_network(network_);
vip_destroy_network(network_);
}
input_maps_.clear();
output_maps_.clear();
input_handles_.clear();
output_handles_.clear();
vip_destroy();
}

std::shared_ptr<Handle> ExecutionImpl::CreateInputHandle(uint32_t in_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipInputBuffer(network_, in_idx)) {
return handle;
} else {
return nullptr;
}
}

std::shared_ptr<Handle> ExecutionImpl::CreateOutputHandle(uint32_t out_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipPOutputBuffer(network_, out_idx)) {
return handle;
} else {
return nullptr;
}
}

Execution& ExecutionImpl::BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) {
if (!IsValid()) {
return *this;
}
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
for (uint32_t i = 0; i < handles.size(); i ++) {
auto handle = handles[i];
if (!handle) {
status = VIP_ERROR_FAILURE;
break;
}
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (input_maps_.find(handle) == input_maps_.end()) {
if (!QueryInputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
for (auto handle : handles) {
if (input_handles_.end() == std::find(input_handles_.begin(), input_handles_.end(), handle)) {
input_handles_.push_back(handle);
auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
vip_status_e status = vip_set_input(network_, handle_impl->Index(), handle_impl->VipHandle());
if (status != VIP_SUCCESS) {
std::cout << "Set input for network failed." << std::endl;
assert(false);
}
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
input_maps_[handle] = internal_handle;
} else {
internal_handle = input_maps_.at(handle);
}
status = vip_set_input(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
std::cout << "The input handle has been binded, need not bind it again." << std::endl;
}
}
return *this;
Expand All @@ -174,37 +119,38 @@ Execution& ExecutionImpl::BindOutputs(const std::vector<std::shared_ptr<Handle>>
if (!IsValid()) {
return *this;
}
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
for (uint32_t i = 0; i < handles.size(); i ++) {
auto handle = handles[i];
if (!handle) {
status = VIP_ERROR_FAILURE;
break;
}
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (output_maps_.find(handle) == output_maps_.end()) {
if (!QueryOutputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
for (auto handle : handles) {
if (output_handles_.end() == std::find(output_handles_.begin(), output_handles_.end(), handle)) {
output_handles_.push_back(handle);
auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
vip_status_e status = vip_set_output(network_, handle_impl->Index(), handle_impl->VipHandle());
if (status != VIP_SUCCESS) {
std::cout << "Set output for network failed." << std::endl;
assert(false);
}
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
output_maps_[handle] = internal_handle;
} else {
internal_handle = output_maps_.at(handle);
}
status = vip_set_output(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
std::cout << "The output handle has been binded, need not bind it again." << std::endl;
}
}
return *this;
};

Execution& ExecutionImpl::UnBindInput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(input_handles_.begin(), input_handles_.end(), handle);
if (input_handles_.end() != it) {
input_handles_.erase(it);
}
return *this;
}

Execution& ExecutionImpl::UnBindOutput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(output_handles_.begin(), output_handles_.end(), handle);
if (output_handles_.end() != it) {
output_handles_.erase(it);
}
return *this;
}

bool ExecutionImpl::Trigger() {
if (!IsValid()) {
return false;
Expand Down
39 changes: 24 additions & 15 deletions src/tim/lite/execution_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,30 @@ namespace tim {
namespace lite {

class ExecutionImpl : public Execution {
public :
ExecutionImpl(const void* executable, size_t executable_size);
~ExecutionImpl();
Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) override;
bool Trigger() override;
bool IsValid() const { return valid_; };
vip_network network() { return network_; };
private:
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> input_maps_;
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> output_maps_;
bool valid_;
vip_network network_;
public:
ExecutionImpl(const void* executable, size_t executable_size);
~ExecutionImpl();
std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx, uint8_t* buffer,
size_t size) override;
std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx, uint8_t* buffer,
size_t size) override;
Execution& BindInputs(
const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& BindOutputs(
const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) override;
Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) override;
bool Trigger() override;
bool IsValid() const { return valid_; };
vip_network network() { return network_; };

private:
std::vector<std::shared_ptr<Handle>> input_handles_;
std::vector<std::shared_ptr<Handle>> output_handles_;
bool valid_;
vip_network network_;
};

}
}
} // namespace lite
} // namespace tim
#endif
Loading

0 comments on commit c90efe7

Please sign in to comment.