diff --git a/LuaCppInterface/luacoroutine.cpp b/LuaCppInterface/luacoroutine.cpp index 55b9653..4affe26 100644 --- a/LuaCppInterface/luacoroutine.cpp +++ b/LuaCppInterface/luacoroutine.cpp @@ -14,6 +14,7 @@ std::string LuaCoroutine::RunScript(std::string script) lua_pop(state.get(), 1); int status = luaL_loadstring(thread, script.c_str()); status = lua_resume(thread, NULL, 0); + if (status != LUA_OK && status != LUA_YIELD) { return LuaGetLastError(thread); @@ -26,7 +27,8 @@ std::string LuaCoroutine::Resume() PushToStack(state.get()); lua_State* thread = lua_tothread(state.get(), -1); lua_pop(state.get(), 1); - int status = lua_resume(thread, NULL, 0); + int status = lua_resume(thread, NULL, lua_gettop(thread)); + if (status != LUA_OK && status != LUA_YIELD) { return LuaGetLastError(thread); diff --git a/LuaCppInterface/luacppinterface.h b/LuaCppInterface/luacppinterface.h index c043e5a..83bc488 100644 --- a/LuaCppInterface/luacppinterface.h +++ b/LuaCppInterface/luacppinterface.h @@ -65,8 +65,7 @@ class Lua static int lua_yieldingFunction(lua_State* state) { int numVals = LuaFunction::staticFunction(state); - lua_yield(state, numVals); - return numVals; + return lua_yield(state, numVals); }; template diff --git a/tests/Makefile.am b/tests/Makefile.am index 9cbc472..b2df5f9 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -30,9 +30,11 @@ TESTS = crashtest \ testgettypeofvalueat \ testinvalidscript \ testregistry \ + testreturnfromyieldingfunction \ demonstration1 \ demonstration2 \ - demonstration3 + demonstration3 \ + demonstration4 check_PROGRAMS = $(TESTS) @@ -123,6 +125,8 @@ testinvalidscript_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/lib testregistry_SOURCES = testregistry.cpp lua testregistry_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/liblua.a +testreturnfromyieldingfunction_SOURCES = testreturnfromyieldingfunction.cpp lua +testreturnfromyieldingfunction_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/liblua.a demonstration1_SOURCES = demonstration1.cpp lua @@ -134,6 +138,9 @@ demonstration2_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/liblua demonstration3_SOURCES = demonstration3.cpp lua demonstration3_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/liblua.a +demonstration4_SOURCES = demonstration4.cpp lua +demonstration4_LDADD = ../LuaCppInterface/libluacppinterface.a ../lua/src/liblua.a + BUILT_SOURCES = ../lua/src/liblua.a diff --git a/tests/config.lua b/tests/config.lua new file mode 100644 index 0000000..9a4f0ee --- /dev/null +++ b/tests/config.lua @@ -0,0 +1,3 @@ +width = 640 +height = 480 +windowTitle = "Lua Rocks" \ No newline at end of file diff --git a/tests/demonstration4.cpp b/tests/demonstration4.cpp new file mode 100644 index 0000000..d76002b --- /dev/null +++ b/tests/demonstration4.cpp @@ -0,0 +1,18 @@ +#include +#include +#include +#include +#include + +int main() +{ + std::ifstream file("config.lua"); + std::string script((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + Lua lua; + lua.RunScript(script); + int width = lua.GetGlobalEnvironment().Get("width"); // get the width + int height = lua.GetGlobalEnvironment().Get("height"); // get the height + std::string windowTitle = lua.GetGlobalEnvironment().Get("windowTitle"); + + return width != 640 || height != 480 || windowTitle.compare("Lua Rocks"); +} diff --git a/tests/testreturnfromyieldingfunction.cpp b/tests/testreturnfromyieldingfunction.cpp new file mode 100644 index 0000000..121ca4e --- /dev/null +++ b/tests/testreturnfromyieldingfunction.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include + +int main() +{ + Lua luaInstance; + + auto globalTable = luaInstance.GetGlobalEnvironment(); + std::stringstream ss; + + auto myOwnPrint = luaInstance.CreateYieldingFunction + ( + [&](std::string str) + { + ss << str; + } + ); + + auto lengthOf = luaInstance.CreateYieldingFunction + ( + [](std::string str) -> int + { + return str.size(); + } + ); + + globalTable.Set("myOwnPrint", myOwnPrint); + globalTable.Set("lengthOf", lengthOf); + + luaInstance.LoadStandardLibraries(); + + auto cr = luaInstance.CreateCoroutine(); + + auto err = cr.RunScript( + "x = lengthOf('haha')\n" + "myOwnPrint ('size:' .. x)\n" + ); + + while (cr.CanResume()) + { + ss << ";yield;"; + err = cr.Resume(); + } + + auto resstr = ss.str(); + + return resstr.compare(";yield;size:4;yield;"); +}