From db2b45f24dcb0678ef6075b9912f63f7522d65a7 Mon Sep 17 00:00:00 2001 From: mabdollahpo Date: Wed, 29 Jan 2025 17:09:07 +0100 Subject: [PATCH] [software] end2end mimo simulation with Sionna. --- software/PHY_emulator/MMSE_BER.py | 513 ++++++++++++----------- software/PHY_emulator/ene2end_mimo.py | 578 ++++++++++++++++++++++++++ 2 files changed, 842 insertions(+), 249 deletions(-) create mode 100644 software/PHY_emulator/ene2end_mimo.py diff --git a/software/PHY_emulator/MMSE_BER.py b/software/PHY_emulator/MMSE_BER.py index e9854a8ab..17b056e26 100644 --- a/software/PHY_emulator/MMSE_BER.py +++ b/software/PHY_emulator/MMSE_BER.py @@ -1,11 +1,7 @@ -#!/usr/bin/env python3 -# Copyright 2022 ETH Zurich and University of Bologna. -# Solderpad Hardware License, Version 0.51, see LICENSE for details. -# SPDX-License-Identifier: SHL-0.51 # Runs and end to end MIMO transmission with HW in the loop. -# Author: Marco Bertuletti +# Author: Mahdi Abdollahpour, University of Bologna import numpy as np import pandas as pd @@ -18,169 +14,9 @@ import sys import pyflexfloat as ff import matplotlib.pyplot as plt -from scipy.linalg import solve_triangular +import end2end_mimo -# __ __ ___ __ __ ___ _______ __ -# | \/ |_ _| \/ |/ _ \ __|_ _\ \/ / -# | |\/| || || |\/| | (_) |___|| | > < -# |_| |_|___|_| |_|\___/ |_| /_/\_\ -## - - -class Constellation: - def __init__(self, constellation_type): - self.constellation_type = constellation_type - self.symbols = None - self.codes = None - self.num_bits = None - self.get_constellation() - - def get_constellation(self): - if self.constellation_type == '4QAM': - self.symbols = np.array([-1 - 1j, +1 - 1j, - -1 + 1j, +1 + 1j]) - elif self.constellation_type == '16QAM': - self.symbols = np.array([-3 - 3j, -3 - 1j, -3 + 3j, -3 + 1j, - -1 - 3j, -1 - 1j, -1 + 3j, -1 + 1j, - +3 - 3j, +3 - 1j, +3 + 3j, +3 + 1j, - +1 - 3j, +1 - 1j, +1 + 3j, +1 + 1j]) - elif self.constellation_type == '64QAM': - self.symbols = np.array([-7 - 7j, -7 - 5j, -7 - 1j, -7 - 3j, -7 + 7j, -7 + 5j, -7 + 1j, -7 + 3j, - -5 - 7j, -5 - 5j, -5 - 1j, -5 - 3j, -5 + 7j, -5 + 5j, -5 + 1j, -5 + 3j, - -1 - 7j, -1 - 5j, -1 - 1j, -1 - 3j, -1 + 7j, -1 + 5j, -1 + 1j, -1 + 3j, - -3 - 7j, -3 - 5j, -3 - 1j, -3 - 3j, -3 + 7j, -3 + 5j, -3 + 1j, -3 + 3j, - +7 - 7j, +7 - 5j, +7 - 1j, +7 - 3j, +7 + 7j, +7 + 5j, +7 + 1j, +7 + 3j, - +5 - 7j, +5 - 5j, +5 - 1j, +5 - 3j, +5 + 7j, +5 + 5j, +5 + 1j, +5 + 3j, - +1 - 7j, +1 - 5j, +1 - 1j, +1 - 3j, +1 + 7j, +1 + 5j, +1 + 1j, +1 + 3j, - +3 - 7j, +3 - 5j, +3 - 1j, +3 - 3j, +3 + 7j, +3 + 5j, +3 + 1j, +3 + 3j]) - else: - raise ValueError("Unsupported constellation type.") - - # Calculate number of points and bits per symbol - num_points = len(self.symbols) - self.num_bits = np.int32(np.log2(num_points)) - - # Generate binary codes for each constellation point - self.codes = [format(i, f'0{self.num_bits}b') - for i in range(num_points)] - - def plot_constellation(self): - # Plot the constellation - plt.figure(figsize=(6, 6)) - plt.scatter(np.real(self.symbols), np.imag(self.symbols), color='blue') - - # Add labels (binary codes) to each point - for i, point in enumerate(self.symbols): - plt.text( - np.real(point) + 0.1, - np.imag(point) + 0.1, - self.codes[i], - fontsize=12) - - # Draw axes and grid - plt.axhline(0, color='black', linewidth=0.5) - plt.axvline(0, color='black', linewidth=0.5) - plt.grid(True) - - # Label the axes - plt.title('{} Constellation Diagram with Binary Labels'.format( - self.constellation_type)) - plt.xlabel("In-phase (I)") - plt.ylabel("Quadrature (Q)") - - # Save the plot as an image - plt.savefig('MMSE_QAM.png') - - # Show the plot - plt.show() - - def encode_symbol(self, x): - # Combine real and imaginary parts into complex numbers - x_complex = x[:, 0] + 1j * x[:, 1] - # Find the closest symbols in the constellation - idx_x = [np.argmin(np.abs(self.symbols - est)) for est in x_complex] - # Get the corresponding binary code for each symbol - bit_x = [self.codes[k] for k in idx_x] - - return bit_x - - -def mmse(x, H, y, N, my_type): - - # Type cast - H = H.real.astype(my_type) + 1j * H.imag.astype(my_type) - x = x.real.astype(my_type) + 1j * x.imag.astype(my_type) - y = y.real.astype(my_type) + 1j * y.imag.astype(my_type) - N = N.astype(my_type) - - # MMSE estimator - H_h = H.conj().T - G = np.matmul(H_h, H) + N * np.eye(H.shape[1]) - xhat = np.matmul(np.linalg.inv(G), np.dot(H_h, y)) - # Type cast - xhat = xhat.real.astype(my_type) + 1j * xhat.imag.astype(my_type) - H = H.flatten(order='C') - return N, H, y, x, xhat - - -def generate_mimo_transmission_f16( - N_tx, N_rx, N_itr, symbols, channel_type, SNRdB): - - # Create input vector - idx = np.random.randint(0, len(symbols), size=N_tx) - Es = np.mean(abs(symbols)**2) - x = symbols[idx] - - # Generate channel and noise - if channel_type == 'rayleigh': - # Generate Rayleigh fading channel - H = np.sqrt(np.random.chisquare(2, [N_rx, N_tx])) + 1j * \ - np.sqrt(np.random.chisquare(2, [N_rx, N_tx])) - Eh = (np.linalg.norm(H, 'fro')**2) / N_rx - elif channel_type == 'random': - H = np.sqrt(0.5) * \ - (np.random.normal(0, 1, [N_rx, N_tx]) + 1j * - np.random.normal(0, 1, [N_rx, N_tx])) - Eh = (np.linalg.norm(H, 'fro')**2) / N_rx - else: - # Generate AWGN channel - H = np.eye(N_rx) + 1.j * np.zeros([N_rx, N_tx]) - Eh = 1 - - # Noise variance - N = 0.5 * Es * Eh * 10**(-SNRdB / 10) - n = (np.random.normal(0, np.sqrt(N), N_rx) + 1j * - np.random.normal(0, np.sqrt(N), N_rx)) - N = N / Es - - # Channel propagation - y = np.dot(H, x) + n.flatten() - # MMSE estimator - N64, H64, y64, x64, xhat64 = mmse(x, H, y, N, np.float64) - N16, H16, y16, x16, xhat16 = mmse(x, H, y, N, np.float16) - - # 16b inputs and outputs - N16 = N16 * np.ones(2 * N_tx) - y16 = np.column_stack((y16.real, y16.imag)).flatten() - H16 = np.column_stack((H16.real, H16.imag)).flatten() - x16 = np.column_stack((x16.real, x16.imag)).flatten() - xhat16 = np.column_stack((xhat16.real, xhat16.imag)).flatten() - # Golden model - x64 = np.column_stack((x64.real, x64.imag)).flatten() - xhat64 = np.column_stack((xhat64.real, xhat64.imag)).flatten() - - output = { - "N16": N16, - "y16": y16, - "H16": H16, - "x16": x16, - "x64": x64, - "xhat16": xhat16, - "xhat64": xhat64 - } - return output - # ___ _ # | _ ) __ _ _ _ __| |_ ___ ___ @@ -189,6 +25,8 @@ def generate_mimo_transmission_f16( ## + + def stringify_array(arr, typ, name): count = 0 output_string = typ @@ -206,7 +44,6 @@ def stringify_array(arr, typ, name): output_string += "};\n" return output_string - def gen_data_header_file(outdir, my_type, **kwargs): file = outdir / f"{kwargs['name']}.h" @@ -214,6 +51,7 @@ def gen_data_header_file(outdir, my_type, **kwargs): // Licensed under the Apache License, Version 2.0, see LICENSE for details.\n \ // SPDX-License-Identifier: Apache-2.0\n\n \ // File generated with .data/print_header.py\n" + bb = kwargs['N_itr'] with open(file, 'w') as f: string += "#define N_TX ({})\n".format(kwargs['N_tx']) string += "#define N_RX ({})\n".format(kwargs['N_rx']) @@ -259,6 +97,7 @@ def banshee_call(banshee_dir: pathlib.Path.cwd(), return result + def banshee_cast_output(string): # Capture the output @@ -273,7 +112,7 @@ def banshee_cast_output(string): return np.array(half_float_array) -def plot_result(vBER, vEVM, vSNRdB, precisions): +def plot_result(vBER, vEVM, vSNRdB, precisions, name_tag): # Create a figure with two subplots side by side fig, axs = plt.subplots(1, 2, figsize=(16, 6)) @@ -283,7 +122,8 @@ def plot_result(vBER, vEVM, vSNRdB, precisions): axs[0].semilogy(vSNRdB, vBER[j], marker='o', label='{}'.format(precisions[j][0])) axs[0].set_title('BER vs SNR') - axs[0].set_xlabel('SNR (dB)') + # axs[0].set_xlabel('SNR (dB)') + axs[0].set_xlabel('Eb/No (dB)') axs[0].set_ylabel('BER') axs[0].grid(True, which='both') axs[0].legend() @@ -299,7 +139,65 @@ def plot_result(vBER, vEVM, vSNRdB, precisions): axs[1].legend() # Adjust layout and show the plots plt.tight_layout() - plt.savefig('MMSE_BER.png') + plt.savefig(f'MMSE_BER_{name_tag}.png') + # plt.show() + +def bits_per_symbol(mod_scheme): + mod_scheme = mod_scheme.lower() + mod_to_bits = { + 'bpsk': 1, + 'qpsk': 2, + '16qam': 4, + '64qam': 6, + '256qam': 8 + } + + if mod_scheme in mod_to_bits: + return mod_to_bits[mod_scheme] + else: + raise ValueError(f"Unknown modulation scheme: {mod_scheme}") + +def cast_reshape_to_2D_real(data, dtype=np.float16): + + first_dim_size = data.shape[0] + last_dim_size = data.shape[1]*2 + + + # stack as real float16 elements + if np.isrealobj(data): + data = np.stack( ( data.astype(dtype),data.astype(dtype) ) ,axis=-1) + else: + data = np.stack( ( data.real.astype(dtype),data.imag.astype(dtype) ) ,axis=-1) + + # reshape to 2D + data = data.reshape(first_dim_size,last_dim_size) + + # swap batch dimension for banshee call (put the batch dimension in 1'th axis) + data = np.transpose(data,(-1,0)) + + return data + +def cast_reshape_to_2D_complex(data,dtype=np.complex64): + + # [2*N_tx, batch_size(N_batch * num_symbols)] + shape = data.shape + data_size = int(shape[0]/2) + batch_size = shape[1] + + # to [batch_size(N_batch * num_symbols), 2*N_tx] + data = np.transpose(data,(1,0)) + + # to [batch_size, data_size, 2] + data = np.reshape( data, (batch_size,data_size,2) ) + + # to complex [batch_size, data_size] + data_complex = np.zeros((batch_size,data_size),dtype=dtype) + data_complex.real = np.squeeze(data[:,:,0]) + data_complex.imag = np.squeeze(data[:,:,1]) + + + return data_complex + def main(): @@ -340,8 +238,17 @@ def main(): help='Number of transmitting user equipments (UEs). Defaults to 4.') parser.add_argument("-m", "--receivers", type=int, default=4, required=False, help='Number of receiving antennas. Defaults to 4.') - parser.add_argument("-b", "--batchsize", type=int, default=500, required=False, - help='Batch size for transmission processing that fits within L1 cache. Defaults to 500.') + parser.add_argument("-b", "--batchsize", type=int, default=32, required=False, + help='Batch size for transmission processing that fits within L1 cache. Defaults to 32.') + parser.add_argument("-f", "--frequency", type=float, default=3.5e9, required=False, + help='Carrier frequency of the OFDM MIMO transmision. Defaults to 3.5 GHz') + parser.add_argument("-t", "--fftsize", type=int, default=128, required=False, + help='Number of subcarriers. Default to 128') + parser.add_argument("--numofdmsymbols", type=int, default=14, required=False, + help='Number of OFDM symbols. Default to 14') + + + args = parser.parse_args() # Directories @@ -355,12 +262,23 @@ def main(): N_rx = args.receivers N_batch = args.batchsize + carrier_frequency = args.frequency + fft_size = args.fftsize + num_ofdm_symbols = args.numofdmsymbols + + + num_symbols = num_ofdm_symbols * fft_size + + + num_bits_per_symbol = bits_per_symbol(args.constellation) + constellation = args.constellation + + # Arithmetic precisions + compiler flags if run_banshee & (channel_type == "rayleigh"): precisions = [['64b', ""], - ['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], - ['16b-MP cDotp', "\"-DSINGLE -DBANSHEE -DVEC -D__CDOTP\""]] - vSNRdB = range(0, 40, 4) + ['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""]] + vSNRdB = range(-6, 40, 4) vITR = np.concatenate([np.full(9, 1), np.full(1, 2)]) elif run_banshee & (channel_type == "awgn"): precisions = [['64b', ""], @@ -368,25 +286,100 @@ def main(): ['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], ['16b-MP cDotp', "\"-DSINGLE -DBANSHEE -DVEC -D__CDOTP\""], ['8b-MP', "\"-DSINGLE -DBANSHEE\""], - ['8b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""]] - vSNRdB = range(0, 20, 2) - vITR = np.concatenate([np.full(3, 1), np.full(5, 2), np.full(2, 6)]) + ['8b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""] + ] + vSNRdB = range(0, 14, 2) + vITR = np.concatenate([np.full(3, 3), np.full(5, 6), np.full(5, 10)]) + max_it = 50 + be_target = 10 + elif run_banshee & (channel_type == "umi"): + precisions = [['64b', ""], + # ['16b-MP', "\"-DSINGLE -DBANSHEE\""], + ['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], #works + ['16b-MP wDotp dgLoaded', "\"-DSINGLE -DBANSHEE -DVEC\""], + # ['16b-MP cDotp', "\"-DSINGLE -DBANSHEE -DVEC -D__CDOTP\""], + # ['8b-MP', "\"-DSINGLE -DBANSHEE\""], + # ['8b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], + ] + vSNRdB = range(0, 40, 2) + vITR = np.concatenate([np.full(5, 10), np.full(5, 20), np.full(5, 70), np.full(8, 90), np.full(8, 100)]) + max_it = 150 + be_target = 10 + elif run_banshee & (channel_type == "flatfading"): + precisions = [['64b', ""], + # ['16b-MP', "\"-DSINGLE -DBANSHEE\""], + ['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], + ['16b-MP wDotp dgLoaded', "\"-DSINGLE -DBANSHEE -DVEC\""], + # ['16b-MP cDotp', "\"-DSINGLE -DBANSHEE -DVEC -D__CDOTP\""], + # ['8b-MP', "\"-DSINGLE -DBANSHEE\""], + # ['8b-MP dgLoaded', "\"-DSINGLE -DBANSHEE\""], + # ['8b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""], + ] + vSNRdB = range(0, 40, 4) + vITR = np.concatenate([np.full(6, 1), np.full(6, 20), np.full(4, 30), np.full(8, 40), np.full(8, 50)]) + max_it = 50 + be_target = 10 else: precisions = [['64b', '']] - vSNRdB = range(0, 20, 2) - vITR = np.concatenate([np.full(3, 1), np.full(5, 2), np.full(2, 6)]) + vSNRdB = range(-6, 16, 2) + vITR = np.concatenate([np.full(3, 20), np.full(5, 20), np.full(2, 50)]) + max_it = 200 + be_target = 10 + + + + + # Golden Model Data Types + # tf.complex64: float32_real + float32_imag + bits_dtype = np.int16 + gm_dtype_bits = 64 + if gm_dtype_bits ==32: + gm_dtype_cplx = np.complex64 + gm_dtype_real = np.float32 + elif gm_dtype_bits==64: + gm_dtype_cplx = np.complex128 + gm_dtype_real = np.float64 + else: + raise UnexpectedError('Unexpected dtype!') - # Constellation - const = Constellation(args.constellation) + + + name_tag = f'{channel_type}_Tx{N_tx}_Rx{N_rx}_{constellation}_EbNo_{vSNRdB.start}_{vSNRdB.stop}_{vSNRdB.step}_perc_{len(precisions)}.npy' + + vTBE_file = os.path.join(script_dir.parents[0],f'vTBE_{name_tag}') + vTB_file = os.path.join(script_dir.parents[0],f'vTB_{name_tag}') + + # Initialize vTBE and vTB if files exist, else start from scratch + if os.path.exists(vTBE_file) and os.path.exists(vTB_file): + vTBE = np.load(vTBE_file) + vTB = np.load(vTB_file) + print('Previous simulation loaded.') + else: + vTBE = np.zeros([len(precisions), len(vSNRdB)], np.float64) + vTB = np.zeros([len(precisions), len(vSNRdB)], np.float64) + print('Previous simulation not found. Starting from scratch!') # Vectors for computation - vTBE = np.zeros([len(precisions), len(vSNRdB)], np.float64) vBER = np.zeros([len(precisions), len(vSNRdB)], np.float64) vMSE = np.zeros([len(precisions), len(vSNRdB)], np.float64) vEVM = np.zeros([len(precisions), len(vSNRdB)], np.float64) vVM = np.zeros(len(vSNRdB), np.float64) + + + + + # Monte Carlo Params + + # coderate = 553/1024 # i=14 + coderate = 1 + # Instanciate OFDMMIMO class + ofdm_mimo = end2end_mimo.OFDMMIMO(scenario=channel_type, perfect_csi=True, carrier_frequency=carrier_frequency,fft_size=fft_size, + num_ofdm_symbols=num_ofdm_symbols,subcarrier_spacing=30e3,cyclic_prefix_length=20,num_ut=N_tx,num_ut_ant=1, + num_bs_ant=N_rx, num_bits_per_symbol=num_bits_per_symbol,coderate=coderate, dtype_bits=gm_dtype_bits) + + num_data_symbols = ofdm_mimo.get_num_data_symbols() # ----------------------------------------------------------- # SNR LOOP # ----------------------------------------------------------- @@ -396,94 +389,107 @@ def main(): for iSNR in range(0, len(vSNRdB)): SNRdB = vSNRdB[iSNR] - N_itr = vITR[iSNR] * N_batch - + # N_itr = vITR[iSNR] * N_batch + N_itr = max_it + total_bits = 0 # ----------------------------------------------------------- # BATCH LOOP # ----------------------------------------------------------- # Loop over the batches that fit the cluster L1 memory - for iMC in range(0, N_itr, N_batch): - - # Vector for results + # for iMC in range(0, N_itr, N_batch): + for iMC in range(0, max_it): + # Vector for one batch results vxhat = np.empty( - (len(precisions), 2 * N_tx, N_batch), dtype=np.float64) - - # Generate data - np.random.seed(int(time.time())) - vN16 = np.empty((2 * N_tx, N_batch), dtype=np.float16) - vy16 = np.empty((2 * N_rx, N_batch), dtype=np.float16) - vH16 = np.empty((2 * N_tx * N_rx, N_batch), dtype=np.float16) - vx16 = np.empty((2 * N_tx, N_batch), dtype=np.float16) - vxhat16 = np.empty((2 * N_tx, N_batch), dtype=np.float16) - # Golden model - vx64 = np.empty((2 * N_tx, N_batch), dtype=np.float64) - vxhat64 = np.empty((2 * N_tx, N_batch), dtype=np.float64) - - # Random BER iterations - for iBatch in range(0, N_batch): - output_mmse = generate_mimo_transmission_f16( - N_tx, N_rx, N_batch, const.symbols, channel_type, SNRdB) - vN16[:, iBatch] = output_mmse["N16"] - vy16[:, iBatch] = output_mmse["y16"] - vH16[:, iBatch] = output_mmse["H16"] - vx16[:, iBatch] = output_mmse["x16"] - vx64[:, iBatch] = output_mmse["x64"] - vxhat16[:, iBatch] = output_mmse["xhat16"] - vxhat64[:, iBatch] = output_mmse["xhat64"] - - # Collect decoded symbols - vxhat[0, :, :] = vxhat64 + (len(precisions), 2 * N_tx, num_data_symbols*N_batch), dtype=gm_dtype_real) + vbhat = np.empty( + (len(precisions), N_tx* num_data_symbols*N_batch*num_bits_per_symbol), dtype=bits_dtype) + vxhat_gm = np.empty((2 * N_tx, num_data_symbols*N_batch), dtype=gm_dtype_real) + + # ebno_db = SNRdB - 10*np.log10(num_bits_per_symbol*coderate) + ebno_db = SNRdB + b, b_hat, x, xhat, y, H, s, no_eff, num_symbols, no_db = ofdm_mimo(batch_size=N_batch, ebno_db=ebno_db, min_ut_velocity=8.3, max_ut_velocity=8.3) + tbe, num_bits, ber = ofdm_mimo.count_total_bit_error(b,b_hat) + vTBE[0,iSNR] += tbe + vTB[0,iSNR] += num_bits + + + # ----- Dimensions ----- + # y: [batch_size* num_rx* num_ofdm_symbols* num_effective_subcarriers, num_rx_ant] + # h: [batch_size* num_rx* num_ofdm_symbols* num_effective_subcarriers,.. + # ..., num_rx_ant* num_streams_per_rx(num_Interfering_streams_per_rx)] + # s [batch_size* num_rx* num_ofdm_syms* fft_size, num_rx_ant] + # xhat [batch_size* num_data_symbols, num_tx* num_streams] + + vN16 = cast_reshape_to_2D_real(s.numpy().real, dtype=np.float16) + vy16 = cast_reshape_to_2D_real(y.numpy(), dtype=np.float16) + vH16 = cast_reshape_to_2D_real(H.numpy(), dtype=np.float16) + vx16 = cast_reshape_to_2D_real(x.numpy(), dtype=np.float16) + vx64 = cast_reshape_to_2D_real(x.numpy(), dtype=gm_dtype_real) + vxhat16 = cast_reshape_to_2D_real(xhat.numpy(), dtype=np.float16) + vxhat_gm = cast_reshape_to_2D_real(xhat.numpy(), dtype=gm_dtype_real) + # ---------------------------------------------------------------- # BANSHEE CALL # ---------------------------------------------------------------- + + banshee_batch_size = vy16.shape[1] # to support simple_mimo too if run_banshee: for iPrec, (precision, flag) in enumerate(precisions[1:]): + vN16_ = vN16 + if 'dgLoaded' in precision: + vN16_ = 0.001+vN16 + if precision in ('8b-MP', '8b-MP wDotp'): vH8 = ff.array(vH16, "e5m2") - vN8 = ff.array(vN16, "e5m2") + vN8 = ff.array(vN16_, "e5m2") vy8 = ff.array(vy16, "e5m2") kwargs = {'name': 'data_mimo_mmse_f8', 'H': vH8, 'N': vN8, 'y': vy8, - 'N_tx': N_tx, 'N_rx': N_rx, 'N_itr': N_batch} + 'N_tx': N_tx, 'N_rx': N_rx, 'N_itr': banshee_batch_size} gen_data_header_file(DATA_DIR, '__fp8', **kwargs) result = banshee_call( BANSHEE_DIR, SOFTWARE_DIR, flag, "mimo_mmse_f8") else: kwargs = {'name': 'data_mimo_mmse_f16', - 'H': vH16, 'N': vN16, 'y': vy16, - 'N_tx': N_tx, 'N_rx': N_rx, 'N_itr': N_batch} + 'H': vH16, 'N': vN16_, 'y': vy16, + 'N_tx': N_tx, 'N_rx': N_rx, 'N_itr': banshee_batch_size} gen_data_header_file(DATA_DIR, '__fp16', **kwargs) result = banshee_call( BANSHEE_DIR, SOFTWARE_DIR, flag, "mimo_mmse_f16") + + + # [2*N_tx * N_batch * num_symbols] result_casted = banshee_cast_output(result.stderr) - vxhat[iPrec + 1, :, :] = result_casted.reshape(2 * N_tx, N_batch, order='F') - # ---------------------------------------------------------------- + # [2*N_tx, N_batch * num_symbols] + xhat_ = result_casted.reshape(2 * N_tx, banshee_batch_size, order='F') + + # TO Complex: [ batch_size(N_batch * num_symbols), N_tx] + xhat_ = cast_reshape_to_2D_complex( xhat_.astype(gm_dtype_real),dtype=gm_dtype_cplx ) + + b_hat_, xhat_ = ofdm_mimo.demap_banshee(xhat_, no_eff, batch_size=N_batch, num_symbols=num_symbols) + tbe, num_bits, ber = ofdm_mimo.count_total_bit_error(b,b_hat_) + + vTBE[iPrec+1,iSNR] += tbe + vTB[iPrec+1,iSNR] += num_bits # ---------------------------------------------------------------- - # BER COMPUTATION - # ---------------------------------------------------------------- - for iBatch in range(0, N_batch): - - # Compute bit-encodings for x - x = (vx64[:, iBatch]).reshape(-1, 2) - bit_x = const.encode_symbol(x) - vVM[iSNR] += np.linalg.norm(x)**2 - - # Compute bit-encodings for xhat - for iPrec in range(0, len(precisions)): - xhat = (vxhat[iPrec, :, iBatch]).reshape(-1, 2) - bit_xhat = const.encode_symbol(xhat) - # Compute BER between x and each xhat in vvxhat - vTBE[iPrec][iSNR] += sum(a_bit != b_bit - for a_str, b_str in zip(bit_x, bit_xhat) - for a_bit, b_bit in zip(a_str, b_str)) - # Compute MSE between x and each xhat in vvxhat - vMSE[iPrec][iSNR] += np.linalg.norm(xhat - x)**2 + # Save Results (for one batch) # ---------------------------------------------------------------- + np.save(vTBE_file, vTBE) + np.save(vTB_file, vTB) + + # ---------------------------------------------------------------- + min_TBE = np.min(vTBE[:,iSNR]) + if min_TBE>=be_target and iMC+1>=vITR[iSNR]: + # print(f'target TBE reached: SNR:{SNRdB} iMC:{iMC} min_TBE:{min_TBE}') + N_itr = iMC+1 + break + if iMC >= max_it: + break # ----------------------------------------------------------- # END BATCH LOOP # ----------------------------------------------------------- @@ -493,32 +499,41 @@ def main(): time.gmtime(time.time() - startime)) checkpoint_print = elapstime + \ " SNR={}dB BER@{}itr= ".format(SNRdB, N_itr) - total_bits = (N_tx * const.num_bits * N_itr) + for iPrec in range(len(vTBE)): - checkpoint_print += "{:.4f}, ".format(vTBE[iPrec][iSNR] / total_bits) - vBER[iPrec][iSNR] = vTBE[iPrec][iSNR] / total_bits - vEVM[iPrec][iSNR] = np.sqrt(vMSE[iPrec][iSNR] / vVM[iSNR]) + vBER[iPrec][iSNR] = vTBE[iPrec][iSNR] / vTB[iPrec][iSNR] + checkpoint_print += "{:.7f}, ".format(vBER[iPrec][iSNR] ) + + print(checkpoint_print) # ----------------------------------------------------------- # END SNR LOOP # ----------------------------------------------------------- + + + + + # Store output in file current_local_time = time.localtime() timestr = time.strftime("%Y%m%d_%H%M%S", current_local_time) col_names = [precision[0] for precision in precisions] + label = f'{channel_type}_Tx{N_tx}_Rx{N_rx}_{constellation}' row_names = [f"{value} dB" for value in vSNRdB] df_ber = pd.DataFrame(np.transpose(vBER), columns=col_names, index=row_names) - df_evm = pd.DataFrame(np.transpose(vEVM), columns=col_names, index=row_names) - df_ber.to_excel(f"BER_{timestr}.ods", index=True, header=True, engine='odf') - df_evm.to_excel(f"EVM_{timestr}.ods", index=True, header=True, engine='odf') + df_ber.to_excel(f"BER_{label}.ods", index=True, header=True, engine='odf') + + # Plot output - plot_result(vBER, vEVM, vSNRdB, precisions) - const.plot_constellation() + plot_result(vBER, vEVM, vSNRdB, precisions, name_tag) + if __name__ == "__main__": main() + + diff --git a/software/PHY_emulator/ene2end_mimo.py b/software/PHY_emulator/ene2end_mimo.py new file mode 100644 index 000000000..575f4fddc --- /dev/null +++ b/software/PHY_emulator/ene2end_mimo.py @@ -0,0 +1,578 @@ + + +# Simulates end to end MIMO transmission. +# Author: Mahdi Abdollahpour, University of Bologna + +import os +gpu_num = "" # Use "" to use the CPU + +os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}" + + +# Import Sionna +try: + import sionna +except: + import sionna + + + +import tensorflow as tf + +# Avoid warnings from TensorFlow +tf.get_logger().setLevel('ERROR') + + + +import matplotlib.pyplot as plt +import numpy as np +import time +import pickle + +from sionna.mimo import StreamManagement + +from sionna.ofdm import ResourceGrid, ResourceGridMapper, LSChannelEstimator, LMMSEEqualizer +from sionna.ofdm import OFDMModulator, OFDMDemodulator, ZFPrecoder, RemoveNulledSubcarriers + +from sionna.channel.tr38901 import Antenna, AntennaArray, CDL, UMi, UMa, RMa +from sionna.channel import gen_single_sector_topology as gen_topology +from sionna.channel import subcarrier_frequencies, cir_to_ofdm_channel, cir_to_time_channel +from sionna.channel import ApplyOFDMChannel, ApplyTimeChannel, OFDMChannel + +from sionna.fec.ldpc.encoding import LDPC5GEncoder +from sionna.fec.ldpc.decoding import LDPC5GDecoder + +from sionna.mapping import Mapper, Demapper + +from sionna.utils import BinarySource, ebnodb2no, sim_ber, QAMSource, flatten_last_dims, count_errors +from sionna.utils.metrics import compute_ber + +from sionna.channel import FlatFadingChannel, AWGN + +import equalization as eqz + +import numpy as np + + + + + + +class OFDMMIMO(tf.keras.Model): + """Simulate OFDM MIMO transmissions over a 3GPP 38.901 model. + """ + def __init__(self, scenario='umi', perfect_csi=False, carrier_frequency=3.5e9,fft_size=128, + num_ofdm_symbols=14,subcarrier_spacing=30e3,cyclic_prefix_length=20,num_ut=4,num_ut_ant=1, + num_bs_ant=8, num_bits_per_symbol=4,coderate=1, dtype_bits=32): + + + + + + if dtype_bits==32: + self.m_dtype = tf.complex64 + self.m_dtype_real = tf.float32 + elif dtype_bits==64: + self.m_dtype = tf.complex128 + self.m_dtype_real = tf.float64 + else: + raise ValueError("[OFDMMIMO] Unsupported number of bits.") + super().__init__(dtype = self.m_dtype) + + + + self._scenario = scenario + self._perfect_csi = perfect_csi + + # Internally set parameters + self._carrier_frequency = carrier_frequency + self._fft_size = fft_size + self._subcarrier_spacing = subcarrier_spacing + self._num_ofdm_symbols = num_ofdm_symbols + self._cyclic_prefix_length = cyclic_prefix_length + self._pilot_ofdm_symbol_indices = [2, 11] + self._num_bs_ant = num_bs_ant + self._num_ut = num_ut + self._num_ut_ant = num_ut_ant + self._num_bits_per_symbol = num_bits_per_symbol + self._coderate = coderate + + + + + + # Create an RX-TX association matrix + # rx_tx_association[i,j]=1 means that receiver i gets at least one stream + # from transmitter j. Depending on the transmission direction (uplink or downlink), + # the role of UT and BS can change. + bs_ut_association = np.zeros([1, self._num_ut]) + bs_ut_association[0, :] = 1 + self._rx_tx_association = bs_ut_association + self._num_tx = self._num_ut + self._num_streams_per_tx = self._num_ut_ant + + + # Setup an OFDM Resource Grid + self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols, + fft_size=self._fft_size, + subcarrier_spacing=self._subcarrier_spacing, + num_tx=self._num_tx, + num_streams_per_tx=self._num_streams_per_tx, + cyclic_prefix_length=self._cyclic_prefix_length, + pilot_pattern="kronecker", + pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices, + dtype=self.m_dtype) + + + + # Precompute indices to extract data symbols + mask = self._rg.pilot_pattern.mask + num_data_symbols = self._rg.pilot_pattern.num_data_symbols + data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") + self._data_ind = data_ind[...,:num_data_symbols] + + + # Setup StreamManagement + self._sm = StreamManagement(self._rx_tx_association, self._num_streams_per_tx) + + # Configure antenna arrays + self._ut_array = AntennaArray( + num_rows=1, + num_cols=1, + polarization="single", + polarization_type="V", + antenna_pattern="omni", + carrier_frequency=self._carrier_frequency, + dtype=self.m_dtype) + + self._bs_array = AntennaArray(num_rows=1, + num_cols=int(self._num_bs_ant/2), + polarization="dual", + polarization_type="cross", + antenna_pattern="38.901", + carrier_frequency=self._carrier_frequency, + dtype=self.m_dtype) + + # Configure the channel model + self._simple_mimo = False + if self._scenario == 'awgn': + self._channel_model = AWGN(dtype=self.m_dtype) + self._simple_mimo = True + elif self._scenario == 'flatfading': # transmissions over an i.i.d. Rayleigh fading channel. + spatial_corr = None + self._channel_model = FlatFadingChannel(self._num_ut, + self._num_bs_ant, + spatial_corr=spatial_corr, + add_awgn=True, + return_channel=True, + dtype=self.m_dtype) + self._simple_mimo = True + elif self._scenario == "umi": + self._channel_model = UMi(carrier_frequency=self._carrier_frequency, + o2i_model="low", + ut_array=self._ut_array, + bs_array=self._bs_array, + direction="uplink", + enable_pathloss=False, + enable_shadow_fading=False, + dtype=self.m_dtype) + elif self._scenario == "uma": + self._channel_model = UMa(carrier_frequency=self._carrier_frequency, + o2i_model="low", + ut_array=self._ut_array, + bs_array=self._bs_array, + direction="uplink", + enable_pathloss=False, + enable_shadow_fading=False, + dtype=self.m_dtype) + elif self._scenario == "rma": + self._channel_model = RMa(carrier_frequency=self._carrier_frequency, + ut_array=self._ut_array, + bs_array=self._bs_array, + direction="uplink", + enable_pathloss=False, + enable_shadow_fading=False, + dtype=self.m_dtype) + else: + raise ValueError("[OFDMMIMO] Unsupported scenario type.") + # Instantiate other building blocks + self._binary_source = BinarySource() + self._qam_source = QAMSource(self._num_bits_per_symbol,dtype=self.m_dtype) + + + if self._coderate==1 and not self._simple_mimo: + self._k = int(self._rg.num_data_symbols*self._num_bits_per_symbol) + self._num_symbols = int(self._k/self._num_bits_per_symbol) + self._hard_out = True + elif self._coderate<1 and not self._simple_mimo: + self._n = int(self._rg.num_data_symbols*self._num_bits_per_symbol) # Number of coded bits + self._k = int(self._n*self._coderate) # Number of information bits + self._encoder = LDPC5GEncoder(self._k, self._n) + self._decoder = LDPC5GDecoder(self._encoder) + self._hard_out = False + self._num_symbols = int(self._n/self._num_bits_per_symbol) + + if self._coderate==1 and self._simple_mimo: + self._k = 192 + self._hard_out = True + self._num_symbols = int(self._k/self._num_bits_per_symbol) + elif self._coderate<1 and self._simple_mimo: + self._n = 192 # Number of coded bits + self._k = int(self._n*self._coderate) #i=14, k=553 # Number of information bits + # self._k = 553 #i=14, k=553 # Number of information bits + # print(f'coderate:{self._coderate}, {self._k/self._n }') + self._num_symbols = int(self._n/self._num_bits_per_symbol) + self._encoder = LDPC5GEncoder(self._k, self._n) + self._decoder = LDPC5GDecoder(self._encoder) + self._hard_out = False + + + + + self._mapper = Mapper("qam", self._num_bits_per_symbol,dtype=self.m_dtype) + self._rg_mapper = ResourceGridMapper(self._rg, dtype=self.m_dtype) + + self._ofdm_channel = OFDMChannel(self._channel_model, self._rg, add_awgn=True, + normalize_channel=True, return_channel=True, dtype=self._dtype) + + self._remove_nulled_subcarriers = RemoveNulledSubcarriers(self._rg, dtype=self.m_dtype) + self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn", dtype=self.m_dtype) + self._lmmse_equ = eqz.LMMSEEqualizer(self._rg, self._sm, whiten_interference=False, dtype=self.m_dtype) + self._demapper = Demapper("app", "qam", self._num_bits_per_symbol,hard_out=self._hard_out, dtype=self.m_dtype) + # self._demapper = Demapper("maxlog", "qam", self._num_bits_per_symbol,hard_out=self._hard_out, dtype=self.m_dtype) + + def get_num_data_symbols(self): + return self._rg.num_data_symbols.numpy() + + def new_topology(self, batch_size, min_ut_velocity=0.0, max_ut_velocity=0.0): + """Set new topology""" + topology = gen_topology(batch_size, + self._num_ut, + self._scenario, + min_ut_velocity=min_ut_velocity, + max_ut_velocity=max_ut_velocity, + dtype = self.m_dtype) + + self._channel_model.set_topology(*topology) + + + + def _extract_data_symbols(self, x): + + # x: [batch_size, num_tx, num_streams_per_ut, num_ofdm_symbols, fft_size] + # to [batch_size, num_tx, num_streams_per_ut, num_ofdm_symbols, num_effective_subcarriers] + x = self._remove_nulled_subcarriers(x) + + # to [batch_size, num_tx, num_streams_per_ut, num_ofdm_symbols * num_effective_subcarriers] + shape = x.shape + x = tf.reshape( x, (-1, shape[1], shape[2], shape[3]*shape[4]) ) + + # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,... + # ..., batch_size] + x = tf.transpose(x,perm=[1,2,3,0]) + + # Gather data symbols + # [num_tx, num_streams, num_data_symbols, batch_size] + x = tf.gather(x, self._data_ind, batch_dims=2, axis=2) + + # Put batch_dim first + # [batch_size, num_tx, num_streams, num_data_symbols] + x = tf.transpose(x, [3, 0, 1, 2]) + return x + + def demap_banshee(self, x_hat, no_eff, batch_size,num_symbols): + + # cast to self.m_dtype + x_hat = tf.cast(x_hat,self.m_dtype) + + + if self._simple_mimo: + # x_hat [batch_size* num_symbols(k/M), num_ut] + # to [batch_size, num_symbols(k/M), num_ut] + # self._num_symbols + x_hat = tf.reshape(x_hat,(batch_size,num_symbols,self._num_ut) ) + no_eff = tf.reshape(no_eff,(batch_size,num_symbols,self._num_ut) ) + + + # to [batch_size, num_ut,num_symbols(k/M)] + x_hat = tf.transpose(x_hat,perm=[0,2,1]) + no_eff = tf.transpose(no_eff,perm=[0,2,1]) + + + if self._coderate==1: + b_hat = self._demapper([x_hat, no_eff]) + else: + llr = self._demapper([x_hat, no_eff]) + b_hat = self._decoder(llr) + + # b_hat: [batch_size, num_ut,num_bits(k)] + + else: + + # no_eff [batch_size, num_tx, num_streams, num_data_symbols] + # x_hat [ batch_size(N_batch * num_symbols), N_tx] + # x_hat complex numpy: [batch_size*num_ofdm_sym*fft_size, num_tx*num_streams] + + + + # reshape to [batch_size,num_ofdm_sym,fft_size, num_tx,num_streams] + x_hat = tf.reshape( x_hat,(batch_size, self._num_ofdm_symbols, self._fft_size, self._num_tx,self._num_streams_per_tx) ) + + # reshape to + # [batch_size, num_tx, num_streams_per_ut, num_ofdm_symbols, fft_size] + x_hat = tf.transpose( x_hat,perm=[0,3,4,1,2] ) + + # Extract data symbols + # [batch_size, num_tx, num_streams, num_data_symbols] + x_hat = self._extract_data_symbols(x_hat) + + if self._coderate==1: + b_hat = self._demapper([x_hat, no_eff]) + else: + llr = self._demapper([x_hat, no_eff]) + b_hat = self._decoder(llr) + + # x_hat: [batch_size, num_tx, num_streams, num_data_symbols] + # reshape to [batch_size, num_data_symbols, num_tx, num_streams] + x_hat = tf.transpose(x_hat,perm=[0,3,1,2]) + + # reshape to + # [batch_size* num_data_symbols, num_tx* num_streams] + shape = x_hat.shape + x_hat = tf.reshape(x_hat,(-1,shape[2]*shape[3])) + + # print(f'[demapp bansshe]b_hat:{b_hat.shape}') + + return b_hat, x_hat + def count_total_bit_error(self, b,b_hat): + # print(f'[count err]b:{b.shape},b_hat:{b_hat.shape}') + b = tf.cast(b,tf.float32) + b_hat = tf.cast(b_hat,tf.float32) + TBE = count_errors(b,b_hat) + num_bits = b.shape.num_elements() + ber = compute_ber(b,b_hat) + + + return TBE.numpy(), num_bits, ber.numpy() + + + # @tf.function # Run in graph mode. See the following guide: https://www.tensorflow.org/guide/function + def call(self, batch_size, ebno_db, min_ut_velocity=0.0, max_ut_velocity=0.0): + + + # print("GPUs available:", len(tf.config.list_physical_devices('GPU'))) + + + if self._simple_mimo: + + # b[batch_size, num_tx, num_bits] + b = self._binary_source([batch_size, self._num_tx, self._k]) + if self._coderate==1: + c = b + else: + c = self._encoder(b) + x = self._mapper(c) + num_symbols = tf.shape(x)[-1] + # self._num_symbols = tf.shape(x)[-1] + # self._num_symbols = tf.keras.backend.get_value(tf.shape(x)[-1]) + + # x: [batch_size,num_tx,num_symbols] + # to [batch_size,num_symbols,num_tx] + x = tf.transpose(x,perm=[0,2,1]) + shape = tf.shape(x) + + # to [batch_size*num_symbols,num_tx] + x = tf.reshape(x, [-1, self._num_tx]) + no = ebnodb2no(ebno_db, self._num_bits_per_symbol, self._coderate) + + if self._scenario != 'awgn': + no *= np.sqrt(self._num_bs_ant) + no = tf.cast(no,self.m_dtype_real) + y_dt, h_dt_desired = self._channel_model([x, no]) + else: + no = tf.cast(no,self.m_dtype_real) + y_dt = self._channel_model((x, no)) + h = tf.eye(self._num_ut, dtype=self.m_dtype_real) + h = tf.tile(tf.expand_dims(h, axis=0), [batch_size*num_symbols, 1, 1]) + h = tf.complex(h, tf.zeros_like(h),self.m_dtype) + h_dt_desired = h + + # y: [batch_size*num_symbols, num_rx_ant] + # H: [batch_size*num_symbols, num_rx_ant, num_tx_ant] + + s = tf.eye(self._num_bs_ant, self._num_bs_ant) + s = tf.cast(s,self.m_dtype_real) + s = no*s + s = tf.complex(s, tf.cast(0.0,self.m_dtype_real)) + + x_hat, no_eff = eqz.lmmse_equalizer(y_dt, h_dt_desired, s, whiten_interference=False) + + # [batch_size*num_symbols,num_tx] + # to [batch_size,num_symbols,num_tx] + x_hat = tf.reshape(x_hat, shape) + x_data = tf.reshape(x, shape) + no_eff = tf.reshape(no_eff, shape) + + + # to [batch_size,num_tx,num_symbols] + x_hat = tf.transpose(x_hat,perm=[0,2,1]) + no_eff = tf.transpose(no_eff,perm=[0,2,1]) + + if self._coderate==1: + b_hat = self._demapper([x_hat, no_eff]) + else: + llr = self._demapper([x_hat, no_eff]) + b_hat = self._decoder(llr) + + + # b [batch_size, num_tx, num_bits] + # b_hat: [batch_size,num_tx,num_bits] + # x_hat [batch_size,num_tx,num_symbols] + + + + + # Reshape for Banshee + + + # no_eff = tf.transpose(no_eff,perm=[0,2,1]) + # no_eff = tf.reshape(no_eff,(-1,self._num_ut) ) + + # # x_hat:[batch_size,num_tx,num_symbols] + # # to [batch_size, num_symbols(k/M), num_ut] + x_hat = tf.transpose(x_hat,perm=[0,2,1]) + # to [batch_size* num_symbols(k/M), num_ut] + x_hat = tf.reshape(x_hat,(-1,self._num_ut) ) + + x_data = tf.transpose(x_data,perm=[0,2,1]) + x_data = tf.reshape(x_data,(-1,self._num_ut) ) + + # H: [batch_size*num_symbols, num_rx_ant, num_tx_ant] + # to [batch_size*num_symbols, num_rx_ant* num_tx_ant] + h_dt_desired = tf.reshape(h_dt_desired,(-1,self._num_bs_ant*self._num_ut)) + + # s: [num_rx_ant] + s = tf.linalg.diag_part(s) + # to: [batch_size*num_symbols, num_rx_ant] + # dim0 = int(batch_size*self._k/self._num_bits_per_symbol) + dim0 = batch_size*num_symbols + s = tf.repeat(tf.expand_dims(s, axis=0),dim0,axis=0) + + + + # y: [batch_size*num_symbols, num_rx_ant] + # H: [batch_size*num_symbols, num_rx_ant* num_tx_ant] + # s: [batch_size*num_symbols, num_rx_ant] + # x_hat: [batch_size* num_symbols(k/M), num_ut] + + else: + + self.new_topology(batch_size,min_ut_velocity=min_ut_velocity, max_ut_velocity=max_ut_velocity) + no = ebnodb2no(ebno_db, self._num_bits_per_symbol, self._coderate, self._rg) + + # b: [batch_size, num_tx, num_stream_per_tx, k] + b = self._binary_source([batch_size, self._num_tx, self._num_streams_per_tx, self._k]) + if self._coderate==1: + c = b + else: + c = self._encoder(b) + + # x:[batch_size, num_tx, num_stream_per_tx, k/M] + x = self._mapper(c) + # self._num_symbols = tf.shape(x)[-1] + num_symbols = tf.shape(x)[-1] + + # x_rg: [batch_size, num_tx, num_streams_per_ut, num_ofdm_symbols, fft_size] + x_rg = self._rg_mapper(x) + + + y, h = self._ofdm_channel([x_rg, no]) + if self._perfect_csi: + h_hat = self._remove_nulled_subcarriers(h) + err_var = 0.0 + else: + h_hat, err_var = self._ls_est ([y, no]) + + + x_hat, no_eff, y_dt, h_dt_desired, s = self._lmmse_equ([y, h_hat, err_var, no]) + + # y_dt [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] + + # x_hat [batch_size, num_tx, num_streams, num_data_symbols] + + # h_dt_desired [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. + # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] + + # s: [batch_size, num_rx, num_ofdm_syms, fft_size, num_rx_ant, num_rx_ant] + + # print(f'x_hat:{x_hat.shape}, no_eff:{no_eff.shape}, s:{s.shape}') + + if self._coderate==1: + b_hat = self._demapper([x_hat, no_eff]) + else: + llr = self._demapper([x_hat, no_eff]) + b_hat = self._decoder(llr) + + # b_hat: [batch_size, num_tx, num_streams, k=num_data_symbols*M] + # b [batch_size, num_tx, num_stream_per_tx, k] + ber = compute_ber(b,tf.cast(b_hat,tf.float32)) + + # num_rx>1 not considered (TODO) + # --------- reshape for banshee ----------- + + + # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] + # TO [batch_size* num_rx* num_ofdm_symbols* num_effective_subcarriers, num_rx_ant] + y_dt = tf.reshape(y_dt,(-1,self._num_bs_ant)) + + + + # TO [batch_size* num_rx* num_ofdm_symbols* num_effective_subcarriers,.. + # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] + h_shape = tf.shape(h_dt_desired) + h_dt_desired = tf.reshape(h_dt_desired,(-1,h_shape[4],h_shape[5])) + + + #TO [batch_size* num_rx* num_ofdm_symbols* num_effective_subcarriers,.. + # ..., num_rx_ant* num_streams_per_rx(num_Interfering_streams_per_rx)] + h_dt_desired = tf.reshape(h_dt_desired,(-1,h_shape[4]*h_shape[5])) + + # s: [batch_size, num_rx, num_ofdm_syms, fft_size, num_rx_ant, num_rx_ant] + # to [batch_size, num_rx, num_ofdm_syms, fft_size, num_rx_ant] + s = tf.linalg.diag_part(s) + # to [batch_size* num_rx* num_ofdm_syms* fft_size, num_rx_ant] + s = tf.reshape(s, (-1,self._num_bs_ant)) + + + + # x: extract data symbols + # [batch_size, num_tx, num_streams, num_data_symbols] + x_data = self._extract_data_symbols(x=x_rg) + + # x_hat: [batch_size, num_tx, num_streams, num_data_symbols] + # reshape to [batch_size, num_data_symbols, num_tx, num_streams] + x_hat = tf.transpose(x_hat,perm=[0,3,1,2]) + x_data = tf.transpose(x_data,perm=[0,3,1,2]) + + + + # reshape to + # [batch_size* num_data_symbols, num_tx* num_streams] + shape = x_hat.shape + x_hat = tf.reshape(x_hat,(-1,shape[2]*shape[3])) + x_data = tf.reshape(x_data,(-1,shape[2]*shape[3])) + + + + + # ----------- End Simple MIMO or OFDM ----------- + no = tf.cast(no,self.m_dtype_real) + no_db = 10*tf.math.log(no)/tf.math.log(tf.constant(10.0,self.m_dtype_real)) + + return b, b_hat, x_data, x_hat, y_dt, h_dt_desired, s, no_eff, num_symbols, no_db + + + + + +