Skip to content

Commit

Permalink
detect naxis
Browse files Browse the repository at this point in the history
  • Loading branch information
oguyon committed Nov 15, 2023
1 parent ac40c91 commit 43146db
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
27 changes: 24 additions & 3 deletions plugins/milk-extra-src/linalgebra/SGEMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,83 +159,102 @@ errno_t computeSGEMM(
int inA_Mdim;
int inA_Mdim0;
int inA_Mdim1;
int inA_Mdim1_active = 1; // is axis used ?

int inA_Ndim;
int inA_Ndim0;
int inA_Ndim1;
int inA_Ndim1_active = 1; // is axis used ?

if(imginA.md->naxis == 3)
{
//printf("inA_Mdim : %d x %d\n", imginA.md->size[0], imginA.md->size[1]);
inA_Mdim = imginA.md->size[0] * imginA.md->size[1];
inA_Mdim0 = imginA.md->size[0];
inA_Mdim1 = imginA.md->size[1];
inA_Mdim1_active = 1;

//printf("inA_Ndim : %d\n", imginA.md->size[2]);
inA_Ndim = imginA.md->size[2];
inA_Ndim0 = imginA.md->size[2];
inA_Ndim1 = 1;
inA_Ndim1_active = 0;
}
else
{
//printf("inA_Mdim : %d\n", imginA.md->size[0]);
inA_Mdim = imginA.md->size[0];
inA_Mdim0 = imginA.md->size[1];
inA_Mdim0 = imginA.md->size[0];
inA_Mdim1 = 1;
inA_Mdim1_active = 0;

//printf("inNdim : %d\n", imginA.md->size[1]);
inA_Ndim = imginA.md->size[1];
inA_Ndim0 = imginA.md->size[1];
inA_Ndim1 = 1;
inA_Ndim1_active = 0;
}


int inB_Mdim;
int inB_Mdim0;
int inB_Mdim1;
int inB_Mdim1_active = 1;

int inB_Ndim;
int inB_Ndim0;
int inB_Ndim1;
int inB_Ndim1_active = 1;

if(imginB.md->naxis == 3)
{
//printf("inB_Mdim : %d x %d\n", imginB.md->size[0], imginB.md->size[1]);
inB_Mdim = imginB.md->size[0] * imginB.md->size[1];
inB_Mdim0 = imginB.md->size[0];
inB_Mdim1 = imginB.md->size[1];
inB_Mdim1_active = 1;

//printf("inB_Ndim : %d\n", imginB.md->size[2]);
inB_Ndim = imginB.md->size[2];
inB_Ndim0 = imginB.md->size[2];
inB_Ndim1 = 1;
inB_Ndim1_active = 0;
}
else
{
//printf("inB_Mdim : %d\n", imginB.md->size[0]);
inB_Mdim = imginB.md->size[0];
inB_Mdim0 = imginB.md->size[1];
inB_Mdim0 = imginB.md->size[0];
inB_Mdim1 = 1;
inB_Mdim1_active = 0;

//printf("inB_Ndim : %d\n", imginB.md->size[1]);
inB_Ndim = imginB.md->size[1];
inB_Ndim0 = imginB.md->size[1];
inB_Ndim1 = 1;
inB_Ndim1_active = 0;
}


// input to SGEMM function
int Mdim, Ndim, Kdim;
int Mdim0, Ndim0, Kdim0;
int Mdim1, Ndim1, Kdim1;
int Mdim1_active = 1;
int Ndim1_active = 1;


// if no transpose
Mdim = inA_Mdim;
Mdim0 = inA_Mdim0;
Mdim1 = inA_Mdim1;
Mdim1_active = inA_Mdim1_active;


Ndim = inB_Ndim;
Ndim0 = inB_Ndim0;
Ndim1 = inB_Ndim1;
Ndim1_active = inB_Ndim1_active;

Kdim = inA_Ndim;

Expand All @@ -244,6 +263,7 @@ errno_t computeSGEMM(
Mdim = inA_Ndim;
Mdim0 = inA_Ndim0;
Mdim1 = inA_Ndim1;
Mdim1_active = inA_Ndim1_active;

Kdim = inA_Mdim;

Expand All @@ -253,6 +273,7 @@ errno_t computeSGEMM(
Ndim = inB_Mdim;
Ndim0 = inB_Mdim0;
Ndim1 = inB_Mdim1;
Ndim1_active = inB_Mdim1_active;
}

printf("T %d %d -> SGEMM M=%d,(%d %d) N=%d, (%d %d) K=%d\n",
Expand All @@ -270,7 +291,7 @@ errno_t computeSGEMM(
//
int outMdim = Mdim;
int outNdim = Ndim;
if( Mdim1 == 1)
if( Mdim1_active == 0 )
{
// 2D output
outimg->naxis = 2;
Expand Down
2 changes: 1 addition & 1 deletion plugins/milk-extra-src/linalgebra/SingularValueDecomp.c
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ errno_t compute_SVD(
{
//printf("inMdim : %d\n", imgin.md->size[0]);
inMdim = imgin.md->size[0];
inMdim0 = imgin.md->size[1];
inMdim0 = imgin.md->size[0];
inMdim1 = 1;

//printf("inNdim : %d\n", imgin.md->size[1]);
Expand Down

0 comments on commit 43146db

Please sign in to comment.