Skip to content

Commit d6a3056

Browse files
committed
Add auto-bitcasts for AMX
1 parent 882f72d commit d6a3056

File tree

5 files changed

+125
-12
lines changed

5 files changed

+125
-12
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::iter::zip;
55
use libc::c_uint;
66
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
77
use rustc_codegen_ssa::MemFlags;
8+
use rustc_codegen_ssa::common::TypeKind;
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
910
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1011
use rustc_codegen_ssa::traits::*;
@@ -412,7 +413,35 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
412413
let expected_return_ty = cx.get_return_type(fn_ty);
413414
let expected_argument_tys = cx.func_params_types(fn_ty);
414415

415-
let equate_ty = |rust_ty, llvm_ty| rust_ty == llvm_ty;
416+
let equate_ty = |rust_ty, llvm_ty| {
417+
if rust_ty == llvm_ty {
418+
return true;
419+
}
420+
match cx.type_kind(llvm_ty) {
421+
TypeKind::X86_AMX => {
422+
// we will insert casts from/to x86amx in callsite, so this is fine
423+
if cx.type_kind(rust_ty) == TypeKind::Vector {
424+
let element_count = cx.vector_length(rust_ty);
425+
let element_ty = cx.element_type(rust_ty);
426+
let element_size_bits = match cx.type_kind(element_ty) {
427+
TypeKind::Half => 16,
428+
TypeKind::Float => 32,
429+
TypeKind::Double => 64,
430+
TypeKind::FP128 => 128,
431+
TypeKind::Integer => cx.int_width(element_ty),
432+
TypeKind::Pointer => cx.int_width(cx.isize_ty),
433+
_ => bug!(
434+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
435+
),
436+
};
437+
element_size_bits * element_count as u64 == 8192
438+
} else {
439+
false
440+
}
441+
}
442+
_ => false,
443+
}
444+
};
416445

417446
if actual_argument_tys.len() != expected_argument_tys.len() {
418447
todo!("A very friendly error msg")

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
6767
) -> &'ll Value {
6868
debug!("call {:?} with args ({:?})", llfn, args);
6969

70-
let args = self.check_call("call", llty, llfn, args);
70+
let args = self.cast_arguments("call", llty, llfn, args);
7171
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
7272
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
7373
if let Some(funclet_bundle) = funclet_bundle {
@@ -101,6 +101,51 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
101101
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
102102
}
103103

104+
pub(crate) fn cast_vector_to_tile(&mut self, val: &'ll Value) -> &'ll Value {
105+
let vector_type = self.cx.val_ty(val);
106+
107+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
108+
109+
let intrinsic = llvm::Intrinsic::lookup(b"llvm.x86.cast.vector.to.tile".as_ref()).unwrap();
110+
let (fn_ty, f) = self.cx.get_or_declare_intrinsic(intrinsic, &[vector_type]);
111+
unsafe {
112+
llvm::LLVMBuildCallWithOperandBundles(
113+
self.llbuilder,
114+
fn_ty,
115+
f,
116+
[val].as_ptr().cast(),
117+
1,
118+
[].as_ptr(),
119+
0,
120+
c"".as_ptr(),
121+
)
122+
}
123+
}
124+
125+
pub(crate) fn cast_tile_to_vector(
126+
&mut self,
127+
val: &'ll Value,
128+
vector_type: &'ll Type,
129+
) -> &'ll Value {
130+
assert!(self.cx.val_ty(val) == self.cx.type_x86amx());
131+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
132+
133+
let intrinsic = llvm::Intrinsic::lookup(b"llvm.x86.cast.tile.to.vector").unwrap();
134+
let (fn_ty, f) = self.cx.get_or_declare_intrinsic(intrinsic, &[vector_type]);
135+
unsafe {
136+
llvm::LLVMBuildCallWithOperandBundles(
137+
self.llbuilder,
138+
fn_ty,
139+
f,
140+
[val].as_ptr().cast(),
141+
1,
142+
[].as_ptr(),
143+
0,
144+
c"".as_ptr(),
145+
)
146+
}
147+
}
148+
104149
pub(crate) fn ret_void(&mut self) {
105150
llvm::LLVMBuildRetVoid(self.llbuilder);
106151
}
@@ -349,7 +394,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
349394
) -> &'ll Value {
350395
debug!("invoke {:?} with args ({:?})", llfn, args);
351396

352-
let args = self.check_call("invoke", llty, llfn, args);
397+
let args = self.cast_arguments("invoke", llty, llfn, args);
353398
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
354399
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
355400
if let Some(funclet_bundle) = funclet_bundle {
@@ -381,8 +426,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
381426
};
382427
if let Some(fn_abi) = fn_abi {
383428
fn_abi.apply_attrs_callsite(self, invoke);
429+
self.cast_return(fn_abi, llfn, invoke)
430+
} else {
431+
invoke
384432
}
385-
invoke
386433
}
387434

388435
fn unreachable(&mut self) {
@@ -1404,7 +1451,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14041451
) -> &'ll Value {
14051452
debug!("call {:?} with args ({:?})", llfn, args);
14061453

