Skip to content

Commit

Permalink
Add a example of withTensor, but it does not work.
Browse files Browse the repository at this point in the history
  • Loading branch information
junjihashimoto committed Aug 30, 2019
1 parent 2f2a5e9 commit eaebc3d
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions hasktorch/src/Torch/Static.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import Data.Finite
import Data.Kind (Constraint)
import Data.Reflection
import GHC.TypeLits
import Data.Int
import Data.Word

import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
Expand Down Expand Up @@ -94,19 +96,41 @@ someShape (h : t) = case someNatVal (fromIntegral h) of
(SomeShape (Proxy :: Proxy tt)) -> SomeShape $ Proxy @(ht ': tt)

data SomeDType where
SomeDType :: forall (dtype :: DType.DType). Proxy dtype -> SomeDType
SomeDType :: forall dtype. (Reifies dtype DType.DType) => Proxy dtype -> SomeDType

someDType :: DType.DType -> SomeDType
someDType DType.Float = SomeDType $ Proxy @DType.Float

withTensor :: D.Tensor ->
(forall (dtype :: DType.DType) (shape :: [Nat]).
KnownShape shape => Tensor dtype shape -> r) ->
r

withTensor d f = case someShape (D.shape d) of
(SomeShape (Proxy :: Proxy shape)) -> case someDType (D.dtype d) of
(SomeDType (Proxy :: Proxy dtype)) -> f $ UnsafeMkTensor @dtype @shape d
someDType DType.UInt8 = SomeDType $ Proxy @Word8
someDType DType.Int8 = SomeDType $ Proxy @Int8
someDType DType.Int16 = SomeDType $ Proxy @Int16
someDType DType.Int32 = SomeDType $ Proxy @Int32
someDType DType.Int64 = SomeDType $ Proxy @Int64
someDType DType.Half = undefined
someDType DType.Float = SomeDType $ Proxy @Float
someDType DType.Double = SomeDType $ Proxy @Double


withTensor
:: forall r. D.Tensor
-> (forall dtype (shape :: [Nat]). Tensor dtype shape -> r)
-> r
withTensor d f =
case (someShape (D.shape d),
someDType (D.dtype d)) of
(SomeShape (Proxy :: Proxy shape),
SomeDType (Proxy :: Proxy dtype)) ->
f $ UnsafeMkTensor d

mm' :: D.Tensor -> D.Tensor -> D.Tensor
mm' t0 t1 =
withTensor t0 $ \t0' ->
withTensor t1 $ \t1' ->
toDynamic $ mm t0' t1'

add' :: D.Tensor -> D.Tensor -> D.Tensor
add' t0 t1 =
withTensor t0 $ \t0' ->
withTensor t1 $ \t1' ->
toDynamic $ t0' + t1'

--------------------------------------------------------------------------------
-- Broadcast type-level function
Expand Down

0 comments on commit eaebc3d

Please sign in to comment.