deepmd.pt.model.descriptor.sezm_nn.ffn#

Equivariant feed-forward layers for SeZM.

This module defines the full SO(3)-equivariant feed-forward network used inside SeZM interaction blocks and the descriptor output head.

Classes#

EquivariantFFN

Full equivariant FFN operating on all spherical harmonic degrees.

Module Contents#

class deepmd.pt.model.descriptor.sezm_nn.ffn.EquivariantFFN(*, lmax: int, channels: int, hidden_channels: int, kmax: int = 1, grid_mlp: bool = False, grid_branch: int = 0, dtype: torch.dtype, s2_activation: bool = False, ffn_so3_grid: bool = False, lebedev_quadrature: bool = False, activation_function: str = 'silu', glu_activation: bool = True, mlp_bias: bool = False, trainable: bool, seed: int | list[int] | None = None)[source]#

Bases: torch.nn.Module

Full equivariant FFN operating on all spherical harmonic degrees.

Default structure (glu_activation=False):

SO3 linear (in -> hidden) -> GatedActivation -> SO3 linear (hidden -> out)

Default structure (glu_activation=True):

SO3 linear (in -> 2*hidden) -> split -> GatedActivation(val, gate) -> SO3 linear (hidden -> out)

Optional grid-FFN structure (s2_activation=True or ffn_so3_grid=True):

SO3 linear (in -> hidden) -> project packed SO(3) coefficients to the S2 or SO3 grid -> grid GLU, polynomial MLP, or scalar-routed attention on hidden features -> project grid features back to packed SO(3) coefficients -> add scalar LinearSwiGLU branch to l=0 -> SO3 linear (hidden -> out)

GatedActivation serves as the unified “activation” for equivariant networks, analogous to SiLU in standard MLPs, but respecting SO(3) equivariance: - l=0: Uses the specified activation function (or GLU variant when glu_activation=True) - l>0: sigmoid gate from l=0 scalar features

When glu_activation=True, the first linear outputs 2*hidden_channels, then splits into value and gate branches. This transforms activations like silu->swiglu, gelu->geglu. The split approach is more efficient than two separate linear layers.

Parameters:
lmax

Maximum degree.

channels

Number of channels per (l, m) coefficient.

hidden_channels

Hidden dimension for the FFN.

kmax

Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid.

grid_mlp

If True, select the polynomial grid MLP operation when the block-internal FFN grid path is enabled.

grid_branch

Number of scalar-routed polynomial product branches used when the block-internal FFN grid path is enabled. 0 disables this branch mixer. Positive values take precedence over grid_mlp.

dtype

Parameter dtype.

s2_activation

If True, enable the S2 FFN grid path.

ffn_so3_grid

If True, enable the SO3 Wigner-D FFN grid path.

lebedev_quadrature

If True, use Lebedev quadrature for the S2 projector in this FFN.

activation_function

Activation function for l=0 components (e.g., “silu”, “tanh”, “gelu”).

glu_activation

If True, use GLU-style gating (e.g., silu -> swiglu, gelu -> geglu).

mlp_bias

Whether to use bias in SO3Linear (l=0 bias), GatedActivation (gate linear bias), and the scalar point-wise projection when grid_mlp=True.

trainable

Whether parameters are trainable.

seed

Random seed for weight initialization.

lmax[source]#
channels[source]#
hidden_channels[source]#
kmax = 1[source]#
use_grid_mlp = False[source]#
grid_branch = 0[source]#
use_grid_branch = False[source]#
s2_activation = False[source]#
ffn_so3_grid = False[source]#
lebedev_quadrature = False[source]#
s2_grid_method = 'e3nn'[source]#
s2_grid_resolution[source]#
activation_function = 'silu'[source]#
glu_activation = True[source]#
mlp_bias = False[source]#
dtype[source]#
compute_dtype[source]#
device[source]#
precision[source]#
grid_n_frames = 1[source]#
use_grid_net = False[source]#
so3_linear_1[source]#
so3_linear_2[source]#
forward(x: torch.Tensor) torch.Tensor[source]#
Parameters:
x

Input with shape (N, D, F, C) where D=(lmax+1)^2.

Returns:
torch.Tensor

Output with shape (N, D, F, C).

serialize() dict[str, Any][source]#
classmethod deserialize(data: dict[str, Any]) EquivariantFFN[source]#