-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ImageFolder dataloader #600
base: master
Are you sure you want to change the base?
Changes from all commits
328dccb
51b17e6
b1c3df3
b52f7aa
234c0f5
ac747b8
cd06d9f
b6bcc30
d096685
6a0e386
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,7 @@ library | |
, reflection | ||
, stm | ||
, JuicyPixels | ||
, directory | ||
, vector | ||
, bytestring | ||
, safe-exceptions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ import Data.Kind | |
import Data.Maybe | ||
import Data.Proxy | ||
import Data.Reflection | ||
import Data.Type.Equality ((:~:)(..)) | ||
import Foreign.ForeignPtr | ||
import GHC.Generics (Generic) | ||
import GHC.Natural (Natural) | ||
|
@@ -1693,6 +1694,59 @@ anyDim :: | |
Tensor device 'D.Bool shape' | ||
anyDim input = unsafePerformIO $ ATen.cast3 ATen.Managed.any_tlb input (natValI @dim) (keepOrDropDimVal @keepOrDropDim) | ||
|
||
type family ListMax (list :: [Nat]) :: Nat where | ||
ListMax '[] = 0 | ||
ListMax (x:xs) = Max x (ListMax xs) | ||
|
||
type family Sum (list :: [Nat]) :: Nat where | ||
Sum '[] = 0 | ||
Sum (x:xs) = x + (Sum xs) | ||
|
||
|
||
type CheckPermuteDims (shape :: [Nat]) (permuteDims :: [Nat]) (outputShape :: [Nat]) = ( | ||
KnownShape permuteDims, | ||
KnownShape outputShape, | ||
ListLength shape ~ ListLength permuteDims, | ||
ListLength shape ~ (ListMax permuteDims + 1), | ||
((ListLength shape - 1) * ListLength shape) ~ (Sum permuteDims * 2), | ||
|
||
outputShape ~ PermuteDims shape shape permuteDims 0 | ||
) | ||
|
||
type family PermuteDims (shape :: [Nat]) (shape' :: [Nat]) (permuteDims :: [Nat]) (idx :: Nat) :: [Nat] where | ||
PermuteDims shape shape' '[] idx = shape' | ||
PermuteDims shape shape' (x ': xs) idx = PermuteDims shape (ReplaceDim' idx shape' (Index shape x)) xs (idx + 1) | ||
|
||
|
||
|
||
-- | permute | ||
-- | ||
-- >>> t = ones :: CPUTensor 'D.Float '[1, 2, 3, 4] | ||
-- >>> t' = permute @'[1, 3, 2, 0] t | ||
-- >>> dtype &&& shape $ t' | ||
-- (Float,[2,4,3,1]) | ||
-- | ||
-- >>> t = ones :: CPUTensor 'D.Float '[1, 2] | ||
-- >>> t' = permute @'[1, 0] t | ||
-- >>> dtype &&& shape $ t' | ||
-- (Float,[2,1]) | ||
-- | ||
-- >>> t = ones :: CPUTensor 'D.Float '[6, 1, 5, 2, 4, 3] | ||
-- >>> t' = permute @'[1, 3, 5, 4, 2, 0] t | ||
-- >>> dtype &&& shape $ t' | ||
-- (Float,[1,2,3,4,5,6]) | ||
|
||
|
||
permute :: | ||
forall permuteDims shape device dtype shape'. | ||
( CheckPermuteDims shape permuteDims shape' | ||
) => | ||
Tensor device dtype shape -> | ||
Tensor device dtype shape' -- output | ||
permute t = unsafePerformIO $ ATen.cast2 ATen.Managed.tensor_permute_l t permuteDims' | ||
where | ||
permuteDims' = shapeVal @permuteDims | ||
|
||
-- | dropout | ||
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted? | ||
-- TODO: get rid of IO by exposing the RNG state | ||
|
@@ -4318,6 +4372,62 @@ stack :: | |
Tensor device dtype shape | ||
stack tensors = unsafePerformIO $ ATen.cast2 ATen.Managed.stack_ll tensors (natValI @dim :: Int) | ||
|
||
|
||
-- | Stack' | ||
|
||
type Stack' (dim :: Nat) (preShape :: [Nat]) (shape :: [Nat]) (n :: Nat) = ( | ||
KnownNat dim, | ||
KnownNat n, | ||
All KnownNat preShape, | ||
All KnownNat shape, | ||
ListLength shape ~ (ListLength preShape + 1), | ||
n ~ Index shape dim | ||
) | ||
|
||
-- | stack' - Untyped-esque stack that accepts a list of tensors | ||
-- >>> t = ones :: CPUTensor 'D.Float '[] | ||
-- >>> t' = stack' @0 @'[] @'[1] [t] | ||
-- >>> :type t' | ||
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1] | ||
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [Float]) $ t' | ||
-- (Float,([1],[1.0])) | ||
-- >>> t = ones :: CPUTensor 'D.Float '[2,2] | ||
-- >>> t' = stack' @0 @'[2, 2] @'[1, 2, 2] [t] | ||
-- >>> :type t' | ||
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 2, 2] | ||
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t' | ||
-- (Float,([1,2,2],[[[1.0,1.0],[1.0,1.0]]])) | ||
-- >>> t' = stack' @1 @'[2, 2] @'[2, 1, 2] [t] | ||
-- >>> :type t' | ||
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 2] | ||
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t' | ||
-- (Float,([2,1,2],[[[1.0,1.0]],[[1.0,1.0]]])) | ||
-- >>> t' = stack' @2 @'[2, 2] @'[2, 2, 1] [t] | ||
-- >>> :type t' | ||
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 1] | ||
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t' | ||
-- (Float,([2,2,1],[[[1.0],[1.0]],[[1.0],[1.0]]])) | ||
-- >>> t' = stack' @2 @'[2, 2] @'[2, 2, 3] [t, t, t] | ||
-- >>> :type t' | ||
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 3] | ||
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t' | ||
-- (Float,([2,2,3],[[[1.0,1.0,1.0],[1.0,1.0,1.0]],[[1.0,1.0,1.0],[1.0,1.0,1.0]]])) | ||
|
||
stack' :: | ||
forall dim preShape shape dtype device n. | ||
( KnownNat dim, | ||
Stack' dim preShape shape n | ||
) => | ||
-- | input list of tensors | ||
[Tensor device dtype preShape] -> | ||
-- | output | ||
Tensor device dtype shape | ||
stack' tensors = case someNatVal (fromIntegral $ length tensors) of | ||
Just (SomeNat len) -> | ||
case sameNat len (Proxy :: Proxy n) of | ||
Just Refl -> unsafePerformIO $ ATen.cast2 ATen.Managed.stack_ll tensors (natValI @dim :: Int) | ||
Nothing -> error "Shape did not match length of tensor list" | ||
|
||
vecStack :: | ||
forall dim n shape dtype device. | ||
( KnownNat dim, KnownNat n ) => | ||
|
@@ -6012,12 +6122,12 @@ avgPool3d input = | |
-- upsample_linear1d _input _output_size _align_corners = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_linear1d_tlb) _input _output_size _align_corners | ||
|
||
type family Upsample2dCheck shape h w where | ||
Upsample2dCheck (b : c : w : h : '[]) h' w' = | ||
Upsample2dCheck (b : c : h : w : '[]) h' w' = | ||
If | ||
(h <=? h') | ||
( If | ||
(w <=? w') | ||
(b : c : w' : h' : '[]) | ||
(b : c : h' : w' : '[]) | ||
(TypeError (Text "Target width must be greater than current width!")) | ||
) | ||
(TypeError (Text "Target height must be greater than current height!")) | ||
|
@@ -6030,14 +6140,14 @@ type Upsample2d shape h w = Upsample2dCheck shape h w | |
-- >>> (dtype &&& shape) $ upsample_bilinear2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2]) | ||
-- (Float,[2,3,3,5]) | ||
upsample_bilinear2d :: | ||
forall w h shape dtype device. | ||
forall h w shape dtype device. | ||
(KnownNat h, KnownNat w, All KnownNat shape) => | ||
-- | if True, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. | ||
Bool -> | ||
Tensor device dtype shape -> | ||
Tensor device dtype (Upsample2d shape h w) | ||
upsample_bilinear2d _align_corners _input = | ||
unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) _input ([w, h] :: [Int]) _align_corners | ||
unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) _input ([h, w] :: [Int]) _align_corners | ||
where | ||
w = natValI @w :: Int | ||
h = natValI @h :: Int | ||
|
@@ -6047,12 +6157,12 @@ upsample_bilinear2d _align_corners _input = | |
-- >>> (dtype &&& shape) $ upsample_bicubic2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2]) | ||
-- (Float,[2,3,3,5]) | ||
upsample_bicubic2d :: | ||
forall w h shape dtype device. | ||
forall h w shape dtype device. | ||
(KnownNat h, KnownNat w, All KnownNat shape) => | ||
Bool -> | ||
Tensor device dtype shape -> | ||
Tensor device dtype (Upsample2d shape h w) | ||
upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bicubic2d_tlb) _input ([w, h] :: [Int]) _align_corners | ||
upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bicubic2d_tlb) _input ([h, w] :: [Int]) _align_corners | ||
where | ||
w = natValI @w :: Int | ||
h = natValI @h :: Int | ||
|
@@ -6063,20 +6173,60 @@ upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Ma | |
-- upsample_nearest1d :: Tensor device dtype shape -> Int -> Tensor device dtype shape | ||
-- upsample_nearest1d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest1d_tl) _input _output_size | ||
|
||
-- | Applies a 2D bicubic upsampling to an input signal composed of several input channels. | ||
-- | Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. | ||
-- | ||
-- >>> (dtype &&& shape) $ upsample_nearest2d @3 @5 (ones :: CPUTensor 'D.Float '[2,3,2,2]) | ||
-- (Float,[2,3,3,5]) | ||
|
||
upsample_nearest2d :: | ||
forall w h shape dtype device. | ||
forall h w shape dtype device. | ||
(KnownNat h, KnownNat w, All KnownNat shape) => | ||
Tensor device dtype shape -> | ||
Tensor device dtype (Upsample2d shape h w) | ||
upsample_nearest2d _input = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) _input ([w, h] :: [Int]) | ||
upsample_nearest2d _input = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) _input ([h, w] :: [Int]) | ||
where | ||
w = natValI @w :: Int | ||
h = natValI @h :: Int | ||
|
||
|
||
-- Freeform Resizing | ||
interpolate2d :: | ||
forall newShape mode w h shape dtype device. | ||
(KnownShape shape, KnownNat w, KnownNat h, KnownSymbol mode) => | ||
Bool -> | ||
Tensor device 'D.Float shape -> | ||
Tensor device 'D.Float newShape | ||
interpolate2d alignCorners tensor | ||
| symbolVal (Proxy @mode) == "nearest" = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) tensor ([h, w] :: [Int]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you forget to compute Can I also ask you to not use a
can you add a sum type for |
||
| symbolVal (Proxy @mode) == "bilinear" = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) tensor ([h, w] :: [Int]) alignCorners | ||
| otherwise = error "Invalid mode for interpolation" | ||
where | ||
w = natValI @w | ||
h = natValI @h | ||
|
||
type Resize2D (shape :: [Nat]) (newShape :: [Nat]) (n :: Nat) (c :: Nat) (h :: Nat) (w :: Nat) (n0 :: Nat) (c0 :: Nat) (h0 :: Nat) (w0 :: Nat) = | ||
( | ||
KnownShape shape, | ||
KnownShape newShape, | ||
All KnownNat newShape, | ||
4 ~ ListLength shape, | ||
4 ~ ListLength newShape, | ||
(n ': c ': h ': w : '[]) ~ newShape, | ||
(n0 ': c0 ': h0 ': w0 ': '[]) ~ shape, | ||
n ~ n0, | ||
c ~ c0 | ||
) | ||
|
||
|
||
resize2d :: | ||
forall newShape mode shape dtype device n c h w n0 c0 h0 w0. | ||
(Resize2D shape newShape n c h w n0 c0 h0 w0, KnownSymbol mode, KnownDType dtype) => | ||
Bool -> | ||
Tensor device dtype shape -> | ||
Tensor device dtype newShape | ||
resize2d alignCorners = toDType @dtype @'D.Float . interpolate2d @newShape @mode @w @h alignCorners . toDType @'D.Float @dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar comments as for |
||
|
||
|
||
-- upsample_nearest3d :: Tensor device dtype shape -> (Int,Int,Int) -> Tensor device dtype shape | ||
-- upsample_nearest3d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest3d_tl) _input _output_size | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I'm missing something here, but it seems like that the output shape is not at all computed from the input shape, is it? So, it's up to the user to know exactly how many elements the list has and do the shape computation by hand?
I understand that there can be desire to do this, but this is a very unsafe thing to do. A better way would be to use
vecStack
from below. You can convert your list first to a sized vector -- which can fail https://hackage.haskell.org/package/vector-sized-1.4.4/docs/Data-Vector-Generic-Sized.html#v:fromList -- and then usevecStack
on it safely.