Skip to content

Commit

Permalink
int8 RoPE 모듈 작성 완료.
Browse files Browse the repository at this point in the history
  • Loading branch information
js4ngu committed Sep 10, 2024
1 parent 2bb446c commit 9e09ba8
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/main/scala/vfrope/RoPEModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,31 @@ import chisel3.util._

class RoPEModule(width:Int) extends Module{
val io = IO(new Bundle{
val m = Input(UInt(width.W))
val theta = Input(UInt(width.W))
val i = Input(UInt(width.W))
val out = Output(UInt(width.W))
val in = Input(Vec(2, UInt(width.W)))
val m = Input(UInt(width.W))
val theta = Input(UInt(width.W))
val i = Input(UInt(width.W))
val out = Output(Vec(2, SInt(width.W)))
})

val inReg = RegInit(VecInit(Seq.fill(2)(0.U(width.W))))
val outReg = RegInit(VecInit(Seq.fill(2)(0.S(width.W))))
val m_theta_i = RegInit(0.U(width.W))
m_theta_i := io.m * io.theta * io.i
io.out := m_theta_i

val sinCosLUT = Module(new SinCosLUT())
val sinVal = RegInit(0.S(8.W))
val cosVal = RegInit(0.S(8.W))

inReg(0) := io.in(0)
inReg(1) := io.in(1)

m_theta_i := io.m * io.theta * io.i
sinCosLUT.io.angle := m_theta_i
sinVal := sinCosLUT.io.sinOut
cosVal := sinCosLUT.io.cosOut

outReg(0) := inReg(0) * cosVal - inReg(1) * sinVal
outReg(1) := inReg(1) * cosVal + inReg(0) * sinVal

io.out(0) := outReg(0)
io.out(1) := outReg(1)
}
46 changes: 46 additions & 0 deletions src/main/scala/vfrope/sincosLUT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package vfrope
import chisel3._
import chisel3.util._

