Skip to content

Commit

Permalink
fallback codegen: matrix transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
swfly committed Nov 29, 2024
1 parent a8f62b4 commit f3ffaea
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,47 @@ class FallbackCodegen {
return b.CreateLoad(transform_type, transform_alloca, "");
}


[[nodiscard]] llvm::Value *_translate_matrix_transpose(
CurrentFunction &current,
IRBuilder &b,
const xir::IntrinsicInst *inst) noexcept
{
// Lookup the operand (the matrix to be transposed)
auto matrix = inst->operand(0u);
auto dimension = inst->type()->dimension();

// Lookup the LLVM value
auto llvm_matrix = _lookup_value(current, b, matrix);

// Type validation
LUISA_ASSERT(matrix->type()->is_matrix(), "Matrix transpose type mismatch");
LUISA_ASSERT(matrix->type()->dimension()>2, "2x2 Matrix is yet unsupported");

// The resulting matrix will have the same dimension
llvm::Type *float_type = llvm::Type::getFloatTy(b.getContext());
llvm::ArrayType *vec4_type = llvm::ArrayType::get(float_type, 4); // Simulate vec4

// Initialize the resulting matrix with 'undef' values
llvm::ArrayType *result_type = llvm::ArrayType::get(vec4_type, dimension);
llvm::Value *result = llvm::UndefValue::get(result_type);

// Transpose the matrix by swapping rows and columns
for (unsigned i = 0; i < dimension; ++i) {
for (unsigned j = 0; j < dimension; ++j) {
// Extract value at position (i, j) in the original matrix
llvm::Value *value = b.CreateExtractValue(llvm_matrix, {i, j});

// Insert the value at position (j, i) in the transposed matrix
result = b.CreateInsertValue(result, value, {j, i});
}
}

return result;
}



[[nodiscard]] llvm::Value *_translate_matrix_linalg_multiply(
CurrentFunction &current,
IRBuilder &b,
Expand Down Expand Up @@ -2273,7 +2314,7 @@ class FallbackCodegen {
case xir::IntrinsicOp::MATRIX_COMP_DIV: break;
case xir::IntrinsicOp::MATRIX_LINALG_MUL: return _translate_matrix_linalg_multiply(current, b, inst);
case xir::IntrinsicOp::MATRIX_DETERMINANT: break;
case xir::IntrinsicOp::MATRIX_TRANSPOSE: break;
case xir::IntrinsicOp::MATRIX_TRANSPOSE: return _translate_matrix_transpose(current, b, inst);
case xir::IntrinsicOp::MATRIX_INVERSE: break;
case xir::IntrinsicOp::ATOMIC_EXCHANGE: break;
case xir::IntrinsicOp::ATOMIC_COMPARE_EXCHANGE: break;
Expand Down

0 comments on commit f3ffaea

Please sign in to comment.