ContextualConv – PyTorch convolutional layers with global context conditioning: per‑channel bias, scale, or modulated FiLM-style scaling.
from contextual_conv import ContextualConv2d
import torch
# FiLM‑style (scale + bias)
conv = ContextualConv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
padding=1,
context_dim=10, # size of global vector c
h_dim=64, # optional MLP hidden dim
use_scale=True, # γ(c)
use_bias=True, # β(c)
scale_mode="film" # or "scale"
)
x = torch.randn(8, 16, 32, 32) # feature map
c = torch.randn(8, 10) # context vector
out = conv(x, c) # shape: (8, 32, 32, 32)
use_scale |
use_bias |
scale_mode |
Behaviour |
---|---|---|---|
False |
True |
– | Contextual bias only |
True |
False |
"scale" |
Scale only: out * γ |
True |
True |
"film" |
FiLM: out * (1 + γ) + β |
True |
True |
"scale" |
Scale + shift: out * γ + β |
False |
False |
– | Plain convolution (no modulation) |
If context_dim
is provided, at least one of use_scale
or use_bias
must be True
.
- ⚙️ Drop‑in replacement for
nn.Conv1d
/nn.Conv2d
→ Same arguments + optional context options. - 🧠 Global vector conditioning via learnable γ(c) and/or β(c)
- 🔀 Modulation modes:
scale_mode="film"
:out * (1 + γ)
scale_mode="scale"
:out * γ
- 🪶 Lightweight – one small MLP (or single
Linear
) per layer - 🧑🔬 FiLM ready – reproduce Feature‑wise Linear Modulation with two lines
- 🧩 Modular – combine with any architecture, works on CPU / GPU
- 📤 Infer context vectors from unmodulated outputs with
.infer_context()
- ✅ Unit‑tested and documented
pip install contextual-conv # version 0.6.3 on PyPI
Or install from source:
git clone https://github.com/abbassix/ContextualConv.git
cd ContextualConv
pip install -e .[dev]
- Shape:
(B, context_dim)
(one global descriptor per sample – class label embedding, latent code, etc.) - Processed by a
ContextProcessor
:Linear(context_dim, out_dim)
(bias‑only / scale‑only)- Small MLP if
h_dim
is set.
- Output dims:
out_channels
→ bias or scale2 × out_channels
→ FiLM (scale + bias)
You can extract the context vector inferred from the output using:
context = conv.infer_context(x)
To also get the unmodulated output from the convolution layer:
context, raw_out = conv.infer_context(x, return_raw_output=True)
This is useful when you need both the input’s context and its original unmodulated features.
Run the full test suite with coverage:
pytest --cov=contextual_conv --cov-report=term-missing
Full API reference & tutorials: https://contextualconv.readthedocs.io
Bug reports, feature requests, and PRs are welcome! See CONTRIBUTING.md
.
GNU GPLv3 – see LICENSE
file for details.