Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap C functions to throw errors in wrapper Lua functions #478

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Lua.cs
@@ -1,4 +1,4 @@
using System;
using System;
using System.Linq;
using System.Reflection;
using System.Collections.Generic;
Expand Down Expand Up @@ -315,6 +315,9 @@ void Init()
}
_luaState.PushGlobalTable();
_luaState.GetGlobal("luanet");
_luaState.PushString("_G");
_luaState.GetGlobal("_G");
_luaState.SetTable(-3);
_luaState.PushString("getmetatable");
_luaState.GetGlobal("getmetatable");
_luaState.SetTable(-3);
Expand Down
95 changes: 57 additions & 38 deletions src/Metatables.cs
@@ -1,4 +1,4 @@
using System;
using System;
using System.Linq;
using System.Collections;
using System.Reflection;
Expand Down Expand Up @@ -47,6 +47,25 @@ public class MetaFunctions
readonly Dictionary<object, Dictionary<object, object>> _memberCache = new Dictionary<object, Dictionary<object, object>>();
readonly ObjectTranslator _translator;

/*
* C function wrapper. Has to be in Lua to not mess up the CLR stack
*/
public const string LuaCFunctionWrapper = @"local function w(f)return function(...)local r={_G.pcall(f,...)}if not _G.table.remove(r, 1) then _G.error('UNWRAPPED LUA ERROR FROM MANAGED CODE!');elseif _G.table.remove(r, 1) then _G.error(_G.table.unpack(r));else return _G.table.unpack(r);end end end;return w";
//@"local function wrap(func)
// return function(...)
// local r = { _G.pcall(func, ...) }
// if not _G.table.remove(r, 1) then
// _G.error('UNWRAPPED LUA ERROR FROM MANAGED CODE!')
// elseif _G.table.remove(r, 1) then
// _G.error(_G.table.unpack(r))
// else
// return _G.table.unpack(r)
// end
// end
// end

// return wrap";

/*
* __index metafunction for CLR objects. Implemented in Lua.
*/
Expand Down Expand Up @@ -94,15 +113,15 @@ private static int RunFunctionDelegate(IntPtr luaState)
var translator = ObjectTranslatorPool.Instance.Find(state);
var func = (LuaNativeFunction)translator.GetRawNetObject(state, 1);
if (func == null)
return state.Error();
return translator.ErrorFromWrappedCFunction(state);

state.Remove(1);
int result = func(luaState);
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -115,7 +134,7 @@ private static int CollectObject(IntPtr state)
{
var luaState = LuaState.FromIntPtr(state);
var translator = ObjectTranslatorPool.Instance.Find(luaState);
return CollectObject(luaState, translator);
return translator.ReturnFromWrappedCFunction(luaState, CollectObject(luaState, translator));
}

private static int CollectObject(LuaState luaState, ObjectTranslator translator)
Expand All @@ -138,7 +157,7 @@ private static int ToStringLua(IntPtr state)
{
var luaState = LuaState.FromIntPtr(state);
var translator = ObjectTranslatorPool.Instance.Find(luaState);
return ToStringLua(luaState, translator);
return translator.ReturnFromWrappedCFunction(luaState, ToStringLua(luaState, translator));
}

private static int ToStringLua(LuaState luaState, ObjectTranslator translator)
Expand Down Expand Up @@ -168,8 +187,8 @@ static int AddLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -186,8 +205,8 @@ static int SubtractLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -204,8 +223,8 @@ static int MultiplyLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -222,8 +241,8 @@ static int DivideLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -240,8 +259,8 @@ static int ModLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -258,8 +277,8 @@ static int UnaryNegationLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

static int UnaryNegationLua(LuaState luaState, ObjectTranslator translator) //-V3009
Expand Down Expand Up @@ -300,8 +319,8 @@ static int EqualLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -318,8 +337,8 @@ static int LessThanLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/*
Expand All @@ -336,8 +355,8 @@ static int LessThanOrEqualLua(IntPtr luaState)
var exception = translator.GetObject(state, -1) as LuaScriptException;

if (exception != null)
return state.Error();
return result;
return translator.ErrorFromWrappedCFunction(state);
return translator.ReturnFromWrappedCFunction(state, result);
}

/// <summary>
Expand Down Expand Up @@ -387,8 +406,8 @@ private static int GetMethod(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private int GetMethodInternal(LuaState luaState)
Expand Down Expand Up @@ -691,8 +710,8 @@ private static int GetBaseMethod(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private int GetBaseMethodInternal(LuaState luaState)
Expand Down Expand Up @@ -994,8 +1013,8 @@ private static int SetFieldOrProperty(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private int SetFieldOrPropertyInternal(LuaState luaState)
Expand Down Expand Up @@ -1199,8 +1218,8 @@ private static int GetClassMethod(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private int GetClassMethodInternal(LuaState luaState)
Expand Down Expand Up @@ -1245,8 +1264,8 @@ private static int SetClassFieldOrProperty(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private int SetClassFieldOrPropertyInternal(LuaState luaState)
Expand Down Expand Up @@ -1277,9 +1296,9 @@ static int CallDelegate(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return translator.ErrorFromWrappedCFunction(luaState);

return result;
return translator.ReturnFromWrappedCFunction(luaState, result);
}

int CallDelegateInternal(LuaState luaState)
Expand Down Expand Up @@ -1346,8 +1365,8 @@ private static int CallConstructor(IntPtr state)
var exception = translator.GetObject(luaState, -1) as LuaScriptException;

if (exception != null)
return luaState.Error();
return result;
return translator.ErrorFromWrappedCFunction(luaState);
return translator.ReturnFromWrappedCFunction(luaState, result);
}

private static ConstructorInfo[] ReorderConstructors(ConstructorInfo[] constructors)
Expand Down