OpenEquivariance API¶
OpenEquivariance exposes two key classes: openequivariance.TensorProduct
, which replaces
o3.TensorProduct
from e3nn, and openequivariance.TensorProductConv
, which fuses
the CG tensor product with a subsequent graph convolution. Initializing either class triggers
JIT compilation of a custom kernel, which can take a few seconds.
Both classes require a configuration object specified
by openequivariance.TPProblem
, which has a constructor
almost identical to o3.TensorProduct
.
We recommend reading the e3nn documentation before
trying our code. OpenEquivariance cannot accelerate all tensor products; see
this page for a list of supported configurations.
- class openequivariance.TensorProduct(*args, **kwargs)¶
Drop-in replacement for
o3.TensorProduct
from e3nn. Supports forward, backward, and double-backward passes using JIT-compiled kernels. Initialization fails if:There are no visible GPUs.
The provided tensor product specification is unsupported.
- Parameters:
problem (TPProblem) – Specification of the tensor product.
- forward(x, y, W)¶
Computes \(W (x \otimes_{\\textrm{CG}} y)\), identical to
o3.TensorProduct.forward
.- Parameters:
x (torch.Tensor) – Tensor of shape
[batch_size, problem.irreps_in1.dim()]
, datatypeproblem.irrep_dtype
.y (torch.Tensor) – Tensor of shape
[batch_size, problem.irreps_in2.dim()]
, datatypeproblem.irrep_dtype
.W (torch.Tensor) –
Tensor of datatype
problem.weight_dtype
and shape[batch_size, problem.weight_numel]
ifproblem.shared_weights=False
[problem.weight_numel]
ifproblem.shared_weights=True
- Returns:
Tensor of shape
[batch_size, problem.irreps_out.dim()]
, datatypeproblem.irrep_dtype
.- Return type:
torch.Tensor
- class openequivariance.TensorProductConv(*args, **kwargs)¶
Given a symmetric, directed graph \(G = (V, E)\), inputs \(x_1...x_{|V|}\), \(y_1...y_{|E|}\), and weights \(W_1...W_{|E|}\), computes
\[\begin{split}z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\textrm{CG}} y_e)\end{split}\]where \((i, j, e) \in \mathcal{N}(i)\) indicates that node \(i\) is connected to node \(j\) via the edge indexed \(e\).
This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply a permutation to transpose the adjacency matrix.
- Parameters:
problem (TPProblem) – Specification of the tensor product.
deterministic (bool) – if
False
, uses atomics for the convolution. IfTrue
, uses a deterministic fixup-based algorithm. Default:False
.kahan (bool) – if
True
, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must setdeterministic=True
. Default:False
.
- forward(X, Y, W, rows, cols, sender_perm=None)¶
Computes the fused CG tensor product + convolution.
- Parameters:
X (torch.Tensor) – Tensor of shape
[|V|, problem.irreps_in1.dim()]
, datatypeproblem.irrep_dtype
.Y (torch.Tensor) – Tensor of shape
[|E|, problem.irreps_in1.dim()]
, datatypeproblem.irrep_dtype
.W (torch.Tensor) –
Tensor of datatype
problem.weight_dtype
and shape[|E|, problem.weight_numel]
ifproblem.shared_weights=False
[problem.weight_numel]
ifproblem.shared_weights=True
rows (torch.Tensor) – Tensor of shape
[|E|]
with row indices for each nonzero in the adjacency matrix, datatypetorch.int64
. Must be row-major sorted along withcols
whendeterministic=True
.cols (torch.Tensor) – Tensor of shape
[|E|]
with column indices for each nonzero in the adjacency matrix, datatypetorch.int64
.sender_perm (torch.Tensor | None) – Tensor of shape
[|E|]
andtorch.int64
datatype containing a permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. Must be provided whendeterministic=True
.
- Returns:
Tensor of shape
[|V|, problem.irreps_out.dim()]
, datatypeproblem.irrep_dtype
.- Return type:
torch.Tensor
- class openequivariance.TPProblem(irreps_in1, irreps_in2, irreps_out, instructions, in1_var=None, in2_var=None, out_var=None, irrep_normalization='component', path_normalization='element', internal_weights=False, shared_weights=None, label=None, irrep_dtype=numpy.float32, weight_dtype=numpy.float32)¶
Specification for a CG tensor product. All parameters from e3nn’s
o3.TensorProduct
are available, along with additional parameters for the types of weights and irreps.- Parameters:
irreps_in1 (Irreps) – Irreps for the first CG argument
irreps_in2 (Irreps) – Irreps for the second CG argument
irreps_out (Irreps) – Irreps for the output
instructions (List[Any]) – A list of 5-tuples, each of the form
(i_in1, i_in2, i_out, has_weight, path_weight)
.i_in1
,i_in2
, andi_out
each index an Irrep fromirreps_in1
,irreps_in2
, andirreps_in3
, respectively.has_weight
(True / False) controls whether trainable weights are included for the instruction, andpath_weight
controls output normalization.irrep_dtype (type[numpy.generic]) – Datatype of irrep inputs; one of
np.float32
ornp.float64
. Default:np.float32
.weight_dtype (type[numpy.generic]) – Datatype of weights; one of
np.float32
ornp.float64
. Default:np.float32
.label (str) – A name for this problem specification (useful for testing / benchmarking).
shared_weights (bool) – If True, all elements in a batch of inputs share a common set of weights. If False, each batch element has a unique set of weights. Default: True.
internal_weights (bool) – Must be False; OpenEquivariance does not support internal weights. Default: False.
irrep_normalization (str) – One of
["component", "norm", "none"]
. Default: “component”.path_normalization (str) – One of
["element", "path", "none"]
. Default: “element”.in1_var (List[float] | None)
in2_var (List[float] | None)
out_var (List[float] | None)
- instructions: List[Any]¶
- internal_weights: bool¶
- label: str¶
- weight_numel: int¶
- weight_range_and_shape_for_instruction(instruction)¶
- Parameters:
instruction (int)
- Return type:
Tuple[int, int, tuple]
- openequivariance.torch_to_oeq_dtype(torch_dtype)¶
Convenience function; converts a torch datatype to the corresponding numpy datatype for use in TPProblem.
- Parameters:
torch_dtype – torch datatype (e.g., torch.float32, torch.float64)
- Returns:
numpy datatype (e.g., np.float32, np.float64)
- Return type:
type[numpy.generic]
- openequivariance.torch_ext_so_path()¶
- Returns:
Path to a
.so
file that must be linked to use OpenEquivariance from the PyTorch C++ Interface.
API Identical to e3nn¶
These remaining API members are identical to the corresponding
objects in e3nn.o3
. You can freely mix these objects from
both packages.
- class openequivariance.Irreps(irreps=None)¶
- Return type:
_MulIr | Irreps
- count(ir)¶
- Returns:
total multiplicity of
ir
.- Parameters:
ir (Irrep)
- Return type:
int
- property dim: int¶
- index(_object)¶
Return first index of value.
Raises ValueError if the value is not present.
- property lmax: int¶
- property ls: List[int]¶
- property num_irreps: int¶
- slices()¶
List of slices corresponding to indices for each irrep.
Examples:
>>> Irreps('2x0e + 1e').slices() [slice(0, 2, None), slice(2, 5, None)]
- sort()¶