OpenEquivariance API

OpenEquivariance exposes two key classes: openequivariance.TensorProduct, which replaces o3.TensorProduct from e3nn, and openequivariance.TensorProductConv, which fuses the CG tensor product with a subsequent graph convolution. Initializing either class triggers JIT compilation of a custom kernel, which can take a few seconds.

Both classes require a configuration object specified by openequivariance.TPProblem, which has a constructor almost identical to o3.TensorProduct. We recommend reading the e3nn documentation before trying our code. OpenEquivariance cannot accelerate all tensor products; see this page for a list of supported configurations.

class openequivariance.TensorProduct(*args, **kwargs)

Drop-in replacement for o3.TensorProduct from e3nn. Supports forward, backward, and double-backward passes using JIT-compiled kernels. Initialization fails if:

  • There are no visible GPUs.

  • The provided tensor product specification is unsupported.

Parameters:

problem (TPProblem) – Specification of the tensor product.

forward(x, y, W)

Computes \(W (x \otimes_{\\textrm{CG}} y)\), identical to o3.TensorProduct.forward.

Parameters:
  • x (torch.Tensor) – Tensor of shape [batch_size, problem.irreps_in1.dim()], datatype problem.irrep_dtype.

  • y (torch.Tensor) – Tensor of shape [batch_size, problem.irreps_in2.dim()], datatype problem.irrep_dtype.

  • W (torch.Tensor) –

    Tensor of datatype problem.weight_dtype and shape

    • [batch_size, problem.weight_numel] if problem.shared_weights=False

    • [problem.weight_numel] if problem.shared_weights=True

Returns:

Tensor of shape [batch_size, problem.irreps_out.dim()], datatype problem.irrep_dtype.

Return type:

torch.Tensor

class openequivariance.TensorProductConv(*args, **kwargs)

Given a symmetric, directed graph \(G = (V, E)\), inputs \(x_1...x_{|V|}\), \(y_1...y_{|E|}\), and weights \(W_1...W_{|E|}\), computes

\[\begin{split}z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\textrm{CG}} y_e)\end{split}\]

where \((i, j, e) \in \mathcal{N}(i)\) indicates that node \(i\) is connected to node \(j\) via the edge indexed \(e\).

This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply a permutation to transpose the adjacency matrix.

Parameters:
  • problem (TPProblem) – Specification of the tensor product.

  • deterministic (bool) – if False, uses atomics for the convolution. If True, uses a deterministic fixup-based algorithm. Default: False.

  • kahan (bool) – if True, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must set deterministic=True. Default: False.

forward(X, Y, W, rows, cols, sender_perm=None)

Computes the fused CG tensor product + convolution.

Parameters:
  • X (torch.Tensor) – Tensor of shape [|V|, problem.irreps_in1.dim()], datatype problem.irrep_dtype.

  • Y (torch.Tensor) – Tensor of shape [|E|, problem.irreps_in1.dim()], datatype problem.irrep_dtype.

  • W (torch.Tensor) –

    Tensor of datatype problem.weight_dtype and shape

    • [|E|, problem.weight_numel] if problem.shared_weights=False

    • [problem.weight_numel] if problem.shared_weights=True

  • rows (torch.Tensor) – Tensor of shape [|E|] with row indices for each nonzero in the adjacency matrix, datatype torch.int64. Must be row-major sorted along with cols when deterministic=True.

  • cols (torch.Tensor) – Tensor of shape [|E|] with column indices for each nonzero in the adjacency matrix, datatype torch.int64.

  • sender_perm (torch.Tensor | None) – Tensor of shape [|E|] and torch.int64 datatype containing a permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. Must be provided when deterministic=True.

Returns:

Tensor of shape [|V|, problem.irreps_out.dim()], datatype problem.irrep_dtype.

Return type:

torch.Tensor

class openequivariance.TPProblem(irreps_in1, irreps_in2, irreps_out, instructions, in1_var=None, in2_var=None, out_var=None, irrep_normalization='component', path_normalization='element', internal_weights=False, shared_weights=None, label=None, irrep_dtype=numpy.float32, weight_dtype=numpy.float32)

Specification for a CG tensor product. All parameters from e3nn’s o3.TensorProduct are available, along with additional parameters for the types of weights and irreps.

Parameters:
  • irreps_in1 (Irreps) – Irreps for the first CG argument

  • irreps_in2 (Irreps) – Irreps for the second CG argument

  • irreps_out (Irreps) – Irreps for the output

  • instructions (List[Any]) – A list of 5-tuples, each of the form (i_in1, i_in2, i_out, has_weight, path_weight). i_in1, i_in2, and i_out each index an Irrep from irreps_in1, irreps_in2, and irreps_in3, respectively. has_weight (True / False) controls whether trainable weights are included for the instruction, and path_weight controls output normalization.

  • irrep_dtype (type[numpy.generic]) – Datatype of irrep inputs; one of np.float32 or np.float64. Default: np.float32.

  • weight_dtype (type[numpy.generic]) – Datatype of weights; one of np.float32 or np.float64. Default: np.float32.

  • label (str) – A name for this problem specification (useful for testing / benchmarking).

  • shared_weights (bool) – If True, all elements in a batch of inputs share a common set of weights. If False, each batch element has a unique set of weights. Default: True.

  • internal_weights (bool) – Must be False; OpenEquivariance does not support internal weights. Default: False.

  • irrep_normalization (str) – One of ["component", "norm", "none"]. Default: “component”.

  • path_normalization (str) – One of ["element", "path", "none"]. Default: “element”.

  • in1_var (List[float] | None)

  • in2_var (List[float] | None)

  • out_var (List[float] | None)

instructions: List[Any]
internal_weights: bool
label: str
shared_weights: bool
weight_numel: int
weight_range_and_shape_for_instruction(instruction)
Parameters:

instruction (int)

Return type:

Tuple[int, int, tuple]

openequivariance.torch_to_oeq_dtype(torch_dtype)

Convenience function; converts a torch datatype to the corresponding numpy datatype for use in TPProblem.

Parameters:

torch_dtype – torch datatype (e.g., torch.float32, torch.float64)

Returns:

numpy datatype (e.g., np.float32, np.float64)

Return type:

type[numpy.generic]

openequivariance.torch_ext_so_path()
Returns:

Path to a .so file that must be linked to use OpenEquivariance from the PyTorch C++ Interface.

API Identical to e3nn

These remaining API members are identical to the corresponding objects in e3nn.o3. You can freely mix these objects from both packages.

class openequivariance.Irreps(irreps=None)
Return type:

_MulIr | Irreps

count(ir)
Returns:

total multiplicity of ir.

Parameters:

ir (Irrep)

Return type:

int

property dim: int
index(_object)

Return first index of value.

Raises ValueError if the value is not present.

property lmax: int
property ls: List[int]
property num_irreps: int
regroup()
Return type:

Irreps

remove_zero_multiplicities()
Return type:

Irreps

simplify()
Return type:

Irreps

slices()

List of slices corresponding to indices for each irrep.

Examples:

>>> Irreps('2x0e + 1e').slices()
[slice(0, 2, None), slice(2, 5, None)]
sort()
static spherical_harmonics(lmax, p=-1)
Parameters:
  • lmax (int)

  • p (int)

Return type:

Irreps