class SinCosLUT extends Module {
val io = IO(new Bundle {
val angle = Input(UInt(8.W)) // Input angle in discrete steps (e.g., 0 to 255)
val sinOut = Output(SInt(8.W)) // Output sine value
val cosOut = Output(SInt(8.W)) // Output cosine value
})

val sinLUT = VecInit(Seq(
0.S, 3.S, 6.S, 9.S, 12.S, 16.S, 19.S, 22.S, 25.S, 28.S, 31.S, 34.S, 37.S, 40.S, 43.S, 46.S, 49.S, 52.S, 54.S, 57.S,
60.S, 63.S, 66.S, 68.S, 71.S, 73.S, 76.S, 78.S, 81.S, 83.S, 86.S, 88.S, 90.S, 92.S, 94.S, 96.S, 98.S, 100.S, 102.S, 104.S,
106.S, 108.S, 109.S, 111.S, 112.S, 114.S, 115.S, 116.S, 118.S, 119.S, 120.S, 121.S, 122.S, 123.S, 123.S, 124.S, 125.S, 125.S, 126.S, 126.S,
126.S, 127.S, 127.S, 127.S, 127.S, 127.S, 127.S, 127.S, 126.S, 126.S, 125.S, 125.S, 124.S, 124.S, 123.S, 122.S, 121.S, 120.S, 119.S, 118.S,
117.S, 116.S, 114.S, 113.S, 112.S, 110.S, 108.S, 107.S, 105.S, 103.S, 101.S, 99.S, 97.S, 95.S, 93.S, 91.S, 89.S, 87.S, 84.S, 82.S,
80.S, 77.S, 75.S, 72.S, 69.S, 67.S, 64.S, 61.S, 59.S, 56.S, 53.S, 50.S, 47.S, 44.S, 41.S, 39.S, 36.S, 32.S, 29.S, 26.S,
23.S, 20.S, 17.S, 14.S, 11.S, 8.S, 5.S, 2.S, -2.S, -5.S, -8.S, -11.S, -14.S, -17.S, -20.S, -23.S, -26.S, -29.S, -32.S, -36.S,
-39.S, -41.S, -44.S, -47.S, -50.S, -53.S, -56.S, -59.S, -61.S, -64.S, -67.S, -69.S, -72.S, -75.S, -77.S, -80.S, -82.S, -84.S, -87.S, -89.S,
-91.S, -93.S, -95.S, -97.S, -99.S, -101.S, -103.S, -105.S, -107.S, -108.S, -110.S, -112.S, -113.S, -114.S, -116.S, -117.S, -118.S, -119.S, -120.S, -121.S,
-122.S, -123.S, -124.S, -124.S, -125.S, -125.S, -126.S, -126.S, -127.S, -127.S, -127.S, -127.S, -127.S, -127.S, -127.S, -126.S, -126.S, -126.S, -125.S, -125.S,
-124.S, -123.S, -123.S, -122.S, -121.S, -120.S, -119.S, -118.S, -116.S, -115.S, -114.S, -112.S, -111.S, -109.S, -108.S, -106.S, -104.S, -102.S, -100.S, -98.S,
-96.S, -94.S, -92.S, -90.S, -88.S, -86.S, -83.S, -81.S, -78.S, -76.S, -73.S, -71.S, -68.S, -66.S, -63.S, -60.S, -57.S, -54.S, -52.S, -49.S,
-46.S, -43.S, -40.S, -37.S, -34.S, -31.S, -28.S, -25.S, -22.S, -19.S, -16.S, -12.S, -9.S, -6.S, -3.S, 0.S
))

val cosLUT = VecInit(Seq(
127.S, 127.S, 127.S, 127.S, 126.S, 126.S, 126.S, 125.S, 125.S, 124.S, 123.S, 122.S, 121.S, 121.S, 120.S, 118.S, 117.S, 116.S, 115.S, 113.S,
112.S, 110.S, 109.S, 107.S, 105.S, 104.S, 102.S, 100.S, 98.S, 96.S, 94.S, 92.S, 90.S, 87.S, 85.S, 83.S, 80.S, 78.S, 75.S, 73.S,
70.S, 68.S, 65.S, 62.S, 59.S, 57.S, 54.S, 51.S, 48.S, 45.S, 42.S, 39.S, 36.S, 33.S, 30.S, 27.S, 24.S, 21.S, 18.S, 15.S,
12.S, 9.S, 5.S, 2.S, -1.S, -4.S, -7.S, -10.S, -13.S, -16.S, -19.S, -23.S, -26.S, -29.S, -32.S, -35.S, -38.S, -41.S, -44.S, -47.S,
-50.S, -52.S, -55.S, -58.S, -61.S, -63.S, -66.S, -69.S, -71.S, -74.S, -77.S, -79.S, -81.S, -84.S, -86.S, -88.S, -91.S, -93.S, -95.S, -97.S,
-99.S, -101.S, -103.S, -105.S, -106.S, -108.S, -110.S, -111.S, -113.S, -114.S, -115.S, -117.S, -118.S, -119.S, -120.S, -121.S, -122.S, -123.S, -124.S, -124.S,
-125.S, -125.S, -126.S, -126.S, -127.S, -127.S, -127.S, -127.S, -127.S, -127.S, -127.S, -127.S, -126.S, -126.S, -125.S, -125.S, -124.S, -124.S, -123.S, -122.S,
-121.S, -120.S, -119.S, -118.S, -117.S, -115.S, -114.S, -113.S, -111.S, -110.S, -108.S, -106.S, -105.S, -103.S, -101.S, -99.S, -97.S, -95.S, -93.S, -91.S,
-88.S, -86.S, -84.S, -81.S, -79.S, -77.S, -74.S, -71.S, -69.S, -66.S, -64.S, -61.S, -58.S, -55.S, -52.S, -50.S, -47.S, -44.S, -41.S, -38.S,
-35.S, -32.S, -29.S, -26.S, -23.S, -19.S, -16.S, -13.S, -10.S, -7.S, -4.S, -1.S, 2.S, 5.S, 9.S, 12.S, 15.S, 18.S, 21.S, 24.S,
27.S, 30.S, 33.S, 36.S, 39.S, 42.S, 45.S, 48.S, 51.S, 54.S, 57.S, 59.S, 62.S, 65.S, 68.S, 70.S, 73.S, 75.S, 78.S, 80.S,
83.S, 85.S, 87.S, 90.S, 92.S, 94.S, 96.S, 98.S, 100.S, 102.S, 104.S, 105.S, 107.S, 109.S, 110.S, 112.S, 113.S, 115.S, 116.S, 117.S,
118.S, 120.S, 121.S, 121.S, 122.S, 123.S, 124.S, 125.S, 125.S, 126.S, 126.S, 126.S, 127.S, 127.S, 127.S, 127.S
))

io.sinOut := sinLUT(io.angle)
io.cosOut := cosLUT(io.angle)
}
25 changes: 25 additions & 0 deletions src/test/scala/vfrope/RoPEModuleTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package vfrope

import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec

class RoPEModuleTester extends AnyFlatSpec with ChiselScalatestTester {
"RoPEModule" should "work" in {
test(new RoPEModule(width = 8)) { dut =>
dut.io.m.poke(0.U)
dut.io.theta.poke(0.U)
dut.io.i.poke(0.U)
dut.io.in(0).poke(1.U)
dut.io.in(1).poke(1.U)
dut.clock.step(10)

// Fetch io.out(0) and io.out(1) values and print them correctly
val out0 = dut.io.out(0).peek().litValue()
val out1 = dut.io.out(1).peek().litValue()

// Use string interpolation to print values
println(s"RoPE module out : $out0, $out1")
}
}
}
33 changes: 33 additions & 0 deletions src/test/scala/vfrope/sincosLUTtest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package vfrope

import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec

class SinCosLUTTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "SinCosLUT"

it should "output correct sine and cosine values for a specific angle" in {
test(new SinCosLUT()) { dut =>
// List of angles to test
//val anglesToTest = Seq(0,10,20,30,40,50,60,64,70,80,90,100,110,120,127,130,140,150,160,170,180,190,200,210,220,230,240,250,255)
val anglesToTest = Seq(0)

// Iterate through each angle
for (angle <- anglesToTest) {
// Set the input angle
dut.io.angle.poke(angle.U)

// Step the clock
dut.clock.step(10)

// Capture the sine and cosine output
val sinOut = dut.io.sinOut.peek().litValue()
val cosOut = dut.io.cosOut.peek().litValue()

// Print out the results for the angle
println(s"Angle: $angle, Sin: $sinOut, Cos: $cosOut")
}
}
}
}

0 comments on commit 9e09ba8

Please sign in to comment.