Skip to content

Commit

Permalink
fallback: matrix-vector multiplication fix
Browse files Browse the repository at this point in the history
  • Loading branch information
swfly committed Nov 29, 2024
1 parent f3ffaea commit fa87c51
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1840,17 +1840,25 @@ class FallbackCodegen {
llvm::ArrayType *result_type = llvm::ArrayType::get(vec4_type, dimension);
llvm::Value *result = llvm::UndefValue::get(result_type);

std::vector<llvm::Value *> values(dimension*dimension);
// Transpose the matrix by swapping rows and columns
for (unsigned i = 0; i < dimension; ++i) {
for (unsigned j = 0; j < dimension; ++j) {
// Extract value at position (i, j) in the original matrix
llvm::Value *value = b.CreateExtractValue(llvm_matrix, {i, j});

// Insert the value at position (j, i) in the transposed matrix
result = b.CreateInsertValue(result, value, {j, i});
auto idx = i + j*dimension;
auto v = b.CreateExtractValue(llvm_matrix, {i, j});
values[idx] = v;
}
}
for (unsigned i = 0; i < dimension; ++i) {
for (unsigned j = 0; j < dimension; ++j) {
auto idx = i + j * dimension;
result = b.CreateInsertValue(result, values[idx], {j, i});
}
}

// Insert the value at position (j, i) in the transposed matrix

return result;
}

Expand Down Expand Up @@ -1908,7 +1916,7 @@ class FallbackCodegen {
llvm::Value *sum = llvm::ConstantFP::get(float_type, 0.0);
for (unsigned k = 0; k < dimension; ++k) {
// Load A[i][k] and B[k][j]
llvm::Value *a_ik = b.CreateExtractValue(llvm_A, {i, k});
llvm::Value *a_ik = b.CreateExtractValue(llvm_A, {k, i});
llvm::Value *b_kj = b.CreateExtractElement(llvm_B, {k});
llvm::Value *product = b.CreateFMul(a_ik, b_kj);
sum = b.CreateFAdd(sum, product);
Expand Down

0 comments on commit fa87c51

Please sign in to comment.