diff --git a/src/tensor/npy.rs b/src/tensor/npy.rs index 416f422d..8d40e55f 100644 --- a/src/tensor/npy.rs +++ b/src/tensor/npy.rs @@ -129,15 +129,23 @@ impl Header { if descr.starts_with('>') { return Err(TchError::FileFormat(format!("little-endian descr {}", descr))); } + // the only supported types in tensor are: + // float64, float32, float16, + // complex64, complex128, + // int64, int32, int16, int8, + // uint8, and bool. match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') { - "f2" => Kind::Half, - "f4" => Kind::Float, - "f8" => Kind::Double, - "i4" => Kind::Int, - "i8" => Kind::Int64, - "i2" => Kind::Int16, - "i1" => Kind::Int8, - "u1" => Kind::Uint8, + "e" | "f2" => Kind::Half, + "f" | "f4" => Kind::Float, + "d" | "f8" => Kind::Double, + "i" | "i4" => Kind::Int, + "q" | "i8" => Kind::Int64, + "h" | "i2" => Kind::Int16, + "b" | "i1" => Kind::Int8, + "B" | "u1" => Kind::Uint8, + "?" | "b1" => Kind::Bool, + "F" | "F4" => Kind::ComplexFloat, + "D" | "F8" => Kind::ComplexDouble, descr => { return Err(TchError::FileFormat(format!("unrecognized descr {}", descr))) }