Function burn::backend::autodiff::ops::broadcast_shape

pub fn broadcast_shape<B, const D: usize>(
    grad: <B as Backend>::FloatTensorPrimitive<D>,
    shape: &Shape<D>,
) -> <B as Backend>::FloatTensorPrimitive<D>
where B: Backend,
Expand description

Make sure the grad tensor has the given shape.

If broadcasting happened during the forward pass, the gradients will be sum along the broadcasted dimension.