Skip to content

Commit

Permalink
update submod
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 1, 2023
1 parent f7b874b commit 20986b6
Show file tree
Hide file tree
Showing 9 changed files with 857 additions and 311 deletions.
8 changes: 4 additions & 4 deletions luisa_compute/src/lang/math.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
pub use super::swizzle::*;
use super::{Aggregate, ExprProxy, Value, VarProxy, __extract, traits::*, Float};
use crate::*;
use serde::{Serialize, Deserialize};
use half::f16;
use luisa_compute_ir::{
context::register_type,
ir::{Func, MatrixType, NodeRef, Primitive, Type, VectorElementType, VectorType},
TypeOf,
};
use serde::{Deserialize, Serialize};
use std::ops::Mul;

macro_rules! def_vec {
($name:ident, $glam_type:ident, $scalar:ty, $align:literal, $($comp:ident), *) => {
#[repr(C, align($align))]
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct $name {
$(pub $comp: $scalar), *
}
Expand Down Expand Up @@ -44,7 +44,7 @@ macro_rules! def_vec {
macro_rules! def_packed_vec {
($name:ident, $vec_type:ident, $glam_type:ident, $scalar:ty, $($comp:ident), *) => {
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, __Value, Serialize, Deserialize)]
#[derive(Copy, Clone, Debug, Default, __Value,PartialEq, Serialize, Deserialize)]
pub struct $name {
$(pub $comp: $scalar), *
}
Expand Down Expand Up @@ -480,7 +480,7 @@ macro_rules! impl_vec_proxy {
}
}
impl VectorVarTrait for $expr_proxy { }
impl ScalarOrVector for $expr_proxy {
impl ScalarOrVector for $expr_proxy {
type Element = Expr<$scalar>;
type ElementHost = $scalar;
}
Expand Down
143 changes: 116 additions & 27 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use math::Uint3;
use std::cell::{Cell, RefCell, UnsafeCell};
use std::ffi::CString;
use std::ops::{Bound, Deref, DerefMut, RangeBounds};

use std::sync::atomic::AtomicUsize;
// use self::math::Uint3;
pub mod math;
pub mod poly;
Expand All @@ -50,6 +50,14 @@ pub use math::*;
pub use poly::*;
pub use printer::*;

pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0);
// prevent node being shared across kernels
// TODO: replace NodeRef with SafeNodeRef
#[derive(Clone, Copy, Debug)]
pub(crate) struct SafeNodeRef {
pub(crate) node: NodeRef,
pub(crate) kernel_id: usize,
}
pub trait Value: Copy + ir::TypeOf + 'static {
type Expr: ExprProxy<Value = Self>;
type Var: VarProxy<Value = Self>;
Expand Down Expand Up @@ -128,18 +136,51 @@ macro_rules! impl_aggregate_for_tuple {
}
impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);

pub unsafe trait _Mask: ToNode {}
pub unsafe trait Mask: ToNode {}
pub trait IntoIndex {
fn to_u64(&self) -> Expr<u64>;
}
impl IntoIndex for i32 {
fn to_u64(&self) -> Expr<u64> {
const_(*self as u64)
}
}
impl IntoIndex for i64 {
fn to_u64(&self) -> Expr<u64> {
const_(*self as u64)
}
}
impl IntoIndex for u32 {
fn to_u64(&self) -> Expr<u64> {
const_(*self as u64)
}
}
impl IntoIndex for u64 {
fn to_u64(&self) -> Expr<u64> {
const_(*self)
}
}
impl IntoIndex for PrimExpr<u32> {
fn to_u64(&self) -> Expr<u64> {
self.ulong()
}
}
impl IntoIndex for PrimExpr<u64> {
fn to_u64(&self) -> Expr<u64> {
*self
}
}

pub trait IndexRead: ToNode {
type Element: Value;
fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<Self::Element>;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element>;
}

pub trait IndexWrite: IndexRead {
fn write<I: Into<Expr<u32>>, V: Into<Expr<Self::Element>>>(&self, i: I, value: V);
fn write<I: IntoIndex, V: Into<Expr<Self::Element>>>(&self, i: I, value: V);
}

pub fn select<A: Aggregate>(mask: impl _Mask, a: A, b: A) -> A {
pub fn select<A: Aggregate>(mask: impl Mask, a: A, b: A) -> A {
let a_nodes = a.to_vec_nodes();
let b_nodes = b.to_vec_nodes();
assert_eq!(a_nodes.len(), b_nodes.len());
Expand Down Expand Up @@ -178,9 +219,9 @@ impl ToNode for bool {
}
}

unsafe impl _Mask for bool {}
unsafe impl Mask for bool {}

unsafe impl _Mask for Bool {}
unsafe impl Mask for Bool {}

pub trait ExprProxy: Copy + Aggregate + FromNode {
type Value: Value;
Expand Down Expand Up @@ -553,6 +594,7 @@ impl<T: Value> CpuFn<T> {

pub(crate) struct Recorder {
pub(crate) scopes: Vec<IrBuilder>,
pub(crate) kernel_id: Option<usize>,
pub(crate) lock: bool,
pub(crate) captured_buffer: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
Expand All @@ -576,6 +618,7 @@ impl Recorder {
self.block_size = None;
self.arena.reset();
self.shared.clear();
self.kernel_id = None;
}
pub(crate) fn new() -> Self {
Recorder {
Expand All @@ -590,6 +633,7 @@ impl Recorder {
pools: None,
arena: Bump::new(),
building_kernel: false,
kernel_id: None,
}
}
}
Expand Down Expand Up @@ -671,6 +715,15 @@ pub fn __module_pools() -> &'static CArc<ModulePools> {
unsafe { std::mem::transmute(pool) }
})
}
// pub fn __load<T: Value>(node: NodeRef) -> Expr<T> {
// __current_scope(|b| {
// let node = b.load(node);
// Expr::<T>::from_node(node)
// })
// }
// pub fn __store(var:NodeRef, value:NodeRef) {
// let inst = &var.get().instruction;
// }

pub fn __extract<T: Value>(node: NodeRef, index: usize) -> NodeRef {
let inst = &node.get().instruction;
Expand All @@ -685,6 +738,14 @@ pub fn __extract<T: Value>(node: NodeRef, index: usize) -> NodeRef {
Func::GetElementPtr
}
}
Instruction::Call(f, args) => match f {
Func::AtomicRef => {
let mut indices = args.to_vec();
indices.push(i);
return b.call(Func::AtomicRef, &indices, <T as TypeOf>::type_());
}
_ => Func::ExtractElement,
},
_ => Func::ExtractElement,
};
let node = b.call(op, &[node, i], <T as TypeOf>::type_());
Expand Down Expand Up @@ -759,6 +820,12 @@ macro_rules! var {
($t:ty, $init:expr) => {
local::<$t>($init.into())
};
($e:expr) => {
def($e)
};
}
pub fn def<E: ExprProxy<Value = T>, T: Value>(init: E) -> Var<T> {
Var::<T>::from_node(__current_scope(|b| b.local(init.node())))
}
pub fn local<T: Value>(init: Expr<T>) -> Var<T> {
Var::<T>::from_node(__current_scope(|b| b.local(init.node())))
Expand Down Expand Up @@ -1294,9 +1361,9 @@ impl<T: Value> Shared<T> {
}),
}
}
pub fn len(&self) -> Expr<u32> {
pub fn len(&self) -> Expr<u64> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => const_(*length as u32),
Type::Array(ArrayType { element: _, length }) => const_(*length as u64),
_ => unreachable!(),
}
}
Expand All @@ -1306,8 +1373,8 @@ impl<T: Value> Shared<T> {
_ => unreachable!(),
}
}
pub fn write<I: Into<Expr<u32>>, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.into();
pub fn write<I: IntoIndex, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.into();

