Skip to content

Commit

Permalink
[Docs] add docs for search methods
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Aug 22, 2023
1 parent 0015d47 commit 1b6f027
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
4 changes: 4 additions & 0 deletions rl4co/models/zoo/eas/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ class EASLayerNet(nn.Module):
"""Instantiate weights and biases for the added layer.
The layer is defined as: h = relu(emb * W1 + b1); out = h * W2 + b2.
Wrapping in `nn.Parameter` makes the parameters trainable and sets gradient to True.
Args:
num_instances: Number of instances in the dataset
emb_dim: Dimension of the embedding
"""

def __init__(self, num_instances: int, emb_dim: int):
Expand Down
30 changes: 24 additions & 6 deletions rl4co/models/zoo/eas/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,30 @@


class EAS(SearchBase):
"""Efficient Active Search (EAS) algorithm.
Reference: TODO
"""Efficient Active Search for Neural Combination Optimization from Hottung et al. (2022).
Fine-tunes a subset of parameters (such as node embeddings or newly added layers) thus avoiding
expensive re-encoding of the problem.
Reference: https://openreview.net/pdf?id=nO5caZwFwYu
Args:
TODO
env: RL4CO environment to be solved
policy: policy network
dataset: dataset to be used for training
use_eas_embedding: whether to use EAS embedding (EASEmb)
use_eas_layer: whether to use EAS layer (EASLay)
eas_emb_cache_keys: keys to cache in the embedding
eas_lambda: lambda parameter for IL loss
batch_size: batch size for training
max_iters: maximum number of iterations
augment_size: number of augmentations per state
augment_dihedral: whether to augment with dihedral rotations
parallel_runs: number of parallel runs
baseline: REINFORCE baseline type (multistart, symmetric, full)
max_runtime: maximum runtime in seconds
save_path: path to save solution checkpoints
optimizer: optimizer to use for training
optimizer_kwargs: keyword arguments for optimizer
verbose: whether to print progress for each iteration
"""

def __init__(
Expand Down Expand Up @@ -285,7 +303,7 @@ def on_train_epoch_end(self) -> None:


class EASEmb(EAS):
"""EAS adapting embedding"""
"""EAS with embedding adaptation"""

def __init__(
self,
Expand All @@ -304,7 +322,7 @@ def __init__(


class EASLay(EAS):
"""EAS adapting layer"""
"""EAS with layer adaptation"""

def __init__(
self,
Expand Down

0 comments on commit 1b6f027

Please sign in to comment.