Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support image_embs input #799

Merged
merged 11 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ input [
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "image_embs"
data_type: TYPE_FP16
dims: [ -1, -1, -1 ]
optional: true
},
{
name: "image_offsets"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
},
{
name: "step"
data_type: TYPE_INT32
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TurbomindModelConfig:
max_position_embeddings: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: int = 0
image_dim: int = 0
irexyc marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_dict(cls, env, allow_none=False):
Expand Down
25 changes: 25 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ async def async_stream_infer(self, *args, **kwargs):
def stream_infer(self,
session_id,
input_ids,
image_embs=None,
irexyc marked this conversation as resolved.
Show resolved Hide resolved
image_offsets=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
Expand Down Expand Up @@ -544,6 +546,29 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
CORRID=np.array(session_id, dtype=np.uint64),
STOP=_broadcast_np((1 if stop else 0), np.int32))

if image_embs is not None:
assert len(image_offsets) == len(image_embs)
# image_embs Union[List[np.array], List[List[np.array]]]
irexyc marked this conversation as resolved.
Show resolved Hide resolved
# image_offsets Union[List[int], List[List[int]]]
if isinstance(image_offsets[0], int):
image_offsets = [image_offsets]
image_embs = [image_embs]
image_embs = [[
torch.from_numpy(x).squeeze().unsqueeze(0) for x in y
] for y in image_embs]
image_embs = [torch.cat(x) for x in image_embs]
image_embs = pad_sequence(image_embs, batch_first=True)
image_offsets = [torch.IntTensor(x) for x in image_offsets]
image_offsets = pad_sequence(image_offsets,
batch_first=True,
padding_value=-1)
if self.tm_model.config.weight_type == 'fp32':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_type may be 'int4'.

Copy link
Collaborator Author

@irexyc irexyc Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int4,lookup table也是half的吧

image_embs = image_embs.float()
else:
image_embs = image_embs.half()
inputs['image_embs'] = image_embs
inputs['image_offsets'] = image_offsets

if ignore_eos:
stop_words = None
bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
Expand Down
31 changes: 30 additions & 1 deletion src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
output_ids = Copy(input_ids, input_length, output_ids);
}

// copy image embeddings
if (model_->image_dim_ > 0 && r->inputs[rank_].isExist("image_embs")) {
irexyc marked this conversation as resolved.
Show resolved Hide resolved
irexyc marked this conversation as resolved.
Show resolved Hide resolved
auto image_embs_tensor = r->inputs[rank_].at("image_embs");
const auto n_offsets = r->inputs[rank_].at("image_offsets").shape.back();
irexyc marked this conversation as resolved.
Show resolved Hide resolved
if (image_embs_tensor.shape.size() != 4 || image_embs_tensor.shape[1] != n_offsets
|| image_embs_tensor.shape[2] != model_->image_dim_) {
TM_LOG_WARNING("[ImageFeature] Invalid image feature, id = %ld, info = %s",
(long)seq.id,
image_embs_tensor.toString().c_str());
continue;
}

T* image_embs = r->inputs[rank_].getPtr<T>("image_embs");
const int* h_image_offsets = r->inputs[rank_].getPtr<int>("image_offsets");
const int count = model_->image_dim_ * model_->hidden_units_;
for (size_t i = 0; i < n_offsets && h_image_offsets[i] > 0; i++) {
irexyc marked this conversation as resolved.
Show resolved Hide resolved
seq.image_offsets.push_back(seq.tokens.size() + h_image_offsets[i]);
auto& emb = seq.image_embs.emplace_back();
emb.resize(count * sizeof(T));
std::memcpy(emb.data(), image_embs, count * sizeof(T));
irexyc marked this conversation as resolved.
Show resolved Hide resolved
irexyc marked this conversation as resolved.
Show resolved Hide resolved
image_embs += count;
}
}

// total context length (history + input)
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;
Expand Down Expand Up @@ -1420,6 +1444,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};

std::vector<const Sequence*> sequences;

BatchedCopy batched_copy;
for (int i = first; i < last; ++i) {
input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids);
Expand All @@ -1436,6 +1462,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
}
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
sequences.push_back(state_->sequences[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
int token_count = input_ids - context_decoder_ids_buf_;
Expand Down Expand Up @@ -1482,7 +1509,9 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
pf_batch_size,
max_input_len,
max_context_cnts[p],
max_context_cnts[p]);
max_context_cnts[p],
decode_lengths.data(),
irexyc marked this conversation as resolved.
Show resolved Hide resolved
sequences.data());

