Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ImageFolder dataloader #600

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion examples/alexNet/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Torch.NN as N
import Torch.Functional as F hiding (take)
import qualified Torch.DType as D
import qualified Torch.Vision as V
import qualified Torch.Typed.Vision as V hiding (getImages')
import qualified Torch.Typed.Vision as V hiding (getImages', hwc2chw, readImageAsRGB8)

data DataSet = DataSet {
images :: [Tensor],
Expand Down
1 change: 1 addition & 0 deletions hasktorch/hasktorch.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ library
, reflection
, stm
, JuicyPixels
, directory
, vector
, bytestring
, safe-exceptions
Expand Down
9 changes: 9 additions & 0 deletions hasktorch/src/Torch/Functional.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,15 @@ upsampleNearest2d ::
Tensor
upsampleNearest2d (outputHeight, outputWidth) scales_h scales_w self = unsafePerformIO $ cast4 ATen.upsample_nearest2d_tldd self [outputHeight, outputWidth] scales_h scales_w

-- | Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
upsampleNearest2d' ::
-- | output_size
(Int, Int) ->
-- | self
Tensor ->
Tensor
upsampleNearest2d' (outputHeight, outputWidth) self = unsafePerformIO $ cast2 ATen.upsample_nearest2d_tl self [outputHeight, outputWidth]

-- | Splits the tensor into chunks of given size if possible.
split ::
-- | split-size
Expand Down
15 changes: 15 additions & 0 deletions hasktorch/src/Torch/Typed/Aux.hs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ type family Init (xs :: [a]) :: [a] where
Init (x ': '[]) = '[]
Init (x ': xs) = x ': Init xs

type family Tail (xs :: [a]) :: [a] where
Tail '[] = TypeError (Text "Tail of empty list.")
Tail '[x] = '[]
Tail (x ': xs) = xs

type family Head (xs :: [a]) :: a where
Head '[] = TypeError (Text "Head of empty list.")
Head (x ': xs) = x

type family Last (xs :: [a]) :: a where
Last '[] = TypeError (Text "Last of empty list.")
Last (x ': '[]) = x
Expand Down Expand Up @@ -218,6 +227,12 @@ type family ReplaceDim (dim :: Nat) (shape :: [Nat]) (n :: Nat) :: Maybe [Nat] w
ReplaceDim dim (h ': t) n = AppendToMaybe h (ReplaceDim (dim - 1) t n)
ReplaceDim _ _ _ = Nothing

type family CheckReplace (dim :: Nat) (shape :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
CheckReplace dim shape Nothing = DimOutOfBound shape dim
CheckReplace _ _ (Just result) = result

type ReplaceDim' dim shape n = CheckReplace dim shape (ReplaceDim dim shape n)

type family If c t e where
If 'True t e = t
If 'False t e = e
Expand Down
16 changes: 16 additions & 0 deletions hasktorch/src/Torch/Typed/Factories.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import Torch.Typed.Aux
import Torch.Typed.Tensor
import Prelude hiding (sin)
Expand All @@ -55,6 +56,21 @@ zeros =
. D.withDType (optionsRuntimeDType @shape @dtype @device)
$ D.defaultOpts
)
empty ::
forall shape dtype device.
(TensorOptions shape dtype device) =>
IO (Tensor device dtype shape )

empty =
cast2
LibTorch.empty_lo
(optionsRuntimeShape @shape @dtype @device)
( D.withDevice (optionsRuntimeDevice @shape @dtype @device)
. D.withDType (optionsRuntimeDType @shape @dtype @device)
$ D.defaultOpts
)



full ::
forall shape dtype device a.
Expand Down
168 changes: 159 additions & 9 deletions hasktorch/src/Torch/Typed/Functional.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import Data.Kind
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Type.Equality ((:~:)(..))
import Foreign.ForeignPtr
import GHC.Generics (Generic)
import GHC.Natural (Natural)
Expand Down Expand Up @@ -1693,6 +1694,59 @@ anyDim ::
Tensor device 'D.Bool shape'
anyDim input = unsafePerformIO $ ATen.cast3 ATen.Managed.any_tlb input (natValI @dim) (keepOrDropDimVal @keepOrDropDim)

type family ListMax (list :: [Nat]) :: Nat where
ListMax '[] = 0
ListMax (x:xs) = Max x (ListMax xs)

type family Sum (list :: [Nat]) :: Nat where
Sum '[] = 0
Sum (x:xs) = x + (Sum xs)


type CheckPermuteDims (shape :: [Nat]) (permuteDims :: [Nat]) (outputShape :: [Nat]) = (
KnownShape permuteDims,
KnownShape outputShape,
ListLength shape ~ ListLength permuteDims,
ListLength shape ~ (ListMax permuteDims + 1),
((ListLength shape - 1) * ListLength shape) ~ (Sum permuteDims * 2),

outputShape ~ PermuteDims shape shape permuteDims 0
)

type family PermuteDims (shape :: [Nat]) (shape' :: [Nat]) (permuteDims :: [Nat]) (idx :: Nat) :: [Nat] where
PermuteDims shape shape' '[] idx = shape'
PermuteDims shape shape' (x ': xs) idx = PermuteDims shape (ReplaceDim' idx shape' (Index shape x)) xs (idx + 1)



-- | permute
--
-- >>> t = ones :: CPUTensor 'D.Float '[1, 2, 3, 4]
-- >>> t' = permute @'[1, 3, 2, 0] t
-- >>> dtype &&& shape $ t'
-- (Float,[2,4,3,1])
--
-- >>> t = ones :: CPUTensor 'D.Float '[1, 2]
-- >>> t' = permute @'[1, 0] t
-- >>> dtype &&& shape $ t'
-- (Float,[2,1])
--
-- >>> t = ones :: CPUTensor 'D.Float '[6, 1, 5, 2, 4, 3]
-- >>> t' = permute @'[1, 3, 5, 4, 2, 0] t
-- >>> dtype &&& shape $ t'
-- (Float,[1,2,3,4,5,6])


permute ::
forall permuteDims shape device dtype shape'.
( CheckPermuteDims shape permuteDims shape'
) =>
Tensor device dtype shape ->
Tensor device dtype shape' -- output
permute t = unsafePerformIO $ ATen.cast2 ATen.Managed.tensor_permute_l t permuteDims'
where
permuteDims' = shapeVal @permuteDims

-- | dropout
-- TODO: probably only defined for floating point tensors, or maybe numeric type is lifted?
-- TODO: get rid of IO by exposing the RNG state
Expand Down Expand Up @@ -4318,6 +4372,62 @@ stack ::
Tensor device dtype shape
stack tensors = unsafePerformIO $ ATen.cast2 ATen.Managed.stack_ll tensors (natValI @dim :: Int)


-- | Stack'

type Stack' (dim :: Nat) (preShape :: [Nat]) (shape :: [Nat]) (n :: Nat) = (
KnownNat dim,
KnownNat n,
All KnownNat preShape,
All KnownNat shape,
ListLength shape ~ (ListLength preShape + 1),
n ~ Index shape dim
)

-- | stack' - Untyped-esque stack that accepts a list of tensors
-- >>> t = ones :: CPUTensor 'D.Float '[]
-- >>> t' = stack' @0 @'[] @'[1] [t]
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [Float]) $ t'
-- (Float,([1],[1.0]))
-- >>> t = ones :: CPUTensor 'D.Float '[2,2]
-- >>> t' = stack' @0 @'[2, 2] @'[1, 2, 2] [t]
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[1, 2, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([1,2,2],[[[1.0,1.0],[1.0,1.0]]]))
-- >>> t' = stack' @1 @'[2, 2] @'[2, 1, 2] [t]
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 1, 2]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,1,2],[[[1.0,1.0]],[[1.0,1.0]]]))
-- >>> t' = stack' @2 @'[2, 2] @'[2, 2, 1] [t]
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 1]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,2,1],[[[1.0],[1.0]],[[1.0],[1.0]]]))
-- >>> t' = stack' @2 @'[2, 2] @'[2, 2, 3] [t, t, t]
-- >>> :type t'
-- t' :: Tensor '( 'D.CPU, 0) 'D.Float '[2, 2, 3]
-- >>> dtype &&& shape &&& (\t'' -> D.asValue (toDynamic t'') :: [[[Float]]]) $ t'
-- (Float,([2,2,3],[[[1.0,1.0,1.0],[1.0,1.0,1.0]],[[1.0,1.0,1.0],[1.0,1.0,1.0]]]))

stack' ::
forall dim preShape shape dtype device n.
( KnownNat dim,
Stack' dim preShape shape n
) =>
-- | input list of tensors
[Tensor device dtype preShape] ->
-- | output
Tensor device dtype shape
stack' tensors = case someNatVal (fromIntegral $ length tensors) of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I'm missing something here, but it seems like that the output shape is not at all computed from the input shape, is it? So, it's up to the user to know exactly how many elements the list has and do the shape computation by hand?
I understand that there can be desire to do this, but this is a very unsafe thing to do. A better way would be to use vecStack from below. You can convert your list first to a sized vector -- which can fail https://hackage.haskell.org/package/vector-sized-1.4.4/docs/Data-Vector-Generic-Sized.html#v:fromList -- and then use vecStack on it safely.

Just (SomeNat len) ->
case sameNat len (Proxy :: Proxy n) of
Just Refl -> unsafePerformIO $ ATen.cast2 ATen.Managed.stack_ll tensors (natValI @dim :: Int)
Nothing -> error "Shape did not match length of tensor list"

vecStack ::
forall dim n shape dtype device.
( KnownNat dim, KnownNat n ) =>
Expand Down Expand Up @@ -6012,12 +6122,12 @@ avgPool3d input =
-- upsample_linear1d _input _output_size _align_corners = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_linear1d_tlb) _input _output_size _align_corners

