Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [BearlyML'24] Add DMA and quantized transformer drivers #1

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions platform/bearly24/inc/hal_dma.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/**
* @file hal_dma.h
* @author Jasmine Angle / [email protected]
* @brief
* @version 0.1
*
* @copyright Copyright (c) 2024
*
*/

#ifndef __HAL_DMA_H
#define __HAL_DMA_H

#ifdef __cplusplus
extern "C" {
#endif

// #include "rv_common.h"
#include "metal.h"
#include "ll_dma.h"

#define DMA0 ((DMA_Type*) 0x8800000)
#define DMA1 ((DMA_Type*) 0x8802000)
#define DMA2 ((DMA_Type*) 0x8804000)
#define DMA3 ((DMA_Type*) 0x8808000)

typedef enum {
MODE_COPY = 0x0,
MODE_MAC = 0x2
} DMA_Operation;

typedef enum {
DMA_OK,
DMA_BADMODE,
DMA_CNTERR,
DMA_DENYR,
DMA_CORRUPTR,
DMA_DENYW
} DMA_Status;

/* Returns the DMA_Status value corresponding with the given DMA engine's status register. */
DMA_Status get_status(DMA_Type* DMAX);

/* Returns whether the given DMA engine's last operation was completed successfully */
static inline uint8_t dma_operation_complete(DMA_Type* DMAX) {
return READ_BITS(DMAX->STATUS, DMA_COMPL_MSK) != 0;
}

/* Returns whether the given DMA engine has an operation in progress */
static inline uint8_t dma_operation_inprogress(DMA_Type* DMAX) {
return READ_BITS(DMAX->STATUS, DMA_INPROG_MSK) != 0;
}

/* Returns whether the given DMA engine has an operation in progress */
static inline uint8_t dma_operation_inprogress_and_not_error(DMA_Type* DMAX) {
return READ_BITS(DMAX->STATUS, DMA_INPROG_MSK) != 0 && READ_BITS(DMAX->STATUS, DMA_ERR_MSK) == 0;
}

/* Returns whether the given DMA engine last operation errored */
static inline uint8_t dma_operation_errored(DMA_Type* DMAX) {
return READ_BITS(DMAX->STATUS, DMA_ERR_MSK) != 0;
}

/* Initializes a block copy from SRC to DST on DMA engine
DMAX to transfer COUNT cache blocks and with stride
SRC_STRIDE between copies. SRC and DST must be cache
block aligned and SRC_STRIDE must be a multiple of a
cache block size. Waits for any previous operations to
complete before initiating current operation. Non blocking
and returns immediately after initiating operation. */
void dma_init_memcpy(DMA_Type* DMAX, void* src, void* dst, uint64_t src_stride, uint32_t count);

/* Initializes a saturating matrix-vector product of int8s
where SRC is a pointer to a matrix of int8s, OPERAND is
a pointer to a vector of int8s to multiply with, STRIDE
is the stride of the input matrix, and COUNT is the numnber
of matrix-vector multiplies to perform. SRC must be cache
block aligned, SRC_STRIDE must be a multiple of a cache
block size, and COUNT must be <= 32 */
void dma_init_MAC(DMA_Type* DMAX, void* src, int8_t* operand, uint64_t src_stride, uint32_t count);

/* Synchronously waits for DMA operation to complete and returns the DMA
Status Code. Refer to BearlyML23 Doc for interpretation of error codes*/
DMA_Status dma_await_result(DMA_Type* DMAX);

/* Synchronously waits for MAC operation to complete, copies result to
*dst, and returns the DMA Status Code. Refer to BearlyML23 Doc
for interpretation of error codes*/
DMA_Status dma_get_MAC_result(DMA_Type* DMAX, int16_t* dst, uint32_t count);

#ifdef __cplusplus
}
#endif

#endif /* __HAL_DMA_H */
49 changes: 49 additions & 0 deletions platform/bearly24/inc/hal_qt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// This file contains the C wrapper for the operations the accelerator can do in hardware.
#include <inttypes.h>
#include <stdio.h>
#include "rocc.h"

#define LOAD_FUNCT7 0
#define STORE_FUNCT7 1
#define QUANTIZE_FUNCT7 2
#define SET_SCALE_FACTOR_FUNCT7 3
#define DOT_PROD_FUNCT7 4
#define RELU_FUNCT7 5
#define SOFTMAX_FUNCT7 6
#define LAYERNORM_FUNCT7 7

// M = CPU memory
// R = CPU RegFile
// vec = Vector RegFile
// Q = Quantization Register

// vec[rs1] = M[rs2]
#define V_LOAD(rs1, rs2) \
ROCC_INSTRUCTION_SS(1, rs1, rs2, LOAD_FUNCT7)

// M[rs2] = vec[rs1]
#define V_STORE(rs1, rs2) \
ROCC_INSTRUCTION_SS(1, rs1, rs2, STORE_FUNCT7)

#define QUANTIZE(rd, rs1, rs2) \
ROCC_INSTRUCTION_DSS(1, rd, rs1, rs2, QUANTIZE_FUNCT7)

// Q[0] = rs1, Q[1] = rs2
#define SET_SCALE_FACTOR(rs1, rs2) \
ROCC_INSTRUCTION_SS(1, rs1, rs2, SET_SCALE_FACTOR_FUNCT7)

// R[rd] = vec[rs1] dot vec[rs2]
#define V_DOT_PROD(rd, rs1, rs2) \
ROCC_INSTRUCTION_DSS(1, rd, rs1, rs2, DOT_PROD_FUNCT7)

// vec[vout] = ReLU(vec[rs1]) (elementwise)
#define V_RELU(rs1) \
ROCC_INSTRUCTION_S(1, rs1, RELU_FUNCT7)

// vec[vout] = SoftMax(vec[rs1]) (elementwise)
#define V_SOFTMAX(rs1) \
ROCC_INSTRUCTION_S(1, rs1, SOFTMAX_FUNCT7)

// vec[vout] = LayerNorm(vec[rs1]) (elementwise)
#define V_LAYERNORM(rs1) \
ROCC_INSTRUCTION_S(1, rs1, LAYERNORM_FUNCT7)
98 changes: 98 additions & 0 deletions platform/bearly24/inc/rocc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Based on code by Schuyler Eldridge. Copyright (c) Boston University
// https://github.com/seldridge/rocket-rocc-examples/blob/master/src/main/c/rocc.h

#ifndef SRC_MAIN_C_ROCC_H
#define SRC_MAIN_C_ROCC_H

#include <stdint.h>

#define STR1(x) #x
#define STR(x) STR1(x)
#define EXTRACT(a, size, offset) (((~(~0 << size) << offset) & a) >> offset)

#define CUSTOMX_OPCODE(x) CUSTOM_ ## x
#define CUSTOM_0 0b0001011
#define CUSTOM_1 0b0101011
#define CUSTOM_2 0b1011011
#define CUSTOM_3 0b1111011

#define CUSTOMX(X, xd, xs1, xs2, rd, rs1, rs2, funct) \
CUSTOMX_OPCODE(X) | \
(rd << (7)) | \
(xs2 << (7+5)) | \
(xs1 << (7+5+1)) | \
(xd << (7+5+2)) | \
(rs1 << (7+5+3)) | \
(rs2 << (7+5+3+5)) | \
(EXTRACT(funct, 7, 0) << (7+5+3+5+5))

// Standard macro that passes rd, rs1, and rs2 via registers
#define ROCC_INSTRUCTION_DSS(X, rd, rs1, rs2, funct) \
ROCC_INSTRUCTION_R_R_R(X, rd, rs1, rs2, funct, 10, 11, 12)

#define ROCC_INSTRUCTION_DS(X, rd, rs1, funct) \
ROCC_INSTRUCTION_R_R_I(X, rd, rs1, 0, funct, 10, 11)

#define ROCC_INSTRUCTION_D(X, rd, funct) \
ROCC_INSTRUCTION_R_I_I(X, rd, 0, 0, funct, 10)

#define ROCC_INSTRUCTION_SS(X, rs1, rs2, funct) \
ROCC_INSTRUCTION_I_R_R(X, 0, rs1, rs2, funct, 11, 12)

#define ROCC_INSTRUCTION_S(X, rs1, funct) \
ROCC_INSTRUCTION_I_R_I(X, 0, rs1, 0, funct, 11)

#define ROCC_INSTRUCTION(X, funct) \
ROCC_INSTRUCTION_I_I_I(X, 0, 0, 0, funct)

// rd, rs1, and rs2 are data
// rd_n, rs_1, and rs2_n are the register numbers to use
#define ROCC_INSTRUCTION_R_R_R(X, rd, rs1, rs2, funct, rd_n, rs1_n, rs2_n) { \
register uint64_t rd_ asm ("x" # rd_n); \
register uint64_t rs1_ asm ("x" # rs1_n) = (uint64_t) rs1; \
register uint64_t rs2_ asm ("x" # rs2_n) = (uint64_t) rs2; \
asm volatile ( \
".word " STR(CUSTOMX(X, 1, 1, 1, rd_n, rs1_n, rs2_n, funct)) "\n\t" \
: "=r" (rd_) \
: [_rs1] "r" (rs1_), [_rs2] "r" (rs2_)); \
rd = rd_; \
}

#define ROCC_INSTRUCTION_R_R_I(X, rd, rs1, rs2, funct, rd_n, rs1_n) { \
register uint64_t rd_ asm ("x" # rd_n); \
register uint64_t rs1_ asm ("x" # rs1_n) = (uint64_t) rs1; \
asm volatile ( \
".word " STR(CUSTOMX(X, 1, 1, 0, rd_n, rs1_n, rs2, funct)) "\n\t" \
: "=r" (rd_) : [_rs1] "r" (rs1_)); \
rd = rd_; \
}

#define ROCC_INSTRUCTION_R_I_I(X, rd, rs1, rs2, funct, rd_n) { \
register uint64_t rd_ asm ("x" # rd_n); \
asm volatile ( \
".word " STR(CUSTOMX(X, 1, 0, 0, rd_n, rs1, rs2, funct)) "\n\t" \
: "=r" (rd_)); \
rd = rd_; \
}

#define ROCC_INSTRUCTION_I_R_R(X, rd, rs1, rs2, funct, rs1_n, rs2_n) { \
register uint64_t rs1_ asm ("x" # rs1_n) = (uint64_t) rs1; \
register uint64_t rs2_ asm ("x" # rs2_n) = (uint64_t) rs2; \
asm volatile ( \
".word " STR(CUSTOMX(X, 0, 1, 1, rd, rs1_n, rs2_n, funct)) "\n\t" \
:: [_rs1] "r" (rs1_), [_rs2] "r" (rs2_)); \
}

#define ROCC_INSTRUCTION_I_R_I(X, rd, rs1, rs2, funct, rs1_n) { \
register uint64_t rs1_ asm ("x" # rs1_n) = (uint64_t) rs1; \
asm volatile ( \
".word " STR(CUSTOMX(X, 0, 1, 0, rd, rs1_n, rs2, funct)) "\n\t" \
:: [_rs1] "r" (rs1_)); \
}

#define ROCC_INSTRUCTION_I_I_I(X, rd, rs1, rs2, funct) { \
asm volatile ( \
".word " STR(CUSTOMX(X, 0, 0, 0, rd, rs1, rs2, funct)) "\n\t" ); \
}

#endif // SRC_MAIN_C_ACCUMULATOR_H
63 changes: 63 additions & 0 deletions platform/bearly24/src/hal_dma.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "hal_dma.h"

DMA_Status get_status(DMA_Type* DMAX) {
if (READ_BITS(DMAX->STATUS, DMA_BADMODE_MSK))
return DMA_BADMODE;
else if (READ_BITS(DMAX->STATUS, DMA_CNTERR_MSK))
return DMA_CNTERR;
else if (READ_BITS(DMAX->STATUS, DMA_DENYR_MSK))
return DMA_DENYR;
else if (READ_BITS(DMAX->STATUS, DMA_CORRUPTR_MSK))
return DMA_CORRUPTR;
else if (READ_BITS(DMAX->STATUS, DMA_DENYW_MSK))
return DMA_DENYW;
return DMA_OK;
}

void dma_init_memcpy(DMA_Type* DMAX, void* src, void* dst, uint64_t src_stride, uint32_t count) {
while (dma_operation_inprogress_and_not_error(DMAX));

DMAX->SRC_ADDR = (uint64_t) src;
DMAX->DEST_ADDR = (uint64_t) dst;
DMAX->SRCSTRIDE = src_stride;
DMAX->MODE = MODE_COPY;
DMAX->COUNT = count;
}

void dma_init_MAC(DMA_Type* DMAX, void* src, int8_t* operand, uint64_t src_stride, uint32_t count) {
while (dma_operation_inprogress_and_not_error(DMAX));

uint64_t* op = (uint64_t*) operand;
for (size_t i = 0; i < 8; i++)
DMAX->OPERAND_REG[i] = op[i];
DMAX->SRC_ADDR = (uint64_t) src;
DMAX->SRCSTRIDE = src_stride;
DMAX->MODE = MODE_MAC;
DMAX->COUNT = count;

}

DMA_Status dma_await_result(DMA_Type* DMAX) {
while (dma_operation_inprogress_and_not_error(DMAX));
if (dma_operation_complete(DMAX))
return DMA_OK;
else
return get_status(DMAX);
}

DMA_Status dma_get_MAC_result(DMA_Type* DMAX, int16_t* dst, uint32_t count) {
while (dma_operation_inprogress_and_not_error(DMAX));
if (count > 32)
count = 32;

if (dma_operation_complete(DMAX)){
for (size_t i = 0; i < count; i++)
dst[i] = DMAX->DEST_REG[i];
return DMA_OK;
}
else {
for (size_t i = 0; i < count; i++)
dst[i] = -1;
return get_status(DMAX);
}
}