Skip to content

Commit

Permalink
am: load fw in batches (tinygrad#9185)
Browse files Browse the repository at this point in the history
* am: load fw in batches

* am: 1mb less fw copies

* mypy

* list
  • Loading branch information
nimlgen authored Feb 21, 2025
1 parent 1db4341 commit 041b6d5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
26 changes: 13 additions & 13 deletions tinygrad/runtime/support/am/amdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,57 +48,57 @@ def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[h

# Load other fw
self.ucode_start: dict[str, int] = {}
self.descs: list[tuple[int, memoryview]] = []
self.descs: list[tuple[list[int], memoryview]] = []

blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)

# SDMA firmware
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)]

# PFP, ME, MEC firmware
for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)

# Code part
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'))]

# Stack
fw_types = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnun}_STACK') for fwnun in range(fw_cnt)]
self.descs += [self.desc(typ, blob, hdr.data_offset_bytes, hdr.data_size_bytes) for typ in fw_types]
stack_fws = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnum}_STACK') for fwnum in range(fw_cnt)]
self.descs += [self.desc(blob, hdr.data_offset_bytes, hdr.data_size_bytes, *stack_fws)]
self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)

# IMU firmware
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)]

# RLC firmware
blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)

for mem in ['GPM', 'SRM']:
off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes')
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)]
self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'))]

for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:
off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes')
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]

for mem in ['P', 'V']:
off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]

self.descs += [self.desc(am.GFX_FW_TYPE_RLC_G, blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes)]
self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)]

def load_fw(self, fname:str, *headers):
fpath = next(f for loc in ["/lib/firmware/updates/amdgpu/", "/lib/firmware/amdgpu/"] if (f:=pathlib.Path(loc + fname)).exists())
blob = memoryview(bytearray(fpath.read_bytes()))
return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])

def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])

@dataclasses.dataclass(frozen=True)
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
Expand Down
21 changes: 10 additions & 11 deletions tinygrad/runtime/support/am/ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,10 @@ def init(self):
self._tmr_init()

# SMU fw should be loaded before TMR.
self._load_ip_fw_cmd(self.adev.fw.smu_psp_desc)
self._load_ip_fw_cmd(*self.adev.fw.smu_psp_desc)
self._tmr_load_cmd()

for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(psp_desc)
for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(*psp_desc)
self._rlc_autoload_cmd()

def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000)
Expand Down Expand Up @@ -433,16 +433,15 @@ def _prep_ring_cmd(self, hdr):
cmd.cmd_id = hdr
return cmd

def _load_ip_fw_cmd(self, psp_desc):
if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[psp_desc[0]]}")
fw_type, fw_bytes = psp_desc

def _load_ip_fw_cmd(self, fw_types, fw_bytes):
self._prep_msg1(fw_bytes)
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
return self._ring_submit()
for fw_type in fw_types:
if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[fw_type]}")
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
self._ring_submit()

def _tmr_load_cmd(self):
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)
Expand Down

0 comments on commit 041b6d5

Please sign in to comment.