Skip to content

Commit

Permalink
fix: inpainting noise, composer initial span, readme
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jul 27, 2022
1 parent a7091a5 commit b85b825
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ t = torch.tensor([40, 10, 20])
y = unet(x, t) # [3, 1, 32768], 3 audio tracks of ~1.6s sampled at 20050 Hz
```

### Elucidated Diffusion
### Diffusion

```python
from audio_diffusion_pytorch.diffusion.elucidated import Diffusion, DiffusionSampler, LogNormalSampler, KerrasSchedule
Expand All @@ -72,42 +72,51 @@ sampler = DiffusionSampler(
s_churn=40,
s_noise=1.003
)
# Generate a sample starting from the provided noise
y = sampler(x = torch.randn(1,1,2 ** 15))

```


### Gaussian Diffusion (Old)
Note that this requires `use_learned_time_embedding=False` on the `UNet1d`.
### Diffusion Inpainting and Infinite Generation

```py
from audio_diffusion_pytorch.diffusion.ddpm import Diffusion, DiffusionSampler
# Build diffusion to train denoise function
diffusion = Diffusion(
denoise_fn=unet,
num_timesteps=50,
loss_fn='l1',
loss_weight_gamma=0.5,
loss_weight_k=1
from audio_diffusion_pytorch.diffusion.elucidated import DiffusionInpainter, KerrasSchedule, SpanBySpanComposer

inpainter = DiffusionInpainter(
diffusion,
num_steps=2,
num_resamples=5,
sigma_schedule=KerrasSchedule(
sigma_min=0.002,
sigma_max=1
),
s_tmin=0,
s_tmax=10,
s_churn=40,
s_noise=1.003
)

x = torch.randn(3, 1, 2 ** 15)
loss = diffusion(x)
loss.backwards() # Do this many times
inpaint = torch.randn(1,1,2 ** 15) # This should not be random but your start track, e.g. one sampled with DiffusionSampler
inpaint_mask = torch.randint(0,2, (1,1,2 ** 15), dtype=torch.bool) # Set to `True` the parts you want to keep
y = inpainter(inpaint = inpaint, inpaint_mask = inpaint_mask) # [1, 1, 32768]


# Sample from diffusion model by converting normal tensor to audio
sampler = DiffusionSampler(diffusion)
y = sampler(x = torch.randn(1, 1, 2 ** 15)) # [1, 1, 32768]
# Infinite generation using SpanBySpanComposer
composer = SpanBySpanComposer(inpainter, num_spans=4) # Generates 4 additional spans
y_long = composer(y, keep_start=True) # [1, 1, 98304]

```


## Experiments


| Report | Snapshot | Description |
| --- | --- | --- |
| [Alpha](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-UNet-Alpha---VmlldzoyMjk3MzIz?accessToken=y0l3igdvnm4ogn4d3ph3b0i8twwcf7meufbviwt15f0qtasyn1i14hg340bkk1te) | [6bd9279f19](https://github.com/archinetai/audio-diffusion-pytorch/tree/6bd9279f192fc0c11eb8a21cd919d9c41181bf35) | Initial tests on LJSpeech dataset with new architecture and basic DDPM diffusion model. |
| [Bravo](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Bravo---VmlldzoyMzE4NjIx?accessToken=qt2w1jeqch9l5v3ffjns99p69jsmexk849dszyiennfbivgg396378u6ken2fm2d) | [a05f30aa94](https://github.com/archinetai/audio-diffusion-pytorch/tree/a05f30aa94e07600038d36cfb96f8492ef735a99) | Elucidated diffusion, improved architecture with patching, longer duration, initial good (unsupervised) results on LJSpeech.
| Charlie | (current) | . |
| [Charlie](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Charlie---VmlldzoyMzYyNDA1?accessToken=71gmurcwndv5e2abqrjnlh3n74j5555j3tycpd7h40tnv8fvb17k5pjkb57j9xxa) | (current) | Train on music with YoutubeDataset, larger patch tests for longer tracks, inpainting tests, initial test with infinite generation using StepByStepComposer. |


## Appreciation
Expand Down
8 changes: 5 additions & 3 deletions audio_diffusion_pytorch/diffusion/elucidated.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def step(
epsilon = self.s_noise * torch.randn_like(x)
noise = sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
# Add increased noise to mixed value
x_hat = (x * ~inpaint_mask + inpaint * inpaint_mask) * noise
x_hat = x * ~inpaint_mask + inpaint * inpaint_mask + noise
# Evaluate ∂x/∂sigma at sigma_hat
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
# Take euler step from sigma_hat to sigma_next
Expand Down Expand Up @@ -321,8 +321,10 @@ def __init__(
def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
half_length = start.shape[2] // 2

spans = [start[:, :, :half_length]] if keep_start else []
inpaint = start
spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
# Inpaint second half from first half
inpaint = torch.zeros_like(start)
inpaint[:, :, :half_length] = start[:, :, half_length:]
inpaint_mask = sequential_mask(like=start, start=half_length)

for i in range(self.num_spans):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.6",
version="0.0.7",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit b85b825

Please sign in to comment.