Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jart synapse #709

Open
wants to merge 86 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
ba8a2a4
1st complie test
ZhenmingYu Jul 7, 2022
06343cf
partial fix
ZhenmingYu Jul 7, 2022
589da9a
compile success, need to hide write_noise_std
ZhenmingYu Jul 8, 2022
0bdca9d
fix pybind bugs
ZhenmingYu Jul 8, 2022
37ce140
fix weight granularity
ZhenmingYu Jul 11, 2022
a695df8
tried fixing
ZhenmingYu Jul 19, 2022
a47273c
fix nan
ZhenmingYu Jul 19, 2022
de4fe74
fix w_min
ZhenmingYu Jul 20, 2022
3aaceba
CPU version working!
ZhenmingYu Jul 26, 2022
d8057d6
merge with aihwkit 0.6.0
ZhenmingYu Jul 26, 2022
51839f0
cuda-stuff backup
ZhenmingYu Jul 29, 2022
3965a0d
Store Ndisc in w_persistent_
ZhenmingYu Aug 2, 2022
ec50faa
CUDA version, not able to change pytorch weights
ZhenmingYu Aug 13, 2022
1eb6369
remove redundent parameters
ZhenmingYu Aug 14, 2022
dd37c66
Merge branch 'JART_v1b_device' into old_dev_code_with_v0.5.1
ZhenmingYu Aug 15, 2022
bacc2c3
Merge pull request #1 from ZhenmingYu/old_dev_code_with_v0.5.1
ZhenmingYu Aug 15, 2022
6cf0580
fix merge issue
ZhenmingYu Aug 15, 2022
7923177
fix minor issue
ZhenmingYu Aug 15, 2022
6d3cb56
fix merge issues
ZhenmingYu Aug 15, 2022
0a02480
add JART v1b documentations
ZhenmingYu Aug 15, 2022
cae4f8a
write_noise_std does not fit our device, removed
ZhenmingYu Aug 16, 2022
4a34388
polishing documentation for JART v1b
ZhenmingYu Aug 16, 2022
6a07a5c
change basic test name
ZhenmingYu Aug 16, 2022
c13e089
update test for cuda
ZhenmingYu Aug 16, 2022
422bdfc
Simplify calculations, CUDA somehow works
ZhenmingYu Aug 28, 2022
be48182
fix CUDA version with a delay hack
ZhenmingYu Aug 31, 2022
610d906
Add support for noise
ZhenmingYu Sep 2, 2022
e8ea3cd
improve test useability
ZhenmingYu Sep 5, 2022
bca2d2c
debug test-releted problems
ZhenmingYu Sep 7, 2022
98e8f67
add dtod bounds/ctoc noise to uniform distribution
ZhenmingYu Sep 9, 2022
d5bfbf9
update test configs and ploting scripts
ZhenmingYu Sep 11, 2022
addf5e9
ignore figures
ZhenmingYu Sep 11, 2022
b673afc
bug fix
ZhenmingYu Sep 20, 2022
b9eaf2d
update test
ZhenmingYu Sep 20, 2022
397e4df
fix ctoc variation
ZhenmingYu Oct 11, 2022
0992e53
Update test configs
ZhenmingYu Oct 12, 2022
9fcb103
Merge pull request #1 from ZhenmingYu/JART_v1b_device
ZhenmingYu Oct 12, 2022
007c138
small plot fix
ZhenmingYu Oct 13, 2022
768f496
Merge pull request #2 from ZhenmingYu/JART_v1b_device
ZhenmingYu Oct 13, 2022
4b0a1de
Include JART v1b test instructions
ZhenmingYu Oct 18, 2022
f76383f
Merge pull request #3 from ZhenmingYu/JART_v1b_device
ZhenmingYu Oct 18, 2022
e443dbf
fix bugs tyops and documentations
ZhenmingYu Oct 24, 2022
25158a2
New release prep. Changes to version number and CHANGELOG
Jan 5, 2023
5410864
Merge branch 'release/0.7.0' of https://github.com/IBM-AI-Hardware-Ce…
Jan 30, 2023
d9ba361
Merge branch 'IBM-AI-Hardware-Center-release/0.7.0'
Jan 30, 2023
55f03f3
Merge branch 'master' of https://github.com/IBM/aihwkit
Mar 20, 2023
a9bec5a
upstream CHANGELOG
Mar 20, 2023
877b41f
Merge branch 'master' of https://github.com/IBM/aihwkit
Mar 22, 2023
73a98f0
Merge branch 'master' of https://github.com/IBM/aihwkit
Mar 23, 2023
147827c
Merge branch 'master' of https://github.com/IBM/aihwkit
Mar 24, 2023
b4d1ad8
Merge branch 'master' of https://github.com/IBM/aihwkit
Mar 24, 2023
121f041
plot script
ZhenmingYu May 2, 2023
a0896c6
update documentation to include citations
ZhenmingYu May 2, 2023
b97578f
Merge pull request #4 from ZhenmingYu/JART_v1b_device
ZhenmingYu May 2, 2023
5752105
yaml_loader that holds the configuration function
ZhenmingYu May 4, 2023
740d763
fix cuda __syncthreads
ZhenmingYu May 5, 2023
1cfc864
optimizing CUDA speed
ZhenmingYu May 5, 2023
6a0b28c
Merge branch 'master' of https://github.com/IBM/aihwkit
May 30, 2023
eb29d62
Merge branch 'master' of https://github.com/IBM/aihwkit
Jun 26, 2023
ea91677
Merge branch 'master' of https://github.com/IBM/aihwkit
Jun 27, 2023
843fc0e
Merge branch 'master' of https://github.com/IBM/aihwkit
Jul 10, 2023
ba88bc2
Merge branch 'master' of https://github.com/IBM/aihwkit
Jul 10, 2023
b442780
Merge branch 'master' of https://github.com/IBM/aihwkit
Jul 14, 2023
28b5dca
Merge branch 'master' of https://github.com/IBM/aihwkit
Aug 25, 2023
7ad577a
Merge branch 'master' of https://github.com/IBM/aihwkit
Sep 14, 2023
09f9ac3
Merge branch 'master' of https://github.com/IBM/aihwkit
Sep 15, 2023
e7dcc77
Merge branch 'master' of https://github.com/IBM/aihwkit
Sep 15, 2023
66eba8b
Merge branch 'master' of https://github.com/IBM/aihwkit
Sep 18, 2023
ae66db9
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 1, 2023
921e2c4
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 8, 2023
f8b1ef3
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 15, 2023
3e21469
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 19, 2023
60febbf
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 21, 2023
0d2dd78
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 27, 2023
d6d6f31
Merge branch 'master' of https://github.com/IBM/aihwkit
Dec 28, 2023
a263359
Merge branch 'master' into HEAD
Dec 28, 2023
67c6db9
merge
Dec 28, 2023
eb08472
still needs work
Jan 4, 2024
2376df4
Merge branch 'master' into jart-synapse
kaoutar55 May 8, 2024
a1f3134
Merge branch 'master' into jart-synapse
Borjagodoy Aug 6, 2024
d65d123
fixing src/rpucuda/rpu_JART_v1b_device.h
ZhenmingYu Aug 8, 2024
49971ae
Build success?
ZhenmingYu Aug 9, 2024
aa832bd
Add examples
ZhenmingYu Jan 20, 2025
8cd563d
update examples
ZhenmingYu Jan 21, 2025
032de01
problem solved by changing to public?
ZhenmingYu Jan 22, 2025
54fcb62
Merge branch 'IBM:master' into jart-synapse
ZhenmingYu Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
partial fix
ZhenmingYu committed Jul 7, 2022
commit 06343cffba48196318dd7b174f5a5c1b6160ee76
184 changes: 102 additions & 82 deletions src/rpucuda/rpu_JART_v1b_static_device.cpp
Original file line number Diff line number Diff line change
@@ -38,13 +38,46 @@ void JARTv1bStaticRPUDevice<T>::populate(
}
}

