Skip to content

Commit

Permalink
feat: Nx.put_slice, more complete support (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Nov 27, 2023
1 parent a006dbf commit c4dd79c
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 48 deletions.
39 changes: 16 additions & 23 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -408,34 +408,27 @@ defmodule Candlex.Backend do
@impl true
def put_slice(
%T{} = out,
%T{shape: shape} = t,
%T{} = t,
[_ | _] = start_indices,
%T{shape: slice_shape} = slice
) do
[_last_dim | leading_dimensions] = shape |> Tuple.to_list() |> Enum.reverse()

[_last_slice_dim | leading_slice_dimensions] =
slice_shape |> Tuple.to_list() |> Enum.reverse()

[last_start_index | leading_start_indices] = Enum.reverse(start_indices)
ranges =
slice_shape
|> Tuple.to_list()
|> Enum.with_index(fn axis_size, i ->
start_index =
start_indices
|> Enum.at(i)
|> Nx.to_number()

if leading_dimensions != leading_slice_dimensions do
raise "Unsupported put_slice shapes, tensor=#{inspect(shape)} and slice=#{inspect(slice_shape)}. All-but-last dimensions in slice need to be equal to corresponding dimension in tensor."
end
{start_index, start_index + axis_size - 1}
end)

if Enum.all?(leading_start_indices, fn i -> Nx.equal(i, 0) end) do
t
|> from_nx()
|> Native.slice_scatter(
from_nx(slice),
length(start_indices) - 1,
Nx.to_number(last_start_index)
)
|> unwrap!()
|> to_nx(out)
else
raise "put_slice only supports last start index not to be 0 for now"
end
t
|> from_nx()
|> Native.slice_assign(ranges, from_nx(slice))
|> unwrap!()
|> to_nx(out)
end

@impl true
Expand Down
2 changes: 1 addition & 1 deletion lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ defmodule Candlex.Native do
def concatenate(_tensors, _axis), do: error()
def conv1d(_tensor, _kernel, _opts), do: error()
def conv2d(_tensor, _kernel, _opts), do: error()
def slice_scatter(_tensor, _src, _dim, _start), do: error()
def slice_assign(_tensor, _ranges, _src), do: error()
def pad_with_zeros(_tensor, _left, _right), do: error()
def clamp(_tensor, _min, _max), do: error()
def reverse(_tensor, _axes), do: error()
Expand Down
2 changes: 1 addition & 1 deletion native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ rustler::init! {
tensors::conv1d,
tensors::conv2d,
tensors::permute,
tensors::slice_scatter,
tensors::slice_assign,
tensors::pad_with_zeros,
tensors::dot,
tensors::matmul,
Expand Down
17 changes: 13 additions & 4 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,22 @@ pub fn reshape(t: ExTensor, shape: Term) -> Result<ExTensor, CandlexError> {
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn slice_scatter(
pub fn slice_assign(
t: ExTensor,
ranges: Vec<(usize, usize)>,
src: ExTensor,
dim: usize,
start: usize,
) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?))
use std::ops::Bound;

Ok(ExTensor::new(
t.slice_assign(
&ranges
.iter()
.map(|(start, end)| (Bound::Included(*start), Bound::Included(*end)))
.collect::<Vec<(Bound<usize>, Bound<usize>)>>(),
src.deref(),
)?,
))
}

#[rustler::nif(schedule = "DirtyCpu")]
Expand Down
43 changes: 24 additions & 19 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2031,32 +2031,24 @@ defmodule CandlexTest do
|> assert_equal(t([0, 1, 5, 6, 4]))

t([[1, 2, 3], [4, 5, 6]])
|> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]]))
|> Nx.put_slice([0, 1], t([[7, 8], [9, 10]]))
|> assert_equal(
t([
[7, 8, 9],
[10, 11, 12]
[1, 7, 8],
[4, 9, 10]
])
)

t([[1, 2, 3], [4, 5, 6]])
|> Nx.put_slice([0, 1], t([[7, 8], [9, 10]]))
t([[1.0, 2, 3], [4, 5, 6]])
|> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]]))
|> assert_equal(
t([
[1, 7, 8],
[4, 9, 10]
[1.0, 10.0, 11.0],
[4.0, 5.0, 6.0]
])
)

# t([[1, 2, 3], [4, 5, 6]])
# |> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]]))
# |> assert_equal(t(
# [
# [1.0, 10.0, 11.0],
# [4.0, 5.0, 6.0]
# ]
# ))

# Start index clipping
# t([[1, 2, 3], [4, 5, 6]])
# |> Nx.put_slice([1, 1], t([[7, 8], [9, 10]]))
# |> assert_equal(t(
Expand All @@ -2066,6 +2058,15 @@ defmodule CandlexTest do
# ]
# ))

t([[1, 2, 3], [4, 5, 6]])
|> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]]))
|> assert_equal(
t([
[7, 8, 9],
[10, 11, 12]
])
)

t([
[
[1, 2],
Expand All @@ -2090,9 +2091,13 @@ defmodule CandlexTest do
])
)

# t([[[1, 2], [3, 4]]])
# |> Nx.put_slice([0, 0, 0], t([[[10, 11]]]))
# |> assert_equal(t([[[10, 11], [3, 4]]]))
t([[[1, 2], [3, 4]]])
|> Nx.put_slice([0, 0, 0], t([[[10, 11]]]))
|> assert_equal(t([[[10, 11], [3, 4]]]))

t([[[1, 2], [3, 4]]])
|> Nx.put_slice([0, 1, 0], t([[[10, 11]]]))
|> assert_equal(t([[[1, 2], [10, 11]]]))
end

test "pad" do
Expand Down

0 comments on commit c4dd79c

Please sign in to comment.