if (iter == 0) {
// compute logits of inputs if requested
Expand Down
84 changes: 62 additions & 22 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t inter_size,
size_t num_layer,
size_t vocab_size,
size_t image_dim,
float norm_eps,
const LlamaAttentionParams& attn_params,
int start_id,
Expand All @@ -69,6 +70,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
inter_size_(inter_size),
num_layer_(num_layer),
vocab_size_(vocab_size),
image_dim_(image_dim),
attn_params_(attn_params),
vocab_size_padded_(vocab_size),
rmsnorm_eps_(norm_eps),
Expand Down Expand Up @@ -166,28 +168,63 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
}

template<typename T>
void LlamaV2<T>::forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len)
void LlamaV2<T>::updateImageEmbedding(T* decoder_input,
const int bsz,
const int* decode_lengths,
const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

if (image_dim_ <= 0) {
return;
}

for (int i = 0; i < bsz; i++) {
decoder_input += ((i > 0) ? decode_lengths[i - 1] : 0) * hidden_units_;
irexyc marked this conversation as resolved.
Show resolved Hide resolved
if (decode_lengths[i] == 1) {
irexyc marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
const auto& seq = *sequences[i];
for (int j = 0; j < seq.image_offsets.size(); j++) {
if (seq.image_offsets[j] + image_dim_ <= seq.cache_len) {
continue;
}
int off_dst = std::max(0, seq.image_offsets[j] - seq.cache_len);
int off_src = std::max(0, seq.cache_len - seq.image_offsets[j]);
T* dst_ptr = decoder_input + off_dst * hidden_units_;
std::byte* src_ptr = seq.image_embs[j].data() + off_src * hidden_units_;
size_t count = (image_dim_ - off_src) * hidden_units_ * sizeof(T);
cudaMemcpyAsync(dst_ptr, src_ptr, count, cudaMemcpyDefault, stream_);
}
}
sync_check_cuda_error();
}

template<typename T>
void LlamaV2<T>::forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len,
const int* decode_lengths,
const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

Expand All @@ -203,6 +240,9 @@ void LlamaV2<T>::forwardUnified(T* out,
1,
hidden_units_,
stream_);

updateImageEmbedding(decoder_input, dc_batch_size + pf_batch_size, decode_lengths, sequences);

sync_check_cuda_error();

const auto dtype = getTensorType<T>();
Expand Down
51 changes: 29 additions & 22 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LlamaV2 {
size_t inter_size,
size_t num_layer,
size_t vocab_size,
size_t image_dim,
float norm_eps,
const LlamaAttentionParams& attn_params,
int start_id,
Expand Down Expand Up @@ -107,28 +108,32 @@ class LlamaV2 {

void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);

void forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len);
void updateImageEmbedding(T* decoder_input, const int bsz, const int* decode_lengths, const Sequence** sequences);

void forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len,
const int* decode_lengths,
irexyc marked this conversation as resolved.
Show resolved Hide resolved
const Sequence** sequences);

void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);

Expand Down Expand Up @@ -172,6 +177,8 @@ class LlamaV2 {
const size_t local_kv_head_num_;
NcclParam tensor_para_;

const size_t image_dim_;

cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
IAllocator* allocator_;
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/models/llama/SequenceManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ struct Sequence {

mutable float rope_theta = 0.f;

// image data
mutable std::vector<std::vector<std::byte>> image_embs{};
mutable std::vector<int> image_offsets{};

Sequence(uint64_t _id): id(_id) {}

friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0);

image_dim_ = reader.GetInteger("llama", "image_dim", image_dim_);

// rotary embedding parameters
attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f);
Expand Down Expand Up @@ -273,6 +275,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
inter_size_,
num_layer_,
vocab_size_,
image_dim_,
norm_eps_,
attn_params_,
start_id_,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct LlamaTritonModel: public AbstractTransformerModel {
bool attn_bias_;
int quant_policy_;
int group_size_;
size_t image_dim_;

// shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
Expand Down
Loading