template <typename T> void JARTv1bStaticRPUDevice<T>::printDP(int x_count, int d_count) const {

if (x_count < 0 || x_count > this->x_size_) {
x_count = this->x_size_;
}

if (d_count < 0 || d_count > this->d_size_) {
d_count = this->d_size_;
}
bool persist_if = getPar().usesPersistentWeight();

for (int i = 0; i < d_count; ++i) {
for (int j = 0; j < x_count; ++j) {
std::cout.precision(5);
std::cout << i << "," << j << ": ";
std::cout << "[<" << this->w_max_bound_[i][j] << ",";
std::cout << this->w_min_bound_[i][j] << ">,<";
std::cout << this->w_scale_up_[i][j] << ",";
std::cout << this->w_scale_down_[i][j] << ">,<";
std::cout << w_slope_up_[i][j] << ",";
std::cout << w_slope_down_[i][j] << ">]";
std::cout.precision(10);
std::cout << this->w_decay_scale_[i][j] << ", ";
std::cout.precision(6);
std::cout << this->w_diffusion_rate_[i][j] << ", ";
std::cout << this->w_reset_bias_[i][j];
if (persist_if) {
std::cout << ", " << this->w_persistent_[i][j];
}
std::cout << "]";
}
std::cout << std::endl;
}
}


