diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 7ce9deae098d4..c410f6156c93a 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -48,32 +48,32 @@ def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[h # Load other fw self.ucode_start: dict[str, int] = {} - self.descs: list[tuple[int, memoryview]] = [] + self.descs: list[tuple[list[int], memoryview]] = [] blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0) - self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes) + self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU) # SDMA firmware blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0) - self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)] - self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)] + self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)] + self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)] # PFP, ME, MEC firmware for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]: blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0) # Code part - self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)] + self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'))] # Stack - fw_types = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnun}_STACK') for fwnun in range(fw_cnt)] - self.descs += [self.desc(typ, blob, hdr.data_offset_bytes, hdr.data_size_bytes) for typ in fw_types] + stack_fws = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnum}_STACK') for fwnum in range(fw_cnt)] + self.descs += [self.desc(blob, hdr.data_offset_bytes, hdr.data_size_bytes, *stack_fws)] self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32) # IMU firmware blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0) imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes - self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)] + self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)] # RLC firmware blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0, @@ -81,24 +81,24 @@ def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[h for mem in ['GPM', 'SRM']: off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes') - self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)] + self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'))] for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]: off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes') - self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)] + self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))] for mem in ['P', 'V']: off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes') - self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)] + self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))] - self.descs += [self.desc(am.GFX_FW_TYPE_RLC_G, blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes)] + self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)] def load_fw(self, fname:str, *headers): fpath = next(f for loc in ["/lib/firmware/updates/amdgpu/", "/lib/firmware/amdgpu/"] if (f:=pathlib.Path(loc + fname)).exists()) blob = memoryview(bytearray(fpath.read_bytes())) return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers]) - def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size]) + def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size]) @dataclasses.dataclass(frozen=True) class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 53bc1433f7a5d..c597f798fd16f 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -356,10 +356,10 @@ def init(self): self._tmr_init() # SMU fw should be loaded before TMR. - self._load_ip_fw_cmd(self.adev.fw.smu_psp_desc) + self._load_ip_fw_cmd(*self.adev.fw.smu_psp_desc) self._tmr_load_cmd() - for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(psp_desc) + for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(*psp_desc) self._rlc_autoload_cmd() def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000) @@ -433,16 +433,15 @@ def _prep_ring_cmd(self, hdr): cmd.cmd_id = hdr return cmd - def _load_ip_fw_cmd(self, psp_desc): - if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[psp_desc[0]]}") - fw_type, fw_bytes = psp_desc - + def _load_ip_fw_cmd(self, fw_types, fw_bytes): self._prep_msg1(fw_bytes) - cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW) - cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) - cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes) - cmd.cmd.cmd_load_ip_fw.fw_type = fw_type - return self._ring_submit() + for fw_type in fw_types: + if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[fw_type]}") + cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW) + cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) + cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes) + cmd.cmd.cmd_load_ip_fw.fw_type = fw_type + self._ring_submit() def _tmr_load_cmd(self): cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)