--[[ MIT License Copyright (c) 2017 Mark Langen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] function lookupify(tb) for _, v in pairs(tb) do tb[v] = true end return tb end function CountTable(tb) local c = 0 for _ in pairs(tb) do c = c + 1 end return c end function FormatTableInt(tb, atIndent, ignoreFunc) if tb.Print then return tb.Print() end atIndent = atIndent or 0 local useNewlines = (CountTable(tb) > 1) local baseIndent = string.rep(' ', atIndent+1) local out = "{"..(useNewlines and '\n' or '') for k, v in pairs(tb) do if type(v) ~= 'function' and not ignoreFunc(k) then out = out..(useNewlines and baseIndent or '') if type(k) == 'number' then --nothing to do elseif type(k) == 'string' and k:match("^[A-Za-z_][A-Za-z0-9_]*$") then out = out..k.." = " elseif type(k) == 'string' then out = out.."[\""..k.."\"] = " else out = out.."["..tostring(k).."] = " end if type(v) == 'string' then out = out.."\""..v.."\"" elseif type(v) == 'number' then out = out..v elseif type(v) == 'table' then out = out..FormatTableInt(v, atIndent+(useNewlines and 1 or 0), ignoreFunc) else out = out..tostring(v) end if next(tb, k) then out = out.."," end if useNewlines then out = out..'\n' end end end out = out..(useNewlines and string.rep(' ', atIndent) or '').."}" return out end function FormatTable(tb, ignoreFunc) ignoreFunc = ignoreFunc or function() return false end return FormatTableInt(tb, 0, ignoreFunc) end local WhiteChars = lookupify{' ', '\n', '\t', '\r'} local EscapeForCharacter = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'", ['\\'] = '\\'} local CharacterForEscape = {['r'] = '\r', ['n'] = '\n', ['t'] = '\t', ['"'] = '"', ["'"] = "'", ['\\'] = '\\'} local AllIdentStartChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_'} local AllIdentChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} local HexDigits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'a', 'B', 'b', 'C', 'c', 'D', 'd', 'E', 'e', 'F', 'f'} local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#', '.', ':'} local EqualSymbols = lookupify{'~', '=', '>', '<'} local Keywords = lookupify{ 'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function', 'goto', 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true', 'until', 'while', }; local BlockFollowKeyword = lookupify{'else', 'elseif', 'until', 'end'} local UnopSet = lookupify{'-', 'not', '#'} local BinopSet = lookupify{ '+', '-', '*', '/', '%', '^', '#', '..', '.', ':', '>', '<', '<=', '>=', '~=', '==', 'and', 'or' } local GlobalRenameIgnore = lookupify{ } local BinaryPriority = { ['+'] = {6, 6}; ['-'] = {6, 6}; ['*'] = {7, 7}; ['/'] = {7, 7}; ['%'] = {7, 7}; ['^'] = {10, 9}; ['..'] = {5, 4}; ['=='] = {3, 3}; ['~='] = {3, 3}; ['>'] = {3, 3}; ['<'] = {3, 3}; ['>='] = {3, 3}; ['<='] = {3, 3}; ['and'] = {2, 2}; ['or'] = {1, 1}; }; local UnaryPriority = 8 -- Eof, Ident, Keyword, Number, String, Symbol function CreateLuaTokenStream(text) -- Tracking for the current position in the buffer, and -- the current line / character we are on. local p = 1 local length = #text -- Output buffer for tokens local tokenBuffer = {} -- Get a character, or '' if at eof local function look(n) n = p + (n or 0) if n <= length then return text:sub(n, n) else return '' end end local function get() if p <= length then local c = text:sub(p, p) p = p + 1 return c else return '' end end -- Error local olderr = error local function error(str) local q = 1 local line = 1 local char = 1 while q <= p do if text:sub(q, q) == '\n' then line = line + 1 char = 1 else char = char + 1 end q = q + 1 end for _, token in pairs(tokenBuffer) do print(token.Type.."<"..token.Source..">") end olderr("file<"..line..":"..char..">: "..str) end -- Consume a long data with equals count of `eqcount' local function longdata(eqcount) while true do local c = get() if c == '' then error("Unfinished long string.") elseif c == ']' then local done = true -- Until contested for i = 1, eqcount do if look() == '=' then p = p + 1 else done = false break end end if done and get() == ']' then return end end end end -- Get the opening part for a long data `[` `=`* `[` -- Precondition: The first `[` has been consumed -- Return: nil or the equals count local function getopen() local startp = p while look() == '=' do p = p + 1 end if look() == '[' then p = p + 1 return p - startp - 1 else p = startp return nil end end -- Add token local whiteStart = 1 local tokenStart = 1 local function token(type) local tk = { Type = type; LeadingWhite = text:sub(whiteStart, tokenStart-1); Source = text:sub(tokenStart, p-1); } table.insert(tokenBuffer, tk) whiteStart = p tokenStart = p return tk end -- Parse tokens loop while true do -- Mark the whitespace start whiteStart = p -- Get the leading whitespace + comments while true do local c = look() if c == '' then break elseif c == '-' then if look(1) == '-' then p = p + 2 -- Consume comment body if look() == '[' then p = p + 1 local eqcount = getopen() if eqcount then -- Long comment body longdata(eqcount) else -- Normal comment body while true do local c2 = get() if c2 == '' or c2 == '\n' then break end end end else -- Normal comment body while true do local c2 = get() if c2 == '' or c2 == '\n' then break end end end else break end elseif WhiteChars[c] then p = p + 1 else break end end local leadingWhite = text:sub(whiteStart, p-1) -- Mark the token start tokenStart = p -- Switch on token type local c1 = get() if c1 == '' then -- End of file token('Eof') break elseif c1 == '\'' or c1 == '\"' then -- String constant while true do local c2 = get() if c2 == '\\' then local c3 = get() local esc = CharacterForEscape[c3] if not esc then error("Invalid Escape Sequence `"..c3.."`.") end elseif c2 == c1 then break end end token('String') elseif AllIdentStartChars[c1] then -- Ident or Keyword while AllIdentChars[look()] do p = p + 1 end if Keywords[text:sub(tokenStart, p-1)] then token('Keyword') else token('Ident') end elseif Digits[c1] or (c1 == '.' and Digits[look()]) then -- Number if c1 == '0' and look() == 'x' then p = p + 1 -- Hex number while HexDigits[look()] do p = p + 1 end else -- Normal Number while Digits[look()] do p = p + 1 end if look() == '.' then -- With decimal point p = p + 1 while Digits[look()] do p = p + 1 end end if look() == 'e' or look() == 'E' then -- With exponent p = p + 1 if look() == '-' then p = p + 1 end while Digits[look()] do p = p + 1 end end end token('Number') elseif c1 == '[' then -- '[' Symbol or Long String local eqCount = getopen() if eqCount then -- Long string longdata(eqCount) token('String') else -- Symbol token('Symbol') end elseif c1 == '.' then -- Greedily consume up to 3 `.` for . / .. / ... tokens if look() == '.' then get() if look() == '.' then get() end end token('Symbol') elseif EqualSymbols[c1] then if look() == '=' then p = p + 1 end token('Symbol') elseif Symbols[c1] then token('Symbol') else error("Bad symbol `"..c1.."` in source.") end end return tokenBuffer end function CreateLuaParser(text) -- Token stream and pointer into it local tokens = CreateLuaTokenStream(text) -- for _, tok in pairs(tokens) do -- print(tok.Type..": "..tok.Source) -- end local p = 1 local function get() local tok = tokens[p] if p < #tokens then p = p + 1 end return tok end local function peek(n) n = p + (n or 0) return tokens[n] or tokens[#tokens] end local function getTokenStartPosition(token) local line = 1 local char = 0 local tkNum = 1 while true do local tk = tokens[tkNum] local text; if tk == token then text = tk.LeadingWhite else text = tk.LeadingWhite..tk.Source end for i = 1, #text do local c = text:sub(i, i) if c == '\n' then line = line + 1 char = 0 else char = char + 1 end end if tk == token then break end tkNum = tkNum + 1 end return line..":"..(char+1) end local function debugMark() local tk = peek() return "<"..tk.Type.." `"..tk.Source.."`> at: "..getTokenStartPosition(tk) end local function isBlockFollow() local tok = peek() return tok.Type == 'Eof' or (tok.Type == 'Keyword' and BlockFollowKeyword[tok.Source]) end local function isUnop() return UnopSet[peek().Source] or false end local function isBinop() return BinopSet[peek().Source] or false end local function expect(type, source) local tk = peek() if tk.Type == type and (source == nil or tk.Source == source) then return get() else for i = -3, 3 do print("Tokens["..i.."] = `"..peek(i).Source.."`") end if source then error(getTokenStartPosition(tk)..": `"..source.."` expected.") else error(getTokenStartPosition(tk)..": "..type.." expected.") end end end local function MkNode(node) local getf = node.GetFirstToken local getl = node.GetLastToken function node:GetFirstToken() local t = getf(self) assert(t) return t end function node:GetLastToken() local t = getl(self) assert(t) return t end return node end -- Forward decls local block; local expr; -- Expression list local function exprlist() local exprList = {} local commaList = {} table.insert(exprList, expr()) while peek().Source == ',' do table.insert(commaList, get()) table.insert(exprList, expr()) end return exprList, commaList end local function prefixexpr() local tk = peek() if tk.Source == '(' then local oparenTk = get() local inner = expr() local cparenTk = expect('Symbol', ')') return MkNode{ Type = 'ParenExpr'; Expression = inner; Token_OpenParen = oparenTk; Token_CloseParen = cparenTk; GetFirstToken = function(self) return self.Token_OpenParen end; GetLastToken = function(self) return self.Token_CloseParen end; } elseif tk.Type == 'Ident' then return MkNode{ Type = 'VariableExpr'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } else print(debugMark()) error(getTokenStartPosition(tk)..": Unexpected symbol") end end function tableexpr() local obrace = expect('Symbol', '{') local entries = {} local separators = {} while peek().Source ~= '}' do if peek().Source == '[' then -- Index local obrac = get() local index = expr() local cbrac = expect('Symbol', ']') local eq = expect('Symbol', '=') local value = expr() table.insert(entries, { EntryType = 'Index'; Index = index; Value = value; Token_OpenBracket = obrac; Token_CloseBracket = cbrac; Token_Equals = eq; }) elseif peek().Type == 'Ident' and peek(1).Source == '=' then -- Field local field = get() local eq = get() local value = expr() table.insert(entries, { EntryType = 'Field'; Field = field; Value = value; Token_Equals = eq; }) else -- Value local value = expr() table.insert(entries, { EntryType = 'Value'; Value = value; }) end -- Comma or Semicolon separator if peek().Source == ',' or peek().Source == ';' then table.insert(separators, get()) else break end end local cbrace = expect('Symbol', '}') return MkNode{ Type = 'TableLiteral'; EntryList = entries; Token_SeparatorList = separators; Token_OpenBrace = obrace; Token_CloseBrace = cbrace; GetFirstToken = function(self) return self.Token_OpenBrace end; GetLastToken = function(self) return self.Token_CloseBrace end; } end -- List of identifiers local function varlist() local varList = {} local commaList = {} if peek().Type == 'Ident' then table.insert(varList, get()) end while peek().Source == ',' do table.insert(commaList, get()) local id = expect('Ident') table.insert(varList, id) end return varList, commaList end -- Body local function blockbody(terminator) local body = block() local after = peek() if after.Type == 'Keyword' and after.Source == terminator then get() return body, after else print(after.Type, after.Source) error(getTokenStartPosition(after)..": "..terminator.." expected.") end end -- Function declaration local function funcdecl(isAnonymous) local functionKw = get() -- local nameChain; local nameChainSeparator; -- if not isAnonymous then nameChain = {} nameChainSeparator = {} -- table.insert(nameChain, expect('Ident')) -- while peek().Source == '.' do table.insert(nameChainSeparator, get()) table.insert(nameChain, expect('Ident')) end if peek().Source == ':' then table.insert(nameChainSeparator, get()) table.insert(nameChain, expect('Ident')) end end -- local oparenTk = expect('Symbol', '(') local argList, argCommaList = varlist() local cparenTk = expect('Symbol', ')') local fbody, enTk = blockbody('end') -- return MkNode{ Type = (isAnonymous and 'FunctionLiteral' or 'FunctionStat'); NameChain = nameChain; ArgList = argList; Body = fbody; -- Token_Function = functionKw; Token_NameChainSeparator = nameChainSeparator; Token_OpenParen = oparenTk; Token_ArgCommaList = argCommaList; Token_CloseParen = cparenTk; Token_End = enTk; GetFirstToken = function(self) return self.Token_Function end; GetLastToken = function(self) return self.Token_End; end; } end -- Argument list passed to a funciton local function functionargs() local tk = peek() if tk.Source == '(' then local oparenTk = get() local argList = {} local argCommaList = {} while peek().Source ~= ')' do table.insert(argList, expr()) if peek().Source == ',' then table.insert(argCommaList, get()) else break end end local cparenTk = expect('Symbol', ')') return MkNode{ CallType = 'ArgCall'; ArgList = argList; -- Token_CommaList = argCommaList; Token_OpenParen = oparenTk; Token_CloseParen = cparenTk; GetFirstToken = function(self) return self.Token_OpenParen end; GetLastToken = function(self) return self.Token_CloseParen end; } elseif tk.Source == '{' then return MkNode{ CallType = 'TableCall'; TableExpr = expr(); GetFirstToken = function(self) return self.TableExpr:GetFirstToken() end; GetLastToken = function(self) return self.TableExpr:GetLastToken() end; } elseif tk.Type == 'String' then return MkNode{ CallType = 'StringCall'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } else error("Function arguments expected.") end end local function primaryexpr() local base = prefixexpr() assert(base, "nil prefixexpr") while true do local tk = peek() if tk.Source == '.' then local dotTk = get() local fieldName = expect('Ident') base = MkNode{ Type = 'FieldExpr'; Base = base; Field = fieldName; Token_Dot = dotTk; GetFirstToken = function(self) return self.Base:GetFirstToken() end; GetLastToken = function(self) return self.Field end; } elseif tk.Source == ':' then local colonTk = get() local methodName = expect('Ident') local fargs = functionargs() base = MkNode{ Type = 'MethodExpr'; Base = base; Method = methodName; FunctionArguments = fargs; Token_Colon = colonTk; GetFirstToken = function(self) return self.Base:GetFirstToken() end; GetLastToken = function(self) return self.FunctionArguments:GetLastToken() end; } elseif tk.Source == '[' then local obrac = get() local index = expr() local cbrac = expect('Symbol', ']') base = MkNode{ Type = 'IndexExpr'; Base = base; Index = index; Token_OpenBracket = obrac; Token_CloseBracket = cbrac; GetFirstToken = function(self) return self.Base:GetFirstToken() end; GetLastToken = function(self) return self.Token_CloseBracket end; } elseif tk.Source == '{' then base = MkNode{ Type = 'CallExpr'; Base = base; FunctionArguments = functionargs(); GetFirstToken = function(self) return self.Base:GetFirstToken() end; GetLastToken = function(self) return self.FunctionArguments:GetLastToken() end; } elseif tk.Source == '(' then base = MkNode{ Type = 'CallExpr'; Base = base; FunctionArguments = functionargs(); GetFirstToken = function(self) return self.Base:GetFirstToken() end; GetLastToken = function(self) return self.FunctionArguments:GetLastToken() end; } else return base end end end local function simpleexpr() local tk = peek() if tk.Type == 'Number' then return MkNode{ Type = 'NumberLiteral'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } elseif tk.Type == 'String' then return MkNode{ Type = 'StringLiteral'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } elseif tk.Source == 'nil' then return MkNode{ Type = 'NilLiteral'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } elseif tk.Source == 'true' or tk.Source == 'false' then return MkNode{ Type = 'BooleanLiteral'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } elseif tk.Source == '...' then return MkNode{ Type = 'VargLiteral'; Token = get(); GetFirstToken = function(self) return self.Token end; GetLastToken = function(self) return self.Token end; } elseif tk.Source == '{' then return tableexpr() elseif tk.Source == 'function' then return funcdecl(true) else return primaryexpr() end end local function subexpr(limit) local curNode; -- Initial Base Expression if isUnop() then local opTk = get() local ex = subexpr(UnaryPriority) curNode = MkNode{ Type = 'UnopExpr'; Token_Op = opTk; Rhs = ex; GetFirstToken = function(self) return self.Token_Op end; GetLastToken = function(self) return self.Rhs:GetLastToken() end; } else curNode = simpleexpr() assert(curNode, "nil simpleexpr") end -- Apply Precedence Recursion Chain while isBinop() and BinaryPriority[peek().Source][1] > limit do local opTk = get() local rhs = subexpr(BinaryPriority[opTk.Source][2]) assert(rhs, "RhsNeeded") curNode = MkNode{ Type = 'BinopExpr'; Lhs = curNode; Rhs = rhs; Token_Op = opTk; GetFirstToken = function(self) return self.Lhs:GetFirstToken() end; GetLastToken = function(self) return self.Rhs:GetLastToken() end; } end -- Return result return curNode end -- Expression expr = function() return subexpr(0) end -- Expression statement local function exprstat() local ex = primaryexpr() if ex.Type == 'MethodExpr' or ex.Type == 'CallExpr' then -- all good, calls can be statements return MkNode{ Type = 'CallExprStat'; Expression = ex; GetFirstToken = function(self) return self.Expression:GetFirstToken() end; GetLastToken = function(self) return self.Expression:GetLastToken() end; } else -- Assignment expr local lhs = {ex} local lhsSeparator = {} while peek().Source == ',' do table.insert(lhsSeparator, get()) local lhsPart = primaryexpr() if lhsPart.Type == 'MethodExpr' or lhsPart.Type == 'CallExpr' then error("Bad left hand side of assignment") end table.insert(lhs, lhsPart) end local eq = expect('Symbol', '=') local rhs = {expr()} local rhsSeparator = {} while peek().Source == ',' do table.insert(rhsSeparator, get()) table.insert(rhs, expr()) end return MkNode{ Type = 'AssignmentStat'; Rhs = rhs; Lhs = lhs; Token_Equals = eq; Token_LhsSeparatorList = lhsSeparator; Token_RhsSeparatorList = rhsSeparator; GetFirstToken = function(self) return self.Lhs[1]:GetFirstToken() end; GetLastToken = function(self) return self.Rhs[#self.Rhs]:GetLastToken() end; } end end -- If statement local function ifstat() local ifKw = get() local condition = expr() local thenKw = expect('Keyword', 'then') local ifBody = block() local elseClauses = {} while peek().Source == 'elseif' or peek().Source == 'else' do local elseifKw = get() local elseifCondition, elseifThenKw; if elseifKw.Source == 'elseif' then elseifCondition = expr() elseifThenKw = expect('Keyword', 'then') end local elseifBody = block() table.insert(elseClauses, { Condition = elseifCondition; Body = elseifBody; -- ClauseType = elseifKw.Source; Token = elseifKw; Token_Then = elseifThenKw; }) if elseifKw.Source == 'else' then break end end local enKw = expect('Keyword', 'end') return MkNode{ Type = 'IfStat'; Condition = condition; Body = ifBody; ElseClauseList = elseClauses; -- Token_If = ifKw; Token_Then = thenKw; Token_End = enKw; GetFirstToken = function(self) return self.Token_If end; GetLastToken = function(self) return self.Token_End end; } end -- Do statement local function dostat() local doKw = get() local body, enKw = blockbody('end') -- return MkNode{ Type = 'DoStat'; Body = body; -- Token_Do = doKw; Token_End = enKw; GetFirstToken = function(self) return self.Token_Do end; GetLastToken = function(self) return self.Token_End end; } end -- While statement local function whilestat() local whileKw = get() local condition = expr() local doKw = expect('Keyword', 'do') local body, enKw = blockbody('end') -- return MkNode{ Type = 'WhileStat'; Condition = condition; Body = body; -- Token_While = whileKw; Token_Do = doKw; Token_End = enKw; GetFirstToken = function(self) return self.Token_While end; GetLastToken = function(self) return self.Token_End end; } end -- For statement local function forstat() local forKw = get() local loopVars, loopVarCommas = varlist() local node = {} if peek().Source == '=' then local eqTk = get() local exprList, exprCommaList = exprlist() if #exprList < 2 or #exprList > 3 then error("expected 2 or 3 values for range bounds") end local doTk = expect('Keyword', 'do') local body, enTk = blockbody('end') return MkNode{ Type = 'NumericForStat'; VarList = loopVars; RangeList = exprList; Body = body; -- Token_For = forKw; Token_VarCommaList = loopVarCommas; Token_Equals = eqTk; Token_RangeCommaList = exprCommaList; Token_Do = doTk; Token_End = enTk; GetFirstToken = function(self) return self.Token_For end; GetLastToken = function(self) return self.Token_End end; } elseif peek().Source == 'in' then local inTk = get() local exprList, exprCommaList = exprlist() local doTk = expect('Keyword', 'do') local body, enTk = blockbody('end') return MkNode{ Type = 'GenericForStat'; VarList = loopVars; GeneratorList = exprList; Body = body; -- Token_For = forKw; Token_VarCommaList = loopVarCommas; Token_In = inTk; Token_GeneratorCommaList = exprCommaList; Token_Do = doTk; Token_End = enTk; GetFirstToken = function(self) return self.Token_For end; GetLastToken = function(self) return self.Token_End end; } else error("`=` or in expected") end end -- Repeat statement local function repeatstat() local repeatKw = get() local body, untilTk = blockbody('until') local condition = expr() return MkNode{ Type = 'RepeatStat'; Body = body; Condition = condition; -- Token_Repeat = repeatKw; Token_Until = untilTk; GetFirstToken = function(self) return self.Token_Repeat end; GetLastToken = function(self) return self.Condition:GetLastToken() end; } end -- Local var declaration local function localdecl() local localKw = get() if peek().Source == 'function' then -- Local function def local funcStat = funcdecl(false) if #funcStat.NameChain > 1 then error(getTokenStartPosition(funcStat.Token_NameChainSeparator[1])..": `(` expected.") end return MkNode{ Type = 'LocalFunctionStat'; FunctionStat = funcStat; Token_Local = localKw; GetFirstToken = function(self) return self.Token_Local end; GetLastToken = function(self) return self.FunctionStat:GetLastToken() end; } elseif peek().Type == 'Ident' then -- Local variable declaration local varList, varCommaList = varlist() local exprList, exprCommaList = {}, {} local eqToken; if peek().Source == '=' then eqToken = get() exprList, exprCommaList = exprlist() end return MkNode{ Type = 'LocalVarStat'; VarList = varList; ExprList = exprList; Token_Local = localKw; Token_Equals = eqToken; Token_VarCommaList = varCommaList; Token_ExprCommaList = exprCommaList; GetFirstToken = function(self) return self.Token_Local end; GetLastToken = function(self) if #self.ExprList > 0 then return self.ExprList[#self.ExprList]:GetLastToken() else return self.VarList[#self.VarList] end end; } else error("`function` or ident expected") end end -- Return statement local function retstat() local returnKw = get() local exprList; local commaList; if isBlockFollow() or peek().Source == ';' then exprList = {} commaList = {} else exprList, commaList = exprlist() end return { Type = 'ReturnStat'; ExprList = exprList; Token_Return = returnKw; Token_CommaList = commaList; GetFirstToken = function(self) return self.Token_Return end; GetLastToken = function(self) if #self.ExprList > 0 then return self.ExprList[#self.ExprList]:GetLastToken() else return self.Token_Return end end; } end -- Break statement local function breakstat() local breakKw = get() return { Type = 'BreakStat'; Token_Break = breakKw; GetFirstToken = function(self) return self.Token_Break end; GetLastToken = function(self) return self.Token_Break end; } end -- Expression local function statement() local tok = peek() if tok.Source == 'if' then return false, ifstat() elseif tok.Source == 'while' then return false, whilestat() elseif tok.Source == 'do' then return false, dostat() elseif tok.Source == 'for' then return false, forstat() elseif tok.Source == 'repeat' then return false, repeatstat() elseif tok.Source == 'function' then return false, funcdecl(false) elseif tok.Source == 'local' then return false, localdecl() elseif tok.Source == 'return' then return true, retstat() elseif tok.Source == 'break' then return true, breakstat() else return false, exprstat() end end -- Chunk block = function() local statements = {} local semicolons = {} local isLast = false while not isLast and not isBlockFollow() do -- Parse statement local stat; isLast, stat = statement() table.insert(statements, stat) local next = peek() if next.Type == 'Symbol' and next.Source == ';' then semicolons[#statements] = get() end end return { Type = 'StatList'; StatementList = statements; SemicolonList = semicolons; GetFirstToken = function(self) if #self.StatementList == 0 then return nil else return self.StatementList[1]:GetFirstToken() end end; GetLastToken = function(self) if #self.StatementList == 0 then return nil elseif self.SemicolonList[#self.StatementList] then -- Last token may be one of the semicolon separators return self.SemicolonList[#self.StatementList] else return self.StatementList[#self.StatementList]:GetLastToken() end end; } end return block() end function VisitAst(ast, visitors) local ExprType = lookupify{ 'BinopExpr'; 'UnopExpr'; 'NumberLiteral'; 'StringLiteral'; 'NilLiteral'; 'BooleanLiteral'; 'VargLiteral'; 'FieldExpr'; 'IndexExpr'; 'MethodExpr'; 'CallExpr'; 'FunctionLiteral'; 'VariableExpr'; 'ParenExpr'; 'TableLiteral'; } local StatType = lookupify{ 'StatList'; 'BreakStat'; 'ReturnStat'; 'LocalVarStat'; 'LocalFunctionStat'; 'FunctionStat'; 'RepeatStat'; 'GenericForStat'; 'NumericForStat'; 'WhileStat'; 'DoStat'; 'IfStat'; 'CallExprStat'; 'AssignmentStat'; } -- Check for typos in visitor construction for visitorSubject, visitor in pairs(visitors) do if not StatType[visitorSubject] and not ExprType[visitorSubject] then error("Invalid visitor target: `"..visitorSubject.."`") end end -- Helpers to call visitors on a node local function preVisit(exprOrStat) local visitor = visitors[exprOrStat.Type] if type(visitor) == 'function' then return visitor(exprOrStat) elseif visitor and visitor.Pre then return visitor.Pre(exprOrStat) end end local function postVisit(exprOrStat) local visitor = visitors[exprOrStat.Type] if visitor and type(visitor) == 'table' and visitor.Post then return visitor.Post(exprOrStat) end end local visitExpr, visitStat; visitExpr = function(expr) if preVisit(expr) then -- Handler did custom child iteration or blocked child iteration return end if expr.Type == 'BinopExpr' then visitExpr(expr.Lhs) visitExpr(expr.Rhs) elseif expr.Type == 'UnopExpr' then visitExpr(expr.Rhs) elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or expr.Type == 'VargLiteral' then -- No children to visit, single token literals elseif expr.Type == 'FieldExpr' then visitExpr(expr.Base) elseif expr.Type == 'IndexExpr' then visitExpr(expr.Base) visitExpr(expr.Index) elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then visitExpr(expr.Base) if expr.FunctionArguments.CallType == 'ArgCall' then for index, argExpr in pairs(expr.FunctionArguments.ArgList) do visitExpr(argExpr) end elseif expr.FunctionArguments.CallType == 'TableCall' then visitExpr(expr.FunctionArguments.TableExpr) end elseif expr.Type == 'FunctionLiteral' then visitStat(expr.Body) elseif expr.Type == 'VariableExpr' then -- No children to visit elseif expr.Type == 'ParenExpr' then visitExpr(expr.Expression) elseif expr.Type == 'TableLiteral' then for index, entry in pairs(expr.EntryList) do if entry.EntryType == 'Field' then visitExpr(entry.Value) elseif entry.EntryType == 'Index' then visitExpr(entry.Index) visitExpr(entry.Value) elseif entry.EntryType == 'Value' then visitExpr(entry.Value) else assert(false, "unreachable") end end else assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) end postVisit(expr) end visitStat = function(stat) if preVisit(stat) then -- Handler did custom child iteration or blocked child iteration return end if stat.Type == 'StatList' then for index, ch in pairs(stat.StatementList) do visitStat(ch) end elseif stat.Type == 'BreakStat' then -- No children to visit elseif stat.Type == 'ReturnStat' then for index, expr in pairs(stat.ExprList) do visitExpr(expr) end elseif stat.Type == 'LocalVarStat' then if stat.Token_Equals then for index, expr in pairs(stat.ExprList) do visitExpr(expr) end end elseif stat.Type == 'LocalFunctionStat' then visitStat(stat.FunctionStat.Body) elseif stat.Type == 'FunctionStat' then visitStat(stat.Body) elseif stat.Type == 'RepeatStat' then visitStat(stat.Body) visitExpr(stat.Condition) elseif stat.Type == 'GenericForStat' then for index, expr in pairs(stat.GeneratorList) do visitExpr(expr) end visitStat(stat.Body) elseif stat.Type == 'NumericForStat' then for index, expr in pairs(stat.RangeList) do visitExpr(expr) end visitStat(stat.Body) elseif stat.Type == 'WhileStat' then visitExpr(stat.Condition) visitStat(stat.Body) elseif stat.Type == 'DoStat' then visitStat(stat.Body) elseif stat.Type == 'IfStat' then visitExpr(stat.Condition) visitStat(stat.Body) for _, clause in pairs(stat.ElseClauseList) do if clause.Condition then visitExpr(clause.Condition) end visitStat(clause.Body) end elseif stat.Type == 'CallExprStat' then visitExpr(stat.Expression) elseif stat.Type == 'AssignmentStat' then for index, ex in pairs(stat.Lhs) do visitExpr(ex) end for index, ex in pairs(stat.Rhs) do visitExpr(ex) end else assert(false, "unreachable") end postVisit(stat) end if StatType[ast.Type] then visitStat(ast) else visitExpr(ast) end end function AddVariableInfo(ast) local globalVars = {} local currentScope = nil -- Numbering generator for variable lifetimes local locationGenerator = 0 local function markLocation() locationGenerator = locationGenerator + 1 return locationGenerator end -- Scope management local function pushScope() currentScope = { ParentScope = currentScope; ChildScopeList = {}; VariableList = {}; BeginLocation = markLocation(); } if currentScope.ParentScope then currentScope.Depth = currentScope.ParentScope.Depth + 1 table.insert(currentScope.ParentScope.ChildScopeList, currentScope) else currentScope.Depth = 1 end function currentScope:GetVar(varName) for _, var in pairs(self.VariableList) do if var.Name == varName then return var end end if self.ParentScope then return self.ParentScope:GetVar(varName) else for _, var in pairs(globalVars) do if var.Name == varName then return var end end end end end local function popScope() local scope = currentScope -- Mark where this scope ends scope.EndLocation = markLocation() -- Mark all of the variables in the scope as ending there for _, var in pairs(scope.VariableList) do var.ScopeEndLocation = scope.EndLocation end -- Move to the parent scope currentScope = scope.ParentScope return scope end pushScope() -- push initial scope -- Add / reference variables local function addLocalVar(name, setNameFunc, localInfo) assert(localInfo, "Misisng localInfo") assert(name, "Missing local var name") local var = { Type = 'Local'; Name = name; RenameList = {setNameFunc}; AssignedTo = false; Info = localInfo; UseCount = 0; Scope = currentScope; BeginLocation = markLocation(); EndLocation = markLocation(); ReferenceLocationList = {markLocation()}; } function var:Rename(newName) self.Name = newName for _, renameFunc in pairs(self.RenameList) do renameFunc(newName) end end function var:Reference() self.UseCount = self.UseCount + 1 end table.insert(currentScope.VariableList, var) return var end local function getGlobalVar(name) for _, var in pairs(globalVars) do if var.Name == name then return var end end local var = { Type = 'Global'; Name = name; RenameList = {}; AssignedTo = false; UseCount = 0; Scope = nil; -- Globals have no scope BeginLocation = markLocation(); EndLocation = markLocation(); ReferenceLocationList = {}; } function var:Rename(newName) self.Name = newName for _, renameFunc in pairs(self.RenameList) do renameFunc(newName) end end function var:Reference() self.UseCount = self.UseCount + 1 end table.insert(globalVars, var) return var end local function addGlobalReference(name, setNameFunc) assert(name, "Missing var name") local var = getGlobalVar(name) table.insert(var.RenameList, setNameFunc) return var end local function getLocalVar(scope, name) -- First search this scope -- Note: Reverse iterate here because Lua does allow shadowing a local -- within the same scope, and the later defined variable should -- be the one referenced. for i = #scope.VariableList, 1, -1 do if scope.VariableList[i].Name == name then return scope.VariableList[i] end end -- Then search parent scope if scope.ParentScope then local var = getLocalVar(scope.ParentScope, name) if var then return var end end -- Then return nil end local function referenceVariable(name, setNameFunc) assert(name, "Missing var name") local var = getLocalVar(currentScope, name) if var then table.insert(var.RenameList, setNameFunc) else var = addGlobalReference(name, setNameFunc) end -- Update the end location of where this variable is used, and -- add this location to the list of references to this variable. local curLocation = markLocation() var.EndLocation = curLocation table.insert(var.ReferenceLocationList, var.EndLocation) return var end local visitor = {} visitor.FunctionLiteral = { -- Function literal adds a new scope and adds the function literal arguments -- as local variables in the scope. Pre = function(expr) pushScope() for index, ident in pairs(expr.ArgList) do local var = addLocalVar(ident.Source, function(name) ident.Source = name end, { Type = 'Argument'; Index = index; }) end end; Post = function(expr) popScope() end; } visitor.VariableExpr = function(expr) -- Variable expression references from existing local varibales -- in the current scope, annotating the variable usage with variable -- information. expr.Variable = referenceVariable(expr.Token.Source, function(newName) expr.Token.Source = newName end) end visitor.StatList = { -- StatList adds a new scope Pre = function(stat) pushScope() end; Post = function(stat) popScope() end; } visitor.LocalVarStat = { Post = function(stat) -- Local var stat adds the local variables to the current scope as locals -- We need to visit the subexpressions first, because these new locals -- will not be in scope for the initialization value expressions. That is: -- `local bar = bar + 1` -- Is valid code for varNum, ident in pairs(stat.VarList) do addLocalVar(ident.Source, function(name) stat.VarList[varNum].Source = name end, { Type = 'Local'; }) end end; } visitor.LocalFunctionStat = { Pre = function(stat) -- Local function stat adds the function itself to the current scope as -- a local variable, and creates a new scope with the function arguments -- as local variables. addLocalVar(stat.FunctionStat.NameChain[1].Source, function(name) stat.FunctionStat.NameChain[1].Source = name end, { Type = 'LocalFunction'; }) pushScope() for index, ident in pairs(stat.FunctionStat.ArgList) do addLocalVar(ident.Source, function(name) ident.Source = name end, { Type = 'Argument'; Index = index; }) end end; Post = function() popScope() end; } visitor.FunctionStat = { Pre = function(stat) -- Function stat adds a new scope containing the function arguments -- as local variables. -- A function stat may also assign to a global variable if it is in -- the form `function foo()` with no additional dots/colons in the -- name chain. local nameChain = stat.NameChain local var; if #nameChain == 1 then -- If there is only one item in the name chain, then the first item -- is a reference to a global variable. var = addGlobalReference(nameChain[1].Source, function(name) nameChain[1].Source = name end) else var = referenceVariable(nameChain[1].Source, function(name) nameChain[1].Source = name end) end var.AssignedTo = true pushScope() for index, ident in pairs(stat.ArgList) do addLocalVar(ident.Source, function(name) ident.Source = name end, { Type = 'Argument'; Index = index; }) end end; Post = function() popScope() end; } visitor.GenericForStat = { Pre = function(stat) -- Generic fors need an extra scope holding the range variables -- Need a custom visitor so that the generator expressions can be -- visited before we push a scope, but the body can be visited -- after we push a scope. for _, ex in pairs(stat.GeneratorList) do VisitAst(ex, visitor) end pushScope() for index, ident in pairs(stat.VarList) do addLocalVar(ident.Source, function(name) ident.Source = name end, { Type = 'ForRange'; Index = index; }) end VisitAst(stat.Body, visitor) popScope() return true -- Custom visit end; } visitor.NumericForStat = { Pre = function(stat) -- Numeric fors need an extra scope holding the range variables -- Need a custom visitor so that the generator expressions can be -- visited before we push a scope, but the body can be visited -- after we push a scope. for _, ex in pairs(stat.RangeList) do VisitAst(ex, visitor) end pushScope() for index, ident in pairs(stat.VarList) do addLocalVar(ident.Source, function(name) ident.Source = name end, { Type = 'ForRange'; Index = index; }) end VisitAst(stat.Body, visitor) popScope() return true -- Custom visit end; } visitor.AssignmentStat = { Post = function(stat) -- For an assignment statement we need to mark the -- "assigned to" flag on variables. for _, ex in pairs(stat.Lhs) do if ex.Variable then ex.Variable.AssignedTo = true end end end; } VisitAst(ast, visitor) return globalVars, popScope() end -- Prints out an AST to a string function PrintAst(ast) local printStat, printExpr; local function printt(tk) if not tk.LeadingWhite or not tk.Source then error("Bad token: "..FormatTable(tk)) end io.write(tk.LeadingWhite) io.write(tk.Source) end printExpr = function(expr) if expr.Type == 'BinopExpr' then printExpr(expr.Lhs) printt(expr.Token_Op) printExpr(expr.Rhs) elseif expr.Type == 'UnopExpr' then printt(expr.Token_Op) printExpr(expr.Rhs) elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or expr.Type == 'VargLiteral' then -- Just print the token printt(expr.Token) elseif expr.Type == 'FieldExpr' then printExpr(expr.Base) printt(expr.Token_Dot) printt(expr.Field) elseif expr.Type == 'IndexExpr' then printExpr(expr.Base) printt(expr.Token_OpenBracket) printExpr(expr.Index) printt(expr.Token_CloseBracket) elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then printExpr(expr.Base) if expr.Type == 'MethodExpr' then printt(expr.Token_Colon) printt(expr.Method) end if expr.FunctionArguments.CallType == 'StringCall' then printt(expr.FunctionArguments.Token) elseif expr.FunctionArguments.CallType == 'ArgCall' then printt(expr.FunctionArguments.Token_OpenParen) for index, argExpr in pairs(expr.FunctionArguments.ArgList) do printExpr(argExpr) local sep = expr.FunctionArguments.Token_CommaList[index] if sep then printt(sep) end end printt(expr.FunctionArguments.Token_CloseParen) elseif expr.FunctionArguments.CallType == 'TableCall' then printExpr(expr.FunctionArguments.TableExpr) end elseif expr.Type == 'FunctionLiteral' then printt(expr.Token_Function) printt(expr.Token_OpenParen) for index, arg in pairs(expr.ArgList) do printt(arg) local comma = expr.Token_ArgCommaList[index] if comma then printt(comma) end end printt(expr.Token_CloseParen) printStat(expr.Body) printt(expr.Token_End) elseif expr.Type == 'VariableExpr' then printt(expr.Token) elseif expr.Type == 'ParenExpr' then printt(expr.Token_OpenParen) printExpr(expr.Expression) printt(expr.Token_CloseParen) elseif expr.Type == 'TableLiteral' then printt(expr.Token_OpenBrace) for index, entry in pairs(expr.EntryList) do if entry.EntryType == 'Field' then printt(entry.Field) printt(entry.Token_Equals) printExpr(entry.Value) elseif entry.EntryType == 'Index' then printt(entry.Token_OpenBracket) printExpr(entry.Index) printt(entry.Token_CloseBracket) printt(entry.Token_Equals) printExpr(entry.Value) elseif entry.EntryType == 'Value' then printExpr(entry.Value) else assert(false, "unreachable") end local sep = expr.Token_SeparatorList[index] if sep then printt(sep) end end printt(expr.Token_CloseBrace) else assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) end end printStat = function(stat) if stat.Type == 'StatList' then for index, ch in pairs(stat.StatementList) do printStat(ch) if stat.SemicolonList[index] then printt(stat.SemicolonList[index]) end end elseif stat.Type == 'BreakStat' then printt(stat.Token_Break) elseif stat.Type == 'ReturnStat' then printt(stat.Token_Return) for index, expr in pairs(stat.ExprList) do printExpr(expr) if stat.Token_CommaList[index] then printt(stat.Token_CommaList[index]) end end elseif stat.Type == 'LocalVarStat' then printt(stat.Token_Local) for index, var in pairs(stat.VarList) do printt(var) local comma = stat.Token_VarCommaList[index] if comma then printt(comma) end end if stat.Token_Equals then printt(stat.Token_Equals) for index, expr in pairs(stat.ExprList) do printExpr(expr) local comma = stat.Token_ExprCommaList[index] if comma then printt(comma) end end end elseif stat.Type == 'LocalFunctionStat' then printt(stat.Token_Local) printt(stat.FunctionStat.Token_Function) printt(stat.FunctionStat.NameChain[1]) printt(stat.FunctionStat.Token_OpenParen) for index, arg in pairs(stat.FunctionStat.ArgList) do printt(arg) local comma = stat.FunctionStat.Token_ArgCommaList[index] if comma then printt(comma) end end printt(stat.FunctionStat.Token_CloseParen) printStat(stat.FunctionStat.Body) printt(stat.FunctionStat.Token_End) elseif stat.Type == 'FunctionStat' then printt(stat.Token_Function) for index, part in pairs(stat.NameChain) do printt(part) local sep = stat.Token_NameChainSeparator[index] if sep then printt(sep) end end printt(stat.Token_OpenParen) for index, arg in pairs(stat.ArgList) do printt(arg) local comma = stat.Token_ArgCommaList[index] if comma then printt(comma) end end printt(stat.Token_CloseParen) printStat(stat.Body) printt(stat.Token_End) elseif stat.Type == 'RepeatStat' then printt(stat.Token_Repeat) printStat(stat.Body) printt(stat.Token_Until) printExpr(stat.Condition) elseif stat.Type == 'GenericForStat' then printt(stat.Token_For) for index, var in pairs(stat.VarList) do printt(var) local sep = stat.Token_VarCommaList[index] if sep then printt(sep) end end printt(stat.Token_In) for index, expr in pairs(stat.GeneratorList) do printExpr(expr) local sep = stat.Token_GeneratorCommaList[index] if sep then printt(sep) end end printt(stat.Token_Do) printStat(stat.Body) printt(stat.Token_End) elseif stat.Type == 'NumericForStat' then printt(stat.Token_For) for index, var in pairs(stat.VarList) do printt(var) local sep = stat.Token_VarCommaList[index] if sep then printt(sep) end end printt(stat.Token_Equals) for index, expr in pairs(stat.RangeList) do printExpr(expr) local sep = stat.Token_RangeCommaList[index] if sep then printt(sep) end end printt(stat.Token_Do) printStat(stat.Body) printt(stat.Token_End) elseif stat.Type == 'WhileStat' then printt(stat.Token_While) printExpr(stat.Condition) printt(stat.Token_Do) printStat(stat.Body) printt(stat.Token_End) elseif stat.Type == 'DoStat' then printt(stat.Token_Do) printStat(stat.Body) printt(stat.Token_End) elseif stat.Type == 'IfStat' then printt(stat.Token_If) printExpr(stat.Condition) printt(stat.Token_Then) printStat(stat.Body) for _, clause in pairs(stat.ElseClauseList) do printt(clause.Token) if clause.Condition then printExpr(clause.Condition) printt(clause.Token_Then) end printStat(clause.Body) end printt(stat.Token_End) elseif stat.Type == 'CallExprStat' then printExpr(stat.Expression) elseif stat.Type == 'AssignmentStat' then for index, ex in pairs(stat.Lhs) do printExpr(ex) local sep = stat.Token_LhsSeparatorList[index] if sep then printt(sep) end end printt(stat.Token_Equals) for index, ex in pairs(stat.Rhs) do printExpr(ex) local sep = stat.Token_RhsSeparatorList[index] if sep then printt(sep) end end else assert(false, "unreachable") end end printStat(ast) end -- Adds / removes whitespace in an AST to put it into a "standard formatting" local function FormatAst(ast) local formatStat, formatExpr; local currentIndent = 0 local function applyIndent(token) local indentString = '\n'..('\t'):rep(currentIndent) if token.LeadingWhite == '' or (token.LeadingWhite:sub(-#indentString, -1) ~= indentString) then -- Trim existing trailing whitespace on LeadingWhite -- Trim trailing tabs and spaces, and up to one newline token.LeadingWhite = token.LeadingWhite:gsub("\n?[\t ]*$", "") token.LeadingWhite = token.LeadingWhite..indentString end end local function indent() currentIndent = currentIndent + 1 end local function undent() currentIndent = currentIndent - 1 assert(currentIndent >= 0, "Undented too far") end local function leadingChar(tk) if #tk.LeadingWhite > 0 then return tk.LeadingWhite:sub(1,1) else return tk.Source:sub(1,1) end end local function padToken(tk) if not WhiteChars[leadingChar(tk)] then tk.LeadingWhite = ' '..tk.LeadingWhite end end local function padExpr(expr) padToken(expr:GetFirstToken()) end local function formatBody(openToken, bodyStat, closeToken) indent() formatStat(bodyStat) undent() applyIndent(closeToken) end formatExpr = function(expr) if expr.Type == 'BinopExpr' then formatExpr(expr.Lhs) formatExpr(expr.Rhs) if expr.Token_Op.Source == '..' then -- No padding on .. else padExpr(expr.Rhs) padToken(expr.Token_Op) end elseif expr.Type == 'UnopExpr' then formatExpr(expr.Rhs) --(expr.Token_Op) elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or expr.Type == 'VargLiteral' then -- Nothing to do --(expr.Token) elseif expr.Type == 'FieldExpr' then formatExpr(expr.Base) --(expr.Token_Dot) --(expr.Field) elseif expr.Type == 'IndexExpr' then formatExpr(expr.Base) formatExpr(expr.Index) --(expr.Token_OpenBracket) --(expr.Token_CloseBracket) elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then formatExpr(expr.Base) if expr.Type == 'MethodExpr' then --(expr.Token_Colon) --(expr.Method) end if expr.FunctionArguments.CallType == 'StringCall' then --(expr.FunctionArguments.Token) elseif expr.FunctionArguments.CallType == 'ArgCall' then --(expr.FunctionArguments.Token_OpenParen) for index, argExpr in pairs(expr.FunctionArguments.ArgList) do formatExpr(argExpr) if index > 1 then padExpr(argExpr) end local sep = expr.FunctionArguments.Token_CommaList[index] if sep then --(sep) end end --(expr.FunctionArguments.Token_CloseParen) elseif expr.FunctionArguments.CallType == 'TableCall' then formatExpr(expr.FunctionArguments.TableExpr) end elseif expr.Type == 'FunctionLiteral' then --(expr.Token_Function) --(expr.Token_OpenParen) for index, arg in pairs(expr.ArgList) do --(arg) if index > 1 then padToken(arg) end local comma = expr.Token_ArgCommaList[index] if comma then --(comma) end end --(expr.Token_CloseParen) formatBody(expr.Token_CloseParen, expr.Body, expr.Token_End) elseif expr.Type == 'VariableExpr' then --(expr.Token) elseif expr.Type == 'ParenExpr' then formatExpr(expr.Expression) --(expr.Token_OpenParen) --(expr.Token_CloseParen) elseif expr.Type == 'TableLiteral' then --(expr.Token_OpenBrace) if #expr.EntryList == 0 then -- Nothing to do else indent() for index, entry in pairs(expr.EntryList) do if entry.EntryType == 'Field' then applyIndent(entry.Field) padToken(entry.Token_Equals) formatExpr(entry.Value) padExpr(entry.Value) elseif entry.EntryType == 'Index' then applyIndent(entry.Token_OpenBracket) formatExpr(entry.Index) --(entry.Token_CloseBracket) padToken(entry.Token_Equals) formatExpr(entry.Value) padExpr(entry.Value) elseif entry.EntryType == 'Value' then formatExpr(entry.Value) applyIndent(entry.Value:GetFirstToken()) else assert(false, "unreachable") end local sep = expr.Token_SeparatorList[index] if sep then --(sep) end end undent() applyIndent(expr.Token_CloseBrace) end --(expr.Token_CloseBrace) else assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) end end formatStat = function(stat) if stat.Type == 'StatList' then for _, stat in pairs(stat.StatementList) do formatStat(stat) applyIndent(stat:GetFirstToken()) end elseif stat.Type == 'BreakStat' then --(stat.Token_Break) elseif stat.Type == 'ReturnStat' then --(stat.Token_Return) for index, expr in pairs(stat.ExprList) do formatExpr(expr) padExpr(expr) if stat.Token_CommaList[index] then --(stat.Token_CommaList[index]) end end elseif stat.Type == 'LocalVarStat' then --(stat.Token_Local) for index, var in pairs(stat.VarList) do padToken(var) local comma = stat.Token_VarCommaList[index] if comma then --(comma) end end if stat.Token_Equals then padToken(stat.Token_Equals) for index, expr in pairs(stat.ExprList) do formatExpr(expr) padExpr(expr) local comma = stat.Token_ExprCommaList[index] if comma then --(comma) end end end elseif stat.Type == 'LocalFunctionStat' then --(stat.Token_Local) padToken(stat.FunctionStat.Token_Function) padToken(stat.FunctionStat.NameChain[1]) --(stat.FunctionStat.Token_OpenParen) for index, arg in pairs(stat.FunctionStat.ArgList) do if index > 1 then padToken(arg) end local comma = stat.FunctionStat.Token_ArgCommaList[index] if comma then --(comma) end end --(stat.FunctionStat.Token_CloseParen) formatBody(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End) elseif stat.Type == 'FunctionStat' then --(stat.Token_Function) for index, part in pairs(stat.NameChain) do if index == 1 then padToken(part) end local sep = stat.Token_NameChainSeparator[index] if sep then --(sep) end end --(stat.Token_OpenParen) for index, arg in pairs(stat.ArgList) do if index > 1 then padToken(arg) end local comma = stat.Token_ArgCommaList[index] if comma then --(comma) end end --(stat.Token_CloseParen) formatBody(stat.Token_CloseParen, stat.Body, stat.Token_End) elseif stat.Type == 'RepeatStat' then --(stat.Token_Repeat) formatBody(stat.Token_Repeat, stat.Body, stat.Token_Until) formatExpr(stat.Condition) padExpr(stat.Condition) elseif stat.Type == 'GenericForStat' then --(stat.Token_For) for index, var in pairs(stat.VarList) do padToken(var) local sep = stat.Token_VarCommaList[index] if sep then --(sep) end end padToken(stat.Token_In) for index, expr in pairs(stat.GeneratorList) do formatExpr(expr) padExpr(expr) local sep = stat.Token_GeneratorCommaList[index] if sep then --(sep) end end padToken(stat.Token_Do) formatBody(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'NumericForStat' then --(stat.Token_For) for index, var in pairs(stat.VarList) do padToken(var) local sep = stat.Token_VarCommaList[index] if sep then --(sep) end end padToken(stat.Token_Equals) for index, expr in pairs(stat.RangeList) do formatExpr(expr) padExpr(expr) local sep = stat.Token_RangeCommaList[index] if sep then --(sep) end end padToken(stat.Token_Do) formatBody(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'WhileStat' then --(stat.Token_While) formatExpr(stat.Condition) padExpr(stat.Condition) padToken(stat.Token_Do) formatBody(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'DoStat' then --(stat.Token_Do) formatBody(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'IfStat' then --(stat.Token_If) formatExpr(stat.Condition) padExpr(stat.Condition) padToken(stat.Token_Then) -- local lastBodyOpen = stat.Token_Then local lastBody = stat.Body -- for _, clause in pairs(stat.ElseClauseList) do formatBody(lastBodyOpen, lastBody, clause.Token) lastBodyOpen = clause.Token -- if clause.Condition then formatExpr(clause.Condition) padExpr(clause.Condition) padToken(clause.Token_Then) lastBodyOpen = clause.Token_Then end lastBody = clause.Body end -- formatBody(lastBodyOpen, lastBody, stat.Token_End) elseif stat.Type == 'CallExprStat' then formatExpr(stat.Expression) elseif stat.Type == 'AssignmentStat' then for index, ex in pairs(stat.Lhs) do formatExpr(ex) if index > 1 then padExpr(ex) end local sep = stat.Token_LhsSeparatorList[index] if sep then --(sep) end end padToken(stat.Token_Equals) for index, ex in pairs(stat.Rhs) do formatExpr(ex) padExpr(ex) local sep = stat.Token_RhsSeparatorList[index] if sep then --(sep) end end else assert(false, "unreachable") end end formatStat(ast) end -- Strips as much whitespace off of tokens in an AST as possible without causing problems local function StripAst(ast) local stripStat, stripExpr; local function stript(token) token.LeadingWhite = '' end -- Make to adjacent tokens as close as possible local function joint(tokenA, tokenB) -- Strip the second token's whitespace stript(tokenB) -- Get the trailing A <-> leading B character pair local lastCh = tokenA.Source:sub(-1, -1) local firstCh = tokenB.Source:sub(1, 1) -- Cases to consider: -- Touching minus signs -> comment: `- -42` -> `--42' is invalid -- Touching words: `a b` -> `ab` is invalid -- Touching digits: `2 3`, can't occurr in the Lua syntax as number literals aren't a primary expression -- Abiguous syntax: `f(x)\n(x)()` is already disallowed, we can't cause a problem by removing newlines -- Figure out what separation is needed if (lastCh == '-' and firstCh == '-') or (AllIdentChars[lastCh] and AllIdentChars[firstCh]) then tokenB.LeadingWhite = ' ' -- Use a separator else tokenB.LeadingWhite = '' -- Don't use a separator end end -- Join up a statement body and it's opening / closing tokens local function bodyjoint(open, body, close) stripStat(body) stript(close) local bodyFirst = body:GetFirstToken() local bodyLast = body:GetLastToken() if bodyFirst then -- Body is non-empty, join body to open / close joint(open, bodyFirst) joint(bodyLast, close) else -- Body is empty, just join open and close token together joint(open, close) end end stripExpr = function(expr) if expr.Type == 'BinopExpr' then stripExpr(expr.Lhs) stript(expr.Token_Op) stripExpr(expr.Rhs) -- Handle the `a - -b` -/-> `a--b` case which would otherwise incorrectly generate a comment -- Also handles operators "or" / "and" which definitely need joining logic in a bunch of cases joint(expr.Token_Op, expr.Rhs:GetFirstToken()) joint(expr.Lhs:GetLastToken(), expr.Token_Op) elseif expr.Type == 'UnopExpr' then stript(expr.Token_Op) stripExpr(expr.Rhs) -- Handle the `- -b` -/-> `--b` case which would otherwise incorrectly generate a comment joint(expr.Token_Op, expr.Rhs:GetFirstToken()) elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or expr.Type == 'VargLiteral' then -- Just print the token stript(expr.Token) elseif expr.Type == 'FieldExpr' then stripExpr(expr.Base) stript(expr.Token_Dot) stript(expr.Field) elseif expr.Type == 'IndexExpr' then stripExpr(expr.Base) stript(expr.Token_OpenBracket) stripExpr(expr.Index) stript(expr.Token_CloseBracket) elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then stripExpr(expr.Base) if expr.Type == 'MethodExpr' then stript(expr.Token_Colon) stript(expr.Method) end if expr.FunctionArguments.CallType == 'StringCall' then stript(expr.FunctionArguments.Token) elseif expr.FunctionArguments.CallType == 'ArgCall' then stript(expr.FunctionArguments.Token_OpenParen) for index, argExpr in pairs(expr.FunctionArguments.ArgList) do stripExpr(argExpr) local sep = expr.FunctionArguments.Token_CommaList[index] if sep then stript(sep) end end stript(expr.FunctionArguments.Token_CloseParen) elseif expr.FunctionArguments.CallType == 'TableCall' then stripExpr(expr.FunctionArguments.TableExpr) end elseif expr.Type == 'FunctionLiteral' then stript(expr.Token_Function) stript(expr.Token_OpenParen) for index, arg in pairs(expr.ArgList) do stript(arg) local comma = expr.Token_ArgCommaList[index] if comma then stript(comma) end end stript(expr.Token_CloseParen) bodyjoint(expr.Token_CloseParen, expr.Body, expr.Token_End) elseif expr.Type == 'VariableExpr' then stript(expr.Token) elseif expr.Type == 'ParenExpr' then stript(expr.Token_OpenParen) stripExpr(expr.Expression) stript(expr.Token_CloseParen) elseif expr.Type == 'TableLiteral' then stript(expr.Token_OpenBrace) for index, entry in pairs(expr.EntryList) do if entry.EntryType == 'Field' then stript(entry.Field) stript(entry.Token_Equals) stripExpr(entry.Value) elseif entry.EntryType == 'Index' then stript(entry.Token_OpenBracket) stripExpr(entry.Index) stript(entry.Token_CloseBracket) stript(entry.Token_Equals) stripExpr(entry.Value) elseif entry.EntryType == 'Value' then stripExpr(entry.Value) else assert(false, "unreachable") end local sep = expr.Token_SeparatorList[index] if sep then stript(sep) end end stript(expr.Token_CloseBrace) else assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) end end stripStat = function(stat) if stat.Type == 'StatList' then -- Strip all surrounding whitespace on statement lists along with separating whitespace for i = 1, #stat.StatementList do local chStat = stat.StatementList[i] -- Strip the statement and it's whitespace stripStat(chStat) stript(chStat:GetFirstToken()) -- If there was a last statement, join them appropriately local lastChStat = stat.StatementList[i-1] if lastChStat then -- See if we can remove a semi-colon, the only case where we can't is if -- this and the last statement have a `);(` pair, where removing the semi-colon -- would introduce ambiguous syntax. if stat.SemicolonList[i-1] and (lastChStat:GetLastToken().Source ~= ')' or chStat:GetFirstToken().Source ~= ')') then stat.SemicolonList[i-1] = nil end -- If there isn't a semi-colon, we should safely join the two statements -- (If there is one, then no whitespace leading chStat is always okay) if not stat.SemicolonList[i-1] then joint(lastChStat:GetLastToken(), chStat:GetFirstToken()) end end end -- A semi-colon is never needed on the last stat in a statlist: stat.SemicolonList[#stat.StatementList] = nil -- The leading whitespace on the statlist should be stripped if #stat.StatementList > 0 then stript(stat.StatementList[1]:GetFirstToken()) end elseif stat.Type == 'BreakStat' then stript(stat.Token_Break) elseif stat.Type == 'ReturnStat' then stript(stat.Token_Return) for index, expr in pairs(stat.ExprList) do stripExpr(expr) if stat.Token_CommaList[index] then stript(stat.Token_CommaList[index]) end end if #stat.ExprList > 0 then joint(stat.Token_Return, stat.ExprList[1]:GetFirstToken()) end elseif stat.Type == 'LocalVarStat' then stript(stat.Token_Local) for index, var in pairs(stat.VarList) do if index == 1 then joint(stat.Token_Local, var) else stript(var) end local comma = stat.Token_VarCommaList[index] if comma then stript(comma) end end if stat.Token_Equals then stript(stat.Token_Equals) for index, expr in pairs(stat.ExprList) do stripExpr(expr) local comma = stat.Token_ExprCommaList[index] if comma then stript(comma) end end end elseif stat.Type == 'LocalFunctionStat' then stript(stat.Token_Local) joint(stat.Token_Local, stat.FunctionStat.Token_Function) joint(stat.FunctionStat.Token_Function, stat.FunctionStat.NameChain[1]) joint(stat.FunctionStat.NameChain[1], stat.FunctionStat.Token_OpenParen) for index, arg in pairs(stat.FunctionStat.ArgList) do stript(arg) local comma = stat.FunctionStat.Token_ArgCommaList[index] if comma then stript(comma) end end stript(stat.FunctionStat.Token_CloseParen) bodyjoint(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End) elseif stat.Type == 'FunctionStat' then stript(stat.Token_Function) for index, part in pairs(stat.NameChain) do if index == 1 then joint(stat.Token_Function, part) else stript(part) end local sep = stat.Token_NameChainSeparator[index] if sep then stript(sep) end end stript(stat.Token_OpenParen) for index, arg in pairs(stat.ArgList) do stript(arg) local comma = stat.Token_ArgCommaList[index] if comma then stript(comma) end end stript(stat.Token_CloseParen) bodyjoint(stat.Token_CloseParen, stat.Body, stat.Token_End) elseif stat.Type == 'RepeatStat' then stript(stat.Token_Repeat) bodyjoint(stat.Token_Repeat, stat.Body, stat.Token_Until) stripExpr(stat.Condition) joint(stat.Token_Until, stat.Condition:GetFirstToken()) elseif stat.Type == 'GenericForStat' then stript(stat.Token_For) for index, var in pairs(stat.VarList) do if index == 1 then joint(stat.Token_For, var) else stript(var) end local sep = stat.Token_VarCommaList[index] if sep then stript(sep) end end joint(stat.VarList[#stat.VarList], stat.Token_In) for index, expr in pairs(stat.GeneratorList) do stripExpr(expr) if index == 1 then joint(stat.Token_In, expr:GetFirstToken()) end local sep = stat.Token_GeneratorCommaList[index] if sep then stript(sep) end end joint(stat.GeneratorList[#stat.GeneratorList]:GetLastToken(), stat.Token_Do) bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'NumericForStat' then stript(stat.Token_For) for index, var in pairs(stat.VarList) do if index == 1 then joint(stat.Token_For, var) else stript(var) end local sep = stat.Token_VarCommaList[index] if sep then stript(sep) end end joint(stat.VarList[#stat.VarList], stat.Token_Equals) for index, expr in pairs(stat.RangeList) do stripExpr(expr) if index == 1 then joint(stat.Token_Equals, expr:GetFirstToken()) end local sep = stat.Token_RangeCommaList[index] if sep then stript(sep) end end joint(stat.RangeList[#stat.RangeList]:GetLastToken(), stat.Token_Do) bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'WhileStat' then stript(stat.Token_While) stripExpr(stat.Condition) stript(stat.Token_Do) joint(stat.Token_While, stat.Condition:GetFirstToken()) joint(stat.Condition:GetLastToken(), stat.Token_Do) bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'DoStat' then stript(stat.Token_Do) stript(stat.Token_End) bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) elseif stat.Type == 'IfStat' then stript(stat.Token_If) stripExpr(stat.Condition) joint(stat.Token_If, stat.Condition:GetFirstToken()) joint(stat.Condition:GetLastToken(), stat.Token_Then) -- local lastBodyOpen = stat.Token_Then local lastBody = stat.Body -- for _, clause in pairs(stat.ElseClauseList) do bodyjoint(lastBodyOpen, lastBody, clause.Token) lastBodyOpen = clause.Token -- if clause.Condition then stripExpr(clause.Condition) joint(clause.Token, clause.Condition:GetFirstToken()) joint(clause.Condition:GetLastToken(), clause.Token_Then) lastBodyOpen = clause.Token_Then end stripStat(clause.Body) lastBody = clause.Body end -- bodyjoint(lastBodyOpen, lastBody, stat.Token_End) elseif stat.Type == 'CallExprStat' then stripExpr(stat.Expression) elseif stat.Type == 'AssignmentStat' then for index, ex in pairs(stat.Lhs) do stripExpr(ex) local sep = stat.Token_LhsSeparatorList[index] if sep then stript(sep) end end stript(stat.Token_Equals) for index, ex in pairs(stat.Rhs) do stripExpr(ex) local sep = stat.Token_RhsSeparatorList[index] if sep then stript(sep) end end else assert(false, "unreachable") end end stripStat(ast) end local idGen = 0 local VarDigits = {} for i = ('a'):byte(), ('z'):byte() do table.insert(VarDigits, string.char(i)) end for i = ('A'):byte(), ('Z'):byte() do table.insert(VarDigits, string.char(i)) end for i = ('0'):byte(), ('9'):byte() do table.insert(VarDigits, string.char(i)) end table.insert(VarDigits, '_') local VarStartDigits = {} for i = ('a'):byte(), ('z'):byte() do table.insert(VarStartDigits, string.char(i)) end for i = ('A'):byte(), ('Z'):byte() do table.insert(VarStartDigits, string.char(i)) end local function indexToVarName(index) local id = '' local d = index % #VarStartDigits index = (index - d) / #VarStartDigits id = id..VarStartDigits[d+1] while index > 0 do local d = index % #VarDigits index = (index - d) / #VarDigits id = id..VarDigits[d+1] end return id end local function genNextVarName() local varToUse = idGen idGen = idGen + 1 return indexToVarName(varToUse) end local function genVarName() local varName = '' repeat varName = genNextVarName() until not Keywords[varName] return varName end local function MinifyVariables(globalScope, rootScope) -- externalGlobals is a set of global variables that have not been assigned to, that is -- global variables defined "externally to the script". We are not going to be renaming -- those, and we have to make sure that we don't collide with them when renaming -- things so we keep track of them in this set. local externalGlobals = {} -- First we want to rename all of the variables to unique temoraries, so that we can -- easily use the scope::GetVar function to check whether renames are valid. local temporaryIndex = 0 for _, var in pairs(globalScope) do if var.AssignedTo then var:Rename('_TMP_'..temporaryIndex..'_') temporaryIndex = temporaryIndex + 1 else -- Not assigned to, external global externalGlobals[var.Name] = true end end local function temporaryRename(scope) for _, var in pairs(scope.VariableList) do var:Rename('_TMP_'..temporaryIndex..'_') temporaryIndex = temporaryIndex + 1 end for _, childScope in pairs(scope.ChildScopeList) do temporaryRename(childScope) end end -- Now we go through renaming, first do globals, we probably want them -- to have shorter names in general. -- TODO: Rename all vars based on frequency patterns, giving variables -- used more shorter names. local nextFreeNameIndex = 0 for _, var in pairs(globalScope) do if var.AssignedTo then local varName = '' repeat varName = indexToVarName(nextFreeNameIndex) nextFreeNameIndex = nextFreeNameIndex + 1 until not Keywords[varName] and not externalGlobals[varName] var:Rename(varName) end end -- Now rename all local vars rootScope.FirstFreeName = nextFreeNameIndex local function doRenameScope(scope) for _, var in pairs(scope.VariableList) do local varName = '' repeat varName = indexToVarName(scope.FirstFreeName) scope.FirstFreeName = scope.FirstFreeName + 1 until not Keywords[varName] and not externalGlobals[varName] var:Rename(varName) end for _, childScope in pairs(scope.ChildScopeList) do childScope.FirstFreeName = scope.FirstFreeName doRenameScope(childScope) end end doRenameScope(rootScope) end local function MinifyVariables_2(globalScope, rootScope) -- Variable names and other names that are fixed, that we cannot use -- Either these are Lua keywords, or globals that are not assigned to, -- that is environmental globals that are assigned elsewhere beyond our -- control. local globalUsedNames = {} for kw, _ in pairs(Keywords) do globalUsedNames[kw] = true end -- Gather a list of all of the variables that we will rename local allVariables = {} local allLocalVariables = {} do -- Add applicable globals for _, var in pairs(globalScope) do if var.AssignedTo then -- We can try to rename this global since it was assigned to -- (and thus presumably initialized) in the script we are -- minifying. table.insert(allVariables, var) else -- We can't rename this global, mark it as an unusable name -- and don't add it to the nename list globalUsedNames[var.Name] = true end end -- Recursively add locals, we can rename all of those local function addFrom(scope) for _, var in pairs(scope.VariableList) do table.insert(allVariables, var) table.insert(allLocalVariables, var) end for _, childScope in pairs(scope.ChildScopeList) do addFrom(childScope) end end addFrom(rootScope) end -- Add used name arrays to variables for _, var in pairs(allVariables) do var.UsedNameArray = {} end -- Sort the least used variables first table.sort(allVariables, function(a, b) return #a.RenameList < #b.RenameList end) -- Lazy generator for valid names to rename to local nextValidNameIndex = 0 local varNamesLazy = {} local function varIndexToValidVarName(i) local name = varNamesLazy[i] if not name then repeat name = indexToVarName(nextValidNameIndex) nextValidNameIndex = nextValidNameIndex + 1 until not globalUsedNames[name] varNamesLazy[i] = name end return name end -- For each variable, go to rename it for _, var in pairs(allVariables) do -- Lazy... todo: Make theis pair a proper for-each-pair-like set of loops -- rather than using a renamed flag. var.Renamed = true -- Find the first unused name local i = 1 while var.UsedNameArray[i] do i = i + 1 end -- Rename the variable to that name var:Rename(varIndexToValidVarName(i)) if var.Scope then -- Now we need to mark the name as unusable by any variables: -- 1) At the same depth that overlap lifetime with this one -- 2) At a deeper level, which have a reference to this variable in their lifetimes -- 3) At a shallower level, which are referenced during this variable's lifetime for _, otherVar in pairs(allVariables) do if not otherVar.Renamed then if not otherVar.Scope or otherVar.Scope.Depth < var.Scope.Depth then -- Check Global variable (Which is always at a shallower level) -- or -- Check case 3 -- The other var is at a shallower depth, is there a reference to it -- durring this variable's lifetime? for _, refAt in pairs(otherVar.ReferenceLocationList) do if refAt >= var.BeginLocation and refAt <= var.ScopeEndLocation then -- Collide otherVar.UsedNameArray[i] = true break end end elseif otherVar.Scope.Depth > var.Scope.Depth then -- Check Case 2 -- The other var is at a greater depth, see if any of the references -- to this variable are in the other var's lifetime. for _, refAt in pairs(var.ReferenceLocationList) do if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then -- Collide otherVar.UsedNameArray[i] = true break end end else --otherVar.Scope.Depth must be equal to var.Scope.Depth -- Check case 1 -- The two locals are in the same scope -- Just check if the usage lifetimes overlap within that scope. That is, we -- can shadow a local variable within the same scope as long as the usages -- of the two locals do not overlap. if var.BeginLocation < otherVar.EndLocation and var.EndLocation > otherVar.BeginLocation then otherVar.UsedNameArray[i] = true end end end end else -- This is a global var, all other globals can't collide with it, and -- any local variable with a reference to this global in it's lifetime -- can't collide with it. for _, otherVar in pairs(allVariables) do if not otherVar.Renamed then if otherVar.Type == 'Global' then otherVar.UsedNameArray[i] = true elseif otherVar.Type == 'Local' then -- Other var is a local, see if there is a reference to this global within -- that local's lifetime. for _, refAt in pairs(var.ReferenceLocationList) do if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then -- Collide otherVar.UsedNameArray[i] = true break end end else assert(false, "unreachable") end end end end end -- -- -- print("Total Variables: "..#allVariables) -- print("Total Range: "..rootScope.BeginLocation.."-"..rootScope.EndLocation) -- print("") -- for _, var in pairs(allVariables) do -- io.write("`"..var.Name.."':\n\t#symbols: "..#var.RenameList.. -- "\n\tassigned to: "..tostring(var.AssignedTo)) -- if var.Type == 'Local' then -- io.write("\n\trange: "..var.BeginLocation.."-"..var.EndLocation) -- io.write("\n\tlocal type: "..var.Info.Type) -- end -- io.write("\n\n") -- end -- -- First we want to rename all of the variables to unique temoraries, so that we can -- -- easily use the scope::GetVar function to check whether renames are valid. -- local temporaryIndex = 0 -- for _, var in pairs(allVariables) do -- var:Rename('_TMP_'..temporaryIndex..'_') -- temporaryIndex = temporaryIndex + 1 -- end -- For each variable, we need to build a list of names that collide with it -- --error() end local function BeautifyVariables(globalScope, rootScope) local externalGlobals = {} for _, var in pairs(globalScope) do if not var.AssignedTo then externalGlobals[var.Name] = true end end local localNumber = 1 local globalNumber = 1 local function setVarName(var, name) var.Name = name for _, setter in pairs(var.RenameList) do setter(name) end end for _, var in pairs(globalScope) do if var.AssignedTo then setVarName(var, 'G_'..globalNumber) globalNumber = globalNumber + 1 end end local function modify(scope) for _, var in pairs(scope.VariableList) do local name = 'L_'..localNumber..'_' if var.Info.Type == 'Argument' then name = name..'arg'..var.Info.Index elseif var.Info.Type == 'LocalFunction' then name = name..'func' elseif var.Info.Type == 'ForRange' then name = name..'forvar'..var.Info.Index end setVarName(var, name) localNumber = localNumber + 1 end for _, scope in pairs(scope.ChildScopeList) do modify(scope) end end modify(rootScope) end local function usageError() error( "\nusage: minify or unminify \n" .. " The modified code will be printed to the stdout, pipe it to a file, the\n" .. " lua interpreter, or something else as desired EG:\n\n" .. " lua minify.lua minify input.lua > output.lua\n\n" .. " * minify will minify the code in the file.\n" .. " * unminify will beautify the code and replace the variable names with easily\n" .. " find-replacable ones to aide in reverse engineering minified code.\n", 0) end local args = {...} if #args ~= 2 then usageError() end local sourceFile = io.open(args[2], 'r') if not sourceFile then error("Could not open the input file `" .. args[2] .. "`", 0) end local data = sourceFile:read('*all') local ast = CreateLuaParser(data) local global_scope, root_scope = AddVariableInfo(ast) local function minify(ast, global_scope, root_scope) MinifyVariables(global_scope, root_scope) StripAst(ast) PrintAst(ast) end local function beautify(ast, global_scope, root_scope) BeautifyVariables(global_scope, root_scope) FormatAst(ast) PrintAst(ast) end if args[1] == 'minify' then minify(ast, global_scope, root_scope) elseif args[1] == 'unminify' then beautify(ast, global_scope, root_scope) else usageError() end