namespace {
template <typename T>
struct Voltages
struct Voltages_holder
{
T V_disk;
T V_series;
T V_disk;
T V_plug;
T V_Schottky;
@@ -62,6 +95,7 @@ inline T calculate_current_negative(
const T &beta1,
const T &c0,
const T &c1,
const T &c2,
const T &c3,
const T &d0,
const T &d1,
@@ -87,14 +121,14 @@ inline T calculate_current_positive(
const T &h3,
const T &j_0,
const T &k0,
T &Ndiscmin) {
const T &Ndiscmin) {
return (-g0*(exp(-g1*applied_voltage)-1))/(pow((1+(h0+h1*applied_voltage+h2*exp(-h3*applied_voltage))*pow((Ndisc/Ndiscmin),(-j_0))),(1/k0)));
}

template <typename T>
inline T invert_positive_current(
T &I_mem,
T &applied_voltage,
const T &read_voltage,
const T &g0,
const T &g1,
const T &h0,
@@ -103,9 +137,9 @@ inline T invert_positive_current(
const T &h3,
const T &j_0,
const T &k0,
T &Ndiscmin) {
const T &Ndiscmin) {
if (I_mem>0){
return pow(((pow(((-g0*(exp(-g1*applied_voltage)-1))/I_mem), k0)-1)/(h0+h1*applied_voltage+h2*exp(-h3*applied_voltage))),(1/-j_0))*Ndiscmin;
return pow(((pow(((-g0*(exp(-g1*read_voltage)-1))/I_mem), k0)-1)/(h0+h1*read_voltage+h2*exp(-h3*read_voltage))),(1/-j_0))*Ndiscmin;
}
else{
return 0;
@@ -124,6 +158,7 @@ inline T calculate_current(
const T &beta1,
const T &c0,
const T &c1,
const T &c2,
const T &c3,
const T &d0,
const T &d1,
@@ -141,9 +176,9 @@ inline T calculate_current(
const T &h3,
const T &j_0,
const T &k0,
T &Ndiscmin) {
const T &Ndiscmin) {
if (applied_voltage < 0) {
return calculate_current_negative(Ndisc, applied_voltage, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c3, d0, d1, d2, d3, f0, f1, f2, f3);
return calculate_current_negative(Ndisc, applied_voltage, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c2, c3, d0, d1, d2, d3, f0, f1, f2, f3);
} else {
return calculate_current_positive(Ndisc, applied_voltage, g0, g1, h0, h1, h2, h3, j_0, k0, Ndiscmin);
}
@@ -156,18 +191,16 @@ inline T calculate_T(
const T &T0,
const T &Rth0,
const T &Rtheff_scaling,
T &V_disk,
T &V_plug,
T &V_Schottky) {
Voltages_holder<T> &Voltages) {
if (applied_voltage > 0) {
return T0 + I_mem*(V_disk+V_plug+V_Schottky)*Rth0*Rtheff_scaling;
return T0 + I_mem*(Voltages.V_disk+Voltages.V_plug+Voltages.V_Schottky)*Rth0*Rtheff_scaling;
} else {
return T0 + I_mem*(V_disk+V_plug+V_Schottky)*Rth0;
return T0 + I_mem*(Voltages.V_disk+Voltages.V_plug+Voltages.V_Schottky)*Rth0;
}
}

template <typename T>
inline T* calculate_voltages(
inline Voltages_holder<T> calculate_voltages(
T &applied_voltage,
T &I_mem,
const T &R0,
@@ -176,21 +209,21 @@ inline T* calculate_voltages(
const T &RseriesTiOx,
const T &lcell,
T &ldet,
const T &zvo,
const int &zvo,
const T &e,
T &A,
const T &Nplug,
T &Ndisc,
const T &un) {
T Voltages[4]
Voltages_holder<T> Voltages;
// V_series
Voltages[0] = I_mem*(RseriesTiOx + (R0*(1+alphaline*R0*(I_mem**2)*Rthline)));
Voltages.V_series = I_mem*(RseriesTiOx + (R0*(1+alphaline*R0*pow(I_mem,2)*Rthline)));
// V_disk
Voltages[1] = I_mem*(ldet/(zvo*e*A*Ndisc*un));
Voltages.V_disk = I_mem*(ldet/(zvo*e*A*Ndisc*un));
// V_plug
Voltages[2] = I_mem*((lcell-ldet)/(zvo*e*A*Nplug*un));
Voltages.V_plug = I_mem*((lcell-ldet)/(zvo*e*A*Nplug*un));
// V_Schottky
Voltages[3] = applied_voltage-V_series_result-V_plug_result-V_disk_result;
Voltages.V_Schottky = applied_voltage-Voltages.V_series-Voltages.V_disk-Voltages.V_plug;
return Voltages;
}

@@ -210,15 +243,13 @@ inline T calculate_F1(
template <typename T>
inline T calculate_Eion(
T &applied_voltage,
T &V_disk,
T &V_plug,
T &V_Schottky,
Voltages_holder<T> &Voltages,
const T &lcell,
T &ldet) {
if (applied_voltage > 0) {
return V_disk/ldet;
return Voltages.V_disk/ldet;
} else {
return (V_Schottky + V_plug + V_disk)/lcell;
return (Voltages.V_Schottky + Voltages.V_plug + Voltages.V_disk)/lcell;
}
}

@@ -232,7 +263,7 @@ inline T calculate_dNdt(
const T &Arichardson,
const T &mdiel,
const T &h,
const T &zvo,
const int &zvo,
const T &eps_0,
const T &T0,
const T &eps,
@@ -260,13 +291,13 @@ inline T calculate_dNdt(

T F1 = calculate_F1(applied_voltage, Ndisc, Ndiscmin, Ndiscmax);

T* Voltages = calculate_voltages(applied_voltage, I_mem, R0, alphaline, Rthline, RseriesTiOx, lcell, ldet, zvo, e, A, Nplug, Ndisc, un);
Voltages_holder<T> Voltages = calculate_voltages(applied_voltage, I_mem, R0, alphaline, Rthline, RseriesTiOx, lcell, ldet, zvo, e, A, Nplug, Ndisc, un);

T Eion = calculate_Eion(applied_voltage, *(Voltages + 1),*(Voltages + 2),*(Voltages + 3), lcell, ldet);
T Eion = calculate_Eion(applied_voltage, Voltages, lcell, ldet);

T gamma = zvo*a*Eion/(dWa*M_PI);

T Treal = calculate_T(applied_voltage, I_mem, T0, Rth0, Rtheff_scaling, *(Voltages + 1),*(Voltages + 2),*(Voltages + 3));
T Treal = calculate_T(applied_voltage, I_mem, T0, Rth0, Rtheff_scaling, Voltages);

// dWamin
T dWa_f = dWa*(sqrt(1-pow(gamma,2))-(gamma*M_PI)/2+gamma*asin(gamma));
@@ -278,9 +309,9 @@ inline T calculate_dNdt(
}

template <typename T>
inline T step(
inline void step(
T &applied_voltage,
T &time_step,
const T &time_step,
T &Ndisc,
const T &alpha0,
const T &alpha1,
@@ -290,6 +321,7 @@ inline T step(
const T &beta1,
const T &c0,
const T &c1,
const T &c2,
const T &c3,
const T &d0,
const T &d1,
@@ -312,7 +344,7 @@ inline T step(
const T &Arichardson,
const T &mdiel,
const T &h,
const T &zvo,
const int &zvo,
const T &eps_0,
const T &T0,
const T &eps,
@@ -336,9 +368,9 @@ inline T step(
const T &Rthline,
const T &alphaline,
T &A,
T &Ndisc_min_bound,
T &Ndisc_max_bound) {
T I_mem = calculate_current(Ndisc, applied_voltage, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c3, d0, d1, d2, d3, f0, f1, f2, f3, g0, g1, h0, h1, h2, h3, j_0, k0, Original_Ndiscmin);
const T &Ndisc_min_bound,
const T &Ndisc_max_bound) {
T I_mem = calculate_current(Ndisc, applied_voltage, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c2, c3, d0, d1, d2, d3, f0, f1, f2, f3, g0, g1, h0, h1, h2, h3, j_0, k0, Original_Ndiscmin);
T dNdt = calculate_dNdt(applied_voltage, I_mem, Ndisc, e, kb, Arichardson, mdiel, h, zvo, eps_0, T0, eps, epsphib, phiBn0, phin, un, Ndiscmax, Ndiscmin, Nplug, a, ny0, dWa, Rth0, lcell, ldet, Rtheff_scaling, RseriesTiOx, R0, Rthline, alphaline, A);
Ndisc = Ndisc + dNdt*time_step;
if (Ndisc>Ndisc_max_bound){
@@ -353,8 +385,8 @@ template <typename T>
inline T map_Ndisc_to_weight(
const T &read_voltage,
T &Ndisc,
T &conductance_min,
T &conductance_max,
const T &conductance_min,
const T &conductance_max,
T &weight_min_bound,
T &weight_max_bound,
const T &g0,
@@ -366,18 +398,19 @@ inline T map_Ndisc_to_weight(
const T &j_0,
const T &k0,
const T &Original_Ndiscmin) {
T conductance = calculate_current_positive(Ndisc, read_voltage, g0, g1, h0, h1, h2, h3, j_0, k0, Original_Ndiscmin);
T applied_voltage = read_voltage;
T conductance = calculate_current_positive(Ndisc, applied_voltage, g0, g1, h0, h1, h2, h3, j_0, k0, Original_Ndiscmin);
T weight = ((conductance-conductance_min)/(conductance_max-conductance_min))*(weight_max_bound-weight_min_bound)+weight_min_bound;
return weight;
}

template <typename T>
inline T update_once(
inline void update_once(
const T &read_voltage,
T &pulse_voltage_SET,
T &pulse_voltage_RESET,
T &pulse_length,
T &base_time_step,
const T &pulse_voltage_SET,
const T &pulse_voltage_RESET,
const T &pulse_length,
const T &base_time_step,
const T &alpha0,
const T &alpha1,
const T &alpha2,
@@ -386,6 +419,7 @@ inline T update_once(
const T &beta1,
const T &c0,
const T &c1,
const T &c2,
const T &c3,
const T &d0,
const T &d1,
@@ -408,7 +442,7 @@ inline T update_once(
const T &Arichardson,
const T &mdiel,
const T &h,
const T &zvo,
const int &zvo,
const T &eps_0,
const T &T0,
const T &eps,
@@ -435,12 +469,12 @@ inline T update_once(
T &w,
T &w_apparent,
int &sign,
T &conductance_min,
T &conductance_max,
const T &conductance_min,
const T &conductance_max,
T &weight_min_bound,
T &weight_max_bound,
T &Ndisc_min_bound,
T &Ndisc_max_bound) {
const T &Ndisc_min_bound,
const T &Ndisc_max_bound) {
int pulse_counter = (int) pulse_length/base_time_step;
T pulse_voltage = pulse_voltage_RESET;

@@ -449,7 +483,7 @@ inline T update_once(
}

for (int i = 0; i < pulse_counter; i++) {
step(pulse_voltage, base_time_step, w, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c3, d0, d1, d2, d3, f0, f1, f2, f3, g0, g1, h0, h1, h2, h3, j_0, k0, e, kb, Arichardson, mdiel, h, zvo, eps_0, T0, eps, epsphib, phiBn0, phin, un, Original_Ndiscmin, Ndiscmax, Ndiscmin, Nplug, a, ny0, dWa, Rth0, lcell, ldet, Rtheff_scaling, RseriesTiOx, R0, Rthline, alphaline, A, Ndisc_min_bound, Ndisc_max_bound);
step(pulse_voltage, base_time_step, w, alpha0, alpha1, alpha2, alpha3, beta0, beta1, c0, c1, c2, c3, d0, d1, d2, d3, f0, f1, f2, f3, g0, g1, h0, h1, h2, h3, j_0, k0, e, kb, Arichardson, mdiel, h, zvo, eps_0, T0, eps, epsphib, phiBn0, phin, un, Original_Ndiscmin, Ndiscmax, Ndiscmin, Nplug, a, ny0, dWa, Rth0, lcell, ldet, Rtheff_scaling, RseriesTiOx, R0, Rthline, alphaline, A, Ndisc_min_bound, Ndisc_max_bound);
}

w_apparent = map_Ndisc_to_weight(read_voltage, w, conductance_min, conductance_max, weight_min_bound, weight_max_bound, g0, g1, h0, h1, h2, h3, j_0, k0, Original_Ndiscmin);
@@ -508,18 +542,13 @@ inline T update_once(
// w_apparent = w + write_noise_std * rng->sampleGauss();
// }
// }
} // namespace

template <typename T>
void JARTv1bStaticRPUDevice<T>::doSparseUpdate(
T **weights, int i, const int *x_signed_indices, int x_count, int d_sign, RNG<T> *rng) {

const auto &par = getPar();

T *scale_down = this->w_scale_down_[i];
T *scale_up = this->w_scale_up_[i];
T *slope_down = w_slope_down_[i];
T *slope_up = w_slope_up_[i];
T *w = par.usesPersistentWeight() ? this->w_persistent_[i] : weights[i];
T *w_apparent = weights[i];
T *min_bound = this->w_min_bound_[i];
@@ -528,8 +557,6 @@ void JARTv1bStaticRPUDevice<T>::doSparseUpdate(
T *Ndiscmin = device_specific_Ndiscmin[i];
T *ldet = device_specific_ldet[i];
T *A = device_specific_A[i];

T write_noise_std = par.getScaledWriteNoise();
// if (par.ls_mult_noise) {
// PULSED_UPDATE_W_LOOP(update_once_mult(
// w[j], w_apparent[j], sign, scale_down[j], scale_up[j], slope_down[j],
@@ -541,9 +568,9 @@ void JARTv1bStaticRPUDevice<T>::doSparseUpdate(
// slope_up[j], min_bound[j], max_bound[j], par.dw_min_std,
// write_noise_std, rng););
// }
PULSED_UPDATE_W_LOOP(update_once(par.read_voltage, par.pulse_voltage, par.pulse_length, par.base_time_step,
PULSED_UPDATE_W_LOOP(update_once(par.read_voltage, par.pulse_voltage_SET, par.pulse_voltage_RESET, par.pulse_length, par.base_time_step,
par.alpha0, par.alpha1, par.alpha2, par.alpha3, par.beta0, par.beta1,
par.c0, par.c1, par.c3, par.d0, par.d1, par.d2, par.d3,
par.c0, par.c1, par.c2, par.c3, par.d0, par.d1, par.d2, par.d3,
par.f0, par.f1, par.f2, par.f3, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0,
par.e, par.kb, par.Arichardson, par.mdiel, par.h, par.zvo, par.eps_0,
par.T0, par.eps, par.epsphib, par.phiBn0, par.phin, par.un, par.Ndiscmin,
@@ -562,19 +589,14 @@ void JARTv1bStaticRPUDevice<T>::doDenseUpdate(T **weights, int *coincidences, RN

const auto &par = getPar();

T *scale_down = this->w_scale_down_[0];
T *scale_up = this->w_scale_up_[0];
T *slope_down = w_slope_down_[0];
T *slope_up = w_slope_up_[0];
T *w = par.usesPersistentWeight() ? this->w_persistent_[0] : weights[0];
T *w_apparent = weights[0];
T *min_bound = this->w_min_bound_[0];
T *max_bound = this->w_max_bound_[0];
T write_noise_std = par.getScaledWriteNoise();
T *Ndiscmax = device_specific_Ndiscmax[i];
T *Ndiscmin = device_specific_Ndiscmin[i];
T *ldet = device_specific_ldet[i];
T *A = device_specific_A[i];
T *Ndiscmax = device_specific_Ndiscmax[0];
T *Ndiscmin = device_specific_Ndiscmin[0];
T *ldet = device_specific_ldet[0];
T *A = device_specific_A[0];

// if (par.ls_mult_noise) {
// PULSED_UPDATE_W_LOOP_DENSE(update_once_mult(
@@ -587,9 +609,9 @@ void JARTv1bStaticRPUDevice<T>::doDenseUpdate(T **weights, int *coincidences, RN
// slope_down[j], slope_up[j], min_bound[j], max_bound[j],
// par.dw_min_std, write_noise_std, rng););
// }
PULSED_UPDATE_W_LOOP(update_once(par.read_voltage, par.pulse_voltage, par.pulse_length, par.base_time_step,
PULSED_UPDATE_W_LOOP_DENSE(update_once(par.read_voltage, par.pulse_voltage_SET, par.pulse_voltage_RESET, par.pulse_length, par.base_time_step,
par.alpha0, par.alpha1, par.alpha2, par.alpha3, par.beta0, par.beta1,
par.c0, par.c1, par.c3, par.d0, par.d1, par.d2, par.d3,
par.c0, par.c1, par.c2, par.c3, par.d0, par.d1, par.d2, par.d3,
par.f0, par.f1, par.f2, par.f3, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0,
par.e, par.kb, par.Arichardson, par.mdiel, par.h, par.zvo, par.eps_0,
par.T0, par.eps, par.epsphib, par.phiBn0, par.phin, par.un, par.Ndiscmin,
@@ -672,6 +694,7 @@ template <typename T>
void JARTv1bStaticRPUDevice<T>::resetCols(
T **weights, int start_col, int n_col, T reset_prob, RealWorldRNG<T> &rng) {

const auto &par = getPar();
if (getPar().usesPersistentWeight()) {
T reset_std = getPar().reset_std;
for (int j = 0; j < this->x_size_; ++j) {
@@ -685,9 +708,8 @@ void JARTv1bStaticRPUDevice<T>::resetCols(
w_reset_bias_[i][j] + (reset_std > 0 ? reset_std * rng.sampleGauss() : (T)0.0);
weights[i][j] = MIN(weights[i][j], w_max_bound_[i][j]);
weights[i][j] = MAX(weights[i][j], w_min_bound_[i][j]);
w_persistent_[i][j] = invert_positive_current(
(((weights[i][j]-w_min_bound_[i][j])/(w_max_bound_[i][j]-w_min_bound_[i][j]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage,
par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
T current = (((weights[i][j]-w_min_bound_[i][j])/(w_max_bound_[i][j]-w_min_bound_[i][j]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage;
w_persistent_[i][j] = invert_positive_current(current,par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
}
}
}
@@ -716,6 +738,7 @@ template <typename T>
void JARTv1bStaticRPUDevice<T>::resetAtIndices(
T **weights, std::vector<int> x_major_indices, RealWorldRNG<T> &rng) {

const auto &par = getPar();
if (getPar().usesPersistentWeight()) {
T reset_std = getPar().reset_std;

@@ -726,9 +749,8 @@ void JARTv1bStaticRPUDevice<T>::resetAtIndices(
weights[i][j] = w_reset_bias_[i][j] + (reset_std > 0 ? reset_std * rng.sampleGauss() : (T)0.0);
weights[i][j] = MIN(weights[i][j], w_max_bound_[i][j]);
weights[i][j] = MAX(weights[i][j], w_min_bound_[i][j]);
w_persistent_[i][j] = invert_positive_current(
(((weights[i][j]-w_min_bound_[i][j])/(w_max_bound_[i][j]-w_min_bound_[i][j]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage,
par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
T current = (((weights[i][j]-w_min_bound_[i][j])/(w_max_bound_[i][j]-w_min_bound_[i][j]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage;
w_persistent_[i][j] = invert_positive_current(current,par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);

}
}
@@ -764,9 +786,8 @@ template <typename T> bool JARTv1bStaticRPUDevice<T>::onSetWeights(T **weights)
if (getPar().usesPersistentWeight()) {
PRAGMA_SIMD
for (int i = 0; i < this->size_; i++) {
w_persistent_[0][i] = invert_positive_current(
(((w[i]-min_bound[i])/(max_bound[i]-min_bound[i]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage,
par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
T current = (((w[i]-min_bound[i])/(max_bound[i]-min_bound[i]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage;
w_persistent_[0][i] = invert_positive_current(current,par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
}
// applyUpdateWriteNoise(weights);
return true; // modified device thus true
@@ -783,9 +804,8 @@ template <typename T> void JARTv1bStaticRPUDevice<T>::applyUpdateWriteNoise(T **
T *min_bound = &(w_min_bound_[0][0]);

for (int i = 0; i < this->size_; i++) {
w_persistent_[0][i] = invert_positive_current(
(((w[i]-min_bound[i])/(max_bound[i]-min_bound[i]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage,
par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
T current = (((w[i]-min_bound[i])/(max_bound[i]-min_bound[i]))*(par.conductance_max-par.conductance_min)+par.conductance_min)*par.read_voltage;
w_persistent_[0][i] = invert_positive_current(current,par.read_voltage, par.g0, par.g1, par.h0, par.h1, par.h2, par.h3, par.j_0, par.k0, par.Ndiscmin);
}
}

23 changes: 13 additions & 10 deletions src/rpucuda/rpu_JART_v1b_static_device.h
Original file line number Diff line number Diff line change
@@ -166,8 +166,7 @@ BUILD_PULSED_DEVICE_META_PARAMETER(
,
/*Add*/
// bool implementsWriteNoise() const override { return true; };);
bool implementsWriteNoise() const override { return true; };
bool usesPersistentWeight() const override { return true; };);
bool implementsWriteNoise() const override { return true; };);
// Use hidden weight w as Ndisk, and the write noised weight w_apprent as true w mapped from conductance

template <typename T> class JARTv1bStaticRPUDevice : public PulsedRPUDevice<T> {
@@ -250,14 +249,14 @@ template <typename T> class JARTv1bStaticRPUDevice : public PulsedRPUDevice<T> {
device_specific_A[0][i] = data_ptrs[n_prev + 3][i];
}

T *w = out_weights[0];
T *max_bound = &(w_max_bound_[0][0]);
T *min_bound = &(w_min_bound_[0][0]);
PRAGMA_SIMD
for (int i = 0; i < this->size_; ++i) {
w[i] = MIN(w[i], max_bound[i]);
w[i] = MAX(w[i], min_bound[i]);
}
// T *w = out_weights[0];
// T *max_bound = &(w_max_bound_[0][0]);
// T *min_bound = &(w_min_bound_[0][0]);
// PRAGMA_SIMD
// for (int i = 0; i < this->size_; ++i) {
// w[i] = MIN(w[i], max_bound[i]);
// w[i] = MAX(w[i], min_bound[i]);
// }

// if (getPar().usesPersistentWeight()) {
// PRAGMA_SIMD
@@ -276,6 +275,8 @@ template <typename T> class JARTv1bStaticRPUDevice : public PulsedRPUDevice<T> {
);


void printDP(int x_count, int d_count) const override;

void decayWeights(T **weights, bool bias_no_decay) override;
void decayWeights(T **weights, T alpha, bool bias_no_decay) override;
void driftWeights(T **weights, T time_since_last_call, RNG<T> &rng) override;
@@ -293,6 +294,8 @@ template <typename T> class JARTv1bStaticRPUDevice : public PulsedRPUDevice<T> {
override;
void doDenseUpdate(T **weights, int *coincidences, RNG<T> *rng) override;

bool usesPersistentWeight() const override { return true; };

private:
T **device_specific_Ndiscmax = nullptr;
T **device_specific_Ndiscmin = nullptr;
2 changes: 1 addition & 1 deletion src/rpucuda/rpu_simple_device.h
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ enum DeviceUpdateType {
MixedPrec,
PowStep,
BufferedTransfer,
JARTv1bstatic
JARTv1bStatic
};

// inherit from Simple