Skip to content

Commit

Permalink
torch_helpers: Add layer_index argument to loadGRU and loadLSTM (#156)
Browse files Browse the repository at this point in the history
* torch_helpers: Add layer_index argument to loadGRU and loadLSTM

* Apply clang-format

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
jatinchowdhury18 and github-actions[bot] authored Nov 16, 2024
1 parent 5909c44 commit cd80257
Show file tree
Hide file tree
Showing 9 changed files with 2,054 additions and 2,018 deletions.
32 changes: 20 additions & 12 deletions RTNeural/torch_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,23 @@ namespace torch_helpers
}
}

/** Loads a GRU layer from a JSON object containing a PyTorch state_dict. */
/**
* Loads a GRU layer from a JSON object containing a PyTorch state_dict.
* If your PyTorch GRU has num_layers > 1, you must call this method once
* for each layer, with the correct layer object and layer_index.
*/
template <typename T, typename GRUType>
void loadGRU(const nlohmann::json& modelJson, const std::string& layerPrefix, GRUType& gru, bool hasBias = true)
void loadGRU(const nlohmann::json& modelJson, const std::string& layerPrefix, GRUType& gru, bool hasBias = true, int layer_index = 0)
{
// For the kernel and recurrent weights, PyTorch stores the weights similar to the
// Tensorflow format, but transposed, and with the "r" and "z" indexes swapped.

const std::vector<std::vector<T>> gru_ih_weights = modelJson.at(layerPrefix + "weight_ih_l0");
const std::vector<std::vector<T>> gru_ih_weights = modelJson.at(layerPrefix + "weight_ih_l" + std::to_string(layer_index));
auto wVals = detail::transpose(gru_ih_weights);
detail::swap_rz(wVals, gru.out_size);
gru.setWVals(wVals);

const std::vector<std::vector<T>> gru_hh_weights = modelJson.at(layerPrefix + "weight_hh_l0");
const std::vector<std::vector<T>> gru_hh_weights = modelJson.at(layerPrefix + "weight_hh_l" + std::to_string(layer_index));
auto uVals = detail::transpose(gru_hh_weights);
detail::swap_rz(uVals, gru.out_size);
gru.setUVals(uVals);
Expand All @@ -154,8 +158,8 @@ namespace torch_helpers

if(hasBias)
{
const std::vector<T> gru_ih_bias = modelJson.at(layerPrefix + "bias_ih_l0");
const std::vector<T> gru_hh_bias = modelJson.at(layerPrefix + "bias_hh_l0");
const std::vector<T> gru_ih_bias = modelJson.at(layerPrefix + "bias_ih_l" + std::to_string(layer_index));
const std::vector<T> gru_hh_bias = modelJson.at(layerPrefix + "bias_hh_l" + std::to_string(layer_index));
std::vector<std::vector<T>> gru_bias { gru_ih_bias, gru_hh_bias };
detail::swap_rz(gru_bias, gru.out_size);
gru.setBVals(gru_bias);
Expand All @@ -169,20 +173,24 @@ namespace torch_helpers
}
}

/** Loads a LSTM layer from a JSON object containing a PyTorch state_dict. */
/**
* Loads a LSTM layer from a JSON object containing a PyTorch state_dict.
* If your PyTorch LSTM has num_layers > 1, you must call this method once
* for each layer, with the correct layer object and layer_index.
*/
template <typename T, typename LSTMType>
void loadLSTM(const nlohmann::json& modelJson, const std::string& layerPrefix, LSTMType& lstm, bool hasBias = true)
void loadLSTM(const nlohmann::json& modelJson, const std::string& layerPrefix, LSTMType& lstm, bool hasBias = true, int layer_index = 0)
{
const std::vector<std::vector<T>> lstm_weights_ih = modelJson.at(layerPrefix + "weight_ih_l0");
const std::vector<std::vector<T>> lstm_weights_ih = modelJson.at(layerPrefix + "weight_ih_l" + std::to_string(layer_index));
lstm.setWVals(detail::transpose(lstm_weights_ih));

const std::vector<std::vector<T>> lstm_weights_hh = modelJson.at(layerPrefix + "weight_hh_l0");
const std::vector<std::vector<T>> lstm_weights_hh = modelJson.at(layerPrefix + "weight_hh_l" + std::to_string(layer_index));
lstm.setUVals(detail::transpose(lstm_weights_hh));

if(hasBias)
{
std::vector<T> lstm_bias_ih = modelJson.at(layerPrefix + "bias_ih_l0");
std::vector<T> lstm_bias_hh = modelJson.at(layerPrefix + "bias_hh_l0");
std::vector<T> lstm_bias_ih = modelJson.at(layerPrefix + "bias_ih_l" + std::to_string(layer_index));
std::vector<T> lstm_bias_hh = modelJson.at(layerPrefix + "bias_hh_l" + std::to_string(layer_index));
for(size_t i = 0; i < lstm_bias_ih.size(); ++i)
lstm_bias_hh[i] += lstm_bias_ih[i];
lstm.setBVals(lstm_bias_hh);
Expand Down
2 changes: 1 addition & 1 deletion models/gru_torch.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion models/lstm_torch.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions python/gru_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.gru = torch.nn.GRU(1, 8)
self.gru2 = torch.nn.GRU(8, 8, num_layers=3)
self.dense = torch.nn.Linear(8, 1)

def forward(self, torch_in):
x, _ = self.gru(torch_in)
x, _ = self.gru2(x)
return self.dense(x)

x = np.random.uniform(-1, 1, 1000)
Expand Down
2 changes: 2 additions & 0 deletions python/lstm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(1, 8)
self.lstm2 = torch.nn.LSTM(8, 8, num_layers=3)
self.dense = torch.nn.Linear(8, 1)

def forward(self, torch_in):
x, _ = self.lstm(torch_in)
x, _ = self.lstm2(x)
return self.dense(x)

x = np.random.uniform(-1, 1, 1000)
Expand Down
Loading

0 comments on commit cd80257

Please sign in to comment.