Skip to content

Commit

Permalink
fix: make torch.load safer with weights_only=True everywhere possible
Browse files Browse the repository at this point in the history
Our prepared datasets pickle a lot of our code, and so these are not just
tensors and basic data types, but most of our use of `torch.load()` actually
works properly with `weights_only=True`, so let's use that everywhere possible
and be explicit in the one function where it's not possible.

Partly fixes #621
  • Loading branch information
joanise committed Jan 16, 2025
1 parent 29b0d79 commit fd98358
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 28 deletions.
4 changes: 3 additions & 1 deletion everyvoice/base_cli/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def load_checkpoint(model_path: Path, minimal=True) -> Dict[str, Any]:
"""
import torch

checkpoint = torch.load(str(model_path), map_location=torch.device("cpu"))
checkpoint = torch.load(
str(model_path), map_location=torch.device("cpu"), weights_only=True
)

if minimal:
# Some clean up of useless stuff.
Expand Down
4 changes: 3 additions & 1 deletion everyvoice/base_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def train_base_command(
# removes all paths for checkpoint portability. However, some paths, like "vocoder_path"
# should be still accessible when training is resumed.
new_config_with_paths = model_obj.config.model_dump(mode="json")
old_ckpt = torch.load(last_ckpt, map_location=torch.device("cpu"))
old_ckpt = torch.load(
last_ckpt, map_location=torch.device("cpu"), weights_only=True
)
old_ckpt["hyper_parameters"]["config"] = new_config_with_paths
# TODO: check if we need to do the same thing with stats and any thing else registered on the model
with tempfile.NamedTemporaryFile() as tmp:
Expand Down
11 changes: 8 additions & 3 deletions everyvoice/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def __init__(

def setup(self, stage: Optional[str] = None):
# load it back here
# Here we consider it safe to use torch.load() with weights_only=False
# because the dataset files are prepared and saved by this software, and
# not shared.
# TODO: investigate the possibility of changing our prepared dataset
# formats not to need weights_only=False
if stage == "fit":
self.train_dataset = torch.load(self.train_path)
self.val_dataset = torch.load(self.val_path)
self.train_dataset = torch.load(self.train_path, weights_only=False)
self.val_dataset = torch.load(self.val_path, weights_only=False)
if stage == "predict":
self.predict_dataset = torch.load(self.predict_path)
self.predict_dataset = torch.load(self.predict_path, weights_only=False)

def train_dataloader(self):
sampler = (
Expand Down
4 changes: 3 additions & 1 deletion everyvoice/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def create_demo_app(

require_ffmpeg()
device = get_device_from_accelerator(accelerator)
vocoder_ckpt = torch.load(spec_to_wav_model_path, map_location=device)
vocoder_ckpt = torch.load(
spec_to_wav_model_path, map_location=device, weights_only=True
)
# TODO: Should we also wrap this load_hifigan_from_checkpoint in case the checkpoint is not a Vocoder?
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(vocoder_ckpt, device)
model: FastSpeech2 = FastSpeech2.load_from_checkpoint(text_to_spec_model_path).to( # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion everyvoice/model/e2e/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, dataset, config: EveryVoiceConfig, use_segments=True):

def _load_file(self, bn, spk, lang, dir, fn):
return torch.load(
self.preprocessed_dir / dir / self.sep.join([bn, spk, lang, fn])
self.preprocessed_dir / dir / self.sep.join([bn, spk, lang, fn]),
weights_only=True,
)

def _load_audio(self, bn, spk, lang, dir, fn):
Expand Down
28 changes: 16 additions & 12 deletions everyvoice/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,12 @@ def compute_stats(
logger.info("Gathering energy values")
with tqdm_joblib_context(tqdm(desc="Gathering energy values")):
for energy_data in parallel(
delayed(torch.load)(path) for path in paths
delayed(torch.load)(path, weights_only=True) for path in paths
):
energy_scaler.data.append(energy_data)
else:
for path in tqdm(paths, desc="Gathering energy values"):
energy_data = torch.load(path)
energy_data = torch.load(path, weights_only=True)
energy_scaler.data.append(energy_data)
if pitch:
pitch_scaler = Scaler()
Expand All @@ -398,12 +398,12 @@ def compute_stats(
logger.info("Gathering pitch values")
with tqdm_joblib_context(tqdm(desc="Gathering pitch values")):
for pitch_data in parallel(
delayed(torch.load)(path) for path in paths
delayed(torch.load)(path, weights_only=True) for path in paths
):
pitch_scaler.data.append(pitch_data)
else:
for path in tqdm(paths, desc="Gathering pitch values"):
pitch_data = torch.load(path)
pitch_data = torch.load(path, weights_only=True)
pitch_scaler.data.append(pitch_data)
return energy_scaler if energy else energy, pitch_scaler if pitch else pitch

Expand All @@ -423,7 +423,7 @@ def normalize_stats(self, energy_scaler: Scaler, pitch_scaler: Scaler):
),
desc="Normalizing energy values",
):
energy = torch.load(path)
energy = torch.load(path, weights_only=True)
energy = energy_scaler.normalize(energy)
save_tensor(energy, path)
stats["energy"] = energy_stats
Expand All @@ -437,7 +437,7 @@ def normalize_stats(self, energy_scaler: Scaler, pitch_scaler: Scaler):
),
desc="Normalizing pitch values",
):
pitch = torch.load(path)
pitch = torch.load(path, weights_only=True)
pitch = pitch_scaler.normalize(pitch)
save_tensor(pitch, path)
stats["pitch"] = pitch_stats
Expand Down Expand Up @@ -587,15 +587,15 @@ def process_energy(self, item):
"spec",
f"spec-{self.input_sampling_rate}-{self.audio_config.spec_type}.pt",
)
spec = torch.load(spec_path)
spec = torch.load(spec_path, weights_only=True)
energy = self.extract_energy(spec)
if (
isinstance(self.config, FeaturePredictionConfig)
and self.config.model.variance_predictors.energy.level == "phone"
and not self.config.model.learn_alignment
):
dur_path = self.create_path(item, "duration", "duration.pt")
durs = torch.load(dur_path)
durs = torch.load(dur_path, weights_only=True)
energy = self.average_data_by_durations(energy, durs)
save_tensor(energy, energy_path)

Expand All @@ -614,7 +614,7 @@ def process_pitch(self, item):
and not self.config.model.learn_alignment
):
dur_path = self.create_path(item, "duration", "duration.pt")
durs = torch.load(dur_path)
durs = torch.load(dur_path, weights_only=True)
pitch = self.average_data_by_durations(pitch, durs)
save_tensor(pitch, pitch_path)

Expand Down Expand Up @@ -674,7 +674,7 @@ def process_attn_prior(self, item):
"spec",
f"spec-{self.input_sampling_rate}-{self.audio_config.spec_type}.pt",
)
input_spec = torch.load(input_spec_path)
input_spec = torch.load(input_spec_path, weights_only=True)
if process_phones:
phone_attn_prior = torch.from_numpy(
binomial_interpolator(input_spec.size(1), len(phone_tokens))
Expand Down Expand Up @@ -955,8 +955,12 @@ def check_data(
+ audio[audio <= audio_min].size(0)
- 2
)
pitch = torch.load(self.create_path(item, "pitch", "pitch.pt"))
energy = torch.load(self.create_path(item, "energy", "energy.pt"))
pitch = torch.load(
self.create_path(item, "pitch", "pitch.pt"), weights_only=True
)
energy = torch.load(
self.create_path(item, "energy", "energy.pt"), weights_only=True
)
audio_length_s = len(audio) / self.input_sampling_rate
data_point["total_clipped_samples"] = total_clipping
data_point["pitch_min"] = float(pitch.min())
Expand Down
10 changes: 5 additions & 5 deletions everyvoice/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def monkey_on_save_checkpoint(checkpoint):
)

# We don't want just serializable, but actually serialized!
ckpt = torch.load(tmpdir / "model.ckpt")
ckpt = torch.load(tmpdir / "model.ckpt", weights_only=True)
try:
json.dumps(ckpt["hyper_parameters"])
except (TypeError, OverflowError):
Expand Down Expand Up @@ -312,11 +312,11 @@ def test_wrong_model_type(self):
trainer.strategy.connect(model)
ckpt_fn = tmpdir_str + "/checkpoint.ckpt"
trainer.save_checkpoint(ckpt_fn)
m = torch.load(ckpt_fn)
m = torch.load(ckpt_fn, weights_only=True)
self.assertIn("model_info", m.keys())
m["model_info"]["name"] = "BAD_TYPE"
torch.save(m, ckpt_fn)
m = torch.load(ckpt_fn)
m = torch.load(ckpt_fn, weights_only=True)
self.assertIn("model_info", m.keys())
self.assertEqual(m["model_info"]["name"], "BAD_TYPE")
# self.assertEqual(m["model_info"]["version"], "1.0")
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_missing_model_version(self):
trainer.strategy.connect(model)
ckpt_fn = tmpdir_str + "/checkpoint.ckpt"
trainer.save_checkpoint(ckpt_fn)
m = torch.load(ckpt_fn)
m = torch.load(ckpt_fn, weights_only=True)
self.assertIn("model_info", m.keys())
self.assertEqual(m["model_info"]["name"], ModelType.__name__)
self.assertEqual(m["model_info"]["version"], CANARY_VERSION)
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_newer_model_version(self):
trainer.strategy.connect(model)
ckpt_fn = tmpdir_str + "/checkpoint.ckpt"
trainer.save_checkpoint(ckpt_fn)
m = torch.load(ckpt_fn)
m = torch.load(ckpt_fn, weights_only=True)
self.assertIn("model_info", m.keys())
self.assertEqual(m["model_info"]["name"], ModelType.__name__)
self.assertEqual(m["model_info"]["version"], NEWER_VERSION)
Expand Down
8 changes: 4 additions & 4 deletions everyvoice/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_pitch(self):
]
)
)
durs = torch.load(dur_path)
durs = torch.load(dur_path, weights_only=True)
feats = self.preprocessor.extract_spectral_features(
audio, self.preprocessor.input_spectral_transform
)
Expand Down Expand Up @@ -304,7 +304,7 @@ def test_duration(self):
]
)
)
durs = torch.load(dur_path)
durs = torch.load(dur_path, weights_only=True)
feats = self.preprocessor.extract_spectral_features(
audio, self.preprocessor.input_spectral_transform
)
Expand Down Expand Up @@ -341,7 +341,7 @@ def test_energy(self):
]
)
)
durs = torch.load(dur_path)
durs = torch.load(dur_path, weights_only=True)
# ming024_energy = np.load(
# self.data_dir
# / "ming024"
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_text_processing(self):
if "phone_tokens" in x
]
phonological_features = [
torch.load(f)
torch.load(f, weights_only=True)
for f in sorted(
list((output_filelist.parent / "pfs").glob("*.pt"))
)
Expand Down

0 comments on commit fd98358

Please sign in to comment.