Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Action transition system (wip) #447

Draft
wants to merge 130 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
28a7c62
initial experiment
tscholak Jul 6, 2020
c536abc
add unbuffered parser
tscholak Jul 8, 2020
16462f4
add buffer
tscholak Jul 9, 2020
fede361
more testing
tscholak Jul 9, 2020
5271a86
misc
tscholak Jul 9, 2020
35e19d0
switch to yoctoparsec
tscholak Jul 9, 2020
00bfe99
add check
tscholak Jul 10, 2020
7c98e0f
try something
tscholak Jul 10, 2020
c24d514
this may work
tscholak Jul 11, 2020
c532c90
add production annotation
tscholak Jul 11, 2020
3fead87
use lenses
tscholak Jul 11, 2020
80f7480
better batching
tscholak Jul 12, 2020
de6bb16
cleanup
tscholak Jul 12, 2020
c36f43a
add bool
tscholak Jul 12, 2020
ad127d9
test bool
tscholak Jul 12, 2020
823c95b
add relations
tscholak Jul 13, 2020
02c7015
state-based relations
tscholak Jul 13, 2020
fbf001e
add hedgehog STLC code
tscholak Jul 13, 2020
4109374
finish attention
tscholak Jul 15, 2020
6495ba9
use discovery for tests
tscholak Jul 15, 2020
1cbfe97
towards locally-nameless terms
tscholak Jul 15, 2020
581e590
add stlc using bound
tscholak Jul 17, 2020
9d53cc2
add a bunch of boilerplate for FreshT
tscholak Jul 17, 2020
4bebe63
line everything up
tscholak Jul 17, 2020
1eb58ed
remove unused code
tscholak Jul 17, 2020
a69c5c6
clean up
tscholak Jul 18, 2020
c4c6afb
add rat
tscholak Jul 22, 2020
98e531d
add test program
tscholak Jul 22, 2020
b51e297
add batching
tscholak Jul 22, 2020
0f20dba
clean up and finish training loop
tscholak Jul 22, 2020
1ff2891
replace buggy EndsWith with IsSuffixOf typeclass, turn action transit…
tscholak Jul 23, 2020
b72e349
reimplement cross entropy loss
tscholak Jul 23, 2020
b41b5f9
add some debugging, make transformer better
tscholak Jul 24, 2020
5c533d7
debug attention and key padding masks
tscholak Jul 24, 2020
e2f6b3a
switch back to bigger model
tscholak Jul 24, 2020
3a5527e
fix some of the segfaults with cuda
tscholak Jul 24, 2020
396b962
fix segfaults in display2dTensor
tscholak Jul 25, 2020
92a55e2
slightly better logging
tscholak Jul 25, 2020
719b412
Merge remote-tracking branch 'upstream/master' into action-transition…
tscholak Jul 25, 2020
e1b719e
use concurrent streaming
tscholak Jul 25, 2020
33b9bd5
make the model bigger
tscholak Jul 25, 2020
65ca1db
make seq shorter and model slightly smaller
tscholak Jul 25, 2020
a1cffdf
move data to CUDA only at the last possible moment
tscholak Jul 25, 2020
b6e7f50
add made batch debug statement
tscholak Jul 25, 2020
ad72bec
add more debug output
tscholak Jul 25, 2020
351f5ec
fix infinite loop
tscholak Jul 25, 2020
2fe315e
add performGC to hopefully help with memory cleanup
tscholak Jul 25, 2020
e508449
add rts ghc option
tscholak Jul 25, 2020
df532ee
add rtsopts instead
tscholak Jul 25, 2020
eb1f87d
add threaded ;)
tscholak Jul 25, 2020
b5fa55a
temporarily turn off evaluation to see if that affects memory consump…
tscholak Jul 25, 2020
9f0d447
resolve parse error
tscholak Jul 25, 2020
fda6558
make it compile
tscholak Jul 25, 2020
5e8fc4c
restore evaluation, use fewer iterations to avoid OOM
tscholak Jul 25, 2020
0b8035b
clean up
tscholak Jul 26, 2020
0441295
remove some debug print statements
tscholak Jul 26, 2020
375cadb
add pretty printing and improve Exp distribution
tscholak Jul 27, 2020
fa570fc
try to improve sampling
tscholak Jul 27, 2020
101f529
compute and the report the loss on the masked tokens and the unmasked…
tscholak Jul 27, 2020
6e6b0a0
fix boolean logic
tscholak Jul 27, 2020
371a483
switch masked and non-masked
tscholak Jul 27, 2020
895fe28
add evaluation of initial model
tscholak Jul 27, 2020
30196e6
try again some more evaluation
tscholak Jul 27, 2020
a3b35cc
ok, maybe not
tscholak Jul 27, 2020
aa6a035
add more sophisticated loss evaluation
tscholak Jul 27, 2020
19d38c4
make it possible to use different pMask values for input and target s…
tscholak Jul 27, 2020
abf5e5f
set pMaskInput to 0
tscholak Jul 27, 2020
fa2b190
better generator seeding
tscholak Jul 28, 2020
e28396b
increase number of layers
tscholak Jul 28, 2020
a932f02
remove seed debug output
tscholak Jul 28, 2020
bc2f35e
lower memory consumption
tscholak Jul 28, 2020
a2cc27c
use 4 layers
tscholak Jul 28, 2020
fa4de9e
go back to 3 layers
tscholak Jul 28, 2020
e70eb41
more debug outputs about relations
tscholak Jul 28, 2020
aeed13d
improve relations
tscholak Jul 28, 2020
9113f62
remove debug output
tscholak Jul 28, 2020
d3e6020
make relations more sparse
tscholak Jul 28, 2020
679f54e
richer meta data and support for nested type constructors in ReplaceD…
tscholak Jul 29, 2020
27d12d3
better parent-child tracking
tscholak Jul 29, 2020
5b0d212
turn off gradient clipping temporarily and significantly reduce learn…
tscholak Jul 29, 2020
9aba3d6
switch to concatenation rather than summing of embeddings
tscholak Jul 29, 2020
66be642
add a tiny amount of masking to the input
tscholak Jul 29, 2020
f15e820
use our favourite learning rate
tscholak Jul 29, 2020
a0d742c
test if the data is constant
tscholak Jul 30, 2020
05e3282
turn off masking as well
tscholak Jul 30, 2020
c4d4709
try to overfit
tscholak Jul 30, 2020
6cae1b6
increase learning rate by two orders of magnitude to speed up overfit…
tscholak Jul 30, 2020
e942a48
reduce learning rate by one order of magnitude because training is un…
tscholak Jul 30, 2020
d4280e7
add simple learning rate schedule
tscholak Jul 30, 2020
65f0228
compute loss only on token mask
tscholak Jul 30, 2020
bd23d2d
generalize transformer implementation
tscholak Jul 30, 2020
04a51a0
switch to gdm for testing
tscholak Jul 31, 2020
9ec40db
add debug output to loss
tscholak Jul 31, 2020
7938396
smaller batches and shorter sequences
tscholak Jul 31, 2020
25f12fb
more loss debug output
tscholak Jul 31, 2020
cad545c
add global next function, switch to batchSize 1 for debugging
tscholak Jul 31, 2020
e4715bf
one batch only
tscholak Jul 31, 2020
829c767
print selection mask, too
tscholak Jul 31, 2020
e077c28
back to proper training
tscholak Jul 31, 2020
599b87f
smaller batch sizes
tscholak Jul 31, 2020
5a59f1d
towards constrained training
tscholak Aug 3, 2020
1fac2e2
about to switch to BS
tscholak Aug 4, 2020
02136b5
clean up, fix parsin
tscholak Aug 4, 2020
8c2b355
lets risk a bit more generality
tscholak Aug 10, 2020
db66696
improve naming
tscholak Aug 10, 2020
6cba5d7
guard against invalid actions
tscholak Aug 10, 2020
0de9f42
misc renaming
tscholak Aug 11, 2020
6a165bc
create invalid token mask
tscholak Aug 11, 2020
bb5a24d
move mask token to base actions
tscholak Aug 14, 2020
f8e4b9c
Merge branch 'master' into action-transition-system
tscholak Aug 14, 2020
f572f7d
apply ormolu
tscholak Aug 14, 2020
eb8bd99
resolve conflicts
tscholak Aug 14, 2020
1fb9115
run ormolu to avoid merge conflicts
tscholak Aug 22, 2020
9cdd3bc
resolve merge conflicts
tscholak Aug 22, 2020
ce090e8
switch to new Dataset class
tscholak Aug 23, 2020
41bfbcc
add plotting
tscholak Aug 23, 2020
757dd48
back to proper training again
tscholak Aug 23, 2020
d1b9c09
lower batch size
tscholak Aug 23, 2020
5fe78f3
reimplement batching to avoid exploding compile time and ghc memory
tscholak Aug 24, 2020
0de28c1
increase data loading parallelism
tscholak Aug 24, 2020
8893f39
misc improvements
tscholak Aug 25, 2020
e5b0d85
set core affinity and export opengl-driver in library path
tscholak Aug 25, 2020
030c482
completely revamp the definition and implementation of the Parameteri…
tscholak Aug 27, 2020
483bfd1
fix tests
tscholak Aug 27, 2020
bab34be
switch to GPU
tscholak Aug 27, 2020
7525458
add missing model and optim saving implementation
tscholak Aug 27, 2020
3fc54fc
add saving and loading notifications
tscholak Aug 28, 2020
5306305
fix and format typed examples
tscholak Aug 28, 2020
4706e24
Merge branch 'master' into action-transition-system
tscholak Aug 28, 2020
e53876d
resolve merge conflicts
tscholak Sep 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
167 changes: 167 additions & 0 deletions experimental/action-transition-system/Control/Monad/Fresh.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Control.Monad.Fresh where

