Skip to content

Commit

Permalink
generation demo: smoother support for silent vids
Browse files Browse the repository at this point in the history
  • Loading branch information
v-iashin committed Oct 22, 2021
1 parent 2a3ecba commit c9ae30a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
8 changes: 8 additions & 0 deletions feature_extraction/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def which_ffprobe() -> str:
return ffprobe_path


def check_video_for_audio(path):
assert which_ffprobe() != '', 'Is ffmpeg installed? Check if the conda environment is activated.'
cmd = f'{which_ffprobe()} -loglevel error -show_entries stream=codec_type -of default=nw=1 {path}'
result = subprocess.run(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
result = result.stdout.decode('utf-8')
print(result)
return 'codec_type=audio' in result

def get_duration(path):
assert which_ffprobe() != '', 'Is ffmpeg installed? Check if the conda environment is activated.'
cmd = f'{which_ffprobe()} -hide_banner -loglevel panic' \
Expand Down
69 changes: 36 additions & 33 deletions generation_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@
"from torchvision.utils import make_grid\n",
"from tqdm import tqdm\n",
"\n",
"from feature_extraction.demo_utils import (ExtractResNet50,\n",
"from feature_extraction.demo_utils import (ExtractResNet50, check_video_for_audio,\n",
" extract_melspectrogram, load_model,\n",
" show_grid, trim_video)\n",
"from sample_visualization import (all_attention_to_st, get_class_preditions,\n",
" last_attention_to_st, spec_to_audio_to_st,\n",
" tensor_to_plt)\n",
"from specvqgan.data.vggsound import CropImage\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
Expand Down Expand Up @@ -172,12 +172,12 @@
"source": [
"## Codebook Reconstruction of the Input Spectrogram\n",
"\n",
"This cell can be ignored if the video doesn't have any audio.\n",
"\n",
"If the video has an audio, it will extract mel-spectrogram from the \n",
"audio track using the same pre-processing pipeline as in our experiments.\n",
"This is also useful if you plan to prime the sampling with a half of \n",
"the ground truth codes.\n",
"\n",
"This cell can be ignored if the video doesn't have any audio."
"the ground truth codes."
]
},
{
Expand Down Expand Up @@ -222,31 +222,34 @@
}
],
"source": [
"# Extract Spectrogram\n",
"audio_fps = 22050\n",
"spectrogram = extract_melspectrogram(video_path, audio_fps)\n",
"spectrogram = {'input': spectrogram}\n",
"# [80, 860] -> [80, 848]\n",
"random_crop = False\n",
"crop_img_fn = CropImage([config.data.params.mel_num, config.data.params.spec_crop_len], random_crop)\n",
"spectrogram = crop_img_fn(spectrogram)\n",
"\n",
"# Prepare input\n",
"batch = default_collate([spectrogram])\n",
"batch['image'] = batch['input'].to(device)\n",
"x = sampler.get_input(sampler.first_stage_key, batch)\n",
"\n",
"# Encode and Decode the Spectrogram\n",
"with torch.no_grad():\n",
" quant_z, z_indices = sampler.encode_to_z(x)\n",
" xrec = sampler.first_stage_model.decode(quant_z)\n",
"\n",
"print('Original Spectrogram:')\n",
"display.display(tensor_to_plt(x, flip_dims=(2,)))\n",
"print('Reconstructed Spectrogram:')\n",
"display.display(tensor_to_plt(xrec, flip_dims=(2,)))\n",
"plt.close()\n",
"plt.close()"
"if check_video_for_audio(video_path):\n",
" # Extract Spectrogram\n",
" audio_fps = 22050\n",
" spectrogram = extract_melspectrogram(video_path, audio_fps)\n",
" spectrogram = {'input': spectrogram}\n",
" # [80, 860] -> [80, 848]\n",
" random_crop = False\n",
" crop_img_fn = CropImage([config.data.params.mel_num, config.data.params.spec_crop_len], random_crop)\n",
" spectrogram = crop_img_fn(spectrogram)\n",
"\n",
" # Prepare input\n",
" batch = default_collate([spectrogram])\n",
" batch['image'] = batch['input'].to(device)\n",
" x = sampler.get_input(sampler.first_stage_key, batch)\n",
"\n",
" # Encode and Decode the Spectrogram\n",
" with torch.no_grad():\n",
" quant_z, z_indices = sampler.encode_to_z(x)\n",
" xrec = sampler.first_stage_model.decode(quant_z)\n",
"\n",
" print('Original Spectrogram:')\n",
" display.display(tensor_to_plt(x, flip_dims=(2,)))\n",
" print('Reconstructed Spectrogram:')\n",
" display.display(tensor_to_plt(xrec, flip_dims=(2,)))\n",
" plt.close()\n",
" plt.close()\n",
"else:\n",
" print('Could not find an audio track in the video file...')"
]
},
{
Expand Down Expand Up @@ -355,13 +358,13 @@
"\n",
" B, D, hr_h, hr_w = sampling_shape = (1, 256, 5, 53*W_scale)\n",
"\n",
" z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device)\n",
"\n",
" if mode == 'full':\n",
" start_step = 0\n",
" else:\n",
" start_step = (patch_size_j // 2) * patch_size_i\n",
"\n",
" z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device)\n",
" z_pred_indices[:, :start_step] = z_indices[:, :start_step]\n",
" z_pred_indices[:, :start_step] = z_indices[:, :start_step]\n",
"\n",
" pbar = tqdm(range(start_step, hr_w * hr_h), desc='Sampling Codebook Indices')\n",
" for step in pbar:\n",
Expand Down

0 comments on commit c9ae30a

Please sign in to comment.