Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let DataLoader(..., batchsize=0) produce one batch #145

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mcabbott
Copy link
Contributor

Closes #144

Before, treats 0 like "less than 0", perhaps we can call that a bug?

julia> DataLoader([1 2 3; 4 5 6]; batchsize=0)
3-element DataLoader(::Matrix{Int64}, batchsize=0)
  with first element:
  2-element Vector{Int64}

julia> collect(ans)
3-element Vector{Vector{Int64}}:
 [1, 4]
 [2, 5]
 [3, 6]

help?> DataLoader
search: DataLoader

  DataLoader(data; [batchsize, buffer, collate, parallel, partial, rng, shuffle])
[...]
    •  batchsize: If less than 0, iterates over individual observations. Otherwise, each
       iteration (except possibly the last) yields a mini-batch containing batchsize
       observations. Default 1.
[...]

julia> BatchView(1:10, batchsize=0) |> collect
ERROR: DivideError: integer division error

After, converts 0 to numobs on construction:

julia> DataLoader([1 2 3; 4 5 6]; batchsize=0)
1-element DataLoader(::Matrix{Int64}, batchsize=3)
  with first element:
  2×3 Matrix{Int64}

julia> BatchView(1:10, batchsize=0) |> collect  # unchanged
ERROR: DivideError: integer division error

Needs tests, but locally, tests passed with this change. That's some evidence that the zero wasn't really an intentional feature. (In addition to the doc saying "If less than 0".)

@codecov-commenter
Copy link

codecov-commenter commented Feb 11, 2023

Codecov Report

Merging #145 (fecc37c) into main (ff2fcc1) will increase coverage by 0.32%.
The diff coverage is 100.00%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

@@            Coverage Diff             @@
##             main     #145      +/-   ##
==========================================
+ Coverage   88.28%   88.60%   +0.32%     
==========================================
  Files          15       13       -2     
  Lines         589      588       -1     
==========================================
+ Hits          520      521       +1     
+ Misses         69       67       -2     
Impacted Files Coverage Δ
src/eachobs.jl 87.35% <100.00%> (+0.14%) ⬆️
src/Datasets/Datasets.jl
src/MLUtils.jl

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@CarloLucibello
Copy link
Member

Before, treats 0 like "less than 0", perhaps we can call that a bug?

yes, doesn't deserve a major version bump

@CarloLucibello
Copy link
Member

well actually previous batchsize=0 behavior is useful, disables batching and makes the dataloader equivalent to eachobs, something that cannot be obtained with batchsize=1.
Maybe we can set batchsize=nothing for maximum batch size?

@mcabbott
Copy link
Contributor Author

equivalent to eachobs, something that cannot be obtained with batchsize=1

Isn't this batchsize=-1? That's what the docs seem to say:

julia> DataLoader(ones(2,3); batchsize=1) |> first
2×1 Matrix{Float64}:
 1.0
 1.0

julia> DataLoader(ones(2,3); batchsize=-1) |> first
2-element Vector{Float64}:
 1.0
 1.0

julia> eachobs(ones(2,3)) |> first
2-element Vector{Float64}:
 1.0
 1.0

@mcabbott
Copy link
Contributor Author

Perhaps ideally, batchsize=0 would make more sense for the "individual obs without batch dim" command, and batchsize=-1 for "wrap around to largest possible batch". Both are special cases. But this way around would clearly break documented behaviour.

My first idea in #144 was to use batchsize=Inf for "largest possible". It's a little weird to allow only one Float64 value.

Are there other good tokens? Using a function like batchsize=all seems clear to read but pretty unusual. Base uses dims=: internally in many places for "all the dimensions", but not quite like this; seems more like sum(rand(2,3); dims=(1,2))::Matrix than like sum(rand(2,3); dims=:)::Number.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

batchsize=Inf or something?
3 participants