diff --git a/src/rheed_learn/AFM.py b/src/rheed_learn/AFM.py new file mode 100644 index 0000000..c8ff0f7 --- /dev/null +++ b/src/rheed_learn/AFM.py @@ -0,0 +1,382 @@ +import numpy as np +import imutils +from matplotlib import (pyplot as plt, animation, colors, ticker, path, patches, patheffects) +import plotly.graph_objects as go +import pylab as pl +import scipy +from scipy import special +from scipy import signal +from scipy.signal import savgol_filter + +class afm_substrate(): + """ + This class is designed to facilitate the analysis of an atomic force microscopy (AFM) substrate image.  + The class includes methods for image rotation, coordinate transformation, peak detection, and step parameter calculation. + """ + def __init__(self, img, pixels, size): + ''' + img: the image to be analyzed + pixels: the number of pixels in the image + size: the size of the image in meters + ''' + self.img = img + self.pixels = pixels + self.size = size + + def rotate_image(self, angle, colorbar_range=None, demo=True): + ''' + angle: the angle to rotate the image in degrees + ''' + rad = np.radians(angle) + scale = 1/(np.abs(np.sin(rad)) + np.abs(np.cos(rad))) + size_rot = self.size * scale + + img_rot = imutils.rotate(self.img, angle=angle, scale=scale) + h, w = img_rot.shape[:2] + + if demo: + plt.figure(figsize=(10, 8)) + im = plt.imshow(img_rot) + plt.plot([0, w], [h//4, h//4], color='w') + plt.plot([0, w], [h//2, h//2], color='w') + plt.plot([0, w], [h*3//4, h*3//4], color='w') + if colorbar_range: + im.set_clim(colorbar_range) + plt.colorbar() + plt.show() + return img_rot, size_rot + + + def rotate_xz(self, x, z, xz_angle): + ''' + x: the x coordinates of the image + z: the z coordinates of the image + xz_angle: the angle to rotate the xz plane + ''' + theta = np.radians(xz_angle) + x_rot = x * np.cos(theta) - z * np.sin(theta) + z_rot = x * np.sin(theta) + z * np.cos(theta) + return x_rot, z_rot + + def show_peaks(self, x, z, peaks=None, valleys=None): + ''' + x: the x-axis data + z: the z-axis data - height + peaks: the indices of the peaks + valleys: the indices of the valleys + ''' + fig = go.Figure() + fig.add_trace(go.Scatter(x=x, y=z, mode='lines+markers', name='Original Plot')) + if isinstance(peaks, np.ndarray): + marker=dict(size=8, color='red', symbol='cross') + fig.add_trace(go.Scatter(x=x[peaks], y=z[peaks], mode='markers', marker=marker, name='Detected Peaks')) + if isinstance(valleys, np.ndarray): + marker=dict(size=8, color='black', symbol='cross') + fig.add_trace(go.Scatter(x=x[valleys], y=z[valleys], mode='markers', marker=marker, name='Detected valleys')) + fig.show() + + def slice_rotate(self, img_rot, size, j, prominence, width, xz_angle=0, demo=False): + ''' + img_rot: the rotated image + size: the size of the image in meters + j: the column to slice + xz_angle: the angle between the x and z axes in degrees + ''' + i = np.linspace(0, self.pixels-1, self.pixels) + x = i / self.pixels * size + z = img_rot[np.argwhere(img_rot[:, j]!=0).flatten(), j] + x = x[np.argwhere(img_rot[:, j]!=0).flatten()] + peak_indices, _ = signal.find_peaks(z, prominence=prominence, width=width) + valley_indices, _ = signal.find_peaks(-z, prominence=prominence, width=width) + + if xz_angle != 0: + x_min, x_max, z_min, z_max = np.min(x), np.max(x), np.min(z), np.max(z) + x_norm = (x - x_min) / (x_max - x_min) + z_norm = (z - z_min) / (z_max - z_min) + + peak_indices, _ = signal.find_peaks(z_norm, prominence=prominence, width=width) + valley_indices, _ = signal.find_peaks(-z_norm, prominence=prominence, width=width) + + # rotate the xz plane to level the step + x_norm_rot, z_norm_rot = self.rotate_xz(x_norm, z_norm, xz_angle) + x, z = x_norm_rot * (x_max - x_min) + x_min, z_norm_rot * (z_max - z_min) + z_min + + if demo: + self.show_peaks(x, z, peak_indices, valley_indices) + return x, z, peak_indices, valley_indices + + + def calculate_simple(self, x, z, peak_indices, fixed_height=None, demo=False): + ''' + Calculate the height, width, and miscut of the steps in a straight forward way. + Calculate the height and width of each step from the rotated line profile. + x: the x-axis data + z: the z-axis data - height + peak_indices: the indices of the peaks + fixed_height: the height of the steps + ''' + + # find the level of z and step height and width + step_widths = np.diff(x[peak_indices]) + if fixed_height: + step_heights = np.full(len(step_widths), fixed_height) + else: + step_heights = z[peak_indices[1:]] - z[peak_indices[:-1]] + miscut = np.degrees(np.arctan(step_heights/step_widths)) + + if demo: + for i in range(len(step_heights)): + print(f"Step {i+1}: Height = {step_heights[i]:.2e}, Width = {step_widths[i]:.2e}, Miscut = {miscut[i]:.3f}°") + print('Results:') + print(f" Average step height = {np.mean(step_heights):.2e}, Standard deviation = {np.std(step_heights):.2e}") + print(f" Average step width = {np.mean(step_widths):.2e}, Standard deviation = {np.std(step_widths):.2e}") + print(f" Average miscut = {np.mean(miscut):.3f}°, Standard deviation = {np.std(miscut):.3f}°") + return step_heights, step_widths, miscut + + def calculate_fit(self, x, z, peak_indices, valley_indices, fixed_height, demo=False): + ''' + calculate the step height, width and miscut angle. + The step height is calculated by the perpendicular distance between lower step bottom point (valley) and the fitting function of higher step edge (line between left peak and right peak). + x: the x-axis data + z: the z-axis data - height + peak_indices: the indices of the peaks + valley_indices: the indices of the valleys + fixed_height: the fixed step height + demo: whether to show the demo plot + ''' + # print(valley_indices) + step_widths = [] + for i, v_ind in enumerate(valley_indices): + x_valley, z_valley = x[v_ind], z[v_ind] + + # ignore if there's no peak on the left + if x_valley < np.min(x[peak_indices]): continue + # if there's no peak on the right, then the valley is the last one + if x_valley > np.max(x[peak_indices]): continue + + # find the nearest peak on the left of the valley v_ind + peaks_lhs = peak_indices[np.where(x[peak_indices] < x_valley)] + left_peak_indice = peaks_lhs[np.argmax(peaks_lhs)] + x_left_peak, z_left_peak = x[left_peak_indice], z[left_peak_indice] + + # find the nearest peak on the right of the valley v_ind + peaks_rhs = peak_indices[np.where(x[peak_indices] > x_valley)] + right_peak_indice = peaks_rhs[np.argmin(peaks_rhs)] + x_right_peak, z_right_peak = x[right_peak_indice], z[right_peak_indice] + + # ignore if can't make a peak, valley, peak sequence + if i!=0 and i!=len(valley_indices)-1: + if x[valley_indices[i-1]] > x_left_peak or x[valley_indices[i+1]] < x_right_peak: + continue + + # fit the linear function between the right peak and the valley + m, b = scipy.stats.linregress(x=[x_right_peak, x_valley], y=[z_right_peak, z_valley])[0:2] + m = (z_right_peak-z_valley)/(x_right_peak-x_valley) + b = z_valley - m*x_valley + + # calculate the euclidean distance between the left peak and fitted linear function + step_width = np.abs((m * x_left_peak - z_left_peak + b)) / (np.sqrt(m**2 + 1)) + step_widths.append(step_width) + + # print left peak, valley, right peak + if demo: + print(f'step {i}: step_width: {step_width:.2e}, left_peak: ({x_left_peak:.2e}, {z_left_peak:.2e}), valley: ({x_valley:.2e}, {z_valley:.2e}), right_peak: ({x_right_peak:.2e}, {z_right_peak:.2e})') + + step_heights = np.full(len(step_widths), fixed_height) + miscut = np.degrees(np.arctan(step_heights/step_widths)) + + if demo: + print('Results:') + print(f" Average step height = {np.mean(step_heights):.2e}, Standard deviation = {np.std(step_heights):.2e}") + print(f" Average step width = {np.mean(step_widths):.2e}, Standard deviation = {np.std(step_widths):.2e}") + print(f" Average miscut = {np.mean(miscut):.3f}°, Standard deviation = {np.std(miscut):.3f}°") + return step_heights, step_widths, miscut + + def clean_data(self, step_heights, step_widths, miscut, std_range=1, demo=False): + ''' + step_heights: the heights of the steps + step_widths: the widths of the steps + miscut: the miscut of the steps + std_range: the range of standard deviation to remove outliers + demo: whether to show the cleaned results + ''' + # remove outliers + miscut = miscut[np.abs(miscut-np.mean(miscut))start) & (xstart) & (x=dist] # avoid first one that is not full curve + x_peaks = x_peaks[x_peaks<=len(curve_y)-dist] + + # get all partial curve + xs, ys = [], [] + for i in range(1, len(x_peaks)): + # xs.append(list(curve_x[5+x_peaks[i-1]:x_peaks[i]])) + # ys.append(list(curve_y[5+x_peaks[i-1]:x_peaks[i]])) + xs.append(list(curve_x[x_peaks[i-1]:x_peaks[i]])) + ys.append(list(curve_y[x_peaks[i-1]:x_peaks[i]])) + return x_peaks/camera_freq, xs, ys + +def remove_outlier(x, y, ub): + + """ + Removes outliers from the given data based on the provided upper bound. + + Args: + x (numpy.array): The x-values of the data. + y (numpy.array): The y-values of the data. + ub (float): The upper bound for z-score filtering. + + Returns: + tuple: A tuple containing the filtered x-values and y-values. + """ + + z = zscore(y, axis=0, ddof=0) + x = np.delete(x, np.where(z>ub)) + y = np.delete(y, np.where(z>ub)) + return x, y + +def smooth(y, box_pts): + """ + Applies a smoothing filter to the given data using a moving average window. + + Args: + y (numpy.array): The input data. + box_pts (int): The size of the moving average window. + + Returns: + numpy.array: The smoothed data. + """ + box = np.ones(box_pts)/box_pts + y_smooth = np.convolve(y, box, mode='same') + return y_smooth + + +def denoise_fft(sample_x, sample_y, cutoff_freq, denoise_order, sample_frequency, viz=False): + + nyquist = 0.5 * sample_frequency + low = cutoff_freq / nyquist + # b, a = butter(denoise_order, low, btype='low') + b, a = butter(denoise_order, low, btype='low') + + # Apply the low-pass filter to denoise the signal + denoised_sample_y = filtfilt(b, a, sample_y) + + # Compute the frequency spectrum of the original and denoised signals + freq = np.fft.rfftfreq(len(sample_x), d=1/sample_frequency) + fft_original = np.abs(np.fft.rfft(sample_y)) + fft_denoised = np.abs(np.fft.rfft(denoised_sample_y)) + + if viz: + fig, axes = plt.subplots(2, 1, figsize=(8, 4)) + axes[0].scatter(sample_x, sample_y, label='Original Signal') + axes[0].plot(sample_x, denoised_sample_y, color='r', label='Denoised Signal') + axes[0].legend() + + axes[1].plot(freq, fft_original, label='Original Spectrum') + axes[1].plot(freq, fft_denoised, label='Denoised Spectrum') + axes[1].set_xlabel('Frequency (Hz)') + axes[1].set_ylabel('Amplitude') + axes[1].set_yscale('log') + axes[1].legend() + plt.tight_layout() + plt.title('fft filter') + plt.show() + return denoised_sample_y + + +def bandpass_filter_fft(sample_x, sample_y, low_cutoff, high_cutoff, sample_frequency, viz=False): + # Compute the FFT of the signal + fft = np.fft.rfft(sample_y) + freq = np.fft.rfftfreq(len(sample_x), d=1/sample_frequency) + + # Create a mask for the band-pass filter + mask = (freq >= low_cutoff) & (freq <= high_cutoff) + + # Apply the mask to the FFT + fft_filtered = fft * mask + + # Compute the inverse FFT to get the filtered signal + filtered_sample_y = np.fft.irfft(fft_filtered, n=len(sample_y)) + + if viz: + fig, axes = plt.subplots(2, 1, figsize=(10, 8)) + axes[0].plot(sample_x, sample_y, label='Original Signal') + axes[0].plot(sample_x, filtered_sample_y, color='r', label='Filtered Signal') + axes[0].set_xlabel('Time') + axes[0].set_ylabel('Amplitude') + axes[0].legend() + + axes[1].plot(freq, np.abs(fft), label='Original Spectrum') + axes[1].plot(freq, np.abs(fft_filtered), label='Filtered Spectrum') + axes[1].set_xlabel('Frequency (Hz)') + axes[1].set_ylabel('Amplitude') + axes[1].set_yscale('log') + axes[1].set_xlim(-2, 20) + axes[1].axvline(low_cutoff, color='g', linestyle='--', label='Low cutoff') + axes[1].axvline(high_cutoff, color='r', linestyle='--', label='High cutoff') + axes[1].legend() + plt.tight_layout() + plt.suptitle('FFT Band-pass Filter') + plt.show() + + return filtered_sample_y + +from scipy.signal import medfilt +def denoise_median(sample_x, sample_y, kernel_size, viz=False): + denoised_sample_y = medfilt(sample_y, kernel_size=kernel_size) + # print(denoised_sample_y.shape, sample_y.shape) + + if viz: + plt.figure(figsize=(8,2)) + plt.scatter(sample_x, sample_y, label='Original Signal') + plt.plot(sample_x, denoised_sample_y, color='r', label='Denoised Signal') + plt.tight_layout() + plt.title('median filter') + plt.show() + + return sample_x, denoised_sample_y + +def process_rheed_data(sample_x, sample_y, camera_freq, denoise_params, viz=False): + + """Processes RHEED data by interpolating, denoising, and applying dimensionality reduction. + + Args: + savgol_window_order (tuple, optional): The order of the Savitzky-Golay filter window. Defaults to (15, 3). + pca_component (int, optional): The number of components for PCA. Defaults to 10. + + Returns: + tuple: A.""" + + savgol_window_order = denoise_params['savgol_window_order'] + pca_component = denoise_params['pca_component'] + fft_cutoff = denoise_params['fft_cutoff'] + fft_order = denoise_params['fft_order'] + median_kernel_size = denoise_params['median_kernel_size'] + + denoised_sample_y = sample_y + + # denoise the data + if not isinstance(savgol_window_order, type(None)): + denoised_sample_y = savgol_filter(sample_y, savgol_window_order[0], savgol_window_order[1]) + if viz: + fig, ax = plt.subplots(1, 1, figsize=(8, 2)) + ax.scatter(sample_x, sample_y, label='Original Signal') + ax.plot(sample_x, denoised_sample_y, color='r', label='Denoised Signal') + plt.legend() + plt.tight_layout() + plt.title('savgol_filter') + plt.show() + sample_y = denoised_sample_y + + # # apply PCA + # if pca_component: + # sample_y = denoised_sample_y + # pca = PCA(n_components=pca_component) + # denoised_sample_y = pca.inverse_transform(pca.fit_transform(sample_y)) + # if viz: + # fig, ax = plt.subplots(1, 1, figsize=(8, 4)) + # ax.scatter(sample_x, sample_y, label='Original Signal') + # ax.plot(sample_x, denoised_sample_y, color='r', label='Denoised Signal') + # plt.legend() + # plt.tight_layout() + # plt.show() + + # fft + if not isinstance(fft_cutoff, type(None)) or not isinstance(fft_order, type(None)): + sample_y = denoised_sample_y + # denoised_sample_y = denoise_fft(sample_x, sample_y, cutoff_freq=fft_cutoff[1], denoise_order=fft_order, + # sample_frequency=camera_freq, viz=viz) + denoised_sample_y = bandpass_filter_fft(sample_x, sample_y, low_cutoff=fft_cutoff[0], + high_cutoff=fft_cutoff[1], sample_frequency=camera_freq, viz=viz) + sample_y = denoised_sample_y + + # median filter + if not isinstance(median_kernel_size, type(None)): + sample_x, sample_y = denoise_median(sample_x, sample_y, kernel_size=median_kernel_size, viz=viz) + sample_y = denoised_sample_y + + return sample_x, sample_y + + +def reset_tails(ys, ratio=0.1): + for i, y in enumerate(ys): + num = int(len(y) * ratio) + y[-num:] = y[-2*num:-num] + ys[i] = y + return ys + +def linear_func(x, a, b): + return a*x + b + +from scipy.optimize import curve_fit +def remove_linear_bg(xs, ys, linear_ratio=0.8): + ''' + assume there is a background intensity change linearly with time, so extract the linear background from full range curve + ''' + for i in range(len(ys)): + length = int(len(ys[i]) * linear_ratio) + popt, pcov = curve_fit(linear_func, xs[i][-length:], ys[i][-length:]) + a, b = popt + y_fit = linear_func(np.array(xs[i]), a, b=0) + ys[i] = ys[i] - y_fit + return xs, ys + +# def find_sign_change(values, change='increase_to_decrease'): +# ''' +# change: 'increase_to_decrease' or 'decrease_to_increase' +# ''' +# for i in range(len(values) - 2): +# if change == 'increase_to_decrease': +# # print(values[i], values[i + 1], values[i + 2]) +# if values[i] < values[i + 1] and values[i + 1] > values[i + 2]: +# return i + 1 # The position where it changes from increase to decrease +# elif change == 'decrease_to_increase': +# if values[i] > values[i + 1] and values[i + 1] < values[i + 2]: +# return i + 1 # The position where it changes from decrease to increase +# return -1 # Return -1 if no such transition is found + +# def find_sign_change(values, rule='increase_to_decrease'): +# ''' +# change: 'increase_to_decrease' or 'decrease_to_increase' +# ''' +# for i in range(len(values) - 2): +# if rule == 'increase_to_decrease': +# # print(values[i], values[i + 1], values[i + 2]) +# if values[i] < values[i + 1] and values[i + 1] > values[i + 2]: +# return i + 1 # The position where it changes from increase to decrease +# elif rule == 'decrease_to_increase': +# if values[i] > values[i + 1] and values[i + 1] < values[i + 2]: +# return i + 1 # The position where it changes from decrease to increase +# return -1 # Return -1 if no such transition is found + +# def remove_starting_signal(xs, ys): +# xs_new, ys_new = [], [] +# for x_target, y_target in zip(xs, ys): +# x_pos = find_sign_change(y_target) +# xs.append(x_target[x_pos:]) +# ys.append(y_target[x_pos:]) +# return xs, ys + +def find_sign_change(values, rule='increase_to_decrease', window=20, threshold=16): + + for i in range(window, len(values)-window): + increase_count, decrease_count = 0, 0 + if rule == 'increase_to_decrease': + for j in range(i-window, i): + # print(j, j+1) + if values[j] < values[j+1]: + increase_count += 1 + for j in range(i, i+window): + if values[j] > values[j+1]: + decrease_count += 1 + # print(increase_count, decrease_count) + if increase_count >= threshold and decrease_count >= threshold: + return i + elif rule == 'decrease_to_increase': + for j in range(i-window, i): + if values[j] > values[j+1]: + decrease_count += 1 + for j in range(i, i+window): + if values[j] < values[j+1]: + increase_count += 1 + if increase_count >= threshold and decrease_count >= threshold: + return i + else: + # print('Cannot find trend change point.') + return -1 + +def process_curves(xs, ys, curve_params): + + tune_tail = curve_params['tune_tail'] + trim_first = curve_params['trim_first'] + linear_ratio = curve_params['linear_ratio'] + + # trim tails + if tune_tail: + ys = reset_tails(ys) + if trim_first != 0: + xs_trimed, ys_trimed = [], [] + for x, y in zip(xs, ys): + if isinstance(trim_first, str): + pos = find_sign_change(y, trim_first) + elif isinstance(trim_first, int): + pos = trim_first + ys_trimed.append(y[pos:]) + xs_trimed.append(np.linspace(x[0], x[-1], len(y[pos:]))) + xs, ys = xs_trimed, ys_trimed + + # remove linear background + if linear_ratio != 0 or linear_ratio != None: + xs, ys = remove_linear_bg(xs, ys, linear_ratio=linear_ratio) + return xs, ys + +# def process_rheed_data(xs, ys, +# camera_freq, savgol_window_order=(15, 3), pca_component=10, +# fft_cutoff=20, fft_order=1, median_kernel_size=51): + +# """Processes RHEED data by interpolating, denoising, and applying dimensionality reduction. + +# Args: +# xs (list): List of x-values for each partial curve. +# ys (list): List of y-values for each partial curve. +# length (int, optional): The desired length for interpolation. Defaults to 500. +# savgol_window_order (tuple, optional): The order of the Savitzky-Golay filter window. Defaults to (15, 3). +# pca_component (int, optional): The number of components for PCA. Defaults to 10. + +# Returns: +# tuple: A.""" + +# # interpolate the data to same size +# if length == None: +# length = int(np.mean([len(x) for x in xs])) + +# xs_processed = [] +# ys_processed = [] +# for x, y in zip(xs, ys): +# x_sl = np.linspace(np.min(x), np.max(x), length) +# y_sl = np.interp(x_sl, x, y) +# xs_processed.append(x_sl) +# ys_processed.append(y_sl) +# xs_processed, ys_processed = np.array(xs_processed), np.array(ys_processed) + +# # denoise the data +# if savgol_window_order: +# ys_processed = savgol_filter(ys_processed, savgol_window_order[0], savgol_window_order[1]) + +# # apply PCA +# if pca_component: +# pca = PCA(n_components=pca_component) +# ys_processed = pca.inverse_transform(pca.fit_transform(ys_processed)) + +# # fft +# if fft_cutoff and fft_order: +# denoised_sample_y = denoise_fft(sample_x, sample_y, cutoff_freq=20, denoise_order=1, sample_frequency=camera_freq) + +# # median filter +# if median_kernel_size: +# denoised_sample_y = denoise_median(sample_x, sample_y, kernel_size=median_kernel_size) + +# # trim tails +# trim_first = fit_settings['trim_first'] +# if fit_settings['tune_tail']: +# ys = reset_tails(ys) +# if trim_first != 0: +# xs_trimed, ys_trimed = [], [] +# for x, y in zip(xs, ys): +# ys_trimed.append(y[trim_first:]) +# xs_trimed.append(np.linspace(x[0], x[-1], len(y[trim_first:]))) +# xs, ys = xs_trimed, ys_trimed + +# # remove linear background +# xs, ys = remove_linear_bg(xs, ys, linear_ratio=0.8) + +# return xs_processed, ys_processed + + +def normalize_0_1(y, I_start, I_end, I_diff=None, unify=True): + """ + Normalizes the given data to the range [0, 1] based on the provided intensity values. + + Args: + y (numpy.array): The input data. + I_start (float): The start intensity value. + I_end (float): The end intensity value. + I_diff (float, optional): The intensity difference used for normalization. Defaults to None. + unify (bool, optional): Whether to unify the normalization range regardless of the intensity order. Defaults to True. + + Returns: + numpy.array: The normalized data. + """ + if not I_diff: + I_diff = I_end-I_start + + # use I/I0, I0 is saturation intensity (last value) and scale to 0-1 based + if I_end - I_start == 0: # avoid devide by 0 + y_nor = (y-I_start) + elif unify: + y_nor = (y-I_start)/I_diff + # if I_end < I_start: + # y_nor = (y-I_start)/I_diff + # else: + # y_nor = (y-I_start)/I_diff + else: + if I_end < I_start: + y_nor = (y-I_end)/I_diff + else: + y_nor = (y-I_start)/I_diff + return y_nor + +def de_normalize_0_1(y_nor_fit, I_start, I_end, I_diff=None, unify=True): + """ + De-normalizes the given normalized data back to the original range based on the provided intensity values. + + Args: + y_nor_fit (numpy.array): The normalized data to be de-normalized. + I_start (float): The start intensity value. + I_end (float): The end intensity value. + I_diff (float, optional): The intensity difference used for normalization. Defaults to None. + unify (bool, optional): Whether to unify the normalization range regardless of the intensity order. Defaults to True. + + Returns: + numpy.array: The de-normalized data. + """ + if not I_diff: + I_diff = I_end-I_start + if not unify: + I_diff = np.abs(I_diff) + + # use I/I0, I0 is saturation intensity (last value) and scale to 0-1 based + if I_end - I_start == 0: # avoid devide by 0 + y_nor = (y_nor_fit-I_start) + elif unify: + if I_end < I_start: + y_fit = I_start-y_nor_fit*I_diff + else: + y_fit = y_nor_fit*I_diff+I_start + + else: + if I_end < I_start: + y_fit = y_nor_fit*I_diff+I_end + else: + y_fit = y_nor_fit*I_diff+I_start + return y_fit + + +# def fit_exp_function(xs, ys, growth_name, fit_settings={'I_diff': None, 'unify': True, 'bounds': [0.01, 1], 'p_init': (1, 0.1, 0.4), 'n_std': 1}): +# """ +# Fits an exponential function to the given data. + +# This function fits an exponential growth or decay function to each curve in the provided datasets, +# normalizing the data before fitting and optionally unifying the fitting parameters across all curves. + +# Args: +# xs (list of list of floats): List of x-values for each partial curve. Each element is a list of x-values for one curve. +# ys (list of list of floats): List of y-values for each partial curve. Each element is a list of y-values for one curve. +# growth_name (str): Name of the growth process, used for labeling the fitted curves. +# fit_settings (dict, optional): Dictionary of settings for the fitting process. Defaults to: +# { +# 'I_diff': None, # Optional intensity difference for normalization. +# 'unify': True, # Whether to use a unified fitting function for all curves. +# 'bounds': [0.01, 1], # Bounds for the fitting parameters. +# 'p_init': (1, 0.1, 0.4), # Initial guess for the fitting parameters. +# 'n_std': 1 # Number of standard deviations for normalization. +# } + +# Returns: +# tuple: +# - numpy.ndarray: Array of fitted parameters for each curve. +# - list: A list containing: +# - xs: Original x-values. +# - ys: Original y-values. +# - ys_fit: Fitted y-values. +# - ys_nor: Normalized y-values. +# - ys_nor_fit: Fitted normalized y-values. +# - ys_nor_fit_failed: Failed fitted normalized y-values (when applicable). +# - labels: Labels for each fitted curve. +# - losses: Losses for each fitted curve. +# """ +# import numpy as np +# from scipy.optimize import curve_fit + +# def normalize_0_1(y, I_start, I_end, I_diff, unify): +# """Normalize y-values to the range [0, 1].""" +# if I_diff is None: +# I_diff = I_end - I_start +# y_nor = (y - I_start) / I_diff +# return y_nor + +# def de_normalize_0_1(y_nor, I_start, I_end, I_diff, unify): +# """De-normalize y-values from the range [0, 1] back to the original scale.""" +# if I_diff is None: +# I_diff = I_end - I_start +# y = y_nor * I_diff + I_start +# return y + +# def exp_func_inc_simp(x, b1, relax1): +# """Simplified exponential growth function.""" +# return b1 * (1 - np.exp(-x / relax1)) + +# def exp_func_dec_simp(x, b2, relax2): +# """Simplified exponential decay function.""" +# return b2 * np.exp(-x / relax2) + +# def exp_func_inc(x, a1, b1, relax1): +# """Full exponential growth function.""" +# return (a1 * x + b1) * (1 - np.exp(-x / relax1)) + +# def exp_func_dec(x, a2, b2, relax2): +# """Full exponential decay function.""" +# return (a2 * x + b2) * np.exp(-x / relax2) + +# I_diff = fit_settings['I_diff'] +# bounds = fit_settings['bounds'] +# p_init = fit_settings['p_init'] +# unify = fit_settings['unify'] + +# parameters = [] +# ys_nor, ys_nor_fit, ys_nor_fit_failed, ys_fit = [], [], [], [] +# labels, losses = [], [] + +# for i in range(len(xs)): +# x = np.linspace(1e-5, 1, len(ys[i])) # Use second as x axis unit +# n_avg = len(ys[i]) // 100 + 3 +# I_end = np.mean(ys[i][-n_avg:]) +# I_start = np.mean(ys[i][:n_avg]) +# y_nor = normalize_0_1(ys[i], I_start, I_end, I_diff, unify) + +# if unify: +# if bounds is None and p_init is None: +# params, _ = curve_fit(exp_func_inc, x, y_nor, absolute_sigma=False) +# else: +# params, _ = curve_fit(exp_func_inc, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) + +# a, b, relax = params +# y_nor_fit = exp_func_inc(x, a, b, relax) +# labels.append(f'{growth_name}-index {i + 1}:\ny=({np.round(a, 2)}t+{np.round(b, 2)})*(1-exp(-t/{np.round(relax, 2)}))') +# parameters.append((a, b, relax)) +# losses.append((0, 0)) +# y_nor_fit_failed = y_nor_fit + +# else: +# params, _ = curve_fit(exp_func_inc_simp, x, y_nor, p0=p_init[1:], bounds=bounds, absolute_sigma=False) +# b1, relax1 = params +# y1_nor_fit = exp_func_inc_simp(x, b1, relax1) + +# params, _ = curve_fit(exp_func_dec_simp, x, y_nor, p0=p_init[1:], bounds=bounds, absolute_sigma=False) +# b2, relax2 = params +# y2_nor_fit = exp_func_dec_simp(x, b2, relax2) + +# loss1 = ((y_nor - y1_nor_fit) ** 2).mean() +# loss2 = ((y_nor - y2_nor_fit) ** 2).mean() + +# params, _ = curve_fit(exp_func_inc, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) +# a1, b1, relax1 = params +# y1_nor_fit = exp_func_inc(x, a1, b1, relax1) + +# params, _ = curve_fit(exp_func_dec, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) +# a2, b2, relax2 = params +# y2_nor_fit = exp_func_dec(x, a2, b2, relax2) + +# if loss1 < loss2: +# y_nor_fit = y1_nor_fit +# labels.append(f'{growth_name}-index {i + 1}:\ny1=({np.round(a1, 2)}t+{np.round(b1, 2)})*(1-exp(-t/{np.round(relax1, 2)}))') +# parameters.append((a1, b1, relax1)) +# y_nor_fit_failed = y2_nor_fit +# else: +# y_nor_fit = y2_nor_fit +# labels.append(f'{growth_name}-index {i + 1}:\ny2=({np.round(a2, 2)}t+{np.round(b2, 2)})*(exp(-t/{np.round(relax2, 2)}))') +# parameters.append((a2, b2, relax2)) +# y_nor_fit_failed = y1_nor_fit + +# losses.append((loss1, loss2)) + +# y_fit = de_normalize_0_1(y_nor_fit, I_start, I_end, I_diff, unify) +# ys_fit.append(y_fit) +# ys_nor.append(y_nor) +# ys_nor_fit.append(y_nor_fit) +# ys_nor_fit_failed.append(y_nor_fit_failed) + +# return np.array(parameters), [xs, ys, ys_fit, ys_nor, ys_nor_fit, ys_nor_fit_failed, labels, losses] + + +def analyze_curves(dataset, growth_dict, spot, metric, interval=1000, fit_settings={'step_size':5, 'prominence':0.1, 'length':500, 'savgol_window_order': (15,3), 'pca_component': 10, 'I_diff': 8000, 'unify':True, 'bounds':[0.01, 1], 'p_init':(1, 0.1)}): + + """ + Analyzes RHEED curves for a given spot and metric. + + Args: + dataset (str): Name of the dataset. + growth_dict (dict): Names of the growth index and corresponding frequency. + spot (str): Name of the RHEED spot to collect, choice of "spot_1", "spot_2" or "spot_3". + metric (str): Name of the metric to analyze the RHEED spot. + interval (int, optional): Number of RHEED curves to analyze at a time. Defaults to 1000. + fit_settings (dict, optional): Setting parameters for fitting function. Defaults to {'savgol_window_order': (15,3), 'pca_component': 10, 'I_diff': 8000, 'unify':True, 'bounds':[0.01, 1], 'p_init':(1, 0.1)}. + + Returns: + tuple: A tuple containing the fitted parameters for all RHEED curves, the laser ablation counts for all RHEED curves, and a list of processed RHEED data. + + """ + + parameters_all, x_list_all = [], [] + xs_all, ys_all, ys_fit_all, ys_nor_all, ys_nor_fit_all, ys_nor_fit_failed_all = [], [], [], [], [], [] + labels_all, losses_all = [], [] + + x_end = 0 + for growth in list(growth_dict.keys()): + + # load data + sample_x, sample_y = dataset.load_curve(growth, spot, metric, x_start=x_end) + # sample_x, sample_y = load_curve(h5_para_file, growth_name, 'spot_2', 'img_intensity', camera_freq=500, x_start=0) + + # detect peaks + curve_params = {'convolve_step':fit_settings['convolve_step'], 'prominence':fit_settings['prominence'], 'mode':'full'} + x_peaks, xs, ys = detect_peaks(sample_x, sample_y, camera_freq=dataset.camera_freq, + laser_freq=growth_dict[growth], curve_params=curve_params) + + xs, ys = process_rheed_data(xs, ys, length=fit_settings['length'], savgol_window_order=fit_settings['savgol_window_order'], + pca_component=fit_settings['pca_component']) + + # fit exponential function + parameters, info = fit_exp_function(xs, ys, growth, fit_settings=fit_settings) + parameters_all.append(parameters) + xs, ys, ys_fit, ys_nor, ys_nor_fit, ys_nor_fit_failed, labels, losses = info + xs_all.append(xs) + ys_all.append(ys) + ys_fit_all+=ys_fit + ys_nor_all+=ys_nor + ys_nor_fit_all+=ys_nor_fit + ys_nor_fit_failed_all+=ys_nor_fit_failed + labels_all += labels + losses_all += losses + + x_list = x_peaks[:-1] + x_end + x_end = round(x_end + (len(sample_x)+interval)/dataset.camera_freq, 2) + x_list_all.append(x_list) + + parameters_all = np.concatenate(parameters_all, 0) + x_list_all = np.concatenate(x_list_all)[:len(parameters_all)] + xs_all = np.concatenate(xs_all) + ys_all = np.concatenate(ys_all) + ys_nor_all = np.array(ys_nor_all) + ys_nor_fit_all = np.array(ys_nor_fit_all) + losses_all = np.array(losses_all) + ys_nor_fit_all_failed = np.array(ys_nor_fit_failed_all) + + return parameters_all, x_list_all, [xs_all, ys_all, ys_fit_all, ys_nor_all, ys_nor_fit_all, ys_nor_fit_all_failed, labels_all, losses_all] \ No newline at end of file diff --git a/src/rheed_learn/Analysis_umich.py b/src/rheed_learn/Analysis_umich.py new file mode 100644 index 0000000..36ae61c --- /dev/null +++ b/src/rheed_learn/Analysis_umich.py @@ -0,0 +1,175 @@ + +import glob, re +import numpy as np +import matplotlib.pyplot as plt +from m3_learning.viz.layout import layout_fig +from m3_learning.RHEED.Viz import Viz +from m3_learning.RHEED.Analysis import detect_peaks, process_rheed_data, process_curves, remove_linear_bg +from m3_learning.RHEED.Fitting import fit_exp_function +seq_colors = ['#00429d','#2e59a8','#4771b2','#5d8abd','#73a2c6','#8abccf','#a5d5d8','#c5eddf','#ffffe0'] +import numpy as np +import matplotlib.pyplot as plt + +import csv +def read_txt_to_numpy(filename): + # Load data using numpy.loadtxt + data = np.loadtxt(filename, dtype=float, skiprows=1, comments=None) + + # Extract header from the first row + with open(filename, 'r') as file: + header = file.readline().strip().split() + return header, data + + +def select_range(data, start, end, y_col=1): + x = data[:,0] + y = data[:, y_col] + x_selected = x[(x>start) & (xstart) & (x np.mean(tau) - n_std*np.std(tau))[0]] + tau = tau[np.where(tau > np.mean(tau) - n_std*np.std(tau))[0]] + # print('mean of tau:', np.mean(tau)) + + if viz_params['viz_ab']: + fig, axes = layout_fig(4, 1, figsize=(12, 3*4)) + Viz.plot_curve(axes[0], sample_x, sample_y, plot_type='lineplot', xlabel='Time (s)', ylabel='Intensity (a.u.)', yaxis_style='sci') + Viz.plot_curve(axes[1], x_list_all, parameters_all[:,0], plot_type='lineplot', xlabel='Time (s)', ylabel='Fitted a (a.u.)') + Viz.plot_curve(axes[2], x_list_all, parameters_all[:,1], plot_type='lineplot', xlabel='Time (s)', ylabel='Fitted b (a.u.)') + Viz.plot_curve(axes[3], x_clean, tau, plot_type='lineplot', xlabel='Time (s)', ylabel='Characteristic Time (s)') + plt.show() + + if viz_params['viz_tau']: + fig, ax1 = plt.subplots(1, 1, figsize=(8, 2.5), layout='compressed') + ax1.scatter(sample_x, sample_y, color='k', s=1) + Viz.set_labels(ax1, xlabel='Time (s)', ylabel='Intensity (a.u.)', ticks_both_sides=False) + + ax2 = ax1.twinx() + ax2.scatter(x_clean, tau, color=seq_colors[0], s=3) + ax2.plot(x_clean, tau, color='#bc5090', markersize=3) + Viz.set_labels(ax2, ylabel='Characteristic Time (s)', yaxis_style='lineplot', ticks_both_sides=False) + ax2.tick_params(axis="y", color='k', labelcolor=seq_colors[0]) + ax2.set_ylabel('Characteristic Time (s)', color=seq_colors[0]) + plt.title(f'mean of tau: {np.mean(tau):.2f}') + plt.show() + return parameters_all, x_list_all, info, tau + + +def plot_activation_energy(temp_list, tau_list, fit=False, title=None, save_path=None): + tau_mean_list = [np.mean(t_list) for t_list in tau_list] + fig, axes = plt.subplots(1, 2, figsize=(8,2.5)) + + tau_mean = np.array(tau_mean_list) + axes[0].scatter(temp_list, tau_mean, color='k', s=10) + axes[0].set_xlabel('T (C)') + axes[0].set_ylabel('tau (s)') + # axes[0].set_ylim(0,0.2) + + T = np.array(temp_list) + 273 + x = 1/(T) + y = -np.log(tau_mean) + axes[1].scatter(x, y, color='k', s=10) + + if fit: + m, b = np.polyfit(x, y, 1) + axes[1].plot(x, y, 'yo', x, m*x+b, '--k') + axes[1].set_xlabel('1/T (1/K))') + axes[1].set_ylabel(r'-ln($\tau$)') + # axes[1].set_title('Ea=' + str(round(m*-8.617e-5, 2)) + ' eV') + # axes[1].set_ylim(1.8,2.5) + + text = f'Ea={round(m*-8.617e-5, 2)}eV, b={b}' + bbox_props = dict(boxstyle="round,pad=0.3", edgecolor="white", facecolor="white") + # axes[1].text(0.25, 0.1, text, transform=axes[1].transAxes, fontsize=10, verticalalignment="center", horizontalalignment="center", bbox=bbox_props) + + if title: + plt.suptitle(title) + if save_path is not None: + plt.savefig(save_path, dpi=300) + plt.show() \ No newline at end of file diff --git a/src/rheed_learn/Dataset.py b/src/rheed_learn/Dataset.py new file mode 100644 index 0000000..2aa919a --- /dev/null +++ b/src/rheed_learn/Dataset.py @@ -0,0 +1,580 @@ +import os +import h5py +import numpy as np +import matplotlib.pyplot as plt +from m3_learning.viz.layout import imagemap, layout_fig, labelfigs +from m3_learning.RHEED.Viz import Viz +from mpl_toolkits.axes_grid1 import make_axes_locatable +import matplotlib.ticker as ticker +import plotly.express as px +import glob +import json # For dealing with metadata +from datafed.CommandLib import API +import random +# from visualization_functions import show_images +from m3_learning.RHEED.Viz import show_images + +def NormalizeData(data, range=(0,1)): + return (((data - np.min(data)) * (range[1] - range[0])) / (np.max(data) - np.min(data))) + range[0] + +def datafed_upload(file_path, parent_id, metadata=None, wait=True): + df_api = API() + file_name = os.path.basename(file_path) + dc_resp = df_api.dataCreate(file_name, metadata=json.dumps(metadata), parent_id=parent_id) + rec_id = dc_resp[0].data[0].id + put_resp = df_api.dataPut(rec_id, file_path, wait=wait) + print(put_resp) + +def datafed_download(file_id, file_path, wait=True): + df_api = API() + get_resp = df_api.dataGet([file_id], # currently only accepts a list of IDs / aliases + file_path, # directory where data should be downloaded + orig_fname=True, # do not name file by its original name + wait=wait, # Wait until Globus transfer completes + ) + print(get_resp) + + +def pack_rheed_data(h5_path, source_dir, ds_names_load, ds_names_create=None, viz=True): + if ds_names_create==None: + ds_names_create = ds_names_load + h5 = h5py.File(h5_path, mode='a') + for ds_name_load, ds_name_create in zip(ds_names_load, ds_names_create): + file_list = glob.glob(source_dir+'/'+ds_name_load+'.*') + length = len(file_list) + img_shape = plt.imread(file_list[0]).shape + print(ds_name_load, ds_names_create, length, img_shape, plt.imread(file_list[0]).dtype) + + createdata = h5.create_dataset(ds_name_create, shape=(length, *img_shape), dtype=np.uint8) + for i, file in enumerate(file_list): + if i % 10000 == 0: + print(f'{i} - {i+10000} ...') + createdata[i] = plt.imread(file) + + imgs = [] + if viz: + random_files = random.choices(file_list, k=8) + for file in random_files: + imgs.append(plt.imread(file)) + show_images(imgs, img_per_row=8) + +def viz_unpacked_images(source_dir, ds_names): + for ds_name in ds_names: + print(ds_name, len(glob.glob(source_dir+'/'+ds_name+'.*')), + plt.imread(glob.glob(source_dir+'/'+ds_name+'.*')[0]).shape) + + files = glob.glob(source_dir+ds_name+'.*') + files = random.choices(files, k=16) + imgs = [] + for file in files: + imgs.append(plt.imread(file)) + print(ds_name) + show_images(imgs, img_per_row=8) + + +def compress_gaussian_params_H5(file_in, str=None, compression='gzip', compression_opts=9): + """ + Compresses Gaussian parameters in an HDF5 file. + + Args: + file_in (str): Path to the input HDF5 file. + str (str, optional): String to append to the output file name. Defaults to None. + compression (str, optional): Compression algorithm to use. Defaults to 'gzip'. + compression_opts (int, optional): Compression level. Defaults to 9. + """ + + if str is None: + out_name = file_in[:-3] + '_compressed.h5' + else: + out_name = file_in[:-3] + str + + with h5py.File(f"{file_in}", 'r') as f_old: + print(f_old.keys()) + with h5py.File(out_name, 'w') as f_new: + for ds in f_old.keys(): + f_new.create_group(ds) + for spot in f_old[ds].keys(): + f_new[ds].create_group(spot) + for metric in f_old[ds][spot].keys(): + data = f_old[ds][spot][metric][:] + dset = f_new[ds][spot].create_dataset( + metric, data=data, compression=compression, compression_opts=compression_opts) + +def compress_RHEED_spot_H5(file_in, str=None, compression='gzip', compression_opts=9): + """ + Compresses RHEED spots in an HDF5 file. + + Args: + file_in (str): Path to the input HDF5 file. + str (str, optional): String to append to the output file name. Defaults to None. + compression (str, optional): Compression algorithm to use. Defaults to 'gzip'. + compression_opts (int, optional): Compression level. Defaults to 9. + """ + + if str is None: + out_name = file_in[:-3] + '_compressed.h5' + else: + out_name = file_in[:-3] + str + + with h5py.File(file_in, 'r') as f_old: + print(f_old.keys()) + with h5py.File(out_name, 'w') as f_new: + for growth in f_old.keys(): + print(growth) + data = f_old[growth][:] + dset = f_new.create_dataset( + growth, data=data, compression=compression, compression_opts=compression_opts) + + +class RHEED_spot_Dataset: + """A class representing a dataset of RHEED spots. + + Attributes: + path (str): The path to the dataset. + sample_name (str): The name of the sample. + verbose (bool): Whether to enable verbose mode or not. + + Methods: + data_info: Prints information about the dataset. + growth_dataset: Retrieves the RHEED spot data for a specific growth. + viz_RHEED_spot: Visualizes a specific RHEED spot. + + Properties: + sample_name: Getter and setter for the sample name. + + """ + def __init__(self, path, sample_name, verbose=False): + """ + Initializes a new instance of the RHEED_spot_Dataset class. + + Args: + path (str): The path to the dataset. + sample_name (str): The name of the sample. + verbose (bool, optional): Whether to enable verbose mode or not. Defaults to False. + """ + self.path = path + self._sample_name = sample_name + + @property + def data_info(self): + """ + Prints information about the dataset. + + This method reads the dataset file and prints the growth names along with the size of their data arrays. + """ + ... + with h5py.File(self.path, mode='r') as h5: + for g, data in h5.items(): + try: + print(f"Growth: {g}, Size of data: f{data.shape}") + except: + print(f"Growth: {g}") + + @property + def dataset_names(self): + """ + Return dataset names. + + This method reads the dataset file and return the names of dataset. + """ + ... + with h5py.File(self.path, mode='r') as h5: + datasets = list(h5.keys()) + return datasets + + def growth_dataset_length(self, growth): + with h5py.File(self.path, mode='r') as h5: + return len(h5[growth]) + + + def growth_dataset(self, growth, index = None): + """ + Retrieves the RHEED spot data for a specific growth. + + Args: + growth (str): The name of the growth. + index (int or list, optional): The index of the data array to retrieve. Defaults to None. + + Returns: + numpy.ndarray: The RHEED spot data as a numpy array. + + Raises: + ValueError: If the index is out of range. + + """ + with h5py.File(self.path, mode='r') as h5: + if index is None: + return np.array(h5[growth]) + else: + if isinstance(index, int): + i_max = index + else: + i_max = np.max(index) + if i_max<0 or i_max>h5[growth].shape[0]: + raise ValueError('Index out of range') + else: + return np.array(h5[growth][index]) + + def viz_RHEED_spot(self, growth, index, figsize=(2, 2), viz_mode='print', clim=None, filename = None, printing=None, **kwargs): + """ + Visualizes a specific RHEED spot. + + Args: + growth (str): The name of the growth. + index (int): The index of the data array to visualize. + figsize (tuple, optional): The size of the figure. Defaults to (2, 2). + viz_mode (str): The visualization mode for spot image. Defaults to 'print', options: 'print', 'iteractive'. + clim (tuple, optional): The color limit for the plot. Defaults to None. + filename (str or bool, optional): The filename to save the plot. If True, a default filename will be used. Defaults to None. + printing: A printing object used for saving the figure. Defaults to None. + **kwargs: Additional keyword arguments to pass to the printing object. + + """ + print(f'\033[1mFig.\033[0m a: RHEED spot image for {growth} at index {index}.') + + # fig, axes = layout_fig(1, figsize=figsize) + + data = self.growth_dataset(growth, index) + # imagemap(axes[0], data, clim=clim, divider_=True) + # customized version of imagemap + + if viz_mode=='iteractive': + plt.figure(figsize=figsize) + im = px.imshow(data) + im.show() + elif viz_mode=='print': + fig, ax = plt.subplots(1, 1, figsize=figsize) + im = ax.imshow(data) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="10%", pad=0.05) + cbar = fig.colorbar(im, ticks=[data.min(), data.max(), np.mean([data.min(), data.max()])], cax=cax, format="%.2e") + ax.set_yticklabels("") + ax.set_xticklabels("") + ax.set_yticks([]) + ax.set_xticks([]) + labelfigs(ax, 0) + if filename is True: + filename = f"RHEED_{self.sample_name}_{growth}_{index}" + + # prints the figure + if printing is not None and filename is not None: + printing.savefig(fig, filename, **kwargs) + + plt.show() + + @property + def sample_name(self): + """ + Getter for the sample name. + + Returns: + str: The name of the sample. + """ + return self._sample_name + + @sample_name.setter + def sample_name(self, sample_name): + """ + Setter for the sample name. + + Returns: + sample_name (str): The name of the sample. + """ + + self._sample_name = sample_name + + +class RHEED_parameter_dataset(): + """A class representing a dataset of RHEED spots with associated parameters. + + Attributes: + path (str): The path to the dataset. + camera_freq (float): The camera frequency. + sample_name (str): The name of the sample. + verbose (bool): Whether to enable verbose mode or not. + + Methods: + data_info: Prints information about the dataset. + growth_dataset: Retrieves the RHEED spot parameter data for a specific growth, spot, and metric. + load_curve: Loads a parameter curve for a specific growth, spot, and metric. + load_multiple_curves: Loads multiple parameter curves for a list of growths, spot, and metric. + viz_RHEED_parameter: Visualizes the RHEED spot parameter for a specific growth, spot, and index. + viz_RHEED_parameter_trend: Visualizes the parameter trends for multiple growths, spot, and metrics. + + Properties: + camera_freq: Getter and setter for the camera frequency. + sample_name: Getter and setter for the sample name. + + """ + + def __init__(self, path, camera_freq, sample_name, verbose=False): + """ + Initializes a new instance of the RHEED_parameter_dataset class. + + Args: + path (str): The path to the dataset. + camera_freq (float): The camera frequency. + sample_name (str): The name of the sample. + verbose (bool, optional): Whether to enable verbose mode or not. Defaults to False. + """ + self.path = path + self._camera_freq = camera_freq + self._sample_name = sample_name + + @property + def data_info(self): + """ + Prints information about the dataset. + + This method reads the dataset file and prints the growth names, spot names, and the size of their associated data arrays. + """ + with h5py.File(self.path, mode='r') as h5: + for g in h5.keys(): + print(f"Growth: {g}:") + for s in h5[g].keys(): + print(f"--spot: {s}:") + for k in h5[g][s].keys(): + try: + print(f"----{k}:, Size of data: {h5[g][s][k].shape}") + except: + print(f"----metric: {k}") + + + def growth_dataset(self, growth, spot, metric, index = None): + """ + Retrieves the RHEED spot parameter data for a specific growth, spot, and metric. + + Args: + growth (str): The name of the growth. + spot (str): The name of the spot. + metric (str): The name of the metric. Options: "raw_image", "reconstructed_image", "img_sum", "img_max", "img_mean", + "img_rec_sum", "img_rec_max", "img_rec_mean", "height", "x", "y", "width_x", "width_y". + index (int, optional): The index of the data array to retrieve. Defaults to None. + + Returns: + numpy.ndarray: The RHEED spot parameter data as a numpy array. + + Raises: + ValueError: If the index is out of range. + + Options for metric: + + """ + + with h5py.File(self.path, mode='r') as h5: + + if index is None: + return np.array(h5[growth][spot][metric]) + else: + if index<0 or index>h5[growth][spot][metric].shape[0]: + raise ValueError('Index out of range') + else: + return np.array(h5[growth][spot][metric][index]) + + def load_curve(self, growth, spot, metric, x_start): + """ + Loads a parameter curve for a specific growth, spot, and metric. + + Args: + growth (str): The name of the growth. + spot (str): The name of the spot. + metric (str): The name of the metric. + x_start (float): The starting x value for the curve. + + Returns: + tuple: A tuple containing the x and y values of the parameter curve. + + """ + with h5py.File(self.path, mode='r') as h5_para: + y = np.array(h5_para[growth][spot][metric]) + x = np.linspace(x_start, x_start+len(y)-1, len(y))/self.camera_freq + return x, y + + def load_multiple_curves(self, growth_list, spot, metric, x_start=0, head_tail=(100, 100), interval=200): + """ + Loads multiple parameter curves for a list of growths, spot, and metric. + + Args: + growth_list (list): The list of growth names. + spot (str): The name of the spot. + metric (str): The name of the metric. + x_start (float, optional): The starting x value for the first curve. Defaults to 0. + head_tail (tuple, optional): The number of elements to remove from the head and tail of each curve. Defaults to (100, 100). + interval (int, optional): The interval between curves. Defaults to 200. + + Returns: + tuple: A tuple containing the concatenated x and y values of the parameter curves. + + """ + x_all, y_all = [], [] + + for growth in growth_list: + x, y = self.load_curve(growth, spot, metric, x_start) + x = x[head_tail[0]:-head_tail[1]] + y = y[head_tail[0]:-head_tail[1]] + x_start = x_start+len(y)+interval + x_all.append(x) + y_all.append(y) + + x_all = np.concatenate(x_all) + y_all = np.concatenate(y_all) + return x_all, y_all + + def viz_RHEED_parameter(self, growth, spot, index, figsize=None, filename=None, printing=None, **kwargs): + """ + Visualizes the RHEED spot parameter for a specific growth, spot, and index. + + Args: + growth (str): The name of the growth. + spot (str): The name of the spot. + index (int): The index of the data array to visualize. + figsize (tuple, optional): The size of the figure. Defaults to None. + filename (str, optional): The filename to save the plot. Defaults to None. + printing: A printing object used for saving the figure. Defaults to None. + **kwargs: Additional keyword arguments to pass to the printing object. + + """ + if figsize is None: + figsize = (1.25*3, 1.25*1) + # "img_mean", "img_rec_sum", "img_rec_max", "img_rec_mean", "height", "x", "y", "width_x", "width_y". + img = self.growth_dataset(growth, spot, 'raw_image', index) + img_rec = self.growth_dataset(growth, spot, 'reconstructed_image', index) + img_sum = self.growth_dataset(growth, spot, 'img_sum', index) + img_max = self.growth_dataset(growth, spot, 'img_max', index) + img_mean = self.growth_dataset(growth, spot, 'img_mean', index) + img_rec_sum = self.growth_dataset(growth, spot, 'img_rec_sum', index) + img_rec_max = self.growth_dataset(growth, spot, 'img_rec_max', index) + img_rec_mean = self.growth_dataset(growth, spot, 'img_rec_mean', index) + height = self.growth_dataset(growth, spot, 'height', index) + x = self.growth_dataset(growth, spot, 'x', index) + y = self.growth_dataset(growth, spot, 'y', index) + width_x = self.growth_dataset(growth, spot, 'width_x', index) + width_y = self.growth_dataset(growth, spot, 'width_y', index) + + sample_list = [img, img_rec, img_rec-img] + + clim = (img.min(), img.max()) + fig, axes = layout_fig(3, 3, figsize=figsize) + for i, ax in enumerate(axes): + if ax == axes[-1]: + imagemap(ax, sample_list[i], divider_=False, clim=clim, colorbars=True, **kwargs) + else: + imagemap(ax, sample_list[i], divider_=False, clim=clim, colorbars=False, **kwargs) + labelfigs(ax, i) + + if filename is True: + filename = f"RHEED_{self.sample_name}_{growth}_{spot}_{index}_img,img_rec,differerce" + + # prints the figure + if printing is not None and filename is not None: + printing.savefig(fig, filename, **kwargs) + plt.show() + print(f'\033[1mFig.\033[0m a: RHEED spot image, b: reconstructed RHEED spot image, c: difference between original and reconstructed image for {growth} at index {index}.') + #print first 2 digits of each parameter + print(f'img_sum={img_sum:.2f}, img_max={img_max:.2f}, img_mean={img_mean:.2f}') + print(f'img_rec_sum={img_rec_sum:.2f}, img_rec_max={img_rec_max:.2f}, img_rec_mean={img_rec_mean:.2f}') + print(f'height={height:.2f}, x={x:.2f}, y={y:.2f}, width_x={width_x:.2f}, width_y_max={width_y:.2f}') + + + def viz_RHEED_parameter_trend(self, growth_list, spot, metric_list=None, head_tail=(100, 100), interval=0, figsize=None, filename = None, printing=None, **kwargs): + """ + Visualizes the parameter trends for multiple growths, spot, and metrics. + + Args: + growth_list (list): The list of growth names. + spot (str): The name of the spot. + metric_list (list, optional): The list of metrics to visualize. Defaults to None. + filename (str, optional): The filename to save the plot. Defaults to None. + printing: A printing object used for saving the figure. Defaults to None. + **kwargs: Additional keyword arguments to pass to the printing object. + + """ + if metric_list is None: + metric_list = ['img_sum', 'img_rec_sum', 'x', 'y', 'width_x', 'width_y'] + + if len(metric_list) == 1: + if figsize == None: + figsize=(6,2) + fig, ax = plt.subplots(len(metric_list), 1, figsize = figsize) + axes = [ax] + else: + if figsize == None: + figsize = (6, 1.5*len(metric_list)) + fig, axes = plt.subplots(len(metric_list), 1, figsize = figsize) + for i, (ax, metric) in enumerate(zip(axes, metric_list)): + x_curve, y_curve = self.load_multiple_curves(growth_list, spot=spot, metric=metric, head_tail=head_tail, interval=interval) #**kwargs) + ax.scatter(x_curve, y_curve, color='k', s=1) + if i < len(metric_list)-1: + Viz.set_labels(ax, ylabel=f'{metric} (a.u.)', yaxis_style='sci') + ax.set_xticklabels(['' for tick in ax.get_xticks()]) + else: + Viz.set_labels(ax, xlabel='Time (s)', ylabel=f'{metric} (a.u.)', yaxis_style='sci') + formatter = ticker.ScalarFormatter(useMathText=True) + formatter.set_powerlimits((-2, 3)) # Adjust the power limits as needed + ax.yaxis.set_major_formatter(formatter) + ax.yaxis.get_offset_text().set_x(-0.05) + labelfigs(ax, i, label=metric, loc='tl', style='b', size=8, inset_fraction=(0.08, 0.03)) + + fig.subplots_adjust(hspace=0) + + if filename: + filename = f"RHEED_{self.sample_name}_{spot}_metrics" + + # prints the figure + if printing is not None and filename is not None: + printing.savefig(fig, filename, **kwargs) + plt.show() + print(f'Gaussian fitted parameters in time: \033[1mFig.\033[0m a: sum of original image, b: sum of reconstructed image, c: spot center in spot x coordinate, d: spot center in y coordinate, e: spot width in x coordinate, f: spot width in y coordinate.') + + + @property + def camera_freq(self): + """ + Getter for the camera frequency. + + Returns: + float: The camera frequency. + """ + return self._camera_freq + + @camera_freq.setter + def camera_freq(self, camera_freq): + """ + Setter for the camera frequency. + + Args: + camera_freq (float): The new camera frequency. + """ + self._camera_freq = camera_freq + + @property + def sample_name(self): + """ + Getter for the sample name. + + Returns: + str: The name of the sample. + """ + return self._sample_name + + @sample_name.setter + def sample_name(self, sample_name): + """ + Setter for the sample name. + + Args: + sample_name (str): The new name for the sample. + """ + self._sample_name = sample_name + + @property + def dataset_names(self): + """ + Return dataset names. + + This method reads the dataset file and return the names of dataset. + """ + ... + with h5py.File(self.path, mode='r') as h5: + datasets = list(h5.keys()) + return datasets \ No newline at end of file diff --git a/src/rheed_learn/Fit.py b/src/rheed_learn/Fit.py new file mode 100644 index 0000000..404c8ef --- /dev/null +++ b/src/rheed_learn/Fit.py @@ -0,0 +1,567 @@ +import os +import h5py +import sys +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.signal import butter, lfilter, sosfilt, freqz +from scipy import optimize +from joblib import Parallel, delayed +import sys +from m3_learning.viz.layout import layout_fig, labelfigs, imagemap +from m3_learning.RHEED.Viz import Viz + + +def NormalizeData(data, lb=0, ub=1): + """ + Normalize the data to a specified range. + + Args: + data: NumPy array or Pandas DataFrame. The input data to be normalized. + lb: float, optional. The lower bound of the normalization range. Default is 0. + ub: float, optional. The upper bound of the normalization range. Default is 1. + + Returns: + NumPy array or Pandas DataFrame: The normalized data. + + """ + return (data - lb) / (ub - lb) + +def show_fft_frequency(amplitude, samplingFrequency, ranges=None): + """ + Display the frequency domain representation of the input signal. + + Args: + amplitude: NumPy array. The amplitude values of the signal. + samplingFrequency: float. The sampling frequency of the signal. + ranges: tuple, optional. The frequency range to display. Default is None. + + Returns: + tuple: The frequencies and corresponding Fourier transform values. + + """ + # Frequency domain representation + fourierTransform = np.fft.fft( + amplitude)/len(amplitude) # Normalize amplitude + fourierTransform = fourierTransform[range( + int(len(amplitude)/2))] # Exclude sampling frequency + + tpCount = len(amplitude) + values = np.arange(int(tpCount/2)) + timePeriod = tpCount/samplingFrequency + + frequencies = values/timePeriod + fourierTransform[abs(fourierTransform) > 1] = 0 + if ranges: + frequencies_ = frequencies[frequencies > ranges[0]] + fourierTransform_ = fourierTransform[frequencies > ranges[0]] + + frequencies_range = frequencies_[frequencies_ < ranges[1]-ranges[0]] + fourierTransform_range = fourierTransform_[ + frequencies_ < ranges[1]-ranges[0]] + else: + frequencies_range = frequencies + fourierTransform_range = fourierTransform + + plt.figure(figsize=(15, 4)) + plt.plot(frequencies_range, abs(fourierTransform_range)) + plt.show() + return frequencies_range, abs(fourierTransform_range) + + +def butter_filter(data, method, filter_type, cutoff, samplingFrequency, order): + """ + Apply a Butterworth filter to the input data. + + Args: + data: NumPy array. The input data to be filtered. + method: str. The method for filter design. Possible values are 'ba' (Butterworth filter using `butter` function) + or 'sos' (Butterworth filter using `sosfilt` function). + filter_type: str. The filter type. Possible values are 'lowpass', 'highpass', 'bandpass', or 'bandstop'. + cutoff: float or tuple. The cutoff frequency or frequencies of the filter. + samplingFrequency: float. The sampling frequency of the data. + order: int. The filter order. + + Returns: + NumPy array: The filtered data. + + """ + nyq = 0.5 * samplingFrequency + + if type(cutoff) == tuple: + cutoff = list(cutoff) + + if type(cutoff) == list: + cutoff[0] = cutoff[0] / nyq + cutoff[1] = cutoff[1] / nyq + else: + cutoff = cutoff / nyq + + if method == 'ba': + b, a = butter(order, cutoff, btype=filter_type, + analog=False, output='ba') + y = lfilter(b, a, data) + if method == 'sos': + sos = butter(order, cutoff, btype=filter_type, + analog=False, output='sos') + y = sosfilt(sos, data) + return y + + +def process_pass_filter(sound, filter_type, method, cutoff, order, frame_range, samplingFrequency=100): + """ + Apply a Butterworth filter to the input sound data and visualize the results. + + Args: + sound: NumPy array. The input sound data. + filter_type: str. The filter type. Possible values are 'lowpass', 'highpass', 'bandpass', or 'bandstop'. + method: str. The method for filter design. Possible values are 'ba' (Butterworth filter using `butter` function) + or 'sos' (Butterworth filter using `sosfilt` function). + cutoff: float or tuple. The cutoff frequency or frequencies of the filter. + order: int. The filter order. + frame_range: tuple. The range of frames to display in the visualization. + samplingFrequency: float, optional. The sampling frequency of the sound data. Default is 100. + + Returns: + NumPy array: The filtered sound data. + + """ + sig = np.copy(sound) + t = np.arange(0, len(sig)) + ranges = None + + filtered = butter_filter(sig, method, filter_type, + cutoff, samplingFrequency, order) + + show_fft_frequency(sig, samplingFrequency, ranges) + show_fft_frequency(filtered, samplingFrequency, ranges) + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), sharex=True) + ax1.plot(t[frame_range[0]:frame_range[1]], + sig[frame_range[0]:frame_range[1]], marker='v') + + ax2.plot(t[frame_range[0]:frame_range[1]], + filtered[frame_range[0]:frame_range[1]], marker='v') + ax2.set_xlabel('Frame') + plt.tight_layout() + plt.show() + return filtered + + +def show_metrics(data, ranges, plot_ranges): + """ + Display metrics and visualizations of the input data. + + Args: + data: List or NumPy array. The input data. + ranges: tuple, optional. The range of data to display in the visualizations. Default is None. + plot_ranges: tuple. The range of data to plot in the metrics visualization. + + Returns: + tuple: The computed metrics. + + """ + len_img = 16 + img_per_row = 8 + fig, ax = plt.subplots(len_img//img_per_row+1*int(len_img % img_per_row > 0), + img_per_row, figsize=(16, 2*len_img//img_per_row+1)) + for i in range(len_img): + ax[i//img_per_row, i % img_per_row].title.set_text(i) + if ranges: + ax[i//img_per_row, i % + img_per_row].imshow(data[i][ranges[0]:ranges[1], ranges[2]:ranges[3]]) + else: + ax[i//img_per_row, i % img_per_row].imshow(data[i]) + + plt.show() + + sum_list, max_list, min_list, mean_list, std_list = [], [], [], [], [] + for i in range(len(data)): + if ranges: + img = data[i][ranges[0]:ranges[1], ranges[2]:ranges[3]] + else: + img = data[i] + sum_list.append(np.sum(img)) + max_list.append(np.max(img)) + min_list.append(np.min(img)) + mean_list.append(np.mean(img)) + std_list.append(np.std(img)) + + fig, ax = plt.subplots(3, 2, figsize=(15, 12)) + + if ranges: + h = ax[0, 0].plot(sum_list[plot_ranges[0]:plot_ranges[1]]) + ax[0, 0].title.set_text('sum_list') + + h = ax[0, 1].plot(max_list[plot_ranges[0]:plot_ranges[1]]) + ax[0, 1].title.set_text('max_list') + + h = ax[1, 0].plot(min_list[plot_ranges[0]:plot_ranges[1]]) + ax[1, 0].title.set_text('min_list') + + h = ax[1, 1].plot(mean_list[plot_ranges[0]:plot_ranges[1]]) + ax[1, 1].title.set_text('mean_list') + + h = ax[2, 0].plot(std_list[plot_ranges[0]:plot_ranges[1]]) + ax[2, 0].title.set_text('std_list') + + else: + h = ax[0, 0].plot(sum_list) + ax[0, 0].title.set_text('sum_list') + + h = ax[0, 1].plot(max_list) + ax[0, 1].title.set_text('max_list') + + h = ax[1, 0].plot(min_list) + ax[1, 0].title.set_text('min_list') + + h = ax[1, 1].plot(mean_list) + ax[1, 1].title.set_text('mean_list') + + h = ax[2, 0].plot(std_list) + ax[2, 0].title.set_text('std_list') + + plt.show() + return sum_list, max_list, min_list, mean_list, std_list + + +class Gaussian(): + def __init__(self): + self.a = 0 + + def gaussian(self, height, center_x, center_y, width_x, width_y, rotation): + """ + Returns a Gaussian function with the given parameters. + + Args: + height: float. The height of the Gaussian function. + center_x: float. The x-coordinate of the center of the Gaussian function. + center_y: float. The y-coordinate of the center of the Gaussian function. + width_x: float. The width of the Gaussian function in the x-direction. + width_y: float. The width of the Gaussian function in the y-direction. + rotation: float. The rotation angle of the Gaussian function in degrees. + + Returns: + function: A function representing the Gaussian function with the given parameters. + + """ + width_x = float(width_x) + width_y = float(width_y) + + rotation = np.deg2rad(rotation) + center_x = center_x * np.cos(rotation) - center_y * np.sin(rotation) + center_y = center_x * np.sin(rotation) + center_y * np.cos(rotation) + + def rotgauss(x, y): + xp = x * np.cos(rotation) - y * np.sin(rotation) + yp = x * np.sin(rotation) + y * np.cos(rotation) + g = height*np.exp( + -(((center_x-xp)/width_x)**2 + + ((center_y-yp)/width_y)**2)/2.) + return g + return rotgauss + + def moments(self, data): + """ + Returns the Gaussian parameters (height, x, y, width_x, width_y) of a 2D distribution by calculating its moments. + + Args: + data: NumPy array. The input data representing the 2D distribution. + + Returns: + tuple: The Gaussian parameters (height, x, y, width_x, width_y). + + """ + + total = data.sum() + X, Y = np.indices(data.shape) + x = (X*data).sum()/total + y = (Y*data).sum()/total + col = data[:, int(y)] + width_x = np.sqrt(abs((np.arange(col.size)-y)**2*col).sum()/col.sum()) + row = data[int(x), :] + width_y = np.sqrt(abs((np.arange(row.size)-x)**2*row).sum()/row.sum()) + height = data.max() + return height, x, y, width_x, width_y, 0.0 + + def fitgaussian(self, data): + """ + Returns the Gaussian parameters (height, x, y, width_x, width_y) of a 2D distribution found by a fit. + + Args: + data: NumPy array. The input data representing the 2D distribution. + + Returns: + tuple: The Gaussian parameters (height, x, y, width_x, width_y). + + """ + params = self.moments(data) + def errorfunction(p): return np.ravel( + self.gaussian(*p)(*np.indices(data.shape)) - data) + p, success = optimize.leastsq(errorfunction, params) + return p + + def recreate_gaussian(self, image): + """ + Reconstructs a Gaussian function from an image by fitting the Gaussian parameters. + + Args: + image: NumPy array. The input image. + + Returns: + tuple: The reconstructed Gaussian image and its parameters. + + """ + para = self.fitgaussian(image) + y = np.linspace(0, image.shape[0], image.shape[0]) + x = np.linspace(0, image.shape[1], image.shape[1]) + x, y = np.meshgrid(x, y) + return self.gaussian(*para)(y, x), para + + +class RHEED_image_processer: + def __init__(self, spot_ds, crop_dict, fit_function): + """ + Initializes the RHEED_image_processer class. + + Args: + spot_ds: Object. The spot dataset. spots_names: ['spot_1', 'spot_2', 'spot_3', 'background', ...] + crop_dict: dict. The dictionary containing crop information for different spots. + fit_function: function. The function to fit the Gaussian parameters. + + """ + self.spot_ds = spot_ds + self.crop_dict = crop_dict + self.fit_function = fit_function + + def preview_metric(self, growth, spot, metric, camera_freq, viz=False, process_chunk_size=50000): + if metric not in ['max', 'sum', 'mean']: + print(f'can"t preview {metric}, requires to do gaussian fit') + return + metric_in_time = [] + total_length = self.spot_ds.growth_dataset_length(growth) + for i in range(0, total_length, process_chunk_size): + i_end = np.min((total_length, i+process_chunk_size)) + print(f'{i} - {i_end} ...') + + chunk = self.spot_ds.growth_dataset(growth, index=list(range(i,i_end))) + inputs = self.normalize_and_crop_inputs(chunk, spot) + + if metric == 'max': + metric_in_time.append(np.max(inputs, axis=(1,2))) + if metric == 'sum': + metric_in_time.append(np.sum(inputs, axis=(1,2))) + if metric == 'mean': + metric_in_time.append(np.mean(inputs, axis=(1,2))) + metric_in_time = np.concatenate(metric_in_time) + time = np.linspace(0, len(metric_in_time)/camera_freq, len(metric_in_time)) + + if viz: + fig, axes = plt.subplots(1, 1, figsize = (6, 2)) + plt.scatter(x=time, y=metric_in_time) + plt.show() + + return metric_in_time + + + def write_h5_file(self, parameters_file_path, growth_list, replace=False, num_workers=1, process_chunk_size=50000): + + """ + Writes the parameters of RHEED images to an HDF5 file. + + Args: + parameters_file_path: str. The file path of the HDF5 file. + growth_list: list. The list of growth names. + replace: bool, optional. Whether to replace the existing HDF5 file. Default is False. + num_workers: int, optional. The number of workers for parallel processing. Default is 1. + + """ + spots_names = list(self.crop_dict.keys()) + + if os.path.isfile(parameters_file_path): + print('h5 file exist.') + if replace: + os.remove(parameters_file_path) + print('Replace with new file.') + with h5py.File(parameters_file_path, mode='a') as h5_para: + for growth in growth_list: + print(f'{growth}:') + h5_growth = h5_para.create_group(growth) + + for spot in spots_names: + print(f' {spot} generation started:') + + total_length = self.spot_ds.growth_dataset_length(growth) + if total_length > process_chunk_size: + h5_spot = h5_growth.create_group(spot) + + # make some samples + samples = self.spot_ds.growth_dataset(growth, index=list(range(10))) + inputs = self.normalize_and_crop_inputs(samples, spot) + results = self.fit_batch(inputs, num_workers=1) + img_all = np.array([res[0] for res in results]) + img_rec_all = np.array([res[1] for res in results]) + parameters = np.array([res[2] for res in results]) + + raw_image = h5_spot.create_dataset('raw_image', shape=(total_length, *img_all.shape[1:]), dtype=img_all.dtype) + reconstructed_image = h5_spot.create_dataset('reconstructed_image', shape=(total_length, *img_rec_all.shape[1:]), dtype=img_all.dtype) + img_sum = h5_spot.create_dataset('img_sum', shape=(total_length, ), dtype=parameters[:, 0].dtype) + img_max = h5_spot.create_dataset('img_max', shape=(total_length, ), dtype=parameters[:, 1].dtype) + img_mean = h5_spot.create_dataset('img_mean', shape=(total_length, ), dtype=parameters[:, 2].dtype) + img_rec_sum = h5_spot.create_dataset('img_rec_sum', shape=(total_length, ), dtype=parameters[:, 3].dtype) + img_rec_max = h5_spot.create_dataset('img_rec_max', shape=(total_length, ), dtype=parameters[:, 4].dtype) + img_rec_mean = h5_spot.create_dataset('img_rec_mean', shape=(total_length, ), dtype=parameters[:, 5].dtype) + height = h5_spot.create_dataset('height', shape=(total_length, ), dtype=parameters[:, 6].dtype) + x = h5_spot.create_dataset('x', shape=(total_length, ), dtype=parameters[:, 7].dtype) + y = h5_spot.create_dataset('y', shape=(total_length, ), dtype=parameters[:, 8].dtype) + width_x = h5_spot.create_dataset('width_x', shape=(total_length, ), dtype=parameters[:, 9].dtype) + width_y = h5_spot.create_dataset('width_y', shape=(total_length, ), dtype=parameters[:, 10].dtype) + + for i in range(0, total_length, process_chunk_size): + i_end = np.min((total_length, i+process_chunk_size)) + print(f' {i} - {i_end} ...') + chunk = self.spot_ds.growth_dataset(growth, index=list(range(i,i_end))) + + inputs = self.normalize_and_crop_inputs(chunk, spot) + results = self.fit_batch(inputs, num_workers) + + img_all = np.array([res[0] for res in results]) + img_rec_all = np.array([res[1] for res in results]) + parameters = np.array([res[2] for res in results]) + + raw_image[i:i_end] = img_all[:] + reconstructed_image[i:i_end] = img_rec_all[:] + img_sum[i:i_end] = parameters[:, 0] + img_max[i:i_end] = parameters[:, 1] + img_mean[i:i_end] = parameters[:, 2] + img_rec_sum[i:i_end] = parameters[:, 3] + img_rec_max[i:i_end] = parameters[:, 4] + img_rec_mean[i:i_end] = parameters[:, 5] + height[i:i_end] = parameters[:, 6] + x[i:i_end] = parameters[:, 7] + y[i:i_end] = parameters[:, 8] + width_x[i:i_end] = parameters[:, 9] + width_y[i:i_end] = parameters[:, 10] + + else: + print(f' {0} - {total_length} ...') + inputs = self.normalize_and_crop_inputs(self.spot_ds.growth_dataset(growth), spot) + results = self.fit_batch(inputs, num_workers) + + img_all = np.array([res[0] for res in results]) + img_rec_all = np.array([res[1] for res in results]) + parameters = np.array([res[2] for res in results]) + + h5_spot = h5_growth.create_group(spot) + h5_spot.create_dataset('raw_image', data=img_all) + h5_spot.create_dataset('reconstructed_image', data=img_rec_all) + h5_spot.create_dataset('img_sum', data=parameters[:, 0]) + h5_spot.create_dataset('img_max', data=parameters[:, 1]) + h5_spot.create_dataset('img_mean', data=parameters[:, 2]) + h5_spot.create_dataset('img_rec_sum', data=parameters[:, 3]) + h5_spot.create_dataset('img_rec_max', data=parameters[:, 4]) + h5_spot.create_dataset('img_rec_mean', data=parameters[:, 5]) + h5_spot.create_dataset('height', data=parameters[:, 6]) + h5_spot.create_dataset('x', data=parameters[:, 7]) + h5_spot.create_dataset('y', data=parameters[:, 8]) + h5_spot.create_dataset('width_x', data=parameters[:, 9]) + h5_spot.create_dataset('width_y', data=parameters[:, 10]) + + def normalize_and_crop_inputs(self, data, spot): + """ + Normalizes the input data for a specific spot. + + Args: + data: NumPy array. The input data. + spot: str. The spot name. + + Returns: + NumPy array: The normalized input data. + + """ + crop = self.crop_dict[spot] + if len(data.shape) == 2: + inputs = NormalizeData(data[crop['y_start']:crop['y_end'], crop['x_start']:crop['x_end']]) + elif len(data.shape) == 3: + inputs = NormalizeData( + np.array(data[:, crop['y_start']:crop['y_end'], crop['x_start']:crop['x_end']])) + return inputs + + def fit_batch(self, inputs, num_workers): + """ + Fits the Gaussian parameters for a batch of inputs. + + Args: + inputs: list. The list of input data. + num_workers: int. The number of workers for parallel processing. + + Returns: + list: The list of results containing the input, reconstructed image, and parameters. + + """ + if num_workers > 1: + tasks = [delayed(self.fit)(img) for img in inputs] + results = Parallel(n_jobs=num_workers)(tasks) + else: + results = [self.fit(img) for img in inputs] + return results + + def fit(self, img): + """ + Fits the Gaussian parameters for a single input image. + + Args: + img: NumPy array. The input image. + + Returns: + tuple: The input image, reconstructed image, and parameters. + + """ + # para: height, x, y, width_x, width_y, 0.0 + img_rec, para = self.fit_function(img) + img_sum, img_max, img_mean = np.sum(img), np.max(img), np.mean(img) + img_rec_sum, img_rec_max, img_rec_mean = np.sum(img_rec), np.max(img_rec), np.mean(img_rec) + parameters = [img_sum, img_max, img_mean, + img_rec_sum, img_rec_max, img_rec_mean, *para] + return img, img_rec, parameters + + def visualize(self, growth, spot, frame, figsize=(1.25*3, 1.25*1), **kwargs): + """ + Visualizes the RHEED image processing results for a specific growth, spot, and frame. + + Args: + growth: str. The growth name. + spot: str. The spot name. + frame: int. The frame number. + **kwargs: Additional keyword arguments for visualization. + + Returns: + tuple: The input image, reconstructed image, and parameters. + + """ + img = self.spot_ds.growth_dataset(growth, frame) + img = self.normalize_and_crop_inputs(img, spot) + img, img_rec, parameters = self.fit(img) + + sample_list = [img, img_rec, img_rec-img] + clim = (img.min(), img.max()) + + fig, axes = layout_fig(3, 3, figsize=figsize) + for i, ax in enumerate(axes): + if ax == axes[-1]: + imagemap(ax, sample_list[i], divider_=False, clim=clim, colorbars=True, **kwargs) + else: + imagemap(ax, sample_list[i], divider_=False, clim=clim, colorbars=False, **kwargs) + + # ax.imshow(sample_list[i]) + labelfigs(ax, i) + + + plt.show() + print(f'\033[1mFig.\033[0m a: RHEED spot image, b: reconstructed RHEED spot image, c: difference between original and reconstructed image for {growth} at index {frame}.') + #print first 2 digits of each parameter + print(f'The Gaussian fitted parameters are: img_sum={parameters[0]:.2f}, img_max={parameters[1]:.2f}, img_mean={parameters[2]:.2f},') + print(f'img_rec_sum={parameters[3]:.2f}, img_rec_max={parameters[4]:.2f}, img_rec_mean={parameters[5]:.2f},') + print(f'height={parameters[6]:.2f}, x={parameters[7]:.2f}, y={parameters[8]:.2f}, width_x={parameters[9]:.2f}, width_y_max={parameters[10]:.2f}.') + + return img, img_rec, parameters diff --git a/src/rheed_learn/Fitter1D.py b/src/rheed_learn/Fitter1D.py new file mode 100644 index 0000000..9047c1d --- /dev/null +++ b/src/rheed_learn/Fitter1D.py @@ -0,0 +1,1210 @@ +import torch.nn as nn +import torch +# from ...optimizers.AdaHessian import AdaHessian +""" +Created on Sun Feb 26 16:34:00 2021 +@author: Amir Gholami +@coauthor: David Samuel +""" + +import numpy as np +import torch + +class AdaHessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + Arguments: + params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional) -- learning rate (default: 0.1) + betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999)) + eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0) + update_each (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1) + n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, average_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.average_conv_kernel = average_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.generator = torch.Generator().manual_seed(2147483647) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(AdaHessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(2147483647) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] # Rademacher distribution {-1.0, 1.0} + h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.average_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of gradient values + state['exp_hessian_diag_sq'] = torch.zeros_like(p.data) # Exponential moving average of Hessian diagonal square values + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + + + +# from ...nn.random import random_seed + + +from m3_learning.nn.benchmarks.inference import computeTime +# from ...viz.layout import get_axis_range, set_axis, Axis_Ratio +from torch.utils.data import DataLoader +import time +import numpy as np +from sklearn.metrics import mean_squared_error +import matplotlib.pyplot as plt +from scipy.signal import resample +from m3_learning.util.file_IO import make_folder, append_to_csv +import itertools +# from m3_learning.optimizers.TrustRegion import TRCG +import torch +from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable +import numpy as np + + +# extended from Zheng Shi zhs310@lehigh.edu and Majid Jahani maj316@lehigh.edu +# https://github.com/Optimization-and-Machine-Learning-Lab/TRCG +# BSD 3-Clause License + +# Copyright (c) 2023, Optimization-and-Machine-Learning-Lab + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +class TRCG(Optimizer): + + + def __init__(self, model, + radius, device, + closure_size = 1, # specifies how many parts the +# lr=required, momentum=0, dampening=0, +# weight_decay=0, nesterov=False, *, maximize: bool = False, +# foreach: Optional[bool] = None, + cgopttol=1e-7,c0tr=0.2,c1tr=0.25,c2tr=0.75,t1tr=0.25,t2tr=2.0, + radius_max=5.0, + radius_initial=0.1, + differentiable: bool = False + ): + + + self.model = model + self.device = device + self.cgopttol = cgopttol + self.c0tr = c0tr + self.c1tr = c1tr + self.c2tr = c2tr + self.t1tr = t1tr + self.t2tr = t2tr + self.radius_max = radius_max + self.radius_initial = radius_initial + self.radius = radius + self.cgmaxiter = 60 + + + +# if lr is not required and lr < 0.0: +# raise ValueError("Invalid learning rate: {}".format(lr)) +# if momentum < 0.0: +# raise ValueError("Invalid momentum value: {}".format(momentum)) +# if weight_decay < 0.0: +# raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + self.closure_size = closure_size + + + defaults = dict( +# lr=lr, momentum=momentum, dampening=dampening, +# weight_decay=weight_decay, nesterov=nesterov, +# maximize=maximize, foreach=foreach, + differentiable=differentiable + ) +# if nesterov and (momentum <= 0 or dampening != 0): +# raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + self.params = list(model.parameters()) + + + super(TRCG, self).__init__(self.params, defaults) + + def findroot(self,x,p): + aa = 0.0; bb = 0.0; cc = 0.0 + for pi, xi in zip(p,x): + aa += (pi*pi).sum() + bb += (pi*xi).sum() + cc += (xi*xi).sum() + bb = bb*2.0 + cc = cc - self.radius**2 + alpha = (-2.0*cc)/(bb+(bb**2-(4.0*aa*cc)).sqrt()) + return alpha.data.item() + + def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + + state = self.state[p] + if 'pk' not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state['ph']) + + + + def CGSolver(self,loss_grad,cnt_compute, closure): + + cg_iter = 0 # iteration counter + x0 = [] # define x_0 as a list + for i in self.model.parameters(): + x0.append(torch.zeros(i.shape).to(self.device)) + + r0 = [] # set initial residual to gradient + p0 = [] # set initial conjugate direction to -r0 + self.cgopttol = 0.0 + + for i in loss_grad: + r0.append(i.data+0.0) + p0.append(0.0-i.data) + self.cgopttol+=torch.norm(i.data)**2 + + self.cgopttol = self.cgopttol.data.item()**0.5 + self.cgopttol = (min(0.5,self.cgopttol**0.5))*self.cgopttol + + cg_term = 0 + j = 0 + + while 1: + j+=1 + + # if CG does not solve model within max allowable iterations + if j > self.cgmaxiter: + j=j-1 + p1 = x0 + print ('\n\nCG has issues !!!\n\n') + break + # hessian vector product + + + + Hp = self.computeHessianVector(closure, p0) + cnt_compute+=1 + + + pHp = self.computeDotProduct(Hp, p0) # quadratic term + + # if nonpositive curvature detected, go for the boundary of trust region + if pHp.data.item() <= 0: + tau = self.findroot(x0,p0) + p1 = [] + for e in range(len(x0)): + p1.append(x0[e]+tau*p0[e]) + cg_term = 1 + break + + # if positive curvature + # vector product + rr0 = 0.0 + for i in r0: + rr0 += (i*i).sum() + + # update alpha + alpha = (rr0/pHp).data.item() + + x1 = [] + norm_x1 = 0.0 + for e in range(len(x0)): + x1.append(x0[e]+alpha*p0[e]) + norm_x1 += torch.norm(x0[e]+alpha*p0[e])**2 + norm_x1 = norm_x1**0.5 + + # if norm of the updated x1 > radius + if norm_x1.data.item() >= self.radius: + tau = self.findroot(x0,p0) + p1 = [] + for e in range(len(x0)): + p1.append(x0[e]+tau*p0[e]) + cg_term = 2 + break + + # update residual + r1 = [] + norm_r1 = 0.0 + for e in range(len(r0)): + r1.append(r0[e]+alpha*Hp[e]) + norm_r1 += torch.norm(r0[e]+alpha*Hp[e])**2 + norm_r1 = norm_r1**0.5 + + if norm_r1.data.item() < self.cgopttol: + p1 = x1 + cg_term = 3 + break + + rr1 = 0.0 + for i in r1: + rr1 += (i*i).sum() + + beta = (rr1/rr0).data.item() + + # update conjugate direction for next iterate + p1 = [] + for e in range(len(r1)): + p1.append(-r1[e]+beta*p0[e]) + + p0 = p1 + x0 = x1 + r0 = r1 + + + cg_iter = j + d = p1 + + return d,cg_iter,cg_term,cnt_compute + + + + def computeHessianVector(self, closure, p): + + + with torch.enable_grad(): + if self.closure_size == 1 and self.gradient_cache is not None: + # we reuse the gradient computation + + Hpp = torch.autograd.grad(self.gradient_cache, + self.params, + grad_outputs=p, + retain_graph=True) # hessian-vector in tuple + Hp = [Hpi.data+0.0 for Hpi in Hpp] + + + + + else: + + for part in range(self.closure_size): + loss = closure(part,self.closure_size, self.device) + loss_grad_v = torch.autograd.grad(loss,self.params,create_graph=True) + Hpp = torch.autograd.grad(loss_grad_v, + self.params, + grad_outputs=p, + retain_graph=False) # hessian-vector in tuple + if part == 0: + Hp = [Hpi.data+0.0 for Hpi in Hpp] + else: + for Hpi, Hppi in zip(Hp, Hpp): + Hpi.add_(Hppi) + + + return Hp + + def computeLoss(self, closure): + lossVal = 0.0 + with torch.no_grad(): + for part in range(self.closure_size): + loss = closure(part,self.closure_size, self.device) + lossVal+= loss.item() + + + return lossVal + + + def computeGradientAndLoss(self, closure): + lossVal = 0.0 + with torch.enable_grad(): + for part in range(self.closure_size): + loss = closure(part,self.closure_size, self.device) + lossVal+= loss.item() + if self.closure_size == 1 and self.gradient_cache is None: + + loss_grad = torch.autograd.grad(loss,self.params,retain_graph=True,create_graph=True) + self.gradient_cache = loss_grad + else: + + loss_grad = torch.autograd.grad(loss,self.params,create_graph=False) + + if part == 0: + grad = [p.data+0.0 for p in loss_grad] + else: + for gi, gip in zip(grad, loss_grad): + gi.add_(gip) + + + return lossVal, grad + + def computeGradient(self, closure): + return self.computeGradientAndLoss(closure)[1] + + + return grad + + def computeDotProduct(self,v,z): + return torch.sum(torch.vstack([ (vi*zi).sum() for vi, zi in zip(v, z) ])) + + def computeNorm(self,v): + return torch.sqrt(torch.sum(torch.vstack([ (p**2).sum() for p in v]))) + + @_use_grad_for_differentiable + def step(self, closure): + """Performs a single optimization step. + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + self.gradient_cache = None + + # store the initial weights + wInit = [w+0.0 for w in self.params] + + update = 2 + + lossInit, loss_grad = self.computeGradientAndLoss(closure) + NormG = self.computeNorm(loss_grad) + + cnt_compute=1 + + + + # Conjugate Gradient Method + d, cg_iter, cg_term, cnt_compute = self.CGSolver(loss_grad,cnt_compute, closure) + + + Hd = self.computeHessianVector(closure, d) + dHd = self.computeDotProduct(Hd, d) + + + # update solution + for wi, di in zip(self.params, d): + with torch.no_grad(): + wi.add_(di) + + loss_new = self.computeLoss(closure) + numerator = lossInit - loss_new + + gd = self.computeDotProduct(loss_grad, d) + + norm_d = self.computeNorm(d) + + denominator = -gd.data.item() - 0.5*(dHd.data.item()) + + # ratio + rho = numerator/denominator + + + outFval = loss_new + if rho < self.c1tr: # shrink radius + self.radius = self.t1tr*self.radius + update = 0 + elif rho > self.c2tr and np.abs(norm_d.data.item() - self.radius) < 1e-10: # enlarge radius + self.radius = min(self.t2tr*self.radius,self.radius_max) + update = 1 + # otherwise, radius remains the same + if rho <= self.c0tr or numerator < 0: # reject d + update = 3 + self.radius = self.t1tr*self.radius + for wi, di in zip(self.params, d): + with torch.no_grad(): + wi.sub_(di) + outFval = lossInit + return outFval, self.radius, cnt_compute, cg_iter + +from m3_learning.util.rand_util import save_list_to_txt +import pandas as pd + +def static_state_decorator(func): + """Decorator that stops the function from changing the state + + Args: + func (method): any method + """ + def wrapper(*args, **kwargs): + current_state = args[0].get_state + out = func(*args, **kwargs) + args[0].set_attributes(**current_state) + return out + return wrapper + + +def write_csv(write_CSV, + path, + model_name, + optimizer_name, + epochs, + total_time, + train_loss, + batch_size, + loss_func, + seed, + stoppage_early, + model_updates): + + if write_CSV is not None: + headers = ["Model Name", + "Optimizer", + "Epochs", + "Training_Time", + "Train Loss", + "Batch Size", + "Loss Function", + "Seed", + "filename", + "early_stoppage", + "model updates"] + data = [model_name, + optimizer_name, + epochs, + total_time, + train_loss, + batch_size, + loss_func, + seed, + f"{path}/{model_name}_model_epoch_{epochs}_train_loss_{train_loss}.pth", + f"{stoppage_early}", + f"{model_updates}"] + append_to_csv(f"{path}/{write_CSV}", data, headers) + + +class Multiscale1DFitter(nn.Module): + + def __init__(self, + function, # function to fit + x_data, # x_data to generate + input_channels, # number of input channels + num_params, # number of parameters to fit + scaler = None, # scaler object + post_processing = None, + device = "cuda", + **kwargs): + + self.input_channels = input_channels + self.scaler = scaler + self.function = function + self.x_data = x_data + self.post_processing = post_processing + self.device = device + self.num_params = num_params + + super().__init__() + + # Input block of 1d convolution + self.hidden_x1 = nn.Sequential( + nn.Conv1d(in_channels=self.input_channels, out_channels=8, kernel_size=7), + nn.SELU(), + nn.Conv1d(in_channels=8, out_channels=6, kernel_size=7), + nn.SELU(), + nn.Conv1d(in_channels=6, out_channels=4, kernel_size=5), + nn.SELU(), + nn.AdaptiveAvgPool1d(64) + ) + + # fully connected block + self.hidden_xfc = nn.Sequential( + nn.Linear(256, 20), + nn.SELU(), + nn.Linear(20, 20), + nn.SELU(), + ) + + # 2nd block of 1d-conv layers + self.hidden_x2 = nn.Sequential( + nn.MaxPool1d(kernel_size=2), + nn.Conv1d(in_channels=2, out_channels=4, kernel_size=5), + nn.SELU(), + nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5), + nn.SELU(), + nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5), + nn.SELU(), + nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5), + nn.SELU(), + nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5), + nn.SELU(), + nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5), + nn.SELU(), + nn.AdaptiveAvgPool1d(16), # Adaptive pooling layer + nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3), + nn.SELU(), + nn.AdaptiveAvgPool1d(8), # Adaptive pooling layer + nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3), + nn.SELU(), + nn.AdaptiveAvgPool1d(4), # Adaptive pooling layer + ) + + # Flatten layer + self.flatten_layer = nn.Flatten() + + # Final embedding block - Output 4 values - linear + self.hidden_embedding = nn.Sequential( + nn.Linear(28, 16), + nn.SELU(), + nn.Linear(16, 8), + nn.SELU(), + nn.Linear(8, self.num_params), + ) + + def forward(self, x, n=-1): + # print(x.dtype) + # print('1', self.x_data.shape, x.shape) + # output shape - samples, (real, imag), frequency + # x = torch.swapaxes(x, 1, 2) + x = self.hidden_x1(x) + # print(x.dtype) + + xfc = torch.reshape(x, (n, 256)) # batch size, features + xfc = self.hidden_xfc(xfc) + + # batch size, (real, imag), timesteps + x = torch.reshape(x, (n, 2, 128)) + x = self.hidden_x2(x) + cnn_flat = self.flatten_layer(x) + + encoded = torch.cat((cnn_flat, xfc), 1) # merge dense and 1d conv. + embedding = self.hidden_embedding(encoded) # output is 4 parameters + + if self.scaler is not None: + # corrects the scaling of the parameters + unscaled_param = ( + embedding * + torch.tensor(self.scaler.var_ ** 0.5).to(self.device) + + torch.tensor(self.scaler.mean_).to(self.device) + ) + else: + unscaled_param = embedding + # print(unscaled_param.shape) + # unscaled_param[:,0] = torch.relu(unscaled_param[:,0]) + unscaled_param[:,0] = torch.tanh(unscaled_param[:,0]) + # unscaled_param[:,1] = torch.tanh(unscaled_param[:,2]) + unscaled_param[:,1] = torch.relu(unscaled_param[:,1])+1e-3 + + # frequency_bins = resample(self.dataset.frequency_bin, + # self.dataset.resampled_bins) + + # print(unscaled_param.shape, self.x_data.shape) + # passes to the pytorch fitting function + fits = self.function( + unscaled_param, self.x_data, device=self.device) + + # Does the post processing if required + if self.post_processing is not None: + out = self.post_processing.compute(fits) + else: + out = fits + return out, unscaled_param + + # if self.training == True: + # return out, unscaled_param + # if self.training == False: + # # this is a scaling that includes the corrections for shifts in the data + # embeddings = (unscaled_param.to(self.device) - torch.tensor(self.scaler.mean_).to(self.device) + # )/torch.tensor(self.scaler.var_ ** 0.5).to(self.device) + # return out, embeddings, unscaled_param + + +class Model(nn.Module): + + def __init__(self, + model, + dataset, + model_basename='', + training=True, + path='Trained Models/SHO Fitter/', + device=None, + **kwargs): + + super().__init__() + + if device is None: + if torch.cuda.is_available(): + self.device = "cuda" + print(f"Using GPU {torch.cuda.get_device_name(0)}") + else: + self.device = "cpu" + print("Using CPU") + + self.model = model + self.model.dataset = dataset + self.model.training = True + self.model_name = model_basename + self.path = make_folder(path) + + def fit(self, + data_train, + batch_size=200, + epochs=5, + loss_func=torch.nn.MSELoss(), + optimizer='Adam', + seed=42, + datatype=torch.float32, + save_all=False, + write_CSV=None, + closure=None, + basepath=None, + early_stopping_loss=None, + early_stopping_count=None, + early_stopping_time=None, + save_training_loss=True, + **kwargs): + + loss_ = [] + + if basepath is not None: + path = f"{self.path}/{basepath}/" + make_folder(path) + print(f"Saving to {path}") + else: + path = self.path + + # sets the model to be a specific datatype and on cuda + self.to(datatype).to(self.device) + + # Note that the seed will behave differently on different hardware targets (GPUs) + # random_seed(seed=seed) + + torch.cuda.empty_cache() + + # selects the optimizer + if optimizer == 'Adam': + optimizer_ = torch.optim.Adam(self.model.parameters()) + elif optimizer == "AdaHessian": + optimizer_ = AdaHessian(self.model.parameters(), lr=.5) + elif isinstance(optimizer, dict): + if optimizer['name'] == "TRCG": + optimizer_ = optimizer['optimizer']( + self.model, optimizer['radius'], optimizer['device']) + elif isinstance(optimizer, dict): + if optimizer['name'] == "TRCG": + optimizer_ = optimizer['optimizer']( + self.model, optimizer['radius'], optimizer['device']) + else: + try: + optimizer = optimizer(self.model.parameters()) + except: + raise ValueError("Optimizer not recognized") + + # instantiate the dataloader + train_dataloader = DataLoader( + data_train, batch_size=batch_size, shuffle=True) + + # if trust region optimizers stores the TR optimizer as an object and instantiates the ADAM optimizer + if isinstance(optimizer_, TRCG): + TRCG_OP = optimizer_ + optimizer_ = torch.optim.Adam(self.model.parameters(), **kwargs) + + total_time = 0 + low_loss_count = 0 + + # says if the model have already stopped early + already_stopped = False + + model_updates = 0 + + # loops around each epoch + for epoch in range(epochs): + + train_loss = 0.0 + total_num = 0 + epoch_time = 0 + + # sets the model to training mode + self.model.train() + + for train_batch in train_dataloader: + + model_updates += 1 + + # starts the timer + start_time = time.time() + + train_batch = train_batch.to(datatype).to(self.device) + + if "TRCG_OP" in locals() and epoch > optimizer.get("ADAM_epochs", -1): + + def closure(part, total, device): + pred, embedding = self.model(train_batch) + pred = pred.to(torch.float32) + embedding = embedding.to(torch.float32) + loss = loss_func(train_batch, pred) + return loss + + # if closure is not None: + loss, radius, cnt_compute, cg_iter = TRCG_OP.step( + closure) + train_loss += loss * train_batch.shape[0] + total_num += train_batch.shape[0] + optimizer_name = "Trust Region CG" + else: + pred, embedding = self.model(train_batch) + pred = pred.to(torch.float32) + pred = torch.atleast_3d(pred) + embedding = embedding.to(torch.float32) + optimizer_.zero_grad() + loss = loss_func(train_batch, pred) + loss.backward(create_graph=True) + train_loss += loss.item() * pred.shape[0] + total_num += pred.shape[0] + optimizer_.step() + if isinstance(optimizer_, torch.optim.Adam): + optimizer_name = "Adam" + elif isinstance(optimizer_, AdaHessian): + optimizer_name = "AdaHessian" + + epoch_time += (time.time() - start_time) + + total_time += (time.time() - start_time) + + try: + loss_.append(loss.item()) + except: + loss_.append(loss) + + if early_stopping_loss is not None and already_stopped == False: + if loss < early_stopping_loss: + low_loss_count += train_batch.shape[0] + if low_loss_count >= early_stopping_count: + torch.save(self.model.state_dict(), + f"{path}/Early_Stoppage_at_{total_time}_{self.model_name}_model_optimizer_{optimizer_name}_epoch_{epoch}_train_loss_{train_loss/total_num}.pth") + + write_csv(write_CSV, + path, + self.model_name, + optimizer_name, + epoch, + total_time, + train_loss/total_num, + batch_size, + loss_func, + seed, + True, + model_updates) + + already_stopped = True + else: + low_loss_count -= (train_batch.shape[0]*5) + + if "verbose" in kwargs: + if kwargs["verbose"] == True: + print(f"Loss = {loss.item()}") + + train_loss /= total_num + + print(optimizer_name) + print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + + 1, epochs, train_loss)) + print("--- %s seconds ---" % (epoch_time)) + + if save_all: + torch.save(self.model.state_dict(), + f"{path}/{self.model_name}_model_optimizer_{optimizer_name}_epoch_{epoch}_train_loss_{train_loss}.pth") + + if early_stopping_time is not None: + if total_time > early_stopping_time: + torch.save(self.model.state_dict(), + f"{path}/Early_Stoppage_at_{total_time}_{self.model_name}_model_optimizer_{optimizer_name}_epoch_{epoch}_train_loss_{train_loss}.pth") + + write_csv(write_CSV, + path, + self.model_name, + optimizer_name, + epoch, + total_time, + train_loss, # already divided by total_num + batch_size, + loss_func, + seed, + True, + model_updates) + break + + torch.save(self.model.state_dict(), + f"{path}/{self.model_name}_model_optimizer_{optimizer_name}_epoch_{epoch}_train_loss_{train_loss}.pth") + write_csv(write_CSV, + path, + self.model_name, + optimizer_name, + epoch, + total_time, + train_loss, # already divided by total_num + batch_size, + loss_func, + seed, + False, + model_updates) + + if save_training_loss: + save_list_to_txt( + loss_, f"{path}/Training_loss_{self.model_name}_model_optimizer_{optimizer_name}_epoch_{epoch}_train_loss_{train_loss}.txt") + + self.model.eval() + + def load(self, model_path): + self.model.load_state_dict(torch.load(model_path)) + self.model.to(self.device) + + def inference_timer(self, data, batch_size=.5e4): + torch.cuda.empty_cache() + + batch_size = int(batch_size) + + dataloader = DataLoader(data, batch_size) + + # Computes the inference time + computeTime(self.model, dataloader, batch_size, device=self.device) + + def predict(self, data, batch_size=10000, + single=False, + translate_params=True): + + self.model.eval() + + dataloader = DataLoader(data, batch_size=batch_size) + + # preallocate the predictions + num_elements = len(dataloader.dataset) + num_batches = len(dataloader) + data = data.clone().detach().requires_grad_(True) + predictions = torch.zeros_like(data.clone().detach()) + params_scaled = torch.zeros((data.shape[0], 4)) + params = torch.zeros((data.shape[0], 4)) + + # compute the predictions + for i, train_batch in enumerate(dataloader): + start = i * batch_size + end = start + batch_size + + if i == num_batches - 1: + end = num_elements + + pred_batch, params_scaled_, params_ = self.model( + train_batch.to(self.device)) + + predictions[start:end] = pred_batch.cpu().detach() + params_scaled[start:end] = params_scaled_.cpu().detach() + params[start:end] = params_.cpu().detach() + + torch.cuda.empty_cache() + + # converts negative ampltiudes to positive and shifts the phase to compensate + if translate_params: + params[params[:, 0] < 0, 3] = params[params[:, 0] < 0, 3] - np.pi + params[params[:, 0] < 0, 0] = np.abs(params[params[:, 0] < 0, 0]) + + if self.model.dataset.NN_phase_shift is not None: + params_scaled[:, 3] = torch.Tensor(self.model.dataset.shift_phase( + params_scaled[:, 3].detach().numpy(), self.model.dataset.NN_phase_shift)) + params[:, 3] = torch.Tensor(self.model.dataset.shift_phase( + params[:, 3].detach().numpy(), self.model.dataset.NN_phase_shift)) + + return predictions, params_scaled, params + + @staticmethod + def mse_rankings(true, prediction, curves=False): + + def type_conversion(data): + + data = np.array(data) + data = np.rollaxis(data, 0, data.ndim-1) + + return data + + true = type_conversion(true) + prediction = type_conversion(prediction) + + errors = Model.MSE(prediction, true) + + index = np.argsort(errors) + + if curves: + # true will be in the form [ranked error, channel, timestep] + return index, errors[index], true[index], prediction[index] + + return index, errors[index] + + @staticmethod + def MSE(true, prediction): + + # calculates the mse + mse = np.mean((true.reshape( + true.shape[0], -1) - prediction.reshape(true.shape[0], -1))**2, axis=1) + + # converts to a scalar if there is only one value + if mse.shape[0] == 1: + return mse.item() + + return mse + + @staticmethod + def get_rankings(raw_data, pred, n=1, curves=True): + """simple function to get the best, median and worst reconstructions + + Args: + raw_data (np.array): array of the true values + pred (np.array): array of the predictions + n (int, optional): number of values for each. Defaults to 1. + curves (bool, optional): whether to return the curves or not. Defaults to True. + + Returns: + ind: indices of the best, median and worst reconstructions + mse: mse of the best, median and worst reconstructions + """ + index, mse, d1, d2 = Model.mse_rankings( + raw_data, pred, curves=curves) + middle_index = len(index) // 2 + start_index = middle_index - n // 2 + end_index = start_index + n + + ind = np.hstack( + (index[:n], index[start_index:end_index], index[-n:])).flatten().astype(int) + mse = np.hstack( + (mse[:n], mse[start_index:end_index], mse[-n:])) + + d1 = np.stack( + (d1[:n], d1[start_index:end_index], d1[-n:])).squeeze() + d2 = np.stack( + (d2[:n], d2[start_index:end_index], d2[-n:])).squeeze() + + # return ind, mse, np.swapaxes(d1[ind], 1, d1.ndim-1), np.swapaxes(d2[ind], 1, d2.ndim-1) + return ind, mse, d1, d2 + + def print_mse(self, data, labels): + """prints the MSE of the model + + Args: + data (tuple): tuple of datasets to calculate the MSE + labels (list): List of strings with the names of the datasets + """ + + # loops around the dataset and labels and prints the MSE for each + for data, label in zip(data, labels): + + if isinstance(data, torch.Tensor): + # computes the predictions + pred_data, scaled_param, parm = self.predict(data) + elif isinstance(data, dict): + pred_data, _ = self.model.dataset.get_raw_data_from_LSQF_SHO( + data) + data, _ = self.model.dataset.NN_data() + pred_data = torch.from_numpy(pred_data) + + # Computes the MSE + out = nn.MSELoss()(data, pred_data) + + # prints the MSE + print(f"{label} Mean Squared Error: {out:0.4f}") + + +@static_state_decorator +def batch_training(dataset, optimizers, noise, batch_size, epochs, seed, write_CSV="Batch_Training_Noisy_Data.csv", + basepath=None, early_stopping_loss=None, early_stopping_count=None, early_stopping_time=None, skip=-1, **kwargs, + ): + + # Generate all combinations + combinations = list(itertools.product( + optimizers, noise, batch_size, epochs, seed)) + + for i, training in enumerate(combinations): + if i < skip: + print( + f"Skipping combination {i}: {training[0]} {training[1]} {training[2]} {training[3]} {training[4]}") + continue + + optimizer = training[0] + noise = training[1] + batch_size = training[2] + epochs = training[3] + seed = training[4] + + print(f"The type is {type(training[0])}") + + if isinstance(optimizer, dict): + optimizer_name = optimizer['name'] + else: + optimizer_name = optimizer + + dataset.noise = noise + + # random_seed(seed=seed) + + # constructs a test train split + X_train, X_test, y_train, y_test = dataset.test_train_split_( + shuffle=True) + + model_name = f"SHO_{optimizer_name}_noise_{training[1]}_batch_size_{training[2]}_seed_{training[4]}" + + print(f'Working on combination: {model_name}') + + # instantiate the model + model = Model(dataset, training=True, model_basename=model_name) + + # fits the model + model.fit( + X_train, + batch_size=batch_size, + optimizer=optimizer, + + epochs=epochs, + write_CSV=write_CSV, + seed=seed, + basepath=basepath, + early_stopping_loss=early_stopping_loss, + early_stopping_count=early_stopping_count, + early_stopping_time=early_stopping_time, + **kwargs, + ) + + del model + +def find_best_model(basepath, filename): + + # Read the CSV + df = pd.read_csv(basepath + '/' + filename) + + # Extract noise level from the 'Model Name' column + df['Noise Level'] = df['Model Name'].apply(lambda x: float(x.split('_')[3])) + + # Create an empty dictionary to store the results + results = {} + + # Loop over each unique combination of noise level and optimizer + for noise_level in df['Noise Level'].unique(): + for optimizer in df['Optimizer'].unique(): + # Create a mask for the current combination + mask = (df['Noise Level'] == noise_level) & (df['Optimizer'] == optimizer) + + # If there's any row with this combination + if df[mask].shape[0] > 0: + # Find the index of the minimum 'Train Loss' + min_loss_index = df.loc[mask, 'Train Loss'].idxmin() + + # Store the result + results[(noise_level, optimizer)] = df.loc[min_loss_index].to_dict() + + return results \ No newline at end of file diff --git a/src/rheed_learn/Fitting.py b/src/rheed_learn/Fitting.py new file mode 100644 index 0000000..15f7040 --- /dev/null +++ b/src/rheed_learn/Fitting.py @@ -0,0 +1,158 @@ +import numpy as np +from scipy.optimize import curve_fit + +def normalize_0_1(y, I_start, I_end, I_diff): + """Normalize y-values to the range [0, 1].""" + if I_diff is None: + I_diff = I_end - I_start + y_nor = (y - I_start) / I_diff + return y_nor + +def de_normalize_0_1(y_nor, I_start, I_end, I_diff): + """De-normalize y-values from the range [0, 1] back to the original scale.""" + if I_diff is None: + I_diff = I_end - I_start + y = y_nor * I_diff + I_start + return y + +def exp_func_inc_simp(x, b1, relax1): + """Simplified exponential growth function.""" + return b1 * (1 - np.exp(-x / relax1)) + +def exp_func_dec_simp(x, b2, relax2): + """Simplified exponential decay function.""" + return b2 * np.exp(-x / relax2) + +def exp_func_inc(x, a1, b1, relax1): + """Full exponential growth function.""" + return (a1 * x + b1) * (1 - np.exp(-x / relax1)) + +def exp_func_dec(x, a2, b2, relax2): + """Full exponential decay function.""" + return (a2 * x + b2) * np.exp(-x / relax2) + +def normalize_and_extract_amplitude(xs, ys, I_diff, unify): + ys_nor = [] + I_starts = [] + I_ends = [] + + for i in range(len(xs)): + n_avg = len(ys[i]) // 100 + 3 + I_end = np.mean(ys[i][-n_avg:]) + I_start = np.mean(ys[i][:n_avg]) + I_starts.append(I_start) + I_ends.append(I_end) + y_nor = normalize_0_1(ys[i], I_start, I_end, I_diff) + ys_nor.append(y_nor) + + return ys_nor, I_starts, I_ends + +def fit_curves(xs, ys_nor, fit_settings, growth_name): + parameters = [] + ys_nor_fit, ys_nor_fit_failed, labels, losses = [], [], [], [] + + bounds = fit_settings['bounds'] + p_init = fit_settings['p_init'] + unify = fit_settings['unify'] + + for i in range(len(xs)): + x = np.linspace(1e-5, 1, len(ys_nor[i])) # Use second as x axis unit + y_nor = ys_nor[i] + + if unify: + params, _ = curve_fit(exp_func_inc, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) + a, b, relax = params + y_nor_fit = exp_func_inc(x, a, b, relax) + labels.append(f'{growth_name}-index {i + 1}:\ny=({np.round(a, 2)}t+{np.round(b, 2)})*(1-exp(-t/{np.round(relax, 2)}))') + parameters.append((a, b, relax)) + losses.append((0, 0)) + ys_nor_fit_failed.append(y_nor_fit) + else: + params1, _ = curve_fit(exp_func_inc_simp, x, y_nor, p0=p_init[1:], bounds=bounds, absolute_sigma=False) + b1, relax1 = params1 + y1_nor_fit = exp_func_inc_simp(x, b1, relax1) + + params2, _ = curve_fit(exp_func_dec_simp, x, y_nor, p0=p_init[1:], bounds=bounds, absolute_sigma=False) + b2, relax2 = params2 + y2_nor_fit = exp_func_dec_simp(x, b2, relax2) + + loss1 = ((y_nor - y1_nor_fit) ** 2).mean() + loss2 = ((y_nor - y2_nor_fit) ** 2).mean() + + params1_full, _ = curve_fit(exp_func_inc, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) + a1, b1_full, relax1_full = params1_full + y1_nor_fit_full = exp_func_inc(x, a1, b1_full, relax1_full) + + params2_full, _ = curve_fit(exp_func_dec, x, y_nor, p0=p_init, bounds=bounds, absolute_sigma=False) + a2, b2_full, relax2_full = params2_full + y2_nor_fit_full = exp_func_dec(x, a2, b2_full, relax2_full) + + if loss1 < loss2: + y_nor_fit = y1_nor_fit_full + labels.append(f'{growth_name}-index {i + 1}:\ny1=({np.round(a1, 2)}t+{np.round(b1_full, 2)})*(1-exp(-t/{np.round(relax1_full, 2)}))') + parameters.append((a1, b1_full, relax1_full)) + ys_nor_fit_failed.append(y2_nor_fit_full) + else: + y_nor_fit = y2_nor_fit_full + labels.append(f'{growth_name}-index {i + 1}:\ny2=({np.round(a2, 2)}t+{np.round(b2_full, 2)})*(exp(-t/{np.round(relax2_full, 2)}))') + parameters.append((a2, b2_full, relax2_full)) + ys_nor_fit_failed.append(y1_nor_fit_full) + + losses.append((loss1, loss2)) + + ys_nor_fit.append(y_nor_fit) + + return parameters, ys_nor_fit, ys_nor_fit_failed, labels, losses + +def de_normalize_and_assemble(xs, ys, ys_nor_fit, I_starts, I_ends, I_diff, unify): + ys_fit = [] + + for i in range(len(xs)): + y_fit = de_normalize_0_1(ys_nor_fit[i], I_starts[i], I_ends[i], I_diff) + ys_fit.append(y_fit) + + return ys_fit + +def fit_exp_function(xs, ys, growth_name, fit_settings={'I_diff': None, 'unify': True, 'bounds': [0.01, 1], 'p_init': (1, 0.1, 0.4), 'n_std': 1}): + """ + Fits an exponential function to the given data. + + Args: + xs (list of list of floats): List of x-values for each partial curve. Each element is a list of x-values for one curve. + ys (list of list of floats): List of y-values for each partial curve. Each element is a list of y-values for one curve. + growth_name (str): Name of the growth process, used for labeling the fitted curves. + fit_settings (dict, optional): Dictionary of settings for the fitting process. Defaults to: + { + 'I_diff': None, # Optional intensity difference for normalization. + 'unify': True, # Whether to use a unified fitting function for all curves. + 'bounds': [0.01, 1], # Bounds for the fitting parameters. + 'p_init': (1, 0.1, 0.4), # Initial guess for the fitting parameters. + 'n_std': 1 # Number of standard deviations for normalization. + } + + Returns: + tuple: + - numpy.ndarray: Array of fitted parameters for each curve. + - list: A list containing: + - xs: Original x-values. + - ys: Original y-values. + - ys_fit: Fitted y-values. + - ys_nor: Normalized y-values. + - ys_nor_fit: Fitted normalized y-values. + - ys_nor_fit_failed: Failed fitted normalized y-values (when applicable). + - labels: Labels for each fitted curve. + - losses: Losses for each fitted curve. + """ + I_diff = fit_settings['I_diff'] + unify = fit_settings['unify'] + + # Step 1: Normalize and extract amplitude + ys_nor, I_starts, I_ends = normalize_and_extract_amplitude(xs, ys, I_diff, unify) + + # Step 2: Fit curves + parameters, ys_nor_fit, ys_nor_fit_failed, labels, losses = fit_curves(xs, ys_nor, fit_settings, growth_name) + + # Step 3: De-normalize and assemble results + ys_fit = de_normalize_and_assemble(xs, ys, ys_nor_fit, I_starts, I_ends, I_diff, unify) + + return np.array(parameters), [xs, ys, ys_fit, ys_nor, ys_nor_fit, ys_nor_fit_failed, labels, losses] diff --git a/src/rheed_learn/Packed_functions.py b/src/rheed_learn/Packed_functions.py new file mode 100644 index 0000000..4a795f0 --- /dev/null +++ b/src/rheed_learn/Packed_functions.py @@ -0,0 +1,224 @@ +import sys +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import colors +import seaborn as sns + +sys.path.append('../../src') +from m3_learning.viz.layout import layout_fig, labelfigs +from m3_learning.RHEED.Viz import Viz +from m3_learning.RHEED.Analysis import analyze_curves, remove_outlier, smooth + +def decay_curve_examples(df_para, spot, metric, fit_settings, savefig=False, savepath=None): + """ + Plot decay curve examples. + + Args: + df_para (DataFrame): Dataframe containing parameters. + spot (str): Spot identifier. + metric (str): Metric to analyze. + fit_settings (dict): Settings for curve fitting. + + Returns: + None + """ + color_blue = (44/255,123/255,182/255) + seq_colors = ['#00429d','#2e59a8','#4771b2','#5d8abd','#73a2c6','#8abccf','#a5d5d8','#c5eddf','#ffffe0'] + bgc1, bgc2 = (*colors.hex2color(seq_colors[0]), 0.3), (*colors.hex2color(seq_colors[5]), 0.3) + + parameters_all, x_coor_all, info = analyze_curves(df_para, {'growth_1': 1}, spot, metric, interval=0, fit_settings=fit_settings) + [xs_all, ys_all, ys_fit_all, ys_nor_all, ys_nor_fit_all, ys_nor_fit_failed_all, labels_all, losses_all] = info + sample_list = [6, 21] + loc_list = ['ct', 'cb'] + fig, axes = layout_fig(2, 2, figsize=(5, 3)) + for i, ax in enumerate(axes): + Viz.draw_background_colors(ax, ([bgc1, bgc2][i])) + ax.scatter(xs_all[sample_list[i]], ys_nor_all[sample_list[i]], color=np.array(color_blue).reshape(1,-1), s=2) + ax.scatter(xs_all[sample_list[i]], ys_nor_fit_all[sample_list[i]], color='k', s=2) + ax.set_box_aspect(1) + Viz.set_labels(ax, xlabel='Time (s)', ylabel='Intensity (a.u.)', yaxis_style='linear') + labelfigs(ax, None, string_add=labels_all[sample_list[i]], loc=loc_list[i], style='b', size=6) + if savefig: + plt.savefig(f'{savepath}-decay_curve_examples.png', dpi=300) + plt.savefig(f'{savepath}-decay_curve_examples.svg', dpi=300) + + + +def compare_loss_difference(): + """ + Compare the difference in losses between two samples. + + Returns: + None + """ + + # load data if needed + x_all_sample1, y_all_sample1 = np.load('Saved_data/treated_213nm-x_all.npy'), np.load('Saved_data/treated_213nm-y_all.npy') + color_array_sample1 = np.load('Saved_data/treated_213nm-bg_growth.npy') + losses_all_sample1 = np.load('Saved_data/treated_213nm-losses_all.npy') + loss_diff_sample1 = np.abs(losses_all_sample1[:,0] - losses_all_sample1[:,1]) + + x_all_sample2, y_all_sample2 = np.load('Saved_data/treated_81nm-x_all.npy'), np.load('Saved_data/treated_81nm-y_all.npy') + color_array_sample2 = np.load('Saved_data/treated_81nm-bg_growth.npy') + losses_all_sample2 = np.load('Saved_data/treated_81nm-losses_all.npy') + loss_diff_sample2 = np.abs(losses_all_sample2[:,0] - losses_all_sample2[:,1]) + + x_all_sample3, y_all_sample3 = np.load('Saved_data/untreated_162nm-x_all.npy'), np.load('Saved_data/untreated_162nm-y_all.npy') + color_array_sample3 = np.load('Saved_data/untreated_162nm-bg_growth.npy') + losses_all_sample3 = np.load('Saved_data/untreated_162nm-losses_all.npy') + loss_diff_sample3 = np.abs(losses_all_sample3[:,0] - losses_all_sample3[:,1]) + + + seq_colors = ['#00429d','#2e59a8','#4771b2','#5d8abd','#73a2c6','#8abccf','#a5d5d8','#c5eddf','#ffffe0'] + fig, axes = layout_fig(3, 1, figsize=(5, 1.5*3)) + Viz.plot_loss_difference(axes[0], x_all_sample1, y_all_sample1, color_array_sample1[:,0], loss_diff_sample1, + color_array_sample1, color_2=seq_colors[0], title='treated_213nm') + Viz.plot_loss_difference(axes[1], x_all_sample2, y_all_sample2, color_array_sample2[:,0], loss_diff_sample2, + color_array_sample2, color_2=seq_colors[0], title='treated_81nm') + Viz.plot_loss_difference(axes[2], x_all_sample3, y_all_sample3, color_array_sample3[:,0], loss_diff_sample3, + color_array_sample3, color_2=seq_colors[0], title='untreated_162nm') + plt.show() + + +def compare_growth_mechanism(): + """ + Compare the growth mechanism of different samples. + + Returns: + None + """ + + # load data if needed + x_all_sample1, y_all_sample1 = np.load('Saved_data/treated_213nm-x_all.npy'), np.load('Saved_data/treated_213nm-y_all.npy') + color_array_sample1 = np.load('Saved_data/treated_213nm-bg_growth.npy') + boxes_sample1 = np.load('Saved_data/treated_213nm-boxes.npy') + + x_all_sample2, y_all_sample2 = np.load('Saved_data/treated_81nm-x_all.npy'), np.load('Saved_data/treated_81nm-y_all.npy') + color_array_sample2 = np.load('Saved_data/treated_81nm-bg_growth.npy') + boxes_sample2 = np.load('Saved_data/treated_81nm-boxes.npy') + + x_all_sample3, y_all_sample3 = np.load('Saved_data/untreated_162nm-x_all.npy'), np.load('Saved_data/untreated_162nm-y_all.npy') + color_array_sample3 = np.load('Saved_data/untreated_162nm-bg_growth.npy') + boxes_sample3 = np.load('Saved_data/untreated_162nm-boxes.npy') + + color_gray = (128/255, 128/255, 128/255, 0.5) + + fig, axes = layout_fig(3, 1, figsize=(6, 2*3)) + Viz.draw_background_colors(axes[0], color_array_sample1) + Viz.draw_boxes(axes[0], boxes_sample1, color_gray) + axes[0].scatter(x_all_sample1, y_all_sample1, color='k', s=1) + Viz.set_labels(axes[0], xlabel='Time (s)', ylabel='Intensity (a.u.)', title='treated_213nm') + + Viz.draw_background_colors(axes[1], color_array_sample2) + Viz.draw_boxes(axes[1], boxes_sample2, color_gray) + axes[1].scatter(x_all_sample2, y_all_sample2, color='k', s=1) + Viz.set_labels(axes[1], xlabel='Time (s)', ylabel='Intensity (a.u.)', title='treated_81nm') + + Viz.draw_background_colors(axes[2], color_array_sample3) + Viz.draw_boxes(axes[2], boxes_sample3, color_gray) + axes[2].scatter(x_all_sample3, y_all_sample3, color='k', s=1) + Viz.set_labels(axes[2], xlabel='Time (s)', ylabel='Intensity (a.u.)', title='untreated_162nm') + plt.show() + + + +def visualize_characteristic_time(): + """ + Visualize the characteristic time for different samples. + + Returns: + None + """ + seq_colors = ['#00429d','#2e59a8','#4771b2','#5d8abd','#73a2c6','#8abccf','#a5d5d8','#c5eddf','#ffffe0'] + fig, axes = layout_fig(3, 1, figsize=(6, 6)) + ax1, ax3, ax5 = axes[0], axes[1], axes[2] + + x_all_sample1, y_all_sample1 = np.load('Saved_data/treated_213nm-x_all.npy'), np.load('Saved_data/treated_213nm-y_all.npy') + x_sklearn_sample1, tau_sklearn_sample1 = np.swapaxes(np.load('Saved_data/treated_213nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample1, tau_clean_sample1 = remove_outlier(x_sklearn_sample1, tau_sklearn_sample1, 0.95) + tau_smooth_sample1 = smooth(tau_clean_sample1, 3) + + bg_growth_sample1 = np.load('Saved_data/treated_213nm-bg_growth.npy') + Viz.draw_background_colors(ax1, bg_growth_sample1) + ax1.scatter(x_all_sample1, y_all_sample1, color='k', s=1) + Viz.set_labels(ax1, xlabel='Time (s)', ylabel='Intensity (a.u.)', xlim=(-2, 130), title='Treated_213nm', ticks_both_sides=False) + + ax2 = ax1.twinx() + ax2.scatter(x_sklearn_sample1, tau_clean_sample1, color=seq_colors[0], s=3) + ax2.plot(x_sklearn_sample1, tau_smooth_sample1, color='#bc5090', markersize=3) + Viz.set_labels(ax2, ylabel='Characteristic Time (s)', yaxis_style='lineplot', ylim=(-0.05, 0.5), ticks_both_sides=False) + ax2.tick_params(axis="y", color='k', labelcolor=seq_colors[0]) + ax2.set_ylabel('Characteristic Time (s)', color=seq_colors[0]) + ax2.legend(['original', 'processed'], fontsize=8, loc="upper right", frameon=True) + + x_all_sample2, y_all_sample2 = np.load('Saved_data/treated_81nm-x_all.npy'), np.load('Saved_data/treated_81nm-y_all.npy') + x_sklearn_sample2, tau_sklearn_sample2 = np.swapaxes(np.load('Saved_data/treated_81nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample2, tau_clean_sample2 = remove_outlier(x_sklearn_sample2, tau_sklearn_sample2, 0.95) + tau_smooth_sample2 = smooth(tau_clean_sample2, 3) + + bg_growth_sample2 = np.load('Saved_data/treated_81nm-bg_growth.npy') + + Viz.draw_background_colors(ax3, bg_growth_sample2) + ax3.scatter(x_all_sample2, y_all_sample2, color='k', s=1) + Viz.set_labels(ax3, xlabel='Time (s)', ylabel='Intensity (a.u.)', xlim=(-2, 115), title='Treated_81nm', ticks_both_sides=False) + + ax4 = ax3.twinx() + ax4.scatter(x_sklearn_sample2, tau_clean_sample2, color=seq_colors[0], s=3) + ax4.plot(x_sklearn_sample2, tau_smooth_sample2, color='#bc5090', markersize=3) + Viz.set_labels(ax4, ylabel='Characteristic Time (s)', yaxis_style='lineplot', ylim=(-0.05, 0.5), ticks_both_sides=False) + ax4.tick_params(axis="y", color='k', labelcolor=seq_colors[0]) + ax4.set_ylabel('Characteristic Time (s)', color=seq_colors[0]) + ax4.legend(['original', 'processed'], fontsize=8, loc="upper right", frameon=True) + + x_all_sample3, y_all_sample3 = np.load('Saved_data/untreated_162nm-x_all.npy'), np.load('Saved_data/untreated_162nm-y_all.npy') + x_sklearn_sample3, tau_sklearn_sample3 = np.swapaxes(np.load('Saved_data/untreated_162nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample3, tau_clean_sample3 = remove_outlier(x_sklearn_sample3, tau_sklearn_sample3, 0.95) + tau_smooth_sample3 = smooth(tau_clean_sample3, 3) + + bg_growth_sample3 = np.load('Saved_data/untreated_162nm-bg_growth.npy') + Viz.draw_background_colors(ax5, bg_growth_sample3) + ax5.scatter(x_all_sample3, y_all_sample3, color='k', s=1) + Viz.set_labels(ax5, xlabel='Time (s)', ylabel='Intensity (a.u.)', xlim=(-2, 125), title='Untreated_163nm', ticks_both_sides=False) + + ax6 = ax5.twinx() + ax6.scatter(x_sklearn_sample3, tau_clean_sample3, color=seq_colors[0], s=3) + ax6.plot(x_sklearn_sample3, tau_smooth_sample3, color='#bc5090', markersize=3) + Viz.set_labels(ax6, ylabel='Characteristic Time (s)', yaxis_style='lineplot', ylim=(-0.05, 0.5), ticks_both_sides=False) + ax6.tick_params(axis="y", color='k', labelcolor=seq_colors[0]) + ax6.set_ylabel('Characteristic Time (s)', color=seq_colors[0]) + ax6.legend(['original', 'processed'], fontsize=8, loc="upper right", frameon=True) + plt.show() + +def violinplot_characteristic_time(): + """ + Generate a violin plot of the characteristic time for different samples. + + Returns: + None + """ + color_blue = (44/255,123/255,182/255) + color_orange = (217/255,95/255,2/255) + color_purple = (117/255,112/255,179/255) + + x_all_sample1, y_all_sample1 = np.load('Saved_data/treated_213nm-x_all.npy'), np.load('Saved_data/treated_213nm-y_all.npy') + x_sklearn_sample1, tau_sklearn_sample1 = np.swapaxes(np.load('Saved_data/treated_213nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample1, tau_clean_sample1 = remove_outlier(x_sklearn_sample1, tau_sklearn_sample1, 0.95) + + x_all_sample2, y_all_sample2 = np.load('Saved_data/treated_81nm-x_all.npy'), np.load('Saved_data/treated_81nm-y_all.npy') + x_sklearn_sample2, tau_sklearn_sample2 = np.swapaxes(np.load('Saved_data/treated_81nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample2, tau_clean_sample2 = remove_outlier(x_sklearn_sample2, tau_sklearn_sample2, 0.95) + + x_all_sample3, y_all_sample3 = np.load('Saved_data/untreated_162nm-x_all.npy'), np.load('Saved_data/untreated_162nm-y_all.npy') + x_sklearn_sample3, tau_sklearn_sample3 = np.swapaxes(np.load('Saved_data/untreated_162nm-fitting_results(sklearn).npy'), 0, 1)[[0, -1]] + x_sklearn_sample3, tau_clean_sample3 = remove_outlier(x_sklearn_sample3, tau_sklearn_sample3, 0.95) + + fig, ax = plt.subplots(figsize=(6, 2), layout='compressed') + titles = ['Treated substrate\n(step width=213±88nm)', + 'Treated substrate\n(step width=81±44nm)', + 'Untreated substrate\n(step width=162±83μm)'] + ax = sns.violinplot(data=[tau_clean_sample1, tau_clean_sample2, tau_clean_sample3], + palette=[color_blue, color_orange, color_purple], linewidth=0.8) + ax.set_xticklabels(titles) + Viz.set_labels(ax, ylabel='Characteristic Time (s)', ticks_both_sides=False, yaxis_style='linear') + Viz.label_violinplot(ax, [tau_clean_sample1, tau_clean_sample2, tau_clean_sample3], label_type='average', text_pos='right') + plt.show() \ No newline at end of file diff --git a/src/rheed_learn/Viz.py b/src/rheed_learn/Viz.py new file mode 100644 index 0000000..b05410d --- /dev/null +++ b/src/rheed_learn/Viz.py @@ -0,0 +1,583 @@ +import matplotlib.pyplot as plt +from matplotlib.ticker import FormatStrFormatter +import numpy as np +import pylab as pl +import seaborn as sns +from scipy.signal import savgol_filter +from m3_learning.viz.layout import layout_fig, labelfigs + +def trim_axes(axs, N): + """ + Reduce *axs* to *N* Axes. All further Axes are removed from the figure. + """ + axs = axs.flat + for ax in axs[N:]: + ax.remove() + return axs[:N] + +def show_images(images, labels=None, img_per_row=8, img_height=1, show_colorbar=False, + clim=3, scale_0_1=False, hist_bins=None, show_axis=False): + + ''' + Plots multiple images in grid. + + images + labels: labels for every images; + img_per_row: number of images to show per row; + img_height: height of image in axes; + show_colorbar: show colorbar; + clim: int or list of int, value of standard deviation of colorbar range; + scale_0_1: scale image to 0~1; + hist_bins: number of bins for histogram; + show_axis: show axis + ''' + + assert type(images) == list or type(images) == np.ndarray, "do not use torch.tensor for hist" + if type(clim) == list: + assert len(images) == len(clim), "length of clims is not matched with number of images" + + def scale(x): + if x.min() < 0: + return (x - x.min()) / (x.max() - x.min()) + else: + return x/(x.max() - x.min()) + + h = images[0].shape[1] // images[0].shape[0]*img_height + 1 + if not labels: + labels = range(len(images)) + + n = 1 + if hist_bins: n +=1 + + fig, axes = plt.subplots(n*len(images)//img_per_row+1*int(len(images)%img_per_row>0), img_per_row, + figsize=(16, n*h*len(images)//img_per_row+1)) + trim_axes(axes, len(images)) + + for i, img in enumerate(images): + +# if torch.is_tensor(x_tensor): +# if img.requires_grad: img = img.detach() +# img = img.numpy() + + if scale_0_1: img = scale(img) + + if len(images) <= img_per_row and not hist_bins: + index = i%img_per_row + else: + index = (i//img_per_row)*n, i%img_per_row + + axes[index].title.set_text(labels[i]) + im = axes[index].imshow(img) + if show_colorbar: + m, s = np.mean(img), np.std(img) + if type(clim) == list: + im.set_clim(m-clim[i]*s, m+clim[i]*s) + else: + im.set_clim(m-clim*s, m+clim*s) + + fig.colorbar(im, ax=axes[index]) + + if show_axis: + axes[index].tick_params(axis="x",direction="in", top=True) + axes[index].tick_params(axis="y",direction="in", right=True) + else: + axes[index].axis('off') + + if hist_bins: + index_hist = (i//img_per_row)*n+1, i%img_per_row + h = axes[index_hist].hist(img.flatten(), bins=hist_bins) + + plt.show() + + +class Viz: + def __init__(self, printing = None): + """ + Initialize a Viz object. + + Args: + printing (optional): An object used for saving figures. Defaults to None. + """ + self.Printer = printing + + @staticmethod + def make_fine_step(x, transparency, step, color, saturation=1, savgol_filter_level=(15,1)): + """ + Create a fine step for given data. + + Args: + x (ndarray): Input values. + transparency (ndarray): Transparency values. + step (int): Number of steps. + color: The color to use. + saturation (float, optional): Saturation level. Defaults to 1. + savgol_filter_level (tuple, optional): Savitzky-Golay filter level. Defaults to (15, 1). + + Returns: + tuple: Tuple containing the fine step of x and the colors. + """ + x_FineStep = np.hstack([np.linspace(start, stop, num=step+1, endpoint=True)[:-1] for start, stop in zip(x, x[1:])]) + + transparency_FineStep = np.hstack([np.linspace(start, stop, num=step+1, endpoint=True)[:-1] for start, stop in zip(transparency, transparency[1:])]) + if not isinstance(savgol_filter_level, type(None)): + transparency_FineStep_before = np.copy(transparency_FineStep) + transparency_FineStep = savgol_filter(transparency_FineStep, savgol_filter_level[0]*step+1, savgol_filter_level[1]) + + transparency_FineStep_norm = np.expand_dims((transparency_FineStep / max(transparency_FineStep)) * saturation, 1) + transparency_FineStep_norm[transparency_FineStep_norm<0] = 0 + + colors = np.repeat([[*color]], len(transparency_FineStep_norm), 0) + colors_all = np.concatenate([colors, transparency_FineStep_norm], 1) + return x_FineStep, colors_all + + + @staticmethod + def two_color_array(x_all, x1, x2, c1, c2, transparency=1): + """ + Create a two-color array based on given conditions. + + Args: + x_all (ndarray): All values. + x1 (ndarray): Values to be colored with c1. + x2 (ndarray): Values to be colored with c2. + c1: Color 1. + c2: Color 2. + transparency (float, optional): Transparency of the colors. Defaults to 1. + + Returns: + ndarray: Color array. + """ + + color_array = np.zeros([len(x_all), 4], dtype=np.float32) + if len(c1) < 4: + color_array[np.isin(x_all, x1)] = [*c1, transparency] + else: + color_array[np.isin(x_all, x1)] = c1 + + if len(c2) < 4: + color_array[np.isin(x_all, x2)] = [*c2, transparency] + else: + color_array[np.isin(x_all, x2)] = c2 + return color_array + + + @staticmethod + def draw_background_colors(ax, bg_colors): + """ + Draw background colors on the given axes. + + Args: + ax: Axes object. + bg_colors: Background colors. + + Returns: + None + """ + if isinstance(bg_colors, tuple): + ax.set_facecolor(bg_colors) + elif bg_colors is not None: + x_coor = bg_colors[:, 0] + colors = bg_colors[:, 1:] + for i in range(len(x_coor)): + if i == 0: + end = (x_coor[i] + x_coor[i+1]) / 2 + start = end - (x_coor[i+1] - x_coor[i]) + elif i == len(x_coor) - 1: + start = (x_coor[i-1] + x_coor[i]) / 2 + end = start + (x_coor[i] - x_coor[i-1]) + else: + start = (x_coor[i-1] + x_coor[i]) / 2 + end = (x_coor[i] + x_coor[i+1]) / 2 + ax.axvspan(start, end, facecolor=colors[i]) + + @staticmethod + def draw_boxes(ax, boxes, box_color): + """ + Draw boxes on the given axes. + + Args: + ax: Axes object. + boxes: List of box coordinates. + box_color: Color of the boxes. + + Returns: + None + """ + for (box_start, box_end) in boxes: + ax.axvspan(box_start, box_end, facecolor=box_color, edgecolor=box_color) + + @staticmethod + def find_nearest(array, value): + """ + Find the nearest value in the array to the given value. + + Args: + array: Input array. + value: Target value. + + Returns: + float: Nearest value in the array. + """ + idx = np.abs(array - value).argmin() + return array[idx] + + @staticmethod + def set_labels(ax, xlabel=None, ylabel=None, title=None, xlim=None, ylim=None, yaxis_style='sci', + logscale=False, legend=None, ticks_both_sides=True): + """ + Set labels and other properties of the given axes. + + Args: + ax: Axes object. + xlabel (str, optional): X-axis label. Defaults to None. + ylabel (str, optional): Y-axis label. Defaults to None. + title (str, optional): Plot title. Defaults to None. + xlim (tuple, optional): X-axis limits. Defaults to None. + ylim (tuple, optional): Y-axis limits. Defaults to None. + yaxis_style (str, optional): Y-axis style. Defaults to 'sci'. + logscale (bool, optional): Use log scale on the y-axis. Defaults to False. + legend (list, optional): Legend labels. Defaults to None. + ticks_both_sides (bool, optional): Display ticks on both sides of the axes. Defaults to True. + + Returns: + None + """ + if type(xlabel) != type(None): ax.set_xlabel(xlabel) + if type(ylabel) != type(None): ax.set_ylabel(ylabel) + if type(title) != type(None): ax.set_title(title) + if type(xlim) != type(None): ax.set_xlim(xlim) + if type(ylim) != type(None): ax.set_ylim(ylim) + if yaxis_style == 'sci': + ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useLocale=False) + elif yaxis_style == 'float': + ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f')) + # ax.ticklabel_format(axis='y', style='plain') + + if logscale: ax.set_yscale("log") + if legend: ax.legend(legend) + ax.tick_params(axis="x",direction="in") + ax.tick_params(axis="y",direction="in") + if ticks_both_sides: + ax.yaxis.set_ticks_position('both') + ax.xaxis.set_ticks_position('both') + + @staticmethod + def plot_image_with_colorbar(fig, ax, image, style='3_values'): + """ + Plot an image with a colorbar. + + Args: + fig: Figure object. + ax: Axes object. + image: Image data. + style (str, optional): Style of the colorbar ('3_values' or 'continuous'). Defaults to '3_values'. + + Returns: + None + """ + im = ax.imshow(image, vmin=image.min(), vmax=image.max()) + cbar = plt.colorbar(im, ticks=[image.min(), image.max(), (image.min()+image.max())/2]) + cbar.ax.set_yticklabels([image.min(), image.max(), (image.min()+image.max())/2]) + + @staticmethod + def label_curves(ax, curve_x, curve_y, labels_dict): + """ + Label curves on a plot. + + Args: + ax: Axes object. + curve_x: X-axis values of the curve. + curve_y: Y-axis values of the curve. + labels_dict: Dictionary of labels for specific x-values. + + Returns: + None + """ + if type(labels_dict) != type(None): + for x in labels_dict.keys(): + y = curve_y[np.where(curve_x==Viz.find_nearest(curve_x, x))] + pl.text(x, y, str(labels_dict[x]), color="g", fontsize=6) + + + @staticmethod + def plot_curve(ax, curve_x, curve_y, curve_x_fit=None, curve_y_fit=None, plot_colors=['k', 'r'], plot_type='scatter', + markersize=1, xlabel=None, ylabel=None, xlim=None, ylim=None, logscale=False, yaxis_style='sci', + title=None, legend=None): + """ + Plot a curve on the given axes. + + Args: + ax: Axes object. + curve_x: X-axis values of the curve. + curve_y: Y-axis values of the curve. + curve_x_fit (optional): X-axis values of the fitted curve. Defaults to None. + curve_y_fit (optional): Y-axis values of the fitted curve. Defaults to None. + plot_colors (list, optional): Colors for plotting the curve and the fitted curve. Defaults to ['k', 'r']. + plot_type (str, optional): Type of plot to use ('scatter' or 'lineplot'). Defaults to 'scatter'. + markersize (int, optional): Size of markers for scatter plot. Defaults to 1. + xlabel (str, optional): X-axis label. Defaults to None. + ylabel (str, optional): Y-axis label. Defaults to None. + xlim (tuple, optional): X-axis limits. Defaults to None. + ylim (tuple, optional): Y-axis limits. Defaults to None. + logscale (bool, optional): Use log scale on the y-axis. Defaults to False. + yaxis_style (str, optional): Y-axis style. Defaults to 'sci'. + title (str, optional): Plot title. Defaults to None. + legend (list, optional): Legend labels. Defaults to None. + + Returns: + None + """ + + if plot_type == 'scatter': + ax.plot(curve_x, curve_y, color=plot_colors[0], markersize=markersize) + if not isinstance(curve_y_fit, type(None)): + if not isinstance(curve_x_fit, type(None)): + ax.scatter(curve_x_fit, curve_y_fit, color=plot_colors[1], markersize=markersize) + # plot_scatter(ax, curve_x_fit, curve_y_fit, plot_colors[1], markersize) + else: + ax.scatter(curve_x, curve_y_fit, color=plot_colors[1], markersize=markersize) + # plot_scatter(ax, curve_x, curve_y_fit, plot_colors[1], markersize) + + if plot_type == 'lineplot': + ax.plot(curve_x, curve_y, color=plot_colors[0], markersize=markersize) + if not isinstance(curve_y_fit, type(None)): + if not isinstance(curve_x_fit, type(None)): + ax.plot(curve_x_fit, curve_y_fit, color=plot_colors[1], markersize=markersize) + # plot_lineplot(ax, curve_x_fit, curve_y_fit, plot_colors[1], markersize) + else: + ax.plot(curve_x, curve_y_fit, color=plot_colors[1], markersize=markersize) + # plot_lineplot(ax, curve_x, curve_y_fit, plot_colors[1], markersize) + + Viz.set_labels(ax, xlabel=xlabel, ylabel=ylabel, title=title, xlim=xlim, ylim=ylim, yaxis_style=yaxis_style, + logscale=logscale, legend=legend) + + + @staticmethod + def set_index(axes, index, total_length): + """ + Set the index for subplots. + + Args: + axes: Axes object. + index (int): Index value. + total_length (int): Total length of subplots. + + Returns: + None + """ + rows, img_per_row = axes.shape + if total_length <= img_per_row: + index = index%img_per_row + else: + index = (index//img_per_row), index%img_per_row + + # fig, axes = plt.subplots(len(ys)//img_per_row+1*int(len(ys)%img_per_row>0), img_per_row, + # figsize=(16, subplot_height*len(ys)//img_per_row+1)) + # def show_grid_plots(xs, ys, labels=None, ys_fit1=None, ys_fit2=None, img_per_row=4, subplot_height=3, ylim=None, legend=None): + @staticmethod + def show_grid_plots(axes, xs, ys, ys_2=None, labels=None, xlabel=None, ylabel=None, title=None, ylim=None, legend=None, color=None): + """ + Show a grid of plots. + + Args: + axes: Axes object. + xs: X-axis values. + ys: Y-axis values. + labels (optional): Labels for the plots. Defaults to None. + xlabel (str, optional): X-axis label. Defaults to None. + ylabel (str, optional): Y-axis label. Defaults to None. + ylim (tuple, optional): Y-axis limits. Defaults to None. + legend (list, optional): Legend labels. Defaults to None. + color (str, optional): Color for the plots. Defaults to None. + + Returns: + None + """ + if type(labels) == type(None): labels = range(len(ys)) + if isinstance(color, type(None)): color = 'k' + for i in range(len(ys)): + # i = Viz.set_index(axes, i, total_length=len(ys)) + axes[i].plot(xs[i], ys[i], marker='.', markersize=2, color=color) + + if not isinstance(ys_2, type(None)): + axes[i].plot(xs[i], ys_2[i], marker='+', markersize=2) + + Viz.set_labels(axes[i], xlabel=xlabel, ylabel=ylabel, ylim=ylim, legend=legend) + + if isinstance(labels, type(None)): + labelfigs(axes[i], i, loc='cb', size=6) + else: + labelfigs(axes[i], i, string_add=str(labels[i]), loc='cb', size=6) + if title: plt.suptitle(title) + plt.show() + + @staticmethod + def plot_loss_difference(ax1, x_all, y_all, x_coor_all, loss_diff, color_array, color_2, title=None): + """ + Plot the loss difference. + + Args: + ax1: Axes object. + x_all: X-axis values. + y_all: Y-axis values. + x_coor_all: X-axis values for the loss difference. + loss_diff: Loss difference values. + color_array: Array of colors for the background. + color_2: Color for the loss difference plot. + title (str, optional): Plot title. Defaults to None. + + Returns: + None + """ + Viz.draw_background_colors(ax1, color_array) + ax1.scatter(x_all, y_all, c='k', s=1) + Viz.set_labels(ax1, xlabel='Time (s)', ylabel='Intensity (a.u.)') + ax1.tick_params(axis="y", labelcolor='k') + ax1.set_ylabel('Intensity (a.u.)', color='k') + ax1.tick_params(axis="x",direction="in") + + ax2 = ax1.twinx() + ax2.scatter(x_coor_all, loss_diff, color=color_2, s=1) + Viz.set_labels(ax2, xlabel='Time (s)', ylabel='Loss difference (a.u.)', logscale=True) + ax2.tick_params(axis="y", color=color_2, labelcolor=color_2) + ax2.set_ylabel('Loss difference (a.u.)', color=color_2) + ax2.tick_params(axis="x",direction="in") + plt.title(title) + + + @staticmethod + def plot_fit_details(x, y1, y2, y3, labels, figsize=None, mod=6, style='print', logscale=False, layout='compressed', save_name=None, printing=None): + """ + Plot the fit details. + + Args: + x: X-axis values. + y1: Y-axis values for raw data. + y2: Y-axis values for prediction. + y3: Y-axis values for failed data. + labels: labels. + figsize (tuple, optional): Figure size. Defaults to None. + style (str, optional): Style of the plot ('print' or 'presentation'). Defaults to 'print'. + save_name (str, optional): Name to save the plot. Defaults to None. + printing: Printing object. + + Returns: + None + """ + if labels is None: + labels = range(len(y1)) + + if len(y1)//mod > 10 and style == 'print': + n_page = len(y1) // mod // 10 + 1 + for np in range(n_page): + start_plot = np*10*mod + if np == n_page-1: + n_plot = len(y1)%(10*mod) + 1 + else: + n_plot = 10*mod + + if figsize == None: + figsize=(6, 1*(n_plot//mod+1)) + + fig, axes = layout_fig(n_plot, mod=mod, figsize=figsize, layout='compressed') + axes = axes.flatten()[:n_plot] + for i in range(start_plot, start_plot+n_plot): + if np == n_page-1 and i == start_plot+n_plot-1: + handles, labels = axes[-2].get_legend_handles_labels() + axes[-1].legend(handles=handles, labels=labels, loc='center') + axes[-1].set_xticks([]) + axes[-1].set_yticks([]) + axes[-1].set_frame_on(False) + + else: + xlabel, ylabel = None, None + l1 = axes[i%(10*mod)].plot(x[i], y1[i], marker='.', markersize=2, + color=(44/255,123/255,182/255, 0.5), label='Raw data') + l2 = axes[i%(10*mod)].plot(x[i], y2[i], linewidth=2, label='Prediction') + if not isinstance(y3, type(None)): + l3 = axes[i%(10*mod)].plot(x[i], y3[i], linewidth=1, label='Failed') + + if (i%(10*mod)+1) % mod == 1: ylabel = 'Intensity (a.u.)' + if np == n_page-1 and i%(10*mod)+1 >= len(axes)-mod: xlabel = 'Time (s)' + + Viz.set_labels(axes[i%(10*mod)], xlabel=xlabel, ylabel=ylabel, logscale=logscale) + + labelfigs(axes[i%(10*mod)], None, string_add=str(labels[i]), loc='ct', size=8, style='b') + # labelfigs(axes[i%(10*mod)], None, string_add=str(labels[i]), loc='ct', style='b') + axes[i%(10*mod)].set_xticks([]) + axes[i%(10*mod)].set_yticks([]) + axes[i%(10*mod)].xaxis.set_tick_params(labelbottom=False) + axes[i%(10*mod)].yaxis.set_tick_params(labelleft=False) + + plt.tight_layout(pad=-0.5, w_pad=-1, h_pad=-0.5) + if save_name: + printing.savefig(fig, save_name+'-'+str(np+1)) + plt.show() + else: + if figsize == None: + figsize=(6, 1*(n_plot//mod+1)) + + fig, axes = layout_fig(len(y1)+1, mod=mod, figsize=figsize, layout='compressed') + axes = axes.flatten()[:len(y1)+1] + for i in range(len(x)): + xlabel='Time (s)' + ylabel='Intensity (a.u.)' + + l1 = axes[i].plot(x[i], y1[i], marker='.', markersize=2, + color=(44/255,123/255,182/255, 0.5), label='Raw data') + l2 = axes[i].plot(x[i], y2[i], linewidth=2, label='Prediction') + if not isinstance(y3, type(None)): + l3 = axes[i].plot(x[i], y3[i], linewidth=1, label='Failed') + if i+1 < len(axes)-mod: xlabel = None + if not (i+1) % mod == 1: ylabel = None + Viz.set_labels(axes[i], xlabel=xlabel, ylabel=ylabel, logscale=logscale, yaxis_style='float') + labelfigs(axes[i], None, string_add=str(labels[i]), loc='ct', size=8, style='b') + # axes[i].set_xticks([]) + # axes[i].set_yticks([]) + # axes[i].xaxis.set_tick_params(labelbottom=False) + # axes[i].yaxis.set_tick_params(labelleft=False) + + handles, labels = axes[-2].get_legend_handles_labels() + axes[-1].legend(handles=handles, labels=labels, loc='center') + axes[-1].set_xticks([]) + axes[-1].set_yticks([]) + axes[-1].set_frame_on(False) + + # plt.tight_layout(pad=-0.5, w_pad=-1, h_pad=-0.5) + if save_name: + printing.savefig(fig, save_name) + plt.show() + + @staticmethod + def label_violinplot(ax, data, label_type='average', text_pos='center'): + """ + Label a violin plot. + + Args: + ax: Axes object. + data: Data for the violin plot. + label_type (str, optional): Type of label to use ('average' or 'number'). Defaults to 'average'. + text_pos (str, optional): Position of the label text ('center' or 'right'). Defaults to 'center'. + + Returns: + None + """ + + # Calculate number of obs per group & median to position labels + xloc = range(len(data)) + yloc, text = [], [] + + for i, d in enumerate(data): + yloc.append(np.median(d)) + + if label_type == 'number': + text.append("n: "+str(len(d))) + + if label_type == 'average': + text.append(str(round(np.median(d), 4))) + + for tick, label in zip(xloc, ax.get_xticklabels()): + if text_pos == 'center': + ax.text(xloc[tick], yloc[tick]*1.1, text[tick], horizontalalignment='center', size=14, weight='semibold') + if text_pos == 'right': + ax.text(xloc[tick]+0.02, yloc[tick]*0.7, text[tick], horizontalalignment='left', size=14, weight='semibold') \ No newline at end of file diff --git a/src/rheed_learn/XRD.py b/src/rheed_learn/XRD.py new file mode 100644 index 0000000..79d7711 --- /dev/null +++ b/src/rheed_learn/XRD.py @@ -0,0 +1,97 @@ +# pip install xrayutilities +import numpy as np +import matplotlib.pyplot as plt +import xrayutilities as xu +from matplotlib.ticker import MultipleLocator +import glob +from matplotlib import ticker, cm, colors + + +def plot_xrd(ax, files, labels, title=None, xrange=(0,90), diff=1e3, pad_sequence=[]): + """ + Plot X-ray diffraction patterns. + + Parameters: + - ax (matplotlib.axes.Axes): The axes object to plot the patterns. + - files (list of str): The paths to the files containing the X-ray diffraction data. + - labels (list of str): The labels for each X-ray diffraction pattern. + - title (str, optional): The title for the plot (default: None). + - xrange (tuple, optional): The x-axis range of the plot (default: (0, 90)). + - diff (float, optional): Scaling factor for intensity differences between patterns (default: 1e3). + - pad_sequence (list, optional): Padding sequence for X-ray diffraction patterns with different scan ranges (default: []). + + Returns: + None + """ + + Xs, Ys = [], [] + length_list = [] + for file in files: + out = xu.io.getxrdml_scan(file) + Xs.append(out[0]) + Ys.append(out[1]) + length_list.append(len(out[0])) + + if np.mean(length_list) != np.max(length_list): + if pad_sequence == []: + print('Different scan ranges, input pad_sequence to pad') + return + else: + for i in range(len(Ys)): + Ys[i] = np.pad(Ys[i], pad_sequence[i], mode='median') + X = Xs[np.argmax(length_list)] + + # fig, axes = plt.subplots(figsize=figsize) + + for i, Y in enumerate(Ys): + Y[Y==0] = 1 # remove all 0 value + if diff: + Y = Y * diff**(len(Ys)-i-1) + ax.plot(X, Y, label=labels[i]) + + ax.set_xlabel(r"2$\Theta}$", ) + ax.set_ylabel('Intensity [a.u.]') + ax.legend() + + ax.set_yscale('log',base=10) + ax.set_xlim(xrange) + if title: ax.set_title(title) + # ax.set_xticks(np.arange(*xrange, 1)) + + +def plot_rsm(ax, file, reciprocal_space=True, title=None): + """ + Plot a reciprocal space map or a real space map. + + Parameters: + - ax (matplotlib.axes.Axes): The axes object to plot the map. + - file (str): The path to the file containing the data. + - reciprocal_space (bool, optional): Whether to plot a reciprocal space map (default: True). + - title (str, optional): The title for the plot (default: None). + + Returns: + None + """ + + curve_shape = xu.io.getxrdml_scan(file)[0].shape + omega, two_theta, intensity = xu.io.panalytical_xml.getxrdml_map(file) + + omega = omega.reshape(curve_shape) + two_theta = two_theta.reshape(curve_shape) + intensity = intensity.reshape(curve_shape) + intensity[intensity==0]=1 + + if reciprocal_space: + wavelength = 1.54 # unit: angstrom + Qz = (1/wavelength)*(np.sin(np.deg2rad(two_theta-omega))+np.sin(np.deg2rad(omega))) + Qx = (1/wavelength)*(np.cos(np.deg2rad(omega))-np.cos(np.deg2rad(two_theta-omega))) + cs = ax.contourf(Qx, Qz, intensity, locator=ticker.LogLocator(), cmap=cm.viridis, norm=colors.LogNorm()) + else: + cs = ax.contourf(omega, two_theta, intensity, locator=ticker.LogLocator(), cmap=cm.viridis, norm=colors.LogNorm()) + + formatter = ticker.LogFormatterMathtext(base=10, labelOnlyBase=False) + ax.set_xlabel(r"Qx r'$\AA$'") + ax.set_ylabel(r"Qz r'$\AA$'") + + plt.colorbar(cs, ax=ax, format=formatter) + if title: ax.set_title(title) \ No newline at end of file