import Control.Monad.Cont (ContT, MonadCont (..))
import Control.Monad.Except (MonadError (..))
import Control.Monad.Except (ExceptT)
import Control.Monad.Fix (MonadFix (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Identity (Identity (..), IdentityT)
import Control.Monad.Morph (MFunctor (..))
import Control.Monad.RWS (RWST)
import Control.Monad.Reader (MonadReader (..), ReaderT, asks, runReaderT)
import Control.Monad.State (MonadState (..), StateT, evalStateT, modify)
import Control.Monad.Trans (MonadTrans (..))
import Control.Monad.Trans.Maybe (MaybeT)
import Control.Monad.Writer (MonadWriter (..), WriterT)
import GHC.Base (Alternative (..), MonadPlus (..))
import Hedgehog (MonadGen (..), distributeT)

newtype Successor a = Successor {suc :: a -> a}

-- | The monad transformer for generating fresh values.
data FreshT e m a = FreshT {unFreshT :: ReaderT (Successor e) (StateT e m) a}
deriving (Functor)

instance Monad m => MonadFresh e (FreshT e m) where
fresh = FreshT $ do
e <- get
s <- asks suc
modify s
pure e

instance Monad m => Monad (FreshT e m) where
return = FreshT . return
(FreshT m) >>= f = FreshT $ m >>= unFreshT . f

instance MonadPlus m => MonadPlus (FreshT e m) where
mzero = FreshT mzero
mplus (FreshT m) (FreshT m') = FreshT $ mplus m m'

instance (Functor f, Monad f) => Applicative (FreshT e f) where
pure = FreshT . pure
(FreshT f) <*> (FreshT a) = FreshT $ f <*> a

instance (Monad m, Functor m, MonadPlus m) => Alternative (FreshT e m) where
empty = mzero
(<|>) = mplus

type Fresh e = FreshT e Identity

instance MonadTrans (FreshT e) where
lift = FreshT . lift . lift

instance MonadReader r m => MonadReader r (FreshT e m) where
local f m = FreshT $ ask >>= lift . local f . runReaderT (unFreshT m)
ask = FreshT (lift ask)

instance MonadState s m => MonadState s (FreshT e m) where
get = FreshT $ (lift . lift) get
put = FreshT . lift . lift . put

instance (MonadWriter w m) => MonadWriter w (FreshT e m) where
tell m = lift $ tell m
listen = FreshT . listen . unFreshT
pass = FreshT . pass . unFreshT

instance MonadFix m => MonadFix (FreshT e m) where
mfix = FreshT . mfix . (unFreshT .)

instance MonadIO m => MonadIO (FreshT e m) where
liftIO = FreshT . liftIO

instance MonadCont m => MonadCont (FreshT e m) where
callCC f = FreshT $ callCC (unFreshT . f . (FreshT .))

instance MonadError e m => MonadError e (FreshT e' m) where
throwError = FreshT . throwError
catchError m h = FreshT $ catchError (unFreshT m) (unFreshT . h)

instance MFunctor (FreshT e) where
hoist nat m = FreshT $ hoist (hoist nat) (unFreshT m)

instance MonadGen m => MonadGen (FreshT e m) where
type GenBase (FreshT e m) = FreshT e (GenBase m)
toGenT = hoist FreshT . distributeT . hoist distributeT . unFreshT . hoist toGenT
fromGenT = hoist fromGenT . distributeT

successor :: forall e . (e -> e) -> Successor e
successor = Successor

enumSucc :: forall e . Enum e => Successor e
enumSucc = Successor succ

-- | Run a @FreshT@ computation starting from the value
-- @toEnum 0@
runFreshT :: forall e m a . (Enum e, Monad m) => FreshT e m a -> m a
runFreshT = runFreshTFrom (toEnum 0)

-- | Run a @Fresh@ computation starting from the value
-- @toEnum 0@
runFresh :: forall e a . Enum e => Fresh e a -> a
runFresh = runFreshFrom (toEnum 0)

-- | Run a @FreshT@ computation starting from a specific value @e@.
runFreshTFrom :: forall e m a . (Monad m, Enum e) => e -> FreshT e m a -> m a
runFreshTFrom e = runFreshTWith enumSucc e

-- | Run a @Fresh@ computation starting from a specific value @e@.
runFreshFrom :: forall e a . Enum e => e -> Fresh e a -> a
runFreshFrom e = runFreshWith enumSucc e

-- | Run a @FreshT@ computation starting from a specific value @e@ with
-- a the next fresh value determined by @Successor e@.
runFreshTWith :: forall e m a . Monad m => Successor e -> e -> FreshT e m a -> m a
runFreshTWith s e =
flip evalStateT e
. flip runReaderT s
. unFreshT

-- | Run a @FreshT@ computation starting from a specific value @e@ with
-- a the next fresh value determined by @Successor e@.
runFreshWith :: forall e a . Successor e -> e -> Fresh e a -> a
runFreshWith s e = runIdentity . runFreshTWith s e

---------------------------

-- | The MTL style class for generating fresh values
class Monad m => MonadFresh e m | m -> e where
-- | Generate a fresh value @e@, @fresh@ should never produce the
-- same value within a monadic computation.
fresh :: m e

instance MonadFresh e m => MonadFresh e (IdentityT m) where
fresh = lift fresh

instance MonadFresh e m => MonadFresh e (StateT s m) where
fresh = lift fresh

instance MonadFresh e m => MonadFresh e (ReaderT s m) where
fresh = lift fresh

instance (MonadFresh e m, Monoid s) => MonadFresh e (WriterT s m) where
fresh = lift fresh

instance MonadFresh e m => MonadFresh e (MaybeT m) where
fresh = lift fresh

instance MonadFresh e m => MonadFresh e (ContT r m) where
fresh = lift fresh

instance (Monoid w, MonadFresh e m) => MonadFresh e (RWST r w s m) where
fresh = lift fresh

-- instance MonadFresh e m => MonadFresh e (SS.StateT s m) where
-- fresh = lift fresh
-- instance (Monoid w, MonadFresh e m) => MonadFresh e (SW.WriterT w m) where
-- fresh = lift fresh
instance (MonadFresh e m) => MonadFresh e (ExceptT e' m) where
fresh = lift fresh
31 changes: 31 additions & 0 deletions experimental/action-transition-system/Main.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module Main where

import Torch.Data.ActionTransitionSystem (Config (..), testProgram)
import Torch.Data.Pipeline (MapStyleOptions (..), Sample (Sequential))

main :: IO ()
main =
let config =
Config
{ trainingLen = 8192,
evaluationLen = 96,
-- trainingLen = 65536,
-- evaluationLen = 4096,
probMaskInput = 0.15,
probMaskTarget = 0.15,
maxLearningRate = 0.0005,
finalLearningRate = 1e-6,
numEpochs = 1000,
numWarmupEpochs = 10,
numCooldownEpochs = 10,
modelCheckpointFile = "modelCheckpoint",
optimCheckpointFile = "optimCheckpoint",
plotFile = "plot.html",
options =
MapStyleOptions
{ bufferSize = 256,
numWorkers = 8,
shuffle = Sequential
}
}
in testProgram config