Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array Type Pain Points #34

Open
dcvz opened this issue Apr 18, 2024 · 0 comments
Open

Array Type Pain Points #34

dcvz opened this issue Apr 18, 2024 · 0 comments

Comments

@dcvz
Copy link
Contributor

dcvz commented Apr 18, 2024

At the moment, you have to do a lot of type writing for different actions you make + you could typo and won't get an error until runtime. On an Array for example:

let mut array = Array::zeros::<f32>(&[2, 3], StreamOrDevice::default());
array.eval();
let data: &[f32] = array.as_slice().unwrap();

Being able to infer the type of the array, given that we knew it when we created it would be lovely. One option would be to wrap the Array into a typed Array that we expose to the users:

pub struct MlxArray<E: kind::Element, const D: usize> {
    pub tensor: wrapper::Array,
    phantom: std::marker::PhantomData<E>,
}

impl<E: kind::Element, const D: usize> MlxArray<E, D> {
    pub fn eval(&mut self) {
        self.tensor.eval();
    }

    pub fn shape(&self) -> Shape<D> {
        Shape::from(self.tensor.shape())
    }

    pub fn as_slice(&self) -> Option<&[E]> {
        self.tensor.as_slice()
    }
}

This would allow inferring types for many different things, also for when we start running modification ops like abs()

Example that will quickly become outdated in #31

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant