Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add a example of withTensor, but it does not work. #161

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 35 additions & 11 deletions hasktorch/src/Torch/Static.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import Data.Finite
import Data.Kind (Constraint)
import Data.Reflection
import GHC.TypeLits
import Data.Int
import Data.Word

import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
Expand Down Expand Up @@ -94,19 +96,41 @@ someShape (h : t) = case someNatVal (fromIntegral h) of
(SomeShape (Proxy :: Proxy tt)) -> SomeShape $ Proxy @(ht ': tt)

data SomeDType where
SomeDType :: forall (dtype :: DType.DType). Proxy dtype -> SomeDType
SomeDType :: forall dtype. (Reifies dtype DType.DType) => Proxy dtype -> SomeDType

someDType :: DType.DType -> SomeDType
someDType DType.Float = SomeDType $ Proxy @DType.Float

withTensor :: D.Tensor ->
(forall (dtype :: DType.DType) (shape :: [Nat]).
KnownShape shape => Tensor dtype shape -> r) ->
r

withTensor d f = case someShape (D.shape d) of
(SomeShape (Proxy :: Proxy shape)) -> case someDType (D.dtype d) of
(SomeDType (Proxy :: Proxy dtype)) -> f $ UnsafeMkTensor @dtype @shape d
someDType DType.UInt8 = SomeDType $ Proxy @Word8
someDType DType.Int8 = SomeDType $ Proxy @Int8
someDType DType.Int16 = SomeDType $ Proxy @Int16
someDType DType.Int32 = SomeDType $ Proxy @Int32
someDType DType.Int64 = SomeDType $ Proxy @Int64
someDType DType.Half = undefined
someDType DType.Float = SomeDType $ Proxy @Float
someDType DType.Double = SomeDType $ Proxy @Double


withTensor
:: forall r. D.Tensor
-> (forall dtype (shape :: [Nat]). Tensor dtype shape -> r)
-> r
withTensor d f =
case (someShape (D.shape d),
someDType (D.dtype d)) of
(SomeShape (Proxy :: Proxy shape),
SomeDType (Proxy :: Proxy dtype)) ->
f $ UnsafeMkTensor d

mm' :: D.Tensor -> D.Tensor -> D.Tensor
mm' t0 t1 =
withTensor t0 $ \t0' ->
withTensor t1 $ \t1' ->
toDynamic $ mm t0' t1'

add' :: D.Tensor -> D.Tensor -> D.Tensor
add' t0 t1 =
withTensor t0 $ \t0' ->
withTensor t1 $ \t1' ->
toDynamic $ t0' + t1'

--------------------------------------------------------------------------------
-- Broadcast type-level function
Expand Down