if need_runtime_check() {
Expand Down Expand Up @@ -1467,8 +1534,8 @@ impl<T: Value> VLArrayExpr<T> {
_ => unreachable!(),
}
}
pub fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
pub fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
let i = i.to_u64();
if need_runtime_check() {
lc_assert!(i.cmplt(self.len()));
}
Expand All @@ -1477,20 +1544,20 @@ impl<T: Value> VLArrayExpr<T> {
b.call(Func::ExtractElement, &[self.node, i.node()], T::type_())
}))
}
pub fn len(&self) -> Expr<u32> {
pub fn len(&self) -> Expr<u64> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => const_(*length as u32),
Type::Array(ArrayType { element: _, length }) => const_(*length as u64),
_ => unreachable!(),
}
}
}

impl<T: Value, const N: usize> IndexRead for ArrayExpr<T, N> {
type Element = T;
fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
let i = i.to_u64();

lc_assert!(i.cmplt(const_(N as u32)));
lc_assert!(i.cmplt(const_(N as u64)));

Expr::<T>::from_node(__current_scope(|b| {
b.call(Func::ExtractElement, &[self.node, i.node()], T::type_())
Expand All @@ -1500,10 +1567,10 @@ impl<T: Value, const N: usize> IndexRead for ArrayExpr<T, N> {

impl<T: Value, const N: usize> IndexRead for ArrayVar<T, N> {
type Element = T;
fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
let i = i.to_u64();
if need_runtime_check() {
lc_assert!(i.cmplt(const_(N as u32)));
lc_assert!(i.cmplt(const_(N as u64)));
}

Expr::<T>::from_node(__current_scope(|b| {
Expand All @@ -1514,12 +1581,12 @@ impl<T: Value, const N: usize> IndexRead for ArrayVar<T, N> {
}

impl<T: Value, const N: usize> IndexWrite for ArrayVar<T, N> {
fn write<I: Into<Expr<u32>>, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.into();
fn write<I: IntoIndex, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.into();

if need_runtime_check() {
lc_assert!(i.cmplt(const_(N as u32)));
lc_assert!(i.cmplt(const_(N as u64)));
}

__current_scope(|b| {
Expand Down Expand Up @@ -1666,7 +1733,14 @@ impl<T: Value + 'static> CallableParameter for BufferVar<T> {
encoder.buffer(self)
}
}

impl CallableParameter for ByteBufferVar {
fn def_param(_: Option<Rc<dyn Any>>, builder: &mut KernelBuilder) -> Self {
builder.byte_buffer()
}
fn encode(&self, encoder: &mut CallableArgEncoder) {
encoder.byte_buffer(self)
}
}
impl<T: IoTexel + 'static> CallableParameter for Tex2dVar<T> {
fn def_param(_: Option<Rc<dyn Any>>, builder: &mut KernelBuilder) -> Self {
builder.tex2d()
Expand Down Expand Up @@ -1716,7 +1790,11 @@ where
builder.uniform::<T>()
}
}

impl KernelParameter for ByteBufferVar {
fn def_param(builder: &mut KernelBuilder) -> Self {
builder.byte_buffer()
}
}
impl<T: Value> KernelParameter for BufferVar<T> {
fn def_param(builder: &mut KernelBuilder) -> Self {
builder.buffer()
Expand Down Expand Up @@ -1810,6 +1888,17 @@ impl KernelBuilder {
self.args.push(node);
FromNode::from_node(node)
}
pub fn byte_buffer(&mut self) -> ByteBufferVar {
let node = new_node(
__module_pools(),
Node::new(CArc::new(Instruction::Buffer), Type::void()),
);
self.args.push(node);
ByteBufferVar {
node,
handle: None,
}
}
pub fn buffer<T: Value>(&mut self) -> BufferVar<T> {
let node = new_node(
__module_pools(),
Expand Down Expand Up @@ -2288,7 +2377,7 @@ macro_rules! impl_kernel_build_for_fn {
impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);

pub fn if_then_else<R: Aggregate>(
cond: impl _Mask,
cond: impl Mask,
then: impl Fn() -> R,
else_: impl Fn() -> R,
) -> R {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lang/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Printer {
let item_id = items.len() as u32;

if_!(
offset.cmplt(data.len()) & (offset + 1 + args.count as u32).cmple(data.len()),
offset.cmplt(data.len().uint()) & (offset + 1 + args.count as u32).cmple(data.len().uint()),
{
data.atomic_fetch_add(0, 1);
data.write(offset, item_id);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub mod prelude {
pub use crate::lang::traits::{CommonVarOp, FloatVarTrait, IntVarTrait, VarCmp, VarCmpEq};
pub use crate::lang::{
Aggregate, ExprProxy, FromNode, IndexRead, IndexWrite, KernelBuildFn, KernelParameter,
KernelSignature, Value, VarProxy, _Mask,
KernelSignature, Value, VarProxy, Mask,
};
pub use crate::lang::{
__compose, __cpu_dbg, __current_scope, __env_need_backtrace, __extract, __insert,
Expand Down
Loading

0 comments on commit 20986b6

Please sign in to comment.