Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lenses for indexing/slicing #613

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ test-suite spec
, tokenizers
, vector
, vector-sized
, template-haskell

build-tool-depends: hspec-discover:hspec-discover

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import Torch.GraduallyTyped.NN.Type (HasDropout, SHasDropout)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient (..))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim, SSelectDim (..), SelectDim (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
import Torch.GraduallyTyped.Tensor.Indexing (IndexShape, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (GatherDimF, SqueezeDimF, UnsqueezeF, sGatherDim, sSqueezeDim, sUnsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar)
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (MeanAllCheckF, meanAll)
Expand Down Expand Up @@ -846,7 +846,7 @@ instance
targetLayout
targetDevice
targetDataType
(IndexDims ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)]) targetShape),
(IndexShape ('Indices '[ 'SliceAll, 'SliceUpTo ('NegativeIndex 1)]) targetShape),
decoderOutput
~ Tensor
doGradient
Expand Down
167 changes: 141 additions & 26 deletions experimental/gradually-typed/src/Torch/GraduallyTyped/Tensor/Indexing.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -22,14 +26,20 @@ module Torch.GraduallyTyped.Tensor.Indexing
SIndexType (..),
Indices (..),
SIndices (..),
IndexDims,
IndexShape,
(!),
slice,
parseSlice,
setAt,
setAtLike,
toLens,
toLensLike,
)
where

import Control.Arrow ((>>>))
import Control.Monad (forM_, void, (<=<))
import Control.Lens (Traversal)
import Control.Monad (forM_, join, void, (<=<))
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Trans (lift)
import Data.Coerce (coerce)
Expand All @@ -40,20 +50,24 @@ import Data.Singletons.Prelude (Reverse, SBool (..), SList (..), Sing)
import Data.Singletons.TH (genSingletons)
import Data.Type.Equality (type (==))
import Data.Void (Void)
import Foreign (fromBool)
import Foreign (ForeignPtr, fromBool)
import GHC.TypeLits (Div, ErrorMessage (..), Nat, Symbol, type (+), type (-), type (<=?))
import Language.Haskell.TH.Quote (QuasiQuoter (..))
import qualified Language.Haskell.TH as TH
import Text.Megaparsec as M
import qualified Language.Haskell.TH.Syntax as TH
import Text.Megaparsec (ParsecT, between, empty, eof, errorBundlePretty, optional, runParserT, sepBy, some, try, (<|>))
import qualified Text.Megaparsec.Char as M
import qualified Text.Megaparsec.Char.Lexer as L
import Torch.GraduallyTyped.DType (DataType (..))
import Torch.GraduallyTyped.Index.Type (DemotedIndex (..), Index (..), SIndex (..))
import Torch.GraduallyTyped.Prelude (If, IsChecked (..), forgetIsChecked, type (<?))
import Torch.GraduallyTyped.Shape.Class (PrependDimF)
import Torch.GraduallyTyped.Prelude (Catch, If, IsChecked (..), forgetIsChecked, type (<?))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF, PrependDimF)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor (..))
import Torch.GraduallyTyped.Tensor.Type (SGetShape, Tensor (..), TensorLike, fromTensor, toTensor)
import Torch.GraduallyTyped.Unify (type (<+>))
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import qualified Torch.Internal.Type as ATen
import Type.Errors.Pretty (TypeError, type (%), type (<>))

data IndexType a
Expand All @@ -69,10 +83,12 @@ data IndexType a
| SliceFromWithStep a a
| SliceUpToWithStep a a
| SliceFromUpToWithStep a a a
deriving (Show, Eq, Functor)
deriving stock (Show, Eq, Functor, TH.Lift)

genSingletons [''IndexType]

deriving stock instance Show (SIndexType (indexType :: IndexType (Index Nat)))

type ReverseShape :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family ReverseShape shape where
ReverseShape 'UncheckedShape = 'UncheckedShape
Expand Down Expand Up @@ -173,7 +189,10 @@ type family CheckStep (step :: Index Nat) ok where
CheckStep ('Index 0) _ = TypeError StepZeroErrorMessage
CheckStep _ ok = ok

type IndexDimsImpl :: [IndexType (Index Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type IndexDimsImpl ::
[IndexType (Index Nat)] ->
[Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)]
type family IndexDimsImpl indices dims where
IndexDimsImpl '[] dims = 'Shape dims
IndexDimsImpl ('NewAxis ': ixs) dims = 'Dim ('Name "*") ('Size 1) `PrependDimF` IndexDimsImpl ixs dims
Expand Down Expand Up @@ -213,19 +232,26 @@ type family IndexDimsImpl indices dims where
IndexDimsImpl ('SliceFromUpToWithStep _ _ step ': ixs) ('Dim name _ ': dims) =
CheckStep step ('Dim name 'UncheckedSize `PrependDimF` IndexDimsImpl ixs dims)

type family IndexDims indices shape where
IndexDims 'UncheckedIndices _ = 'UncheckedShape
IndexDims _ 'UncheckedShape = 'UncheckedShape
IndexDims ('Indices indices) ('Shape dims) = IndexDimsImpl indices dims
type IndexShape ::
Indices [IndexType (Index Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)]
type family IndexShape indices shape where
IndexShape 'UncheckedIndices _ = 'UncheckedShape
IndexShape _ 'UncheckedShape = 'UncheckedShape
IndexShape ('Indices indices) ('Shape dims) = IndexDimsImpl indices dims

data Indices (indexTypes :: Type) where
UncheckedIndices :: forall indexTypes. Indices indexTypes
Indices :: forall indexTypes. indexTypes -> Indices indexTypes
deriving (Show)

data SIndices (indices :: Indices [IndexType (Index Nat)]) where
SUncheckedIndices :: [IndexType Integer] -> SIndices 'UncheckedIndices
SIndices :: forall indexTypes. SList indexTypes -> SIndices ('Indices indexTypes)

deriving stock instance Show (SIndices (indices :: Indices [IndexType (Index Nat)]))

type instance Sing = SIndices

instance SingI indexTypes => SingI ('Indices (indexTypes :: [IndexType (Index Nat)])) where
Expand All @@ -238,19 +264,13 @@ instance SingKind (Indices [IndexType (Index Nat)]) where
toSing (Unchecked indexTypes) = SomeSing . SUncheckedIndices $ fmap forgetIsChecked <$> indexTypes
toSing (Checked indexTypes) = withSomeSing ((fmap . fmap . fmap) DemotedIndex indexTypes) $ SomeSing . SIndices

(!) ::
forall indices requiresGradient layout device dataType shape m.
MonadThrow m =>
Tensor requiresGradient layout device dataType shape ->
SIndices indices ->
m (Tensor requiresGradient layout device dataType (IndexDims indices shape))
(UnsafeTensor t) ! sIndices = unsafeThrowableIO $ do
toTensorIndexList :: [IndexType Integer] -> IO (ForeignPtr (ATen.StdVector ATen.TensorIndex))
toTensorIndexList indices = do
indexList <- ATen.newTensorIndexList
tensorIndices <- traverse toTensorIndex indices
forM_ tensorIndices $ ATen.tensorIndexList_push_back indexList
UnsafeTensor <$> ATen.index t indexList
pure indexList
where
indices = fmap forgetIsChecked <$> forgetIsChecked (fromSing sIndices)
toTensorIndex =
fmap fromIntegral >>> \case
NewAxis -> ATen.newTensorIndexWithNone
Expand All @@ -266,6 +286,101 @@ instance SingKind (Indices [IndexType (Index Nat)]) where
SliceUpToWithStep upTo step -> ATen.newTensorIndexWithSlice 0 upTo step
SliceFromUpToWithStep from upTo step -> ATen.newTensorIndexWithSlice from upTo step

-- | Indexes/slices a tensor.
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> sRandn' = sRandn . TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat)
-- >>> t = sRandn' (SShape $ SName @"*" :&: SSize @3 :|: SName @"*" :&: SSize @5 :|: SNil) g
-- >>> result <- t ! [slice|:2, 3|]
(!) ::
forall indices gradient layout device dataType shape m.
MonadThrow m =>
Tensor gradient layout device dataType shape ->
SIndices indices ->
m (Tensor gradient layout device dataType (IndexShape indices shape))
(UnsafeTensor t) ! sIndices = unsafeThrowableIO $ do
indexList <- toTensorIndexList indices
UnsafeTensor <$> ATen.index t indexList
where
indices = fmap forgetIsChecked <$> forgetIsChecked (fromSing sIndices)

setAt ::
forall gradient layout device dataType shape shape' indices m.
( shape' ~ BroadcastShapesF shape' (IndexShape indices shape),
SingI gradient,
SingI layout,
SingI device,
MonadThrow m
) =>
Tensor gradient layout device dataType shape ->
SIndices indices ->
Tensor gradient layout device dataType shape' ->
m (Tensor gradient layout device dataType shape)
setAt (UnsafeTensor t') sIndices (UnsafeTensor x) = unsafeThrowableIO $ do
t <- ATen.clone_t t'
indexList <- toTensorIndexList indices
UnsafeTensor <$> ATen.index_put_ t indexList x
where
indices = fmap forgetIsChecked <$> forgetIsChecked (fromSing sIndices)

setAtLike ::
forall gradient layout device dataType shape indices a dType dims m.
( TensorLike a dType dims,
dataType ~ 'DataType dType,
Catch ('Shape dims <+> BroadcastShapesF ('Shape dims) (IndexShape indices shape)),
SingI gradient,
SingI layout,
SingI device,
MonadThrow m
) =>
Tensor gradient layout device dataType shape ->
SIndices indices ->
a ->
m (Tensor gradient layout device dataType shape)
setAtLike (UnsafeTensor t') sIndices x' = unsafeThrowableIO $ do
t <- ATen.clone_t t'
indexList <- toTensorIndexList indices
UnsafeTensor x <- toTensor @gradient @layout @device x'
UnsafeTensor <$> ATen.index_put_ t indexList x
where
indices = fmap forgetIsChecked <$> forgetIsChecked (fromSing sIndices)

toLens ::
forall gradient layout device dataType shape s a indices m.
( s ~ Tensor gradient layout device dataType shape,
a ~ Tensor gradient layout device dataType (IndexShape indices shape),
SingI gradient,
SingI layout,
SingI device,
MonadThrow m,
Traversable m
) =>
SIndices indices ->
Traversal s (m s) a a
toLens sIndices f s =
let fmms = sequenceA $ fmap (setAt s sIndices) . f <$> s ! sIndices
in join <$> fmms

toLensLike ::
forall s gradient layout device dataType shape indices a dType dims shape' m.
( s ~ Tensor gradient layout device dataType shape,
TensorLike a dType dims,
dataType ~ 'DataType dType,
shape' ~ 'Shape dims,
shape' ~ IndexShape indices shape,
Catch ('Shape dims <+> BroadcastShapesF ('Shape dims) (IndexShape indices shape)),
SGetShape shape',
SingI gradient,
SingI layout,
SingI device,
MonadThrow m,
Traversable m
) =>
SIndices indices ->
Traversal s (m s) a a
toLensLike sIndices f s =
let fmms = sequenceA $ fmap (setAtLike s sIndices) . f . fromTensor <$> s ! sIndices
in join <$> fmms

type Parser = ParsecT Void String TH.Q

sc :: Parser ()
Expand All @@ -281,7 +396,7 @@ string :: String -> Parser String
string = lexeme . M.string

parseSlice :: String -> TH.Q TH.Exp
parseSlice = either (fail . errorBundlePretty) pure <=< M.runParserT indicesP ""
parseSlice = either (fail . errorBundlePretty) pure <=< runParserT indicesP ""
where
indicesP :: Parser TH.Exp
indicesP = do
Expand All @@ -308,7 +423,7 @@ parseSlice = either (fail . errorBundlePretty) pure <=< M.runParserT indicesP ""
[ do
index <- L.signed sc $ lexeme L.decimal
let con = if index < 0 then [|SNegativeIndex|] else [|SIndex|]
nat = TH.litT $ TH.numTyLit $ abs index
nat = pure $ TH.LitT $ TH.NumTyLit $ abs index
lift [|$con @($nat)|],
TH.VarE . TH.mkName <$> lexeme (between (char '{') (char '}') (some M.alphaNumChar))
]
Expand Down Expand Up @@ -369,4 +484,4 @@ slice =
quoteDec = notHandled
}
where
notHandled = const . fail $ "'slice' quasiquoter can only be used as an expression."
notHandled = const $ fail "'slice' quasiquoter can only be used as an expression."