diff --git a/frontend/persist.lua b/frontend/persist.lua index b2a559084..dfaf332df 100644 --- a/frontend/persist.lua +++ b/frontend/persist.lua @@ -82,11 +82,20 @@ local codecs = { if not ok then return nil, "cannot serialize " .. tostring(t) .. " (" .. str .. ")" end + + local cbuff, clen = zstd.zstd_compress(str, #str) + + if not path then + local result = ffi.string(cbuff, clen) + C.free(cbuff) + return result, clen + end + local f = C.fopen(path, "wb") if f == nil then + C.free(cbuff) return nil, "fopen: " .. ffi.string(C.strerror(ffi.errno())) end - local cbuff, clen = zstd.zstd_compress(str, #str) if C.fwrite(cbuff, 1, clen, f) < clen then C.fclose(f) C.free(cbuff) @@ -100,7 +109,19 @@ local codecs = { return true, clen end, - deserialize = function(path) + deserialize = function(str, path) + if str and not path then + local buff, ulen = zstd.zstd_uncompress(str, #str) + if not buff then + return nil, "failed to decompress string" + end + local ok, t = pcall(buffer.decode, ffi.string(buff, ulen)) + C.free(buff) + if not ok then + return nil, "malformed serialized data (" .. t .. ")" + end + return t + end local f = C.fopen(path, "rb") if f == nil then return nil, "fopen: " .. ffi.string(C.strerror(ffi.errno())) @@ -124,9 +145,8 @@ local codecs = { local buff, ulen = zstd.zstd_uncompress(data, size) C.free(data) - local str = ffi.string(buff, ulen) + str = ffi.string(buff, ulen) C.free(buff) - local ok, t = pcall(buffer.decode, str) if not ok then return nil, "malformed serialized data (" .. t .. ")" @@ -156,8 +176,11 @@ local codecs = { return content end, - deserialize = function(str) - local t, err = loadfile(str) + deserialize = function(str, path) + local t, err + if path then + t, err = loadfile(path) + end if not t then t, err = loadstring(str) end @@ -221,7 +244,7 @@ end function Persist:load() local t, err if codecs[self.codec].reads_from_file then - t, err = codecs[self.codec].deserialize(self.path) + t, err = codecs[self.codec].deserialize(nil, self.path) else local str str, err = readFile(self.path) diff --git a/spec/unit/persist_spec.lua b/spec/unit/persist_spec.lua index def52d643..0c04c1fd1 100644 --- a/spec/unit/persist_spec.lua +++ b/spec/unit/persist_spec.lua @@ -81,8 +81,7 @@ describe("Persist module", function() it("should return standalone serializers/deserializers", function() local tab = sample - -- NOTE: zstd only deser from a *file*, not a string. - for _, codec in ipairs({"dump", "serpent", "bitser", "luajit"}) do + for _, codec in ipairs({"dump", "serpent", "bitser", "luajit", "zstd"}) do assert.is_true(Persist.getCodec(codec).id == codec) local ser = Persist.getCodec(codec).serialize local deser = Persist.getCodec(codec).deserialize @@ -91,7 +90,7 @@ describe("Persist module", function() if not t then print(codec, "deser failed:", err) end - assert.are.same(t, tab) + assert.are.same(tab, t) end end) @@ -101,7 +100,7 @@ describe("Persist module", function() local ser = Persist.getCodec(codec).serialize local deser = Persist.getCodec(codec).deserialize local str = ser(tab) - assert.are.same(deser(str), tab) + assert.are.same(tab, deser(str)) end end)