Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596940697
  • Loading branch information
Scenic Authors committed Jan 16, 2024
1 parent 7f3bde0 commit ddaae63
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions scenic/projects/pointcloud/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implementation of PCT model layers."""

from typing import Any, Optional, Dict
from typing import Any

import flax.linen as nn
from flax.linen.initializers import zeros
Expand Down Expand Up @@ -53,10 +53,10 @@ class SelfAttentionLayer(nn.Module):
"""Self Attention Layer."""
in_channels: Any
out_channels: Any
kernel_size: Optional[int] = 1
mask_function: Optional[str] = 'linear'
attention_type: Optional[str] = 'naive'
attention_fn_configs: Optional[Dict[Any, Any]] = None
kernel_size: int | None = 1
mask_function: str | None = 'linear'
attention_type: str | None = 'naive'
attention_fn_configs: dict[Any, Any] | None = None

@nn.compact
def __call__(self,
Expand Down Expand Up @@ -157,15 +157,16 @@ class PointCloudTransformerEncoder(nn.Module):
"""Point Cloud Transformer Encoder."""
in_dim: int
feature_dim: int
kernel_size: Optional[int] = 1
encoder_feature_dim: Optional[int] = 1024
num_attention_layers: Optional[int] = 4
num_heads: Optional[int] = 1
attention_fn_configs: Optional[Dict[Any, Any]] = None
use_attention_masking: Optional[bool] = False
use_knn_mask: Optional[bool] = False
nearest_neighbour_count: Optional[int] = 256
mask_function: Optional[str] = 'linear'
kernel_size: int | None = 1
encoder_feature_dim: int | None = 1024
num_attention_layers: int | None = 4
num_heads: int | None = 1
attention_fn_configs: dict[Any, Any] | None = None
use_attention_masking: bool | None = False
use_knn_mask: bool | None = False
nearest_neighbour_count: int | None = 256
mask_function: str | None = 'linear'
out_dim: int | None = None

@nn.compact
def __call__(
Expand Down Expand Up @@ -204,26 +205,34 @@ def __call__(
output = nn.Conv(
self.encoder_feature_dim,
kernel_size=(self.kernel_size, self.kernel_size),
use_bias=True)(
output)
use_bias=True)(output)
output = nn.BatchNorm(use_running_average=not train)(output, mask=mask)
output = nn.leaky_relu(output, negative_slope=0.2)

if self.out_dim is not None:
# conv-batchnorm-relu block
output = nn.Conv(
self.out_dim,
kernel_size=(self.kernel_size, self.kernel_size),
use_bias=True)(output)
output = nn.BatchNorm(use_running_average=not train)(output, mask=mask)
output = nn.leaky_relu(output, negative_slope=0.2)
return output


class PointCloudTransformerClassifier(nn.Module):
"""Point Cloud Transformer Classifier."""
in_dim: int
feature_dim: int
kernel_size: Optional[int] = 1
num_classes: Optional[int] = 40
dropout_rate: Optional[float] = 0.5
attention_type: Optional[str] = 'standard'
attention_fn_configs: Optional[Dict[Any, Any]] = None
use_attention_masking: Optional[bool] = False
use_knn_mask: Optional[bool] = False
nearest_neighbour_count: Optional[int] = 256
mask_function: Optional[str] = 'linear'
kernel_size: int | None = 1
num_classes: int | None = 40
dropout_rate: float | None = 0.5
attention_type: str | None = 'standard'
attention_fn_configs: dict[Any, Any] | None = None
use_attention_masking: bool | None = False
use_knn_mask: bool | None = False
nearest_neighbour_count: int | None = 256
mask_function: str | None = 'linear'

@nn.compact
def __call__(self, inputs, train: bool = False, debug: bool = False):
Expand Down

0 comments on commit ddaae63

Please sign in to comment.