hard_shrink

Function hard_shrink 

pub fn hard_shrink<const D: usize, B>(
    tensor: Tensor<B, D>,
    lambda: f64,
) -> Tensor<B, D>
where B: Backend,
Expand description

Applies the HardShrink function element-wise.

hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise

ยงArguments

  • lambda: the lambda value for the Hard Shrink formulation. Default is 0.5.