type family Upsample2dCheck shape h w where
Upsample2dCheck (b : c : w : h : '[]) h' w' =
Upsample2dCheck (b : c : h : w : '[]) h' w' =
If
(h <=? h')
( If
(w <=? w')
(b : c : w' : h' : '[])
(b : c : h' : w' : '[])
(TypeError (Text "Target width must be greater than current width!"))
)
(TypeError (Text "Target height must be greater than current height!"))
Expand All @@ -6030,14 +6140,14 @@ type Upsample2d shape h w = Upsample2dCheck shape h w
-- >>> (dtype &&& shape) $ upsample_bilinear2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])
upsample_bilinear2d ::
forall w h shape dtype device.
forall h w shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
-- | if True, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels.
Bool ->
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_bilinear2d _align_corners _input =
unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) _input ([w, h] :: [Int]) _align_corners
unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) _input ([h, w] :: [Int]) _align_corners
where
w = natValI @w :: Int
h = natValI @h :: Int
Expand All @@ -6047,12 +6157,12 @@ upsample_bilinear2d _align_corners _input =
-- >>> (dtype &&& shape) $ upsample_bicubic2d @3 @5 False (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])
upsample_bicubic2d ::
forall w h shape dtype device.
forall h w shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
Bool ->
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bicubic2d_tlb) _input ([w, h] :: [Int]) _align_corners
upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bicubic2d_tlb) _input ([h, w] :: [Int]) _align_corners
where
w = natValI @w :: Int
h = natValI @h :: Int
Expand All @@ -6063,20 +6173,60 @@ upsample_bicubic2d _align_corners _input = unsafePerformIO $ (ATen.cast3 ATen.Ma
-- upsample_nearest1d :: Tensor device dtype shape -> Int -> Tensor device dtype shape
-- upsample_nearest1d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest1d_tl) _input _output_size

