Skip to content

Commit

Permalink
Impl round to even for candle
Browse files Browse the repository at this point in the history
  • Loading branch information
med1844 committed Oct 17, 2024
1 parent ea38cd0 commit f5cadd6
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,25 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
}

fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.round().unwrap())
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
// implements round_to_even for consistent behavior vs libtorch
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67

let floor_a = tensor.tensor.floor()?;
let frac_part = tensor.tensor.sub(&floor_a)?;

let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
let mask_half = frac_part.eq(&half)?;
let half_tensor = tensor.tensor.mul(&half)?;
let rounded_half = half_tensor.round()?;
let doubled =
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
let standard_round = tensor.tensor.round()?;
Ok(CandleTensor::new(
mask_half.where_cond(&doubled, &standard_round)?,
))
};
inner(tensor).unwrap()
}

fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
Expand Down

0 comments on commit f5cadd6

Please sign in to comment.