Skip to content

Commit

Permalink
Fix 64-bit dtype for MSVC
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Dec 1, 2024
1 parent e3b5549 commit 663eea1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
42 changes: 21 additions & 21 deletions exllamav2/exllamav2_ext/ext_rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,50 +58,50 @@ void rope_
);
}

long gen_mrope_pos_ids
int64_t gen_mrope_pos_ids
(
torch::Tensor mrope_pos_ids,
torch::Tensor ids,
int merge_size,
const std::vector<std::tuple<long, long>> &spans,
const std::vector<std::tuple<long, long, long>> &grids
const std::vector<std::tuple<int64_t, int64_t>> &spans,
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
)
{
int max_length = mrope_pos_ids.size(1);
int in_length = ids.size(0);

long* in_ids = (long*) ids.data_ptr();
long* pos_ids = (long*) mrope_pos_ids.data_ptr();
int64_t* in_ids = (int64_t*) ids.data_ptr();
int64_t* pos_ids = (int64_t*) mrope_pos_ids.data_ptr();

long* out_t = pos_ids;
long* out_h = pos_ids + max_length;
long* out_w = pos_ids + 2 * max_length;
int64_t* out_t = pos_ids;
int64_t* out_h = pos_ids + max_length;
int64_t* out_w = pos_ids + 2 * max_length;

long base_t = 0;
long next_base_t = 0;
int64_t base_t = 0;
int64_t next_base_t = 0;

for (int i = 0; i < max_length; ++i)
{
bool is_emb = false;
if (i < in_length)
{
long id = in_ids[i];
int64_t id = in_ids[i];

for (int j = 0; j < spans.size(); ++j)
{
long span_start = std::get<0>(spans[j]);
long span_end = std::get<1>(spans[j]);
long span = span_end - span_start;
int64_t span_start = std::get<0>(spans[j]);
int64_t span_end = std::get<1>(spans[j]);
int64_t span = span_end - span_start;
if (id >= span_start && id < span_end)
{
is_emb = true;
long k = id - span_start;
long grid_t = std::get<0>(grids[j]);
long grid_h = std::get<1>(grids[j]) / (long)merge_size;
long grid_w = std::get<2>(grids[j]) / (long)merge_size;
long k_t = base_t + (k / grid_w / grid_h) % grid_t;
long k_h = base_t + (k / grid_w) % grid_h;
long k_w = base_t + k % grid_w;
int64_t k = id - span_start;
int64_t grid_t = std::get<0>(grids[j]);
int64_t grid_h = std::get<1>(grids[j]) / (int64_t)merge_size;
int64_t grid_w = std::get<2>(grids[j]) / (int64_t)merge_size;
int64_t k_t = base_t + (k / grid_w / grid_h) % grid_t;
int64_t k_h = base_t + (k / grid_w) % grid_h;
int64_t k_w = base_t + k % grid_w;
*out_t++ = k_t;
*out_h++ = k_h;
*out_w++ = k_w;
Expand Down
6 changes: 3 additions & 3 deletions exllamav2/exllamav2_ext/ext_rope.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ void rope_
bool neox_style
);

long gen_mrope_pos_ids
int64_t gen_mrope_pos_ids
(
torch::Tensor mrope_pos_ids,
torch::Tensor ids,
int merge_size,
const std::vector<std::tuple<long, long>> &spans,
const std::vector<std::tuple<long, long, long>> &grids
const std::vector<std::tuple<int64_t, int64_t>> &spans,
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
);

0 comments on commit 663eea1

Please sign in to comment.