-- | Applies a 2D bicubic upsampling to an input signal composed of several input channels.
-- | Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
--
-- >>> (dtype &&& shape) $ upsample_nearest2d @3 @5 (ones :: CPUTensor 'D.Float '[2,3,2,2])
-- (Float,[2,3,3,5])

upsample_nearest2d ::
forall w h shape dtype device.
forall h w shape dtype device.
(KnownNat h, KnownNat w, All KnownNat shape) =>
Tensor device dtype shape ->
Tensor device dtype (Upsample2d shape h w)
upsample_nearest2d _input = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) _input ([w, h] :: [Int])
upsample_nearest2d _input = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) _input ([h, w] :: [Int])
where
w = natValI @w :: Int
h = natValI @h :: Int


-- Freeform Resizing
interpolate2d ::
forall newShape mode w h shape dtype device.
(KnownShape shape, KnownNat w, KnownNat h, KnownSymbol mode) =>
Bool ->
Tensor device 'D.Float shape ->
Tensor device 'D.Float newShape
interpolate2d alignCorners tensor
| symbolVal (Proxy @mode) == "nearest" = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest2d_tl) tensor ([h, w] :: [Int])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you forget to compute newShape here? It seems it's completely undetermined.

Can I also ask you to not use a Symbol for case splitting here? There are two alternatives that I find more appropriate:

  • split the function into two and name them appropriately.
  • use a sum type, e.g. Interpolate2dMode = Interpolate2dNearest | Interpolate2dBilinear and add it to the arguments.

