From 685a6b0fa202f5383486fe20d82600e5c48bedd4 Mon Sep 17 00:00:00 2001
From: Roman Bolgaryn <roman.bolgaryn@nrel.gov>
Date: Thu, 16 Jan 2025 17:09:56 -0700
Subject: [PATCH] wip: SolverData, penalty factors

---
 src/PowerFlowData.jl         | 15 +++++++++++++++
 src/common.jl                | 11 +++++++++++
 src/newton_ac_powerflow.jl   |  7 ++++++-
 test/test_utils/legacy_pf.jl |  3 +++
 4 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/src/PowerFlowData.jl b/src/PowerFlowData.jl
index fd144b7b..021b99d8 100644
--- a/src/PowerFlowData.jl
+++ b/src/PowerFlowData.jl
@@ -1,3 +1,8 @@
+Base.@kwdef mutable struct SolverData
+    J::Union{SparseMatrixCSC{Float64, Int}, Nothing} = nothing
+    dSbus_dV_ref::Union{Vector{Float64}, Nothing} = nothing
+end
+
 abstract type PowerFlowContainer end
 
 """
@@ -103,6 +108,7 @@ struct PowerFlowData{
     aux_network_matrix::N
     neighbors::Vector{Set{Int}}
     converged::Vector{Bool}
+    solver_data::Vector{SolverData}
 end
 
 get_bus_lookup(pfd::PowerFlowData) = pfd.bus_lookup
@@ -131,6 +137,7 @@ get_aux_network_matrix(pfd::PowerFlowData) = pfd.aux_network_matrix
 get_neighbor(pfd::PowerFlowData) = pfd.neighbors
 supports_multi_period(::PowerFlowData) = true
 get_converged(pfd::PowerFlowData) = pfd.converged
+get_solver_data(pfd::PowerFlowData) = pfd.solver_data
 
 function clear_injection_data!(pfd::PowerFlowData)
     pfd.bus_activepower_injection .= 0.0
@@ -226,6 +233,7 @@ function PowerFlowData(
     neighbors = _calculate_neighbors(power_network_matrix)
     aux_network_matrix = nothing
     converged = fill(false, time_steps)
+    solver_data = [SolverData() for _ in 1:time_steps]
 
     return make_powerflowdata(
         sys,
@@ -242,6 +250,7 @@ function PowerFlowData(
         valid_ix,
         neighbors,
         converged,
+        solver_data,
     )
 end
 
@@ -299,6 +308,7 @@ function PowerFlowData(
     )
     valid_ix = setdiff(1:n_buses, aux_network_matrix.ref_bus_positions)
     converged = fill(false, time_steps)
+    solver_data = [SolverData() for _ in 1:time_steps]
     return make_dc_powerflowdata(
         sys,
         time_steps,
@@ -312,6 +322,7 @@ function PowerFlowData(
         temp_bus_map,
         valid_ix,
         converged,
+        solver_data,
     )
 end
 
@@ -370,6 +381,7 @@ function PowerFlowData(
     )
     valid_ix = setdiff(1:n_buses, aux_network_matrix.ref_bus_positions)
     converged = fill(false, time_steps)
+    solver_data = [SolverData() for _ in 1:time_steps]
     return make_dc_powerflowdata(
         sys,
         time_steps,
@@ -383,6 +395,7 @@ function PowerFlowData(
         temp_bus_map,
         valid_ix,
         converged,
+        solver_data,
     )
 end
 
@@ -440,6 +453,7 @@ function PowerFlowData(
     )
     valid_ix = setdiff(1:n_buses, aux_network_matrix.ref_bus_positions)
     converged = fill(false, time_steps)
+    solver_data = [SolverData() for _ in 1:time_steps]
     return make_dc_powerflowdata(
         sys,
         time_steps,
@@ -453,6 +467,7 @@ function PowerFlowData(
         temp_bus_map,
         valid_ix,
         converged,
+        solver_data,
     )
 end
 
diff --git a/src/common.jl b/src/common.jl
index a05dc7a3..a0bae85c 100644
--- a/src/common.jl
+++ b/src/common.jl
@@ -150,6 +150,7 @@ function make_dc_powerflowdata(
     temp_bus_map,
     valid_ix,
     converged,
+    solver_data,
 )
     branch_type = Vector{DataType}(undef, length(branch_lookup))
     for (ix, b) in enumerate(PNM.get_ac_branches(sys))
@@ -173,6 +174,7 @@ function make_dc_powerflowdata(
         valid_ix,
         neighbors,
         converged,
+        solver_data,
     )
 end
 
@@ -191,6 +193,7 @@ function make_powerflowdata(
     valid_ix,
     neighbors,
     converged,
+    solver_data,
 )
     bus_type = Vector{PSY.ACBusTypes}(undef, n_buses)
     bus_angles = zeros(Float64, n_buses)
@@ -285,5 +288,13 @@ function make_powerflowdata(
         aux_network_matrix,
         neighbors,
         converged,
+        solver_data,
     )
 end
+
+# work in progress - quick but not optimized function for POC
+function penalty_factors(solver_data::SolverData)
+    J_t = transpose(solver_data.J)
+    f = J_t \ solver_data.dSbus_dV_ref
+    return f
+end
diff --git a/src/newton_ac_powerflow.jl b/src/newton_ac_powerflow.jl
index 3b10675c..49d968e2 100644
--- a/src/newton_ac_powerflow.jl
+++ b/src/newton_ac_powerflow.jl
@@ -175,7 +175,7 @@ function solve_powerflow!(
 )
     pf = ACPowerFlow()  # todo: somehow store in data which PF to use (see issue #50)
 
-    sorted_time_steps = sort(collect(keys(data.timestep_map)))
+    sorted_time_steps = get(kwargs, :time_steps, sort(collect(keys(data.timestep_map))))
     # preallocate results
     ts_converged = fill(false, length(sorted_time_steps))
     ts_V = zeros(Complex{Float64}, length(data.bus_type[:, 1]), length(sorted_time_steps))
@@ -540,6 +540,8 @@ function _newton_powerflow(
     tol = get(kwargs, :tol, DEFAULT_NR_TOL)
     i = 0
 
+    solver_data = data.solver_data[time_step]
+
     Ybus = data.power_network_matrix.data
 
     # Find indices for each bus type
@@ -720,6 +722,9 @@ function _newton_powerflow(
         Sbus_result .*= NaN64
         @error("The powerflow solver with KLU did not converge after $i iterations")
     else
+        solver_data.J = J
+        solver_data.dSbus_dV_ref =
+            [vec(real.(dSbus_dVa[ref, :][:, pvpq])); vec(real.(dSbus_dVm[ref, :][:, pvpq]))]
         @info("The powerflow solver with KLU converged after $i iterations")
     end
 
diff --git a/test/test_utils/legacy_pf.jl b/test/test_utils/legacy_pf.jl
index 5f1c0a50..c7dabb4d 100644
--- a/test/test_utils/legacy_pf.jl
+++ b/test/test_utils/legacy_pf.jl
@@ -116,6 +116,9 @@ function _newton_powerflow(
         @error("The powerflow solver with KLU did not converge after $i iterations")
     else
         Sbus_result = V .* conj(Ybus * V)
+        solver_data.J = J
+        solver_data.dSbus_dV_ref =
+            [vec(real.(dSbus_dVa[ref, :][:, pvpq])); vec(real.(dSbus_dVm[ref, :][:, pvpq]))]
         @info("The powerflow solver with KLU converged after $i iterations")
     end
     return (converged, V, Sbus_result)