Trait burn::prelude::Backend

pub trait Backend:
    Sized
    + FloatTensorOps<Self>
    + BoolTensorOps<Self>
    + IntTensorOps<Self>
    + ModuleOps<Self>
    + ActivationOps<Self>
    + QTensorOps<Self>
    + Clone
    + Default
    + Send
    + Sync
    + Debug
    + 'static {
    type Device: DeviceOps;
    type FullPrecisionBridge: BackendBridge<Self> + 'static;
    type FloatTensorPrimitive: Clone + Send + Sync + 'static + Debug;
    type FloatElem: Element;
    type IntTensorPrimitive: Clone + Send + Sync + 'static + Debug;
    type IntElem: Element;
    type BoolTensorPrimitive: Clone + Send + Sync + 'static + Debug;
    type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + Sync + 'static + Debug;
    type QuantizedEncoding: Element;

    // Required methods
    fn name() -> String;
    fn seed(seed: u64);

    // Provided methods
    fn ad_enabled() -> bool { ... }
    fn sync(_device: &Self::Device) { ... }
}
Expand description

This trait defines all types and functions needed for a backend to be used with burn.

§Design

This trait aims to be as unopinionated as possible and allows implementations to define their own types and patterns. Therefore, there are few pre-defined abstractions baked into this trait.

Backends must define their own tensor types for each data type: float, int, and bool. Since we minimize assumptions, we chose to separate these types, as they are used in different contexts. However, some backends may have a generic tensor type that is used for all data types.

§Eager Mode

Because burn supports dynamic graphs, the backend trait is designed around kernel implementations that can be called without any mutable context or graph. This may not be ideal for backends that want to configure their computational graphs and execute them multiple times.

To implement this kind of backend, channels could be used to communicate with a backend server thread to build the computation graphs and re-execute the ones that are repeated, with some form of cache. Once that pattern has matured, a graph mode backend trait could be extracted from it, allowing other backends of the same kind to be quickly integrated with burn. This pattern could also be used to create an operation fusion trait, which allows backends to define what kind of graph structures can be fused into one operation.

§Multi-Threaded

Backend tensor types are all Clone + Send, which allows them to be safely sent between threads. It is recommended to wrap tensors with Arc, which avoids copying the tensor’s buffer. Note that it is still possible to mutate and reuse tensors’ buffer without locking; see the next section on the Mutable API.

§Mutable API

There is no mutable or inplace operation API to implement, but that does not mean that backends cannot support them. Using try_unwrap and get_mut allows backends to have access to an owned or mutable reference to their tensor buffer data structure if the tensor is not shared. In that case, backends can dispatch to their owned inplace operations for better performance.

§Documentation

Most of the documentation for each function can be found on the user API tensor struct. For modules, public functions are often created, which can be used by burn-core modules.

Required Associated Types§

type Device: DeviceOps

Device type.

type FullPrecisionBridge: BackendBridge<Self> + 'static

A bridge that can cast tensors to full precision.

type FloatTensorPrimitive: Clone + Send + Sync + 'static + Debug

Tensor primitive to be used for all float operations.

type FloatElem: Element

Float element type.

type IntTensorPrimitive: Clone + Send + Sync + 'static + Debug

Tensor primitive to be used for all int operations.

type IntElem: Element

Int element type.

type BoolTensorPrimitive: Clone + Send + Sync + 'static + Debug

Tensor primitive to be used for all bool operations.

type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + Sync + 'static + Debug

Tensor primitive to be used for all quantized operations.

type QuantizedEncoding: Element

Quantized tensor encoding type.

Required Methods§

fn name() -> String

Name of the backend.

fn seed(seed: u64)

Seed the backend.

Provided Methods§

fn ad_enabled() -> bool

If autodiff is enabled.

fn sync(_device: &Self::Device)

Sync the backend, ensure that all computation are finished.

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

§

impl<B> Backend for Fusion<B>
where B: FusionBackend,

§

type Device = <B as Backend>::Device

§

type FullPrecisionBridge = PrecisionBridge<<B as FusionBackend>::FullPrecisionBackend>

§

type FloatTensorPrimitive = FusionTensor<<B as FusionBackend>::FusionRuntime>

§

type FloatElem = <B as Backend>::FloatElem

§

type IntTensorPrimitive = FusionTensor<<B as FusionBackend>::FusionRuntime>

§

type IntElem = <B as Backend>::IntElem

§

type BoolTensorPrimitive = FusionTensor<<B as FusionBackend>::FusionRuntime>

§

type QuantizedTensorPrimitive = QFusionTensor<<B as FusionBackend>::FusionRuntime>

§

type QuantizedEncoding = <B as Backend>::QuantizedEncoding

§

fn name() -> String

§

fn seed(seed: u64)

§

fn sync(device: &<Fusion<B> as Backend>::Device)

§

fn ad_enabled() -> bool

Implementors§

§

impl<B, C> Backend for Autodiff<B, C>

§

impl<E, I, Q> Backend for NdArray<E, I, Q>
where E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement,

§

impl<E, Q> Backend for LibTorch<E, Q>
where E: TchElement, Q: QuantElement,

§

impl<F, I> Backend for Candle<F, I>
where F: FloatCandleElement, I: IntCandleElement,

§

impl<R, F, I> Backend for JitBackend<R, F, I>
where R: JitRuntime, <R as Runtime>::Server: ComputeServer, <R as Runtime>::Device: DeviceOps, F: FloatElement, I: IntElement,

§

type Device = <R as Runtime>::Device

§

type FullPrecisionBridge = PrecisionBridge<R, f32, i32>

§

type FloatElem = F

§

type IntElem = I

§

type FloatTensorPrimitive = JitTensor<R, <JitBackend<R, F, I> as Backend>::FloatElem>

§

type IntTensorPrimitive = JitTensor<R, <JitBackend<R, F, I> as Backend>::IntElem>

§

type BoolTensorPrimitive = JitTensor<R, u32>

§

type QuantizedTensorPrimitive = QJitTensor<R, <JitBackend<R, F, I> as Backend>::FloatElem, <JitBackend<R, F, I> as Backend>::IntElem>

§

type QuantizedEncoding = u32