pub trait Backend:
Sized
+ FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ QTensorOps<Self>
+ Clone
+ Default
+ Send
+ Sync
+ Debug
+ 'static {
type Device: DeviceOps;
type FullPrecisionBridge: BackendBridge<Self> + 'static;
type FloatTensorPrimitive: Clone + Send + Sync + 'static + Debug;
type FloatElem: Element;
type IntTensorPrimitive: Clone + Send + Sync + 'static + Debug;
type IntElem: Element;
type BoolTensorPrimitive: Clone + Send + Sync + 'static + Debug;
type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + Sync + 'static + Debug;
type QuantizedEncoding: Element;
// Required methods
fn name() -> String;
fn seed(seed: u64);
// Provided methods
fn ad_enabled() -> bool { ... }
fn sync(_device: &Self::Device) { ... }
}
Expand description
This trait defines all types and functions needed for a backend to be used with burn.
§Design
This trait aims to be as unopinionated as possible and allows implementations to define their own types and patterns. Therefore, there are few pre-defined abstractions baked into this trait.
Backends must define their own tensor types for each data type: float
, int
, and bool
.
Since we minimize assumptions, we chose to separate these types, as they are used in
different contexts. However, some backends may have a generic tensor type that is used
for all data types.
§Eager Mode
Because burn supports dynamic graphs, the backend trait is designed around kernel implementations that can be called without any mutable context or graph. This may not be ideal for backends that want to configure their computational graphs and execute them multiple times.
To implement this kind of backend, channels could be used to communicate with a backend server thread to build the computation graphs and re-execute the ones that are repeated, with some form of cache. Once that pattern has matured, a graph mode backend trait could be extracted from it, allowing other backends of the same kind to be quickly integrated with burn. This pattern could also be used to create an operation fusion trait, which allows backends to define what kind of graph structures can be fused into one operation.
§Multi-Threaded
Backend tensor types are all Clone
+ Send
, which allows them to be safely
sent between threads. It is recommended to wrap tensors with Arc,
which avoids copying the tensor’s buffer. Note that it is still possible to mutate and
reuse tensors’ buffer without locking; see the next section on the Mutable API.
§Mutable API
There is no mutable or inplace operation API to implement, but that does not mean that backends cannot support them. Using try_unwrap and get_mut allows backends to have access to an owned or mutable reference to their tensor buffer data structure if the tensor is not shared. In that case, backends can dispatch to their owned inplace operations for better performance.
§Documentation
Most of the documentation for each function can be found on the user API tensor struct.
For modules, public functions are often created, which can be used by burn-core
modules.
Required Associated Types§
type FullPrecisionBridge: BackendBridge<Self> + 'static
type FullPrecisionBridge: BackendBridge<Self> + 'static
A bridge that can cast tensors to full precision.
type FloatTensorPrimitive: Clone + Send + Sync + 'static + Debug
type FloatTensorPrimitive: Clone + Send + Sync + 'static + Debug
Tensor primitive to be used for all float operations.
type IntTensorPrimitive: Clone + Send + Sync + 'static + Debug
type IntTensorPrimitive: Clone + Send + Sync + 'static + Debug
Tensor primitive to be used for all int operations.
type BoolTensorPrimitive: Clone + Send + Sync + 'static + Debug
type BoolTensorPrimitive: Clone + Send + Sync + 'static + Debug
Tensor primitive to be used for all bool operations.
type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + Sync + 'static + Debug
type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + Sync + 'static + Debug
Tensor primitive to be used for all quantized operations.
type QuantizedEncoding: Element
type QuantizedEncoding: Element
Quantized tensor encoding type.
Required Methods§
Provided Methods§
fn ad_enabled() -> bool
fn ad_enabled() -> bool
If autodiff is enabled.