Function burn::backend::autodiff::ops::broadcast_shape

pub fn broadcast_shape<B>(
    grad: <B as Backend>::FloatTensorPrimitive,
    shape: &Shape,
) -> <B as Backend>::FloatTensorPrimitive
where B: Backend,
Expand description

Make sure the grad tensor has the given shape.

If broadcasting happened during the forward pass, the gradients will be sum along the broadcasted dimension.