diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 59626ef93c2..28306797c01 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1954,6 +1954,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); // Zen4 has gfni but is slower and 8x16 works better on zen4. 14x16 is faster on Sapphire Rapids + #if XNN_ENABLE_AVX512VNNIGFNI if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnnigfni && cpuinfo_get_core(0)->uarch != cpuinfo_uarch_zen4) { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm); qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(14)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm); @@ -1962,7 +1963,9 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 3; qd8_f32_qb4w_gemm_config.planes = 2; - } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + } else + #endif // XNN_ENABLE_AVX512VNNIGFNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm); qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_avx512vnni_params; diff --git a/src/subgraph/fully-connected.c b/src/subgraph/fully-connected.c index 04c4c74609f..f43a7a9b3e0 100644 --- a/src/subgraph/fully-connected.c +++ b/src/subgraph/fully-connected.c @@ -796,7 +796,12 @@ static inline enum xnn_compute_type validate_datatypes_with_bias( } break; case xnn_datatype_qbint4: - if (input_datatype == xnn_datatype_qdint8 && + if (input_datatype == xnn_datatype_fp32 && + bias_datatype == xnn_datatype_fp32 && + output_datatype == xnn_datatype_fp32) + { + return xnn_compute_type_fp32; + } else if (input_datatype == xnn_datatype_qdint8 && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { @@ -883,7 +888,9 @@ static inline enum xnn_compute_type validate_datatypes_without_bias( } break; case xnn_datatype_qbint4: - if (input_datatype == xnn_datatype_qdint8 && output_datatype == xnn_datatype_fp32) { + if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { + return xnn_compute_type_fp32; + } else if (input_datatype == xnn_datatype_qdint8 && output_datatype == xnn_datatype_fp32) { return xnn_compute_type_qd8_to_fp32; } else if (input_datatype == xnn_datatype_qdint8 && output_datatype == xnn_datatype_fp16) { return xnn_compute_type_qd8_to_fp16;