Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward unification 1: add output type family, HasForward instances #477

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
284 changes: 229 additions & 55 deletions hasktorch/src/Torch/NN.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
Expand All @@ -16,14 +19,18 @@ module Torch.NN where
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State.Strict
import Data.Kind
import Data.Proxy
import Data.Type.Bool
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import Torch.Autograd
import Torch.Functional
import Torch.Initializers
import Torch.Internal.Cast (cast3)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import Torch.Random (Generator)
import Torch.Tensor
import Torch.TensorFactories (ones', randIO', randnIO')

Expand All @@ -38,73 +45,240 @@ nextParameter = do
[] -> error "Not enough parameters supplied to replaceParameters"
(p : t) -> do put t; return p

class HasForward f a b | f a -> b where
forward :: f -> a -> b
default forward ::
( Generic f,
Generic a,
Generic b,
GHasForward (Rep f) (Rep a) (Rep b)
) =>
f ->
a ->
b
forward f a = to $ gForward (from f) (from a)
forwardStoch :: f -> a -> IO b
default forwardStoch ::
( Generic f,
Generic a,
Generic b,
GHasForward (Rep f) (Rep a) (Rep b)
) =>
f ->
a ->
IO b
forwardStoch f a = to <$> gForwardStoch (from f) (from a)

class GHasForward (f :: Type -> Type) (a :: Type -> Type) (b :: Type -> Type) | f a -> b where
gForward :: forall c c' c''. f c -> a c' -> b c''
gForwardStoch :: forall c c' c''. f c -> a c' -> IO (b c)

instance GHasForward U1 U1 U1 where
gForward U1 U1 = U1
gForwardStoch U1 U1 = return U1
class HasForward model input where
type Output model input :: Type
forward :: model -> input -> Output model input

data ModelRandomness = Deterministic | Stochastic

type family Contains (f :: k) (a :: Type) :: Bool where
Contains a a = 'True
Contains (f g) a = Contains f a || Contains g a
Contains _ _ = 'False

type family ModelRandomnessR (output :: Type) :: (ModelRandomness, Type) where
ModelRandomnessR (Generator -> (output, Generator)) =
If
(Contains output Generator)
(TypeError (Text "The random generator appears in a wrong position in the output type."))
'( 'Stochastic, output)
ModelRandomnessR output =
If
(Contains output Generator)
(TypeError (Text "The random generator appears in a wrong position in the output type."))
'( 'Deterministic, output)

class
HasForwardProduct
(modelARandomness :: ModelRandomness)
modelA
inputA
outputA
(modelBRandomness :: ModelRandomness)
modelB
inputB
outputB
where
type OutputProduct modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB :: Type
forwardProduct ::
Proxy modelARandomness ->
modelA ->
inputA ->
Proxy outputA ->
Proxy modelBRandomness ->
modelB ->
inputB ->
Proxy outputB ->
OutputProduct modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB

class
HasForwardSum
(modelARandomness :: ModelRandomness)
modelA
inputA
outputA
(modelBRandomness :: ModelRandomness)
modelB
inputB
outputB
where
type OutputSum modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB :: Type
forwardSum ::
Proxy modelARandomness ->
Proxy modelBRandomness ->
Either modelA modelB ->
Either inputA inputB ->
Proxy (Either outputA outputB) ->
OutputSum modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB

instance
( HasForward modelA inputA,
Output modelA inputA ~ outputA,
HasForward modelB inputB,
Output modelB inputB ~ outputB
) =>
HasForwardProduct 'Deterministic modelA inputA outputA 'Deterministic modelB inputB outputB
where
type OutputProduct 'Deterministic modelA inputA outputA 'Deterministic modelB inputB outputB = (outputA, outputB)
forwardProduct _ modelA inputA _ _ modelB inputB _ = (forward modelA inputA, forward modelB inputB)

instance
( HasForward modelA inputA,
Output modelA inputA ~ outputA,
HasForward modelB inputB,
Output modelB inputB ~ outputB
) =>
HasForwardSum 'Deterministic modelA inputA outputA 'Deterministic modelB inputB outputB
where
type OutputSum 'Deterministic modelA inputA outputA 'Deterministic modelB inputB outputB = Maybe (Either outputA outputB)
forwardSum _ _ (Left modelA) (Left inputA) _ = Just . Left $ forward modelA inputA
forwardSum _ _ (Right modelB) (Right inputB) _ = Just . Right $ forward modelB inputB
forwardSum _ _ _ _ _ = Nothing

