Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Support for saving/loading? #9

Open
rccarlson opened this issue Feb 23, 2024 · 3 comments
Open

Feature: Support for saving/loading? #9

rccarlson opened this issue Feb 23, 2024 · 3 comments

Comments

@rccarlson
Copy link

Any chance some kind of serialization could be implemented to save/load Markov chains, rather than having to retrain every time the program runs?

@otac0n
Copy link
Owner

otac0n commented Feb 23, 2024

Yes, I can look into that.

@otac0n
Copy link
Owner

otac0n commented Feb 23, 2024

So, I will say that you can use GetStates(), Add(state, next, weight), and AddTerminalState() to do this now, in the hopes that this unblocks you.

@rccarlson
Copy link
Author

I ended up implementing a solution with BinaryReader and BinaryWriter:

In MarkovChain.cs:

	public void Serialize(BinaryWriter bw, Action<BinaryWriter, T> tWriter)
	{
		items.Serialize(bw, 
			(bw, chain) => chain.Serialize(bw, tWriter), 
			(bw, dict) => dict.Serialize(bw, tWriter, (bw, i) => bw.Write(i))
			);
		bw.Write(order);
		terminals.Serialize(bw,
			(bw, chain) => chain.Serialize(bw, tWriter),
			(bw, i) => bw.Write(i)
			);
		bw.Write(trainingSize);
	}
	public static MarkovChain<T> Deserialize(BinaryReader br, Func<BinaryReader, T> tReader)
	{
		var items = Utility.ReadDictionary(br,
			br => ChainState<T>.Deserialize(br, tReader),
			br => Utility.ReadDictionary(br, tReader, b => b.ReadInt32())
			);
		var order = br.ReadInt32();
		var terminals = Utility.ReadDictionary(br,
			br => ChainState<T>.Deserialize(br, tReader),
			br => br.ReadInt32());
		var trainingSize = br.ReadInt32();

		return new MarkovChain<T>(items, order, terminals, trainingSize);
	}

In ChainState.cs:

	public void Serialize(BinaryWriter bw, Action<BinaryWriter, T> tWriter)
	{
		bw.Write(items.Length);
		foreach (var item in items)
			tWriter(bw, item);
	}
	public static ChainState<T> Deserialize(BinaryReader br, Func<BinaryReader, T> tReader)
	{
		var len = br.ReadInt32();
		var items = new T[len];
		for(int i = 0; i < len; i++)
		{
			items[i] = tReader(br);
		}
		return new ChainState<T>(items);
	}

In my Utility class:

    public static void Serialize<TKey, TValue>(this Dictionary<TKey, TValue> dict, BinaryWriter bw,
        Action<BinaryWriter, TKey> keySerializer, Action<BinaryWriter, TValue> valueSerializer)
		where TKey : notnull
		where TValue : notnull
	{
        bw.Write(dict.Count);
        foreach(var (key,value) in dict)
        {
            keySerializer(bw, key);
            valueSerializer(bw, value);
        }
    }
    public static Dictionary<TKey, TValue> ReadDictionary<TKey, TValue>(BinaryReader br,
        Func<BinaryReader, TKey> readKey, Func<BinaryReader, TValue> readValue)
        where TKey : notnull
        where TValue : notnull
    {
        var len = br.ReadInt32();
		Dictionary<TKey, TValue> dict = new();
        for (int i = 0;i < len;i++)
        {
            dict.Add(readKey(br), readValue(br));
        }
        return dict;
    }

It's a little goofy, but it gets the job done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants