Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Monad instance for Tensor #515

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions hasktorch/hasktorch.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ library
, Torch.Optim
, Torch.Optim.CppOptim
, Torch.Vision
, Torch.Monad
, Torch.Typed.Monad
, Torch.NN
, Torch.NN.Recurrent.Cell.Elman
, Torch.NN.Recurrent.Cell.GRU
Expand Down
223 changes: 223 additions & 0 deletions hasktorch/src/Torch/Monad.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Monad where

import qualified Torch.Tensor as T
import Control.Monad
import Control.Applicative
import GHC.TypeLits
import Data.Proxy
import Data.Coerce
import Data.Finite
import Data.Kind
import qualified Data.Vector.Sized as V
import Data.Functor.Compose
import Data.Singletons.Prelude (Reverse)

newtype Batch a = Batch [a]
newtype Channel (ch::Nat) a = Channel [a]
newtype DX a = DX [a]
newtype DY a = DY [a]
newtype CPU a = CPU a
newtype CUDA a = CUDA a

data Tensor a where
Prim :: (T.TensorLike a) => T.Tensor -> Tensor a
Return :: a -> Tensor a
Bind :: Tensor a -> (a -> Tensor b) -> Tensor b

instance {-# OVERLAPPING #-} (T.TensorLike a) => T.TensorLike (Batch a) where
asTensor' v opt = T.asTensor' @[a] (coerce v) opt
asTensor v = T.asTensor @[a] (coerce v)
_asValue v = coerce $ T._asValue @[a] v
_dtype = T._dtype @[a]
_dims v = T._dims @[a] (coerce v)
_deepDims v = T._deepDims @[a] (coerce v)
_peekElemOff ptr offset v = coerce $ T._peekElemOff @[a] ptr offset v
_pokeElemOff ptr offset v = T._pokeElemOff @[a] ptr offset (coerce v)

instance {-# OVERLAPPING #-} (T.TensorLike a) => T.TensorLike (Channel ch a) where
asTensor' v opt = T.asTensor' @[a] (coerce v) opt
asTensor v = T.asTensor @[a] (coerce v)
_asValue v = coerce $ T._asValue @[a] v
_dtype = T._dtype @[a]
_dims v = T._dims @[a] (coerce v)
_deepDims v = T._deepDims @[a] (coerce v)
_peekElemOff ptr offset v = coerce $ T._peekElemOff @[a] ptr offset v
_pokeElemOff ptr offset v = T._pokeElemOff @[a] ptr offset (coerce v)

instance {-# OVERLAPPING #-} (T.TensorLike a) => T.TensorLike (DX a) where
asTensor' v opt = T.asTensor' @[a] (coerce v) opt
asTensor v = T.asTensor @[a] (coerce v)
_asValue v = coerce $ T._asValue @[a] v
_dtype = T._dtype @[a]
_dims v = T._dims @[a] (coerce v)
_deepDims v = T._deepDims @[a] (coerce v)
_peekElemOff ptr offset v = coerce $ T._peekElemOff @[a] ptr offset v
_pokeElemOff ptr offset v = T._pokeElemOff @[a] ptr offset (coerce v)

instance {-# OVERLAPPING #-} (T.TensorLike a) => T.TensorLike (DY a) where
asTensor' v opt = T.asTensor' @[a] (coerce v) opt
asTensor v = T.asTensor @[a] (coerce v)
_asValue v = coerce $ T._asValue @[a] v
_dtype = T._dtype @[a]
_dims v = T._dims @[a] (coerce v)
_deepDims v = T._deepDims @[a] (coerce v)
_peekElemOff ptr offset v = coerce $ T._peekElemOff @[a] ptr offset v
_pokeElemOff ptr offset v = T._pokeElemOff @[a] ptr offset (coerce v)

instance (T.TensorLike a) => T.TensorLike (Tensor a) where
asTensor' = error "Not implemented for Tensor-a-type"
asTensor = toTensor
_asValue = Prim
_dtype = error "Not implemented for Tensor-a-type"
_dims v = error "Not implemented for Tensor-a-type"
_deepDims v = error "Not implemented for Tensor-a-type"
_peekElemOff = error "Not implemented for Tensor-a-type"
_pokeElemOff = error "Not implemented for Tensor-a-type"

{-
instance (T.TensorLike a) => T.TensorLike (CPU a) where
asTensor' = error "Not implemented for Tensor-a-type"
asTensor = T.toCPU . toTensor
_asValue = Prim . T.toCPU
_dtype = error "Not implemented for Tensor-a-type"
_dims v = error "Not implemented for Tensor-a-type"
_deepDims v = error "Not implemented for Tensor-a-type"
_peekElemOff = error "Not implemented for Tensor-a-type"
_pokeElemOff = error "Not implemented for Tensor-a-type"

instance (T.TensorLike a) => T.TensorLike (CUDA a) where
asTensor' = error "Not implemented for Tensor-a-type"
asTensor = T.toCUDA . toTensor
_asValue = Prim
_dtype = error "Not implemented for Tensor-a-type"
_dims v = error "Not implemented for Tensor-a-type"
_deepDims v = error "Not implemented for Tensor-a-type"
_peekElemOff = error "Not implemented for Tensor-a-type"
_pokeElemOff = error "Not implemented for Tensor-a-type"
-}

toTensor :: (T.TensorLike a) => Tensor a -> T.Tensor
toTensor (Prim s) = s
toTensor (Return a) = T.asTensor $ a
toTensor (Bind (Prim s) f) = toTensor (f (T.asValue s))
toTensor (Bind (Return a) f) = toTensor (f a)
toTensor (Bind (Bind ma f) g) = toTensor (Bind ma (\a -> Bind (f a) g))

-- (!!) :: Int -> Tensor [a] -> Tensor a
-- (!!) n tensor = return $ (toTensor tensor) T.! n

instance Functor Tensor where
fmap = liftM

instance Applicative Tensor where
pure = return
(<*>) = ap

instance Monad Tensor where
return = Return
(>>=) = Bind

asValue :: (T.TensorLike a) => Tensor a -> a
asValue = T.asValue . toTensor

instance Functor Batch where
fmap = liftM

instance Applicative Batch where
pure = return
(<*>) = ap

instance Monad Batch where
return v = Batch [v]
(>>=) (Batch xs) f =
Batch $ do
x <- xs
let Batch y = f x
y

instance Functor (Channel ch) where
fmap = liftM

instance Applicative (Channel ch) where
pure = return
(<*>) = ap

instance Monad (Channel ch) where
return v = Channel [v]
(>>=) (Channel xs) f =
Channel $ do
x <- xs
let Channel y = f x
y

instance Functor DX where
fmap = liftM

instance Applicative DX where
pure = return
(<*>) = ap

instance Monad DX where
return v = DX [v]
(>>=) (DX xs) f =
DX $ do
x <- xs
let DX y = f x
y

instance Functor DY where
fmap = liftM

instance Applicative DY where
pure = return
(<*>) = ap

instance Monad DY where
return v = DY [v]
(>>=) (DY xs) f =
DY $ do
x <- xs
let DY y = f x
y

-- concat :: [Tensor a] -> Tensor [a]
--



foo :: Tensor [Float]
foo = return [1,2,3]

foo' :: T.Tensor
foo' = toTensor foo

bfoo :: Tensor (Batch [Float])
bfoo = return $ return [1,2,3]

bfoo' :: T.Tensor
bfoo' = toTensor bfoo

bcfoo :: Tensor (Batch (Channel 1 [Float]))
bcfoo = return $ return $ return [1,2,3]

bcfoo' :: T.Tensor
bcfoo' = toTensor bcfoo

bcxyfoo :: Tensor (Batch (Channel 1 (DY (DX [Float]))))
bcxyfoo = return $ return $ return $ return $ return [1,2,3]

bcxyfoo' :: T.Tensor
bcxyfoo' = toTensor bcxyfoo
43 changes: 43 additions & 0 deletions hasktorch/src/Torch/Tensor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import qualified Data.Vector as V
import qualified Data.Vector.Sized as S
import Data.Word (Word8)
import GHC.TypeLits
import Data.Finite
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
Expand All @@ -46,6 +49,7 @@ import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import Torch.TensorOptions
import qualified Data.Vector.Sized as SV

type ATenTensor = ForeignPtr ATen.Tensor

Expand Down Expand Up @@ -595,6 +599,45 @@ instance {-# OVERLAPPING #-} TensorLike a => TensorLike [a] where
then (_pokeElemOff @a) ptr (offset + i * width) d
else throwIO $ userError $ "There are lists having different length."

instance {-# OVERLAPPING #-} (KnownNat n, TensorLike a) => TensorLike (SV.Vector n a) where
asTensor' v opts = unsafePerformIO $ do
t <- ((cast2 LibTorch.empty_lo) :: [Int] -> TensorOptions -> IO Tensor) (_dims v) $ withDType (_dtype @a) opts
withTensor t $ \ptr -> do
_pokeElemOff ptr 0 v
return t

asTensor v = asTensor' v defaultOpts

_asValue t = unsafePerformIO $ do
if _dtype @a == dtype t
then do
withTensor t $ \ptr -> do
_peekElemOff ptr 0 (shape t)
else throwIO $ userError $ "The infered DType of asValue is " ++ show (_dtype @a) ++ ", but the DType of tensor on memory is " ++ show (dtype t) ++ "."

_dtype = _dtype @a

_dims v = (SV.length v) : (_dims (SV.index v 0))
_deepDims v = Just $ _dims v


_peekElemOff ptr offset [] = throwIO $ userError $ "Sized vector's size is zero."
_peekElemOff ptr offset (d : dims) =
let width = product dims
in do
v <- fmap SV.fromList $ forM [0 .. (d -1)] $ \i ->
_peekElemOff ptr (offset + i * width) dims
case v of
Nothing -> throwIO $ userError $ "Sized vector is not corrent."
Just v' -> return v'

_pokeElemOff ptr offset v =
let width = product (_dims v)
in forM_ (zip [0 ..] (SV.toList v)) $ \(i, d) ->
if product (_dims d) == width -- This validation may be slow.
then (_pokeElemOff @a) ptr (offset + i * width) d
else throwIO $ userError $ "There are lists having different length."

class AsTensors as where
toTensors :: as -> V.Vector Tensor
default toTensors :: (Generic as, GAsTensors (Rep as)) => as -> V.Vector Tensor
Expand Down