instance
( HasForward modelA inputA,
Output modelA inputA ~ (Generator -> (outputA, Generator)),
HasForward modelB inputB,
Output modelB inputB ~ outputB
) =>
HasForwardProduct 'Stochastic modelA inputA outputA 'Deterministic modelB inputB outputB
where
type OutputProduct 'Stochastic modelA inputA outputA 'Deterministic modelB inputB outputB = Generator -> ((outputA, outputB), Generator)
forwardProduct _ modelA inputA _ _ modelB inputB _ = \g -> let (outputA, g') = forward modelA inputA g in ((outputA, forward modelB inputB), g')

instance
( HasForward modelA inputA,
Output modelA inputA ~ outputA,
HasForward modelB inputB,
Output modelB inputB ~ (Generator -> (outputB, Generator))
) =>
HasForwardProduct 'Deterministic modelA inputA outputA 'Stochastic modelB inputB outputB
where
type OutputProduct 'Deterministic modelA inputA outputA 'Stochastic modelB inputB outputB = Generator -> ((outputA, outputB), Generator)
forwardProduct _ modelA inputA _ _ modelB inputB _ = \g -> let (outputB, g') = forward modelB inputB g in ((forward modelA inputA, outputB), g')

instance
( GHasForward f a b,
GHasForward g a' b',
b'' ~ (b :+: b')
( HasForward modelA inputA,
Output modelA inputA ~ (Generator -> (outputA, Generator)),
HasForward modelB inputB,
Output modelB inputB ~ outputB
) =>
GHasForward (f :+: g) (a :+: a') b''
HasForwardSum 'Stochastic modelA inputA outputA 'Deterministic modelB inputB outputB
where
gForward (L1 f) (L1 a) = L1 $ gForward f a
gForward (R1 g) (R1 a') = R1 $ gForward g a'
gForwardStoch (L1 f) (L1 a) = L1 <$> gForwardStoch f a
gForwardStoch (R1 g) (R1 a') = R1 <$> gForwardStoch g a'
type OutputSum 'Stochastic modelA inputA outputA 'Deterministic modelB inputB outputB = Generator -> (Maybe (Either outputA outputB), Generator)
forwardSum _ _ (Left modelA) (Left inputA) _ = \g -> let (outputA, g') = forward modelA inputA g in (Just $ Left outputA, g')
forwardSum _ _ (Right modelB) (Right inputB) _ = \g -> (Just . Right $ forward modelB inputB, g)
forwardSum _ _ _ _ _ = \g -> (Nothing, g)

instance
( GHasForward f a b,
GHasForward g a' b',
b'' ~ (b :*: b')
( HasForward modelA inputA,
Output modelA inputA ~ outputA,
HasForward modelB inputB,
Output modelB inputB ~ (Generator -> (outputB, Generator))
) =>
GHasForward (f :*: g) (a :*: a') b''
HasForwardSum 'Deterministic modelA inputA outputA 'Stochastic modelB inputB outputB
where
gForward (f :*: g) (a :*: a') = gForward f a :*: gForward g a'
gForwardStoch (f :*: g) (a :*: a') = liftA2 (:*:) (gForwardStoch f a) (gForwardStoch g a')
type OutputSum 'Deterministic modelA inputA outputA 'Stochastic modelB inputB outputB = Generator -> (Maybe (Either outputA outputB), Generator)
forwardSum _ _ (Left modelA) (Left inputA) _ = \g -> (Just . Left $ forward modelA inputA, g)
forwardSum _ _ (Right modelB) (Right inputB) _ = \g -> let (outputA, g') = forward modelB inputB g in (Just $ Right outputA, g')
forwardSum _ _ _ _ _ = \g -> (Nothing, g)

--
-- Fully-stochastic instances
--

instance
( HasForward modelA inputA,
Output modelA inputA ~ (Generator -> (outputA, Generator)),
HasForward modelB inputB,
Output modelB inputB ~ (Generator -> (outputB, Generator))
) =>
HasForwardProduct 'Stochastic modelA inputA outputA 'Stochastic modelB inputB outputB
where
type OutputProduct 'Stochastic modelA inputA outputA 'Stochastic modelB inputB outputB = Generator -> ((outputA, outputB), Generator)
forwardProduct _ modelA inputA _ _ modelB inputB _ = runState $ do
outputA <- state (forward modelA inputA)
outputB <- state (forward modelB inputB)
return (outputA, outputB)

instance
( HasForward modelA inputA,
Output modelA inputA ~ (Generator -> (outputA, Generator)),
HasForward modelB inputB,
Output modelB inputB ~ (Generator -> (outputB, Generator))
) =>
HasForwardSum 'Stochastic modelA inputA outputA 'Stochastic modelB inputB outputB
where
type OutputSum 'Stochastic modelA inputA outputA 'Stochastic modelB inputB outputB = Generator -> (Maybe (Either outputA outputB), Generator)
forwardSum _ _ (Left modelA) (Left inputA) _ = \g -> let (outputA, g') = forward modelA inputA g in (Just $ Left outputA, g')
forwardSum _ _ (Right modelB) (Right inputB) _ = \g -> let (outputA, g') = forward modelB inputB g in (Just $ Right outputA, g')
forwardSum _ _ _ _ _ = \g -> (Nothing, g)

-- TODO: move to Torch.Typed.Prelude?
type family Fst (t :: (k, k')) :: k where
Fst '(x, _) = x

type family Snd (t :: (k, k')) :: k' where
Snd '(_, y) = y

instance
(HasForward f a b) =>
GHasForward (K1 i f) (K1 i a) (K1 i b)
( '(modelARandomness, outputA) ~ ModelRandomnessR (Output modelA inputA),
'(modelBRandomness, outputB) ~ ModelRandomnessR (Output modelB inputB),
HasForwardProduct modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB
) =>
HasForward (modelA, modelB) (inputA, inputB)
where
gForward (K1 f) (K1 a) = K1 $ forward f a
gForwardStoch (K1 f) (K1 a) = K1 <$> forwardStoch f a
type
Output (modelA, modelB) (inputA, inputB) =
OutputProduct
(Fst (ModelRandomnessR (Output modelA inputA)))
modelA
inputA
(Snd (ModelRandomnessR (Output modelA inputA)))
(Fst (ModelRandomnessR (Output modelB inputB)))
modelB
inputB
(Snd (ModelRandomnessR (Output modelB inputB)))
forward (modelA, modelB) (inputA, inputB) =
forwardProduct
(Proxy :: Proxy modelARandomness)
modelA
inputA
(Proxy :: Proxy outputA)
(Proxy :: Proxy modelBRandomness)
modelB
inputB
(Proxy :: Proxy outputB)

instance
(GHasForward f a b) =>
GHasForward (M1 i t f) (M1 i t' a) (M1 i t' b)
( '(modelARandomness, outputA) ~ ModelRandomnessR (Output modelA inputA),
'(modelBRandomness, outputB) ~ ModelRandomnessR (Output modelB inputB),
HasForwardSum modelARandomness modelA inputA outputA modelBRandomness modelB inputB outputB
) =>
HasForward (Either modelA modelB) (Either inputA inputB)
where
gForward (M1 f) (M1 a) = M1 $ gForward f a
gForwardStoch (M1 f) (M1 a) = M1 <$> gForwardStoch f a
type
Output (Either modelA modelB) (Either inputA inputB) =
OutputSum
(Fst (ModelRandomnessR (Output modelA inputA)))
modelA
inputA
(Snd (ModelRandomnessR (Output modelA inputA)))
(Fst (ModelRandomnessR (Output modelB inputB)))
modelB
inputB
(Snd (ModelRandomnessR (Output modelB inputB)))
forward eitherModel eitherIn =
forwardSum
(Proxy :: Proxy modelARandomness)
(Proxy :: Proxy modelBRandomness)
eitherModel
eitherIn
(Proxy :: Proxy (Either outputA outputB))

--
-- Parameterized
--

class Parameterized f where
flattenParameters :: f -> [Parameter]
Expand Down Expand Up @@ -214,9 +388,9 @@ linear layer input = linear' input w b
linearForward :: Linear -> Tensor -> Tensor
linearForward = linear -- temporary alias until dependencies are updated

instance HasForward Linear Tensor Tensor where
instance HasForward Linear Tensor where
type Output Linear Tensor = Tensor
forward = linearForward
forwardStoch m x = pure $ linearForward m x

instance Randomizable LinearSpec Linear where
sample LinearSpec {..} = do
Expand Down