1407-
let args = self.check_call("call", llty, llfn, args);
1454+
let args = self.cast_arguments("call", llty, llfn, args);
14081455
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
14091456
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
14101457
if let Some(funclet_bundle) = funclet_bundle {
@@ -1434,8 +1481,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14341481
};
14351482
if let Some(fn_abi) = fn_abi {
14361483
fn_abi.apply_attrs_callsite(self, call);
1484+
self.cast_return(fn_abi, llfn, call)
1485+
} else {
1486+
call
14371487
}
1438-
call
14391488
}
14401489

14411490
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
@@ -1596,7 +1645,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15961645
ret.expect("LLVM does not have support for catchret")
15971646
}
15981647

1599-
fn check_call<'b>(
1648+
fn cast_arguments<'b>(
16001649
&mut self,
16011650
typ: &str,
16021651
fn_ty: &'ll Type,
@@ -1627,7 +1676,11 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16271676
Expected {:?} for param {}, got {:?}; injecting bitcast",
16281677
llfn, expected_ty, i, actual_ty
16291678
);
1630-
self.bitcast(actual_val, expected_ty)
1679+
if self.cx.type_kind(expected_ty) == TypeKind::X86_AMX {
1680+
self.cast_vector_to_tile(actual_val)
1681+
} else {
1682+
self.bitcast(actual_val, expected_ty)
1683+
}
16311684
} else {
16321685
actual_val
16331686
}
@@ -1708,6 +1761,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17081761
self.call(self.type_func(&[src_ty], dest_ty), None, None, f, &[val], None, None)
17091762
}
17101763

1764+
fn cast_return(
1765+
&mut self,
1766+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
1767+
llfn: &'ll Value,
1768+
ret: &'ll Value,
1769+
) -> &'ll Value {
1770+
let expected_ty = fn_abi.llvm_return_type(self.cx);
1771+
let actual_ty = self.cx.val_ty(ret);
1772+
1773+
if expected_ty != actual_ty {
1774+
debug!(
1775+
"type mismatch in function call of {:?}. \
1776+
Expected {:?} for return value, got {:?}; injecting bitcast",
1777+
llfn, expected_ty, actual_ty
1778+
);
1779+
if self.cx.type_kind(actual_ty) == TypeKind::X86_AMX {
1780+
self.cast_tile_to_vector(ret, expected_ty)
1781+
} else {
1782+
self.bitcast(ret, expected_ty)
1783+
}
1784+
} else {
1785+
ret
1786+
}
1787+
}
1788+
17111789
pub(crate) fn landing_pad(
17121790
&mut self,
17131791
ty: &'ll Type,
@@ -1737,7 +1815,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17371815
) -> &'ll Value {
17381816
debug!("invoke {:?} with args ({:?})", llfn, args);
17391817

1740-
let args = self.check_call("callbr", llty, llfn, args);
1818+
let args = self.cast_arguments("callbr", llty, llfn, args);
17411819
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
17421820
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
17431821
if let Some(funclet_bundle) = funclet_bundle {
@@ -1770,8 +1848,10 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17701848
};
17711849
if let Some(fn_abi) = fn_abi {
17721850
fn_abi.apply_attrs_callsite(self, callbr);
1851+
self.cast_return(fn_abi, llfn, callbr)
1852+
} else {
1853+
callbr
17731854
}
1774-
callbr
17751855
}
17761856

17771857
// Emits CFI pointer type membership tests.

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
175175
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
176176
}
177177

178-
#[expect(unused)]
179178
pub(crate) fn get_or_declare_intrinsic(
180179
&self,
181180
intrinsic: llvm::Intrinsic,

compiler/rustc_monomorphize/src/mono_checks/abi_check.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ fn do_check_simd_vector_abi<'tcx>(
4949
// Find the first feature that provides at least this vector size.
5050
let feature = match feature_def.iter().find(|(bits, _)| size.bits() <= *bits) {
5151
Some((_, feature)) => feature,
52+
None if matches!(&*tcx.sess.target.arch, "x86" | "x86_64")
53+
&& size.bits() == 8192 =>
54+
{
55+
"amx-tile"
56+
}
5257
None => {
5358
let (span, _hir_id) = loc();
5459
tcx.dcx().emit_err(errors::AbiErrorUnsupportedVectorType {

compiler/rustc_target/src/target_features.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ pub fn all_rust_features() -> impl Iterator<Item = (&'static str, Stability)> {
786786
// certain size to have their "proper" ABI on each architecture.
787787
// Note that they must be kept sorted by vector size.
788788
const X86_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] =
789-
&[(128, "sse"), (256, "avx"), (512, "avx512f")]; // FIXME: might need changes for AVX10.
789+
&[(128, "sse"), (256, "avx"), (512, "avx512f")];
790790
const AARCH64_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] = &[(128, "neon")];
791791

792792
// We might want to add "helium" too.

0 commit comments

Comments
 (0)