Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize Rust function apply_phase_shift_in_place #230

Open
kevinsung opened this issue Jun 2, 2024 · 4 comments
Open

Parallelize Rust function apply_phase_shift_in_place #230

kevinsung opened this issue Jun 2, 2024 · 4 comments
Labels
rust Involves writing Rust
Milestone

Comments

@kevinsung
Copy link
Collaborator

kevinsung commented Jun 2, 2024

This function here:

indices.for_each(|&str0| {

A straightforward attempt doesn't pass the compiler due to the use of unsafe BLAS functions because each thread needs to have a mutable reference to a row of the array being modified. We know that no two threads will have access to the same row, but the compiler can't tell.

@kevinsung kevinsung added the rust Involves writing Rust label Jun 2, 2024
@kevinsung kevinsung added this to the v0.1 milestone Sep 8, 2024
@S-Erik
Copy link

S-Erik commented Feb 7, 2025

Hey @kevinsung I came up with a possible implementation which most probably has worse performance per iteration but the par_iter may make it faster overall still.

We can iterate over vec instead of indices but then we have to filter for rows appearing in indices but we can use into_par_iter without issues.

    let mut vec = vec.as_array_mut();
    let indices = indices.as_array();
    let shape = vec.shape();
    let dim_b = shape[1] as i32;

    vec.axis_iter_mut(Axis(0)).into_par_iter()
        .enumerate()
        .filter(|(i, _)| indices.iter().any(|idx| idx == i))
        .for_each(|(_, mut row)| {
            match row.as_slice_mut() {
                Some(row) => unsafe {
                    zscal(dim_b, phase, row, 1);
                },
                None => panic!(
                    "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
                ),
            }
        })

pytest is passing all tests with this.

Is there a script which I can use to test the performance of apply_phase_shift_in_place?

I am wondering what you implementation idea was since I had no issues with the unsafe block.

@kevinsung
Copy link
Collaborator Author

Hi @S-Erik, thank you for looking into this! Here is a script that you can adapt to test the performance:

import cmath

import numpy as np

from ffsim._lib import apply_phase_shift_in_place

rng = np.random.default_rng(1234)

dim = 100
n_indices = 50

mat = rng.standard_normal((dim, dim)).astype(complex)
phase_shift = cmath.rect(1, rng.uniform(0, np.pi))
indices = rng.choice(dim, size=n_indices, replace=False).astype(np.uint64)

apply_phase_shift_in_place(mat, phase_shift, indices)

I am wondering what you implementation idea was since I had no issues with the unsafe block.

To be honest I don't remember at this point. Maybe I was mistaken.

@S-Erik
Copy link

S-Erik commented Feb 9, 2025

Thanks for the quick answer.

I tested the performance of the current implementation against the version I suggested but slightly changed to make it more readable (call to to_vec and filter with contains):

    let mut vec = vec.as_array_mut();
    let indices = indices.as_array().to_vec();
    let shape = vec.shape();
    let dim_b = shape[1] as i32;
    vec.axis_iter_mut(Axis(0)).into_par_iter()
        .enumerate()
        .filter(|(i, _)| indices.contains(i))
        .map(|(_, row)| row)
        .for_each(|mut row| {
            match row.as_slice_mut() {
                Some(row) => unsafe {
                    zscal(dim_b, phase, row, 1);
                },
                None => panic!(
                    "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
                ),
            }
        })

For this I timed different number of indices (n_indices) and matrix dimensions (dim). I ran each combination a 1000 times and took the mean runtime in seconds (see my cpu info and python script at the bottom).

Current implementation My implementation using filter
current filter

We see slight performance improvements for larger matrices and larger indices with my filter-implementation but these are below a factor of two from what I see.

The main challenge modifying the current implementation to use concurrent calls to zscal is that each rayon thread has to have a mutable reference to the vec when we iterate over indices with par_iter. This is not possible with par_iter, although each rayon thread would access a different row. So in the end a concurrent implementation seems to be safe since we never access the same data. Nevertheless, I think implementing this needs manually handling threads and unsafe codeblocks since rayon does not allow implementing this concurrently.

I found a stackoverflow discussion about a very similar problem. There it was also suggested to use my filter-approach.

@kevinsung what are your thoughts on that? Currently, I am not eager to implement a concurrent version without rayon.

Other Approaches

I also tried different approaches:

  • Creating a new indices array of bool values. This array was of size equal to number of rows in vec. Then I could par_iter over vec zipped with this bool array. This avoided the call to filter:
    let mut vec = vec.as_array_mut();
    let indices = indices.as_array().to_vec();
    let shape = vec.shape();
    let dim_b = shape[1] as i32;
    let rows_bool: Vec<bool> = (0..vec.len_of(Axis(0)))
        .map(|i| if indices.contains(&i) { true } else { false })
        .collect();
    vec.axis_iter_mut(Axis(0)).into_par_iter().zip(rows_bool).for_each(|(mut row, bool_val)| {
        if bool_val {
            match row.as_slice_mut() {
                Some(row) => unsafe {
                    zscal(dim_b, phase, row, 1);
                },
                None => panic!(
                    "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
                ),
            }
        }
    })
  • Mapping the indices to the phase shifted values with par_iter and then sequentially assigning these values to the rows of vec:
    let mut vec = vec.as_array_mut();
    let indices = indices.as_array();
    let shape = vec.shape();
    let dim_b = shape[1] as i32;
    let indices_mapped: Vec<Array1<Complex64>> = indices.into_par_iter().map(|&str0| {
        let mut target = vec.row(str0).to_owned();
        match target.as_slice_mut() {
            Some(target) => unsafe {
                zscal(dim_b, phase, target, 1);
            },
            None => panic!(
                "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
            ),
        };
        target
    }).collect();
    indices
        .into_iter()
        .zip(indices_mapped)
        .for_each(|(&str0, val)| {
            let mut target = vec.row_mut(str0);
            target.assign(&val);
        })

each of which where slower then the filter-approach

My CPU info (lscpu):

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          43 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   16
  On-line CPU(s) list:    0-15
Vendor ID:                AuthenticAMD
  Model name:             AMD Ryzen 7 2700X Eight-Core Processor
    CPU family:           23
    Model:                8
    Thread(s) per core:   2
    Core(s) per socket:   8
    Socket(s):            1
    Stepping:             2
    Frequency boost:      enabled
    CPU max MHz:          3700,0000
    CPU min MHz:          2200,0000
    BogoMIPS:             7399.86
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pa
                          t pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt 
                          pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid e
                          xtd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 
                          sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp
                          _legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefet
                          ch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext per
                          fctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 av
                          x2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec x
                          getbv1 clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save
                           tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfth
                          reshold avic v_vmsave_vmload vgif overflow_recov succor smca sev
                           sev_es
Virtualization features:  
  Virtualization:         AMD-V
Caches (sum of all):      
  L1d:                    256 KiB (8 instances)
  L1i:                    512 KiB (8 instances)
  L2:                     4 MiB (8 instances)
  L3:                     16 MiB (2 instances)
NUMA:                     
  NUMA node(s):           1
  NUMA node0 CPU(s):      0-15
Vulnerabilities:          
  Gather data sampling:   Not affected
  Itlb multihit:          Not affected
  L1tf:                   Not affected
  Mds:                    Not affected
  Meltdown:               Not affected
  Mmio stale data:        Not affected
  Reg file data sampling: Not affected
  Retbleed:               Mitigation; untrained return thunk; SMT vulnerable
  Spec rstack overflow:   Mitigation; Safe RET
  Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitiza
                          tion
  Spectre v2:             Mitigation; Retpolines; IBPB conditional; STIBP disabled; RSB fi
                          lling; PBRSB-eIBRS Not affected; BHI Not affected
  Srbds:                  Not affected
  Tsx async abort:        Not affected

Python script

import time
import cmath
import numpy as np
import matplotlib.pyplot as plt

from ffsim._lib import apply_phase_shift_in_place

rng = np.random.default_rng(1234)

n_lst = np.arange(200, 701, 100)
dim_lst = np.arange(10, 201, 50) * 100

n = 1_000
time_sum = 0
mean_times = {}  # key is n_indices
for i, n_indices in enumerate(n_lst):
    print(f"{i}/{len(n_lst)}", end="\r")
    mean_times[n_indices] = []
    for dim in dim_lst:
        mat = rng.standard_normal((dim, dim)).astype(complex)
        phase_shift = cmath.rect(1, rng.uniform(0, np.pi))
        indices = rng.choice(dim, size=n_indices, replace=False).astype(np.uint64)
        for _ in range(n):
            start_time = time.perf_counter()
            apply_phase_shift_in_place(mat, phase_shift, indices)
            time_sum += time.perf_counter() - start_time
        mean_times[n_indices].append(time_sum / n)

plt.figure()
for key, vals in mean_times.items():
    plt.plot(dim_lst, vals, label=f"{key} indices", marker=".")
plt.ylabel(f"Mean runtime of {n} runs [s]")
plt.xlabel(f"Dimension of matrix")
plt.yscale("log")
plt.legend()
plt.grid()
plt.savefig("perf.png", bbox_inches="tight", dpi=128)

@kevinsung
Copy link
Collaborator Author

@S-Erik Thank you very much for your investigation!

The main challenge modifying the current implementation to use concurrent calls to zscal is that each rayon thread has to have a mutable reference to the vec when we iterate over indices with par_iter.

You are exactly right about this. I've updated the opening post to reflect this underlying issue more accurately.

Nevertheless, I think implementing this needs manually handling threads and unsafe codeblocks since rayon does not allow implementing this concurrently.

Makes sense.

@kevinsung what are your thoughts on that? Currently, I am not eager to implement a concurrent version without rayon.

I think we should implement the threading manually rather than use your filter-based approach. No worries if you can't work on this. For what it's worth, the similar issue #229 is a higher priority because it's a more significant bottleneck in applications. In that issue, each thread needs to have mutable references to two rows, so I don't think the filter-based approach applies there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rust Involves writing Rust
Projects
None yet
Development

No branches or pull requests

2 participants