Skip to content

Commit

Permalink
Update run and runq
Browse files Browse the repository at this point in the history
run - mirror changes to runq
  • Loading branch information
trholding committed Jul 20, 2024
1 parent e842bf7 commit 3d9ae22
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
50 changes: 41 additions & 9 deletions run.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Inference for Llama-2 Transformer model in pure C */
/* Inference for Llama 2 & LLama 3 Transformer model in pure C */

// L2E Addition
/* The Llama 2 Everywhere @trholding (Vulcan) fork */
Expand Down Expand Up @@ -121,24 +121,32 @@ __static_yoink("zipos");
#endif

// ----------------------------------------------------------------------------
// AVX Support
// OpenMP and OpenACC Support

#ifdef ACCELAVX
#include <immintrin.h>
#ifdef OPENMP
#include <omp.h>
#endif

// ----------------------------------------------------------------------------
// OpenMP and OpenACC Support

// Macro that makes a pragma enabled with string substitution
#define MKPRAGMA_(x) _Pragma (#x)
#define MK_PRAGMA(x) MKPRAGMA_(x)

// Portable OpenMP and OpenACC pragma macros
#ifdef OPENMP
#define ACCELS() MK_PRAGMA(omp parallel for)
#define ACCEL(...) MK_PRAGMA(omp parallel for private(__VA_ARGS__))
#define ACCELRD(VAR) MK_PRAGMA(omp parallel for reduction(+:VAR))
#elif defined(OPENACC)
#define ACCELS() MK_PRAGMA(acc parallel loop)
#define ACCEL(...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__))
#define ACCELRD(VAR) MK_PRAGMA(acc parallel loop reduction(+:VAR))
#endif

// ----------------------------------------------------------------------------
// AVX Support

#ifdef ACCELAVX
#include <immintrin.h>
#endif

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -355,6 +363,9 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
#ifdef BLAS
ss = cblas_sdot(size, x, 1.0f, x, 1.0f);
#else
#ifdef ACCEL
ACCELRD(ss) // OMP/OACC Macro
#endif
// END L2E Addition
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
Expand All @@ -366,6 +377,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
// normalize and scale
// L2E Addition
#ifdef ACCEL
ACCELS() // OMP/OACC Macro
#endif
// END L2E Addition
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
Expand Down Expand Up @@ -620,6 +636,11 @@ float* forward(Transformer* transformer, int token, int pos) {
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);

// residual connection back into x
// L2E Addition
#ifdef ACCEL
ACCELS() // OMP/OACC Macro
#endif
// END L2E Addition
for (int i = 0; i < dim; i++) {
x[i] += s->xb2[i];
}
Expand All @@ -633,6 +654,11 @@ float* forward(Transformer* transformer, int token, int pos) {
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);

// SwiGLU non-linearity
// L2E Addition
#ifdef ACCEL
ACCELS() // OMP/OACC Macro
#endif
// END L2E Addition
for (int i = 0; i < hidden_dim; i++) {
float val = s->hb[i];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
Expand Down Expand Up @@ -926,7 +952,6 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
free(str_buffer);

}

// END L2E Addition

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -1332,7 +1357,14 @@ void error_usage() {
}

int main(int argc, char *argv[]) {

// L2E Addition
#ifdef OPENMP
int num_threads = omp_get_num_procs(); // get the number of CPU cores
omp_set_num_threads(num_threads); // set the number of threads to use for parallel regions
int num_levels = omp_get_supported_active_levels(); // get maximum number of nested parallel regions supported
omp_set_max_active_levels(num_levels); // set to maximum supported parallel regions
#endif
// END L2E Addition
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
Expand Down
3 changes: 1 addition & 2 deletions runq.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Inference for Llama-2 Transformer model in pure C, int8 quantized forward pass. */
/* Inference for Llama 2 & Llama 3 Transformer model in pure C, int8 quantized forward pass. */

// L2E Addition
/* The Llama 2 Everywhere @trholding (Vulcan) fork */
Expand Down Expand Up @@ -1455,7 +1455,6 @@ int main(int argc, char *argv[]) {
omp_set_num_threads(num_threads); // set the number of threads to use for parallel regions
int num_levels = omp_get_supported_active_levels(); // get maximum number of nested parallel regions supported
omp_set_max_active_levels(num_levels); // set to maximum supported parallel regions
// printf("OMP > Max Cores: %d : Max Levels : %d \n",num_threads, num_levels);
#endif
// END L2E Addition
// default parameters
Expand Down

0 comments on commit 3d9ae22

Please sign in to comment.