diff --git a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp index be81a8a..858d9aa 100644 --- a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp @@ -10,6 +10,7 @@ using namespace poplar; tessellate/tile/vertex/tile_qr_vertex.cpp \ -o tessellate/tile/vertex/tile_qr_vertex.gp */ +static constexpr size_t MIN_ALIGN = 8; class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dIndexedVertex : public MultiVertex { @@ -19,18 +20,17 @@ class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dIndexedVertex // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - static constexpr size_t MIN_ALIGN = 8; Input> x; // (N,) x vector Input> y; // (N,) y vector - Input> + Input> start_idx; - Input> + Input> worker_offsets; // (7,) number threads + 1. - Output> partials; // float result. + Output> partials; // float result. bool compute(unsigned wid) { // Always assuming size % 2 == 0 @@ -70,14 +70,14 @@ class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dIndexedVertex class HessenbergCorrectionVectorVertex : public MultiVertex { public: using T = float; - Input> Rcol; // (N,) R column. - Input> sdiag; // (N,) R diag. sign. - Input> + Input> Rcol; // (N,) R column. + Input> sdiag; // (N,) R diag. sign. + Input> cidx; - Output> + Output> v; // (N,) QR correction vector (not normalized) - Output> + Output> vrescale; // (1,) QR correction vector rescaling (2 / norm) @@ -186,14 +186,14 @@ class [[poplar::constraint( // Passing 2 scaling factors is more efficient for the QR implementation. // Avoids another full pass on the v vector in the vertex it is constructed. - Input> + Input> scale1; // (1,) first scaling factor. - Input> + Input> scale2; // (1,) 2nd scaling factor. - Input> + Input> start_idx_; - Input> + Input> worker_offsets; // (7,) threads work size + 1.