can you add a sum type for alignCorners? e.g. Interpolate2dAlignCorners = Interpolate2dWithAlignedCorners | Interpolate2dWithoutAlignedCorners

| symbolVal (Proxy @mode) == "bilinear" = unsafePerformIO $ (ATen.cast3 ATen.Managed.upsample_bilinear2d_tlb) tensor ([h, w] :: [Int]) alignCorners
| otherwise = error "Invalid mode for interpolation"
where
w = natValI @w
h = natValI @h

type Resize2D (shape :: [Nat]) (newShape :: [Nat]) (n :: Nat) (c :: Nat) (h :: Nat) (w :: Nat) (n0 :: Nat) (c0 :: Nat) (h0 :: Nat) (w0 :: Nat) =
(
KnownShape shape,
KnownShape newShape,
All KnownNat newShape,
4 ~ ListLength shape,
4 ~ ListLength newShape,
(n ': c ': h ': w : '[]) ~ newShape,
(n0 ': c0 ': h0 ': w0 ': '[]) ~ shape,
n ~ n0,
c ~ c0
)


resize2d ::
forall newShape mode shape dtype device n c h w n0 c0 h0 w0.
(Resize2D shape newShape n c h w n0 c0 h0 w0, KnownSymbol mode, KnownDType dtype) =>
Bool ->
Tensor device dtype shape ->
Tensor device dtype newShape
resize2d alignCorners = toDType @dtype @'D.Float . interpolate2d @newShape @mode @w @h alignCorners . toDType @'D.Float @dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comments as for interpolate2d. I think that you may have intended interpolate2d as an internal method? If so, please add it in a where clause within the scope of resize2d.



-- upsample_nearest3d :: Tensor device dtype shape -> (Int,Int,Int) -> Tensor device dtype shape
-- upsample_nearest3d _input _output_size = unsafePerformIO $ (ATen.cast2 ATen.Managed.upsample_nearest3d_tl) _input _output_size

Expand Down
20 changes: 20 additions & 0 deletions hasktorch/src/Torch/Typed/Tensor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.C.Types (CBool (..))
import Foreign.ForeignPtr
import Foreign.Storable
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
Expand All @@ -50,6 +52,7 @@ import Torch.Internal.Class
CppTuple4 (..),
)
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as ATen (tensor_is_contiguous, tensor_contiguous, tensor_data_ptr)
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Aux
Expand Down Expand Up @@ -305,6 +308,23 @@ withTensor untypedTensor f = case someShape (D.shape untypedTensor) of
(SomeDType (Proxy :: Proxy dtype)) -> case someDevice (D.device untypedTensor) of
(SomeDevice (Proxy :: Proxy device)) -> f $ UnsafeMkTensor @device @dtype @shape untypedTensor

isContiguous ::
forall shape dtype device.
Tensor device dtype shape ->
Bool
isContiguous t = unsafePerformIO $ cast1 (\x -> cast (x :: D.ATenTensor) ATen.tensor_is_contiguous) t

contiguous ::
forall shape dtype device.
Tensor device dtype shape ->
Tensor device dtype shape
contiguous t = unsafePerformIO $ uncast (unsafePerformIO $ cast1 (\x -> cast (x :: D.ATenTensor) ATen.tensor_contiguous) t :: D.ATenTensor) return

withTensorPtr :: forall shape dtype device a. Tensor shape dtype device -> (Ptr () -> IO a) -> IO a
withTensorPtr t fn =
let tensor = if isContiguous t then t else contiguous t
in cast tensor $ \t' -> withForeignPtr t' $ \tensor_ptr -> ATen.tensor_data_ptr tensor_ptr >>= fn

--------------------------------------------------------------------------------
-- Broadcast type-level function
--------------------------------------------------------------------------------
Expand Down