Skip to content

Advanced

advanced

Configuration for the advanced features.

Classes

AdvancedConfiguration

Bases: BaseModel

Configuration for the advanced features.

Attributes
backend class-attribute instance-attribute
backend: Literal['jax', 'numpy'] = Field(default='jax', description='Backend for the computation.')

Backend for the computation.

jit class-attribute instance-attribute
jit: bool = Field(default=True, description='Whether to jit critical functions.')

Whether to jit critical functions.

vectorization class-attribute instance-attribute
vectorization: Literal['vmap', 'manual loop', 'pmap'] = Field(default='vmap', description='Vectorization method.')

Vectorization method.

device class-attribute instance-attribute
device: Literal['cpu', 'gpu', 'tpu', 'auto'] = Field(default='auto', description='Device to use for the run.')

Device to use for the run.

Functions
validate_backend classmethod
validate_backend(value: Literal['jax', 'numpy']) -> Literal['jax', 'numpy']

Validate the backend to use.

Parameters:

Name Type Description Default
value Literal['jax', 'numpy']

Name of the backend.

required

Returns:

Type Description
Literal['jax', 'numpy']

Name of the validated backend.

Raises:

Type Description
ValueError

If value is numpy.