Advanced
advanced
¶
Configuration for the advanced features.
Classes¶
AdvancedConfiguration
¶
Bases: BaseModel
Configuration for the advanced features.
Attributes¶
backend
class-attribute
instance-attribute
¶
backend: Literal['jax', 'numpy'] = Field(default='jax', description='Backend for the computation.')
Backend for the computation.
jit
class-attribute
instance-attribute
¶
jit: bool = Field(default=True, description='Whether to jit critical functions.')
Whether to jit critical functions.
vectorization
class-attribute
instance-attribute
¶
vectorization: Literal['vmap', 'manual loop', 'pmap'] = Field(default='vmap', description='Vectorization method.')
Vectorization method.
device
class-attribute
instance-attribute
¶
device: Literal['cpu', 'gpu', 'tpu', 'auto'] = Field(default='auto', description='Device to use for the run.')
Device to use for the run.
Functions¶
validate_backend
classmethod
¶
validate_backend(value: Literal['jax', 'numpy']) -> Literal['jax', 'numpy']
Validate the backend to use.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Literal['jax', 'numpy']
|
Name of the backend. |
required |
Returns:
| Type | Description |
|---|---|
Literal['jax', 'numpy']
|
Name of the validated backend. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If value is numpy. |