From 9d0030f9d8017730738f578393b398da9c100ada Mon Sep 17 00:00:00 2001 From: Andrew Lalis Date: Wed, 21 Dec 2022 11:41:34 +0100 Subject: [PATCH] Create minify.lua --- minify.lua | 3227 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3227 insertions(+) create mode 100644 minify.lua diff --git a/minify.lua b/minify.lua new file mode 100644 index 0000000..41d7a60 --- /dev/null +++ b/minify.lua @@ -0,0 +1,3227 @@ +--[[ +MIT License + +Copyright (c) 2017 Mark Langen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] + +function lookupify(tb) + for _, v in pairs(tb) do + tb[v] = true + end + return tb +end + +function CountTable(tb) + local c = 0 + for _ in pairs(tb) do c = c + 1 end + return c +end + +function FormatTableInt(tb, atIndent, ignoreFunc) + if tb.Print then + return tb.Print() + end + atIndent = atIndent or 0 + local useNewlines = (CountTable(tb) > 1) + local baseIndent = string.rep(' ', atIndent+1) + local out = "{"..(useNewlines and '\n' or '') + for k, v in pairs(tb) do + if type(v) ~= 'function' and not ignoreFunc(k) then + out = out..(useNewlines and baseIndent or '') + if type(k) == 'number' then + --nothing to do + elseif type(k) == 'string' and k:match("^[A-Za-z_][A-Za-z0-9_]*$") then + out = out..k.." = " + elseif type(k) == 'string' then + out = out.."[\""..k.."\"] = " + else + out = out.."["..tostring(k).."] = " + end + if type(v) == 'string' then + out = out.."\""..v.."\"" + elseif type(v) == 'number' then + out = out..v + elseif type(v) == 'table' then + out = out..FormatTableInt(v, atIndent+(useNewlines and 1 or 0), ignoreFunc) + else + out = out..tostring(v) + end + if next(tb, k) then + out = out.."," + end + if useNewlines then + out = out..'\n' + end + end + end + out = out..(useNewlines and string.rep(' ', atIndent) or '').."}" + return out +end + +function FormatTable(tb, ignoreFunc) + ignoreFunc = ignoreFunc or function() + return false + end + return FormatTableInt(tb, 0, ignoreFunc) +end + +local WhiteChars = lookupify{' ', '\n', '\t', '\r'} + +local EscapeForCharacter = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'", ['\\'] = '\\'} + +local CharacterForEscape = {['r'] = '\r', ['n'] = '\n', ['t'] = '\t', ['"'] = '"', ["'"] = "'", ['\\'] = '\\'} + +local AllIdentStartChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', + 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', + 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', + 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_'} + +local AllIdentChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', + 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', + 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', + 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} + +local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} + +local HexDigits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'a', 'B', 'b', 'C', 'c', 'D', 'd', 'E', 'e', 'F', 'f'} + +local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#', '.', ':'} + +local EqualSymbols = lookupify{'~', '=', '>', '<'} + +local Keywords = lookupify{ + 'and', 'break', 'do', 'else', 'elseif', + 'end', 'false', 'for', 'function', 'goto', 'if', + 'in', 'local', 'nil', 'not', 'or', 'repeat', + 'return', 'then', 'true', 'until', 'while', +}; + +local BlockFollowKeyword = lookupify{'else', 'elseif', 'until', 'end'} + +local UnopSet = lookupify{'-', 'not', '#'} + +local BinopSet = lookupify{ + '+', '-', '*', '/', '%', '^', '#', + '..', '.', ':', + '>', '<', '<=', '>=', '~=', '==', + 'and', 'or' +} + +local GlobalRenameIgnore = lookupify{ + +} + +local BinaryPriority = { + ['+'] = {6, 6}; + ['-'] = {6, 6}; + ['*'] = {7, 7}; + ['/'] = {7, 7}; + ['%'] = {7, 7}; + ['^'] = {10, 9}; + ['..'] = {5, 4}; + ['=='] = {3, 3}; + ['~='] = {3, 3}; + ['>'] = {3, 3}; + ['<'] = {3, 3}; + ['>='] = {3, 3}; + ['<='] = {3, 3}; + ['and'] = {2, 2}; + ['or'] = {1, 1}; +}; +local UnaryPriority = 8 + +-- Eof, Ident, Keyword, Number, String, Symbol + +function CreateLuaTokenStream(text) + -- Tracking for the current position in the buffer, and + -- the current line / character we are on. + local p = 1 + local length = #text + + -- Output buffer for tokens + local tokenBuffer = {} + + -- Get a character, or '' if at eof + local function look(n) + n = p + (n or 0) + if n <= length then + return text:sub(n, n) + else + return '' + end + end + local function get() + if p <= length then + local c = text:sub(p, p) + p = p + 1 + return c + else + return '' + end + end + + -- Error + local olderr = error + local function error(str) + local q = 1 + local line = 1 + local char = 1 + while q <= p do + if text:sub(q, q) == '\n' then + line = line + 1 + char = 1 + else + char = char + 1 + end + q = q + 1 + end + for _, token in pairs(tokenBuffer) do + print(token.Type.."<"..token.Source..">") + end + olderr("file<"..line..":"..char..">: "..str) + end + + -- Consume a long data with equals count of `eqcount' + local function longdata(eqcount) + while true do + local c = get() + if c == '' then + error("Unfinished long string.") + elseif c == ']' then + local done = true -- Until contested + for i = 1, eqcount do + if look() == '=' then + p = p + 1 + else + done = false + break + end + end + if done and get() == ']' then + return + end + end + end + end + + -- Get the opening part for a long data `[` `=`* `[` + -- Precondition: The first `[` has been consumed + -- Return: nil or the equals count + local function getopen() + local startp = p + while look() == '=' do + p = p + 1 + end + if look() == '[' then + p = p + 1 + return p - startp - 1 + else + p = startp + return nil + end + end + + -- Add token + local whiteStart = 1 + local tokenStart = 1 + local function token(type) + local tk = { + Type = type; + LeadingWhite = text:sub(whiteStart, tokenStart-1); + Source = text:sub(tokenStart, p-1); + } + table.insert(tokenBuffer, tk) + whiteStart = p + tokenStart = p + return tk + end + + -- Parse tokens loop + while true do + -- Mark the whitespace start + whiteStart = p + + -- Get the leading whitespace + comments + while true do + local c = look() + if c == '' then + break + elseif c == '-' then + if look(1) == '-' then + p = p + 2 + -- Consume comment body + if look() == '[' then + p = p + 1 + local eqcount = getopen() + if eqcount then + -- Long comment body + longdata(eqcount) + else + -- Normal comment body + while true do + local c2 = get() + if c2 == '' or c2 == '\n' then + break + end + end + end + else + -- Normal comment body + while true do + local c2 = get() + if c2 == '' or c2 == '\n' then + break + end + end + end + else + break + end + elseif WhiteChars[c] then + p = p + 1 + else + break + end + end + local leadingWhite = text:sub(whiteStart, p-1) + + -- Mark the token start + tokenStart = p + + -- Switch on token type + local c1 = get() + if c1 == '' then + -- End of file + token('Eof') + break + elseif c1 == '\'' or c1 == '\"' then + -- String constant + while true do + local c2 = get() + if c2 == '\\' then + local c3 = get() + local esc = CharacterForEscape[c3] + if not esc then + error("Invalid Escape Sequence `"..c3.."`.") + end + elseif c2 == c1 then + break + end + end + token('String') + elseif AllIdentStartChars[c1] then + -- Ident or Keyword + while AllIdentChars[look()] do + p = p + 1 + end + if Keywords[text:sub(tokenStart, p-1)] then + token('Keyword') + else + token('Ident') + end + elseif Digits[c1] or (c1 == '.' and Digits[look()]) then + -- Number + if c1 == '0' and look() == 'x' then + p = p + 1 + -- Hex number + while HexDigits[look()] do + p = p + 1 + end + else + -- Normal Number + while Digits[look()] do + p = p + 1 + end + if look() == '.' then + -- With decimal point + p = p + 1 + while Digits[look()] do + p = p + 1 + end + end + if look() == 'e' or look() == 'E' then + -- With exponent + p = p + 1 + if look() == '-' then + p = p + 1 + end + while Digits[look()] do + p = p + 1 + end + end + end + token('Number') + elseif c1 == '[' then + -- '[' Symbol or Long String + local eqCount = getopen() + if eqCount then + -- Long string + longdata(eqCount) + token('String') + else + -- Symbol + token('Symbol') + end + elseif c1 == '.' then + -- Greedily consume up to 3 `.` for . / .. / ... tokens + if look() == '.' then + get() + if look() == '.' then + get() + end + end + token('Symbol') + elseif EqualSymbols[c1] then + if look() == '=' then + p = p + 1 + end + token('Symbol') + elseif Symbols[c1] then + token('Symbol') + else + error("Bad symbol `"..c1.."` in source.") + end + end + return tokenBuffer +end + +function CreateLuaParser(text) + -- Token stream and pointer into it + local tokens = CreateLuaTokenStream(text) + -- for _, tok in pairs(tokens) do + -- print(tok.Type..": "..tok.Source) + -- end + local p = 1 + + local function get() + local tok = tokens[p] + if p < #tokens then + p = p + 1 + end + return tok + end + local function peek(n) + n = p + (n or 0) + return tokens[n] or tokens[#tokens] + end + + local function getTokenStartPosition(token) + local line = 1 + local char = 0 + local tkNum = 1 + while true do + local tk = tokens[tkNum] + local text; + if tk == token then + text = tk.LeadingWhite + else + text = tk.LeadingWhite..tk.Source + end + for i = 1, #text do + local c = text:sub(i, i) + if c == '\n' then + line = line + 1 + char = 0 + else + char = char + 1 + end + end + if tk == token then + break + end + tkNum = tkNum + 1 + end + return line..":"..(char+1) + end + local function debugMark() + local tk = peek() + return "<"..tk.Type.." `"..tk.Source.."`> at: "..getTokenStartPosition(tk) + end + + local function isBlockFollow() + local tok = peek() + return tok.Type == 'Eof' or (tok.Type == 'Keyword' and BlockFollowKeyword[tok.Source]) + end + local function isUnop() + return UnopSet[peek().Source] or false + end + local function isBinop() + return BinopSet[peek().Source] or false + end + local function expect(type, source) + local tk = peek() + if tk.Type == type and (source == nil or tk.Source == source) then + return get() + else + for i = -3, 3 do + print("Tokens["..i.."] = `"..peek(i).Source.."`") + end + if source then + error(getTokenStartPosition(tk)..": `"..source.."` expected.") + else + error(getTokenStartPosition(tk)..": "..type.." expected.") + end + end + end + + local function MkNode(node) + local getf = node.GetFirstToken + local getl = node.GetLastToken + function node:GetFirstToken() + local t = getf(self) + assert(t) + return t + end + function node:GetLastToken() + local t = getl(self) + assert(t) + return t + end + return node + end + + -- Forward decls + local block; + local expr; + + -- Expression list + local function exprlist() + local exprList = {} + local commaList = {} + table.insert(exprList, expr()) + while peek().Source == ',' do + table.insert(commaList, get()) + table.insert(exprList, expr()) + end + return exprList, commaList + end + + local function prefixexpr() + local tk = peek() + if tk.Source == '(' then + local oparenTk = get() + local inner = expr() + local cparenTk = expect('Symbol', ')') + return MkNode{ + Type = 'ParenExpr'; + Expression = inner; + Token_OpenParen = oparenTk; + Token_CloseParen = cparenTk; + GetFirstToken = function(self) + return self.Token_OpenParen + end; + GetLastToken = function(self) + return self.Token_CloseParen + end; + } + elseif tk.Type == 'Ident' then + return MkNode{ + Type = 'VariableExpr'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + else + print(debugMark()) + error(getTokenStartPosition(tk)..": Unexpected symbol") + end + end + + function tableexpr() + local obrace = expect('Symbol', '{') + local entries = {} + local separators = {} + while peek().Source ~= '}' do + if peek().Source == '[' then + -- Index + local obrac = get() + local index = expr() + local cbrac = expect('Symbol', ']') + local eq = expect('Symbol', '=') + local value = expr() + table.insert(entries, { + EntryType = 'Index'; + Index = index; + Value = value; + Token_OpenBracket = obrac; + Token_CloseBracket = cbrac; + Token_Equals = eq; + }) + elseif peek().Type == 'Ident' and peek(1).Source == '=' then + -- Field + local field = get() + local eq = get() + local value = expr() + table.insert(entries, { + EntryType = 'Field'; + Field = field; + Value = value; + Token_Equals = eq; + }) + else + -- Value + local value = expr() + table.insert(entries, { + EntryType = 'Value'; + Value = value; + }) + end + + -- Comma or Semicolon separator + if peek().Source == ',' or peek().Source == ';' then + table.insert(separators, get()) + else + break + end + end + local cbrace = expect('Symbol', '}') + return MkNode{ + Type = 'TableLiteral'; + EntryList = entries; + Token_SeparatorList = separators; + Token_OpenBrace = obrace; + Token_CloseBrace = cbrace; + GetFirstToken = function(self) + return self.Token_OpenBrace + end; + GetLastToken = function(self) + return self.Token_CloseBrace + end; + } + end + + -- List of identifiers + local function varlist() + local varList = {} + local commaList = {} + if peek().Type == 'Ident' then + table.insert(varList, get()) + end + while peek().Source == ',' do + table.insert(commaList, get()) + local id = expect('Ident') + table.insert(varList, id) + end + return varList, commaList + end + + -- Body + local function blockbody(terminator) + local body = block() + local after = peek() + if after.Type == 'Keyword' and after.Source == terminator then + get() + return body, after + else + print(after.Type, after.Source) + error(getTokenStartPosition(after)..": "..terminator.." expected.") + end + end + + -- Function declaration + local function funcdecl(isAnonymous) + local functionKw = get() + -- + local nameChain; + local nameChainSeparator; + -- + if not isAnonymous then + nameChain = {} + nameChainSeparator = {} + -- + table.insert(nameChain, expect('Ident')) + -- + while peek().Source == '.' do + table.insert(nameChainSeparator, get()) + table.insert(nameChain, expect('Ident')) + end + if peek().Source == ':' then + table.insert(nameChainSeparator, get()) + table.insert(nameChain, expect('Ident')) + end + end + -- + local oparenTk = expect('Symbol', '(') + local argList, argCommaList = varlist() + local cparenTk = expect('Symbol', ')') + local fbody, enTk = blockbody('end') + -- + return MkNode{ + Type = (isAnonymous and 'FunctionLiteral' or 'FunctionStat'); + NameChain = nameChain; + ArgList = argList; + Body = fbody; + -- + Token_Function = functionKw; + Token_NameChainSeparator = nameChainSeparator; + Token_OpenParen = oparenTk; + Token_ArgCommaList = argCommaList; + Token_CloseParen = cparenTk; + Token_End = enTk; + GetFirstToken = function(self) + return self.Token_Function + end; + GetLastToken = function(self) + return self.Token_End; + end; + } + end + + -- Argument list passed to a funciton + local function functionargs() + local tk = peek() + if tk.Source == '(' then + local oparenTk = get() + local argList = {} + local argCommaList = {} + while peek().Source ~= ')' do + table.insert(argList, expr()) + if peek().Source == ',' then + table.insert(argCommaList, get()) + else + break + end + end + local cparenTk = expect('Symbol', ')') + return MkNode{ + CallType = 'ArgCall'; + ArgList = argList; + -- + Token_CommaList = argCommaList; + Token_OpenParen = oparenTk; + Token_CloseParen = cparenTk; + GetFirstToken = function(self) + return self.Token_OpenParen + end; + GetLastToken = function(self) + return self.Token_CloseParen + end; + } + elseif tk.Source == '{' then + return MkNode{ + CallType = 'TableCall'; + TableExpr = expr(); + GetFirstToken = function(self) + return self.TableExpr:GetFirstToken() + end; + GetLastToken = function(self) + return self.TableExpr:GetLastToken() + end; + } + elseif tk.Type == 'String' then + return MkNode{ + CallType = 'StringCall'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + else + error("Function arguments expected.") + end + end + + local function primaryexpr() + local base = prefixexpr() + assert(base, "nil prefixexpr") + while true do + local tk = peek() + if tk.Source == '.' then + local dotTk = get() + local fieldName = expect('Ident') + base = MkNode{ + Type = 'FieldExpr'; + Base = base; + Field = fieldName; + Token_Dot = dotTk; + GetFirstToken = function(self) + return self.Base:GetFirstToken() + end; + GetLastToken = function(self) + return self.Field + end; + } + elseif tk.Source == ':' then + local colonTk = get() + local methodName = expect('Ident') + local fargs = functionargs() + base = MkNode{ + Type = 'MethodExpr'; + Base = base; + Method = methodName; + FunctionArguments = fargs; + Token_Colon = colonTk; + GetFirstToken = function(self) + return self.Base:GetFirstToken() + end; + GetLastToken = function(self) + return self.FunctionArguments:GetLastToken() + end; + } + elseif tk.Source == '[' then + local obrac = get() + local index = expr() + local cbrac = expect('Symbol', ']') + base = MkNode{ + Type = 'IndexExpr'; + Base = base; + Index = index; + Token_OpenBracket = obrac; + Token_CloseBracket = cbrac; + GetFirstToken = function(self) + return self.Base:GetFirstToken() + end; + GetLastToken = function(self) + return self.Token_CloseBracket + end; + } + elseif tk.Source == '{' then + base = MkNode{ + Type = 'CallExpr'; + Base = base; + FunctionArguments = functionargs(); + GetFirstToken = function(self) + return self.Base:GetFirstToken() + end; + GetLastToken = function(self) + return self.FunctionArguments:GetLastToken() + end; + } + elseif tk.Source == '(' then + base = MkNode{ + Type = 'CallExpr'; + Base = base; + FunctionArguments = functionargs(); + GetFirstToken = function(self) + return self.Base:GetFirstToken() + end; + GetLastToken = function(self) + return self.FunctionArguments:GetLastToken() + end; + } + else + return base + end + end + end + + local function simpleexpr() + local tk = peek() + if tk.Type == 'Number' then + return MkNode{ + Type = 'NumberLiteral'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + elseif tk.Type == 'String' then + return MkNode{ + Type = 'StringLiteral'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + elseif tk.Source == 'nil' then + return MkNode{ + Type = 'NilLiteral'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + elseif tk.Source == 'true' or tk.Source == 'false' then + return MkNode{ + Type = 'BooleanLiteral'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + elseif tk.Source == '...' then + return MkNode{ + Type = 'VargLiteral'; + Token = get(); + GetFirstToken = function(self) + return self.Token + end; + GetLastToken = function(self) + return self.Token + end; + } + elseif tk.Source == '{' then + return tableexpr() + elseif tk.Source == 'function' then + return funcdecl(true) + else + return primaryexpr() + end + end + + local function subexpr(limit) + local curNode; + + -- Initial Base Expression + if isUnop() then + local opTk = get() + local ex = subexpr(UnaryPriority) + curNode = MkNode{ + Type = 'UnopExpr'; + Token_Op = opTk; + Rhs = ex; + GetFirstToken = function(self) + return self.Token_Op + end; + GetLastToken = function(self) + return self.Rhs:GetLastToken() + end; + } + else + curNode = simpleexpr() + assert(curNode, "nil simpleexpr") + end + + -- Apply Precedence Recursion Chain + while isBinop() and BinaryPriority[peek().Source][1] > limit do + local opTk = get() + local rhs = subexpr(BinaryPriority[opTk.Source][2]) + assert(rhs, "RhsNeeded") + curNode = MkNode{ + Type = 'BinopExpr'; + Lhs = curNode; + Rhs = rhs; + Token_Op = opTk; + GetFirstToken = function(self) + return self.Lhs:GetFirstToken() + end; + GetLastToken = function(self) + return self.Rhs:GetLastToken() + end; + } + end + + -- Return result + return curNode + end + + -- Expression + expr = function() + return subexpr(0) + end + + -- Expression statement + local function exprstat() + local ex = primaryexpr() + if ex.Type == 'MethodExpr' or ex.Type == 'CallExpr' then + -- all good, calls can be statements + return MkNode{ + Type = 'CallExprStat'; + Expression = ex; + GetFirstToken = function(self) + return self.Expression:GetFirstToken() + end; + GetLastToken = function(self) + return self.Expression:GetLastToken() + end; + } + else + -- Assignment expr + local lhs = {ex} + local lhsSeparator = {} + while peek().Source == ',' do + table.insert(lhsSeparator, get()) + local lhsPart = primaryexpr() + if lhsPart.Type == 'MethodExpr' or lhsPart.Type == 'CallExpr' then + error("Bad left hand side of assignment") + end + table.insert(lhs, lhsPart) + end + local eq = expect('Symbol', '=') + local rhs = {expr()} + local rhsSeparator = {} + while peek().Source == ',' do + table.insert(rhsSeparator, get()) + table.insert(rhs, expr()) + end + return MkNode{ + Type = 'AssignmentStat'; + Rhs = rhs; + Lhs = lhs; + Token_Equals = eq; + Token_LhsSeparatorList = lhsSeparator; + Token_RhsSeparatorList = rhsSeparator; + GetFirstToken = function(self) + return self.Lhs[1]:GetFirstToken() + end; + GetLastToken = function(self) + return self.Rhs[#self.Rhs]:GetLastToken() + end; + } + end + end + + -- If statement + local function ifstat() + local ifKw = get() + local condition = expr() + local thenKw = expect('Keyword', 'then') + local ifBody = block() + local elseClauses = {} + while peek().Source == 'elseif' or peek().Source == 'else' do + local elseifKw = get() + local elseifCondition, elseifThenKw; + if elseifKw.Source == 'elseif' then + elseifCondition = expr() + elseifThenKw = expect('Keyword', 'then') + end + local elseifBody = block() + table.insert(elseClauses, { + Condition = elseifCondition; + Body = elseifBody; + -- + ClauseType = elseifKw.Source; + Token = elseifKw; + Token_Then = elseifThenKw; + }) + if elseifKw.Source == 'else' then + break + end + end + local enKw = expect('Keyword', 'end') + return MkNode{ + Type = 'IfStat'; + Condition = condition; + Body = ifBody; + ElseClauseList = elseClauses; + -- + Token_If = ifKw; + Token_Then = thenKw; + Token_End = enKw; + GetFirstToken = function(self) + return self.Token_If + end; + GetLastToken = function(self) + return self.Token_End + end; + } + end + + -- Do statement + local function dostat() + local doKw = get() + local body, enKw = blockbody('end') + -- + return MkNode{ + Type = 'DoStat'; + Body = body; + -- + Token_Do = doKw; + Token_End = enKw; + GetFirstToken = function(self) + return self.Token_Do + end; + GetLastToken = function(self) + return self.Token_End + end; + } + end + + -- While statement + local function whilestat() + local whileKw = get() + local condition = expr() + local doKw = expect('Keyword', 'do') + local body, enKw = blockbody('end') + -- + return MkNode{ + Type = 'WhileStat'; + Condition = condition; + Body = body; + -- + Token_While = whileKw; + Token_Do = doKw; + Token_End = enKw; + GetFirstToken = function(self) + return self.Token_While + end; + GetLastToken = function(self) + return self.Token_End + end; + } + end + + -- For statement + local function forstat() + local forKw = get() + local loopVars, loopVarCommas = varlist() + local node = {} + if peek().Source == '=' then + local eqTk = get() + local exprList, exprCommaList = exprlist() + if #exprList < 2 or #exprList > 3 then + error("expected 2 or 3 values for range bounds") + end + local doTk = expect('Keyword', 'do') + local body, enTk = blockbody('end') + return MkNode{ + Type = 'NumericForStat'; + VarList = loopVars; + RangeList = exprList; + Body = body; + -- + Token_For = forKw; + Token_VarCommaList = loopVarCommas; + Token_Equals = eqTk; + Token_RangeCommaList = exprCommaList; + Token_Do = doTk; + Token_End = enTk; + GetFirstToken = function(self) + return self.Token_For + end; + GetLastToken = function(self) + return self.Token_End + end; + } + elseif peek().Source == 'in' then + local inTk = get() + local exprList, exprCommaList = exprlist() + local doTk = expect('Keyword', 'do') + local body, enTk = blockbody('end') + return MkNode{ + Type = 'GenericForStat'; + VarList = loopVars; + GeneratorList = exprList; + Body = body; + -- + Token_For = forKw; + Token_VarCommaList = loopVarCommas; + Token_In = inTk; + Token_GeneratorCommaList = exprCommaList; + Token_Do = doTk; + Token_End = enTk; + GetFirstToken = function(self) + return self.Token_For + end; + GetLastToken = function(self) + return self.Token_End + end; + } + else + error("`=` or in expected") + end + end + + -- Repeat statement + local function repeatstat() + local repeatKw = get() + local body, untilTk = blockbody('until') + local condition = expr() + return MkNode{ + Type = 'RepeatStat'; + Body = body; + Condition = condition; + -- + Token_Repeat = repeatKw; + Token_Until = untilTk; + GetFirstToken = function(self) + return self.Token_Repeat + end; + GetLastToken = function(self) + return self.Condition:GetLastToken() + end; + } + end + + -- Local var declaration + local function localdecl() + local localKw = get() + if peek().Source == 'function' then + -- Local function def + local funcStat = funcdecl(false) + if #funcStat.NameChain > 1 then + error(getTokenStartPosition(funcStat.Token_NameChainSeparator[1])..": `(` expected.") + end + return MkNode{ + Type = 'LocalFunctionStat'; + FunctionStat = funcStat; + Token_Local = localKw; + GetFirstToken = function(self) + return self.Token_Local + end; + GetLastToken = function(self) + return self.FunctionStat:GetLastToken() + end; + } + elseif peek().Type == 'Ident' then + -- Local variable declaration + local varList, varCommaList = varlist() + local exprList, exprCommaList = {}, {} + local eqToken; + if peek().Source == '=' then + eqToken = get() + exprList, exprCommaList = exprlist() + end + return MkNode{ + Type = 'LocalVarStat'; + VarList = varList; + ExprList = exprList; + Token_Local = localKw; + Token_Equals = eqToken; + Token_VarCommaList = varCommaList; + Token_ExprCommaList = exprCommaList; + GetFirstToken = function(self) + return self.Token_Local + end; + GetLastToken = function(self) + if #self.ExprList > 0 then + return self.ExprList[#self.ExprList]:GetLastToken() + else + return self.VarList[#self.VarList] + end + end; + } + else + error("`function` or ident expected") + end + end + + -- Return statement + local function retstat() + local returnKw = get() + local exprList; + local commaList; + if isBlockFollow() or peek().Source == ';' then + exprList = {} + commaList = {} + else + exprList, commaList = exprlist() + end + return { + Type = 'ReturnStat'; + ExprList = exprList; + Token_Return = returnKw; + Token_CommaList = commaList; + GetFirstToken = function(self) + return self.Token_Return + end; + GetLastToken = function(self) + if #self.ExprList > 0 then + return self.ExprList[#self.ExprList]:GetLastToken() + else + return self.Token_Return + end + end; + } + end + + -- Break statement + local function breakstat() + local breakKw = get() + return { + Type = 'BreakStat'; + Token_Break = breakKw; + GetFirstToken = function(self) + return self.Token_Break + end; + GetLastToken = function(self) + return self.Token_Break + end; + } + end + + -- Expression + local function statement() + local tok = peek() + if tok.Source == 'if' then + return false, ifstat() + elseif tok.Source == 'while' then + return false, whilestat() + elseif tok.Source == 'do' then + return false, dostat() + elseif tok.Source == 'for' then + return false, forstat() + elseif tok.Source == 'repeat' then + return false, repeatstat() + elseif tok.Source == 'function' then + return false, funcdecl(false) + elseif tok.Source == 'local' then + return false, localdecl() + elseif tok.Source == 'return' then + return true, retstat() + elseif tok.Source == 'break' then + return true, breakstat() + else + return false, exprstat() + end + end + + -- Chunk + block = function() + local statements = {} + local semicolons = {} + local isLast = false + while not isLast and not isBlockFollow() do + -- Parse statement + local stat; + isLast, stat = statement() + table.insert(statements, stat) + local next = peek() + if next.Type == 'Symbol' and next.Source == ';' then + semicolons[#statements] = get() + end + end + return { + Type = 'StatList'; + StatementList = statements; + SemicolonList = semicolons; + GetFirstToken = function(self) + if #self.StatementList == 0 then + return nil + else + return self.StatementList[1]:GetFirstToken() + end + end; + GetLastToken = function(self) + if #self.StatementList == 0 then + return nil + elseif self.SemicolonList[#self.StatementList] then + -- Last token may be one of the semicolon separators + return self.SemicolonList[#self.StatementList] + else + return self.StatementList[#self.StatementList]:GetLastToken() + end + end; + } + end + + return block() +end + +function VisitAst(ast, visitors) + local ExprType = lookupify{ + 'BinopExpr'; 'UnopExpr'; + 'NumberLiteral'; 'StringLiteral'; 'NilLiteral'; 'BooleanLiteral'; 'VargLiteral'; + 'FieldExpr'; 'IndexExpr'; + 'MethodExpr'; 'CallExpr'; + 'FunctionLiteral'; + 'VariableExpr'; + 'ParenExpr'; + 'TableLiteral'; + } + + local StatType = lookupify{ + 'StatList'; + 'BreakStat'; + 'ReturnStat'; + 'LocalVarStat'; + 'LocalFunctionStat'; + 'FunctionStat'; + 'RepeatStat'; + 'GenericForStat'; + 'NumericForStat'; + 'WhileStat'; + 'DoStat'; + 'IfStat'; + 'CallExprStat'; + 'AssignmentStat'; + } + + -- Check for typos in visitor construction + for visitorSubject, visitor in pairs(visitors) do + if not StatType[visitorSubject] and not ExprType[visitorSubject] then + error("Invalid visitor target: `"..visitorSubject.."`") + end + end + + -- Helpers to call visitors on a node + local function preVisit(exprOrStat) + local visitor = visitors[exprOrStat.Type] + if type(visitor) == 'function' then + return visitor(exprOrStat) + elseif visitor and visitor.Pre then + return visitor.Pre(exprOrStat) + end + end + local function postVisit(exprOrStat) + local visitor = visitors[exprOrStat.Type] + if visitor and type(visitor) == 'table' and visitor.Post then + return visitor.Post(exprOrStat) + end + end + + local visitExpr, visitStat; + + visitExpr = function(expr) + if preVisit(expr) then + -- Handler did custom child iteration or blocked child iteration + return + end + if expr.Type == 'BinopExpr' then + visitExpr(expr.Lhs) + visitExpr(expr.Rhs) + elseif expr.Type == 'UnopExpr' then + visitExpr(expr.Rhs) + elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or + expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or + expr.Type == 'VargLiteral' + then + -- No children to visit, single token literals + elseif expr.Type == 'FieldExpr' then + visitExpr(expr.Base) + elseif expr.Type == 'IndexExpr' then + visitExpr(expr.Base) + visitExpr(expr.Index) + elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then + visitExpr(expr.Base) + if expr.FunctionArguments.CallType == 'ArgCall' then + for index, argExpr in pairs(expr.FunctionArguments.ArgList) do + visitExpr(argExpr) + end + elseif expr.FunctionArguments.CallType == 'TableCall' then + visitExpr(expr.FunctionArguments.TableExpr) + end + elseif expr.Type == 'FunctionLiteral' then + visitStat(expr.Body) + elseif expr.Type == 'VariableExpr' then + -- No children to visit + elseif expr.Type == 'ParenExpr' then + visitExpr(expr.Expression) + elseif expr.Type == 'TableLiteral' then + for index, entry in pairs(expr.EntryList) do + if entry.EntryType == 'Field' then + visitExpr(entry.Value) + elseif entry.EntryType == 'Index' then + visitExpr(entry.Index) + visitExpr(entry.Value) + elseif entry.EntryType == 'Value' then + visitExpr(entry.Value) + else + assert(false, "unreachable") + end + end + else + assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) + end + postVisit(expr) + end + + visitStat = function(stat) + if preVisit(stat) then + -- Handler did custom child iteration or blocked child iteration + return + end + if stat.Type == 'StatList' then + for index, ch in pairs(stat.StatementList) do + visitStat(ch) + end + elseif stat.Type == 'BreakStat' then + -- No children to visit + elseif stat.Type == 'ReturnStat' then + for index, expr in pairs(stat.ExprList) do + visitExpr(expr) + end + elseif stat.Type == 'LocalVarStat' then + if stat.Token_Equals then + for index, expr in pairs(stat.ExprList) do + visitExpr(expr) + end + end + elseif stat.Type == 'LocalFunctionStat' then + visitStat(stat.FunctionStat.Body) + elseif stat.Type == 'FunctionStat' then + visitStat(stat.Body) + elseif stat.Type == 'RepeatStat' then + visitStat(stat.Body) + visitExpr(stat.Condition) + elseif stat.Type == 'GenericForStat' then + for index, expr in pairs(stat.GeneratorList) do + visitExpr(expr) + end + visitStat(stat.Body) + elseif stat.Type == 'NumericForStat' then + for index, expr in pairs(stat.RangeList) do + visitExpr(expr) + end + visitStat(stat.Body) + elseif stat.Type == 'WhileStat' then + visitExpr(stat.Condition) + visitStat(stat.Body) + elseif stat.Type == 'DoStat' then + visitStat(stat.Body) + elseif stat.Type == 'IfStat' then + visitExpr(stat.Condition) + visitStat(stat.Body) + for _, clause in pairs(stat.ElseClauseList) do + if clause.Condition then + visitExpr(clause.Condition) + end + visitStat(clause.Body) + end + elseif stat.Type == 'CallExprStat' then + visitExpr(stat.Expression) + elseif stat.Type == 'AssignmentStat' then + for index, ex in pairs(stat.Lhs) do + visitExpr(ex) + end + for index, ex in pairs(stat.Rhs) do + visitExpr(ex) + end + else + assert(false, "unreachable") + end + postVisit(stat) + end + + if StatType[ast.Type] then + visitStat(ast) + else + visitExpr(ast) + end +end + +function AddVariableInfo(ast) + local globalVars = {} + local currentScope = nil + + -- Numbering generator for variable lifetimes + local locationGenerator = 0 + local function markLocation() + locationGenerator = locationGenerator + 1 + return locationGenerator + end + + -- Scope management + local function pushScope() + currentScope = { + ParentScope = currentScope; + ChildScopeList = {}; + VariableList = {}; + BeginLocation = markLocation(); + } + if currentScope.ParentScope then + currentScope.Depth = currentScope.ParentScope.Depth + 1 + table.insert(currentScope.ParentScope.ChildScopeList, currentScope) + else + currentScope.Depth = 1 + end + function currentScope:GetVar(varName) + for _, var in pairs(self.VariableList) do + if var.Name == varName then + return var + end + end + if self.ParentScope then + return self.ParentScope:GetVar(varName) + else + for _, var in pairs(globalVars) do + if var.Name == varName then + return var + end + end + end + end + end + local function popScope() + local scope = currentScope + + -- Mark where this scope ends + scope.EndLocation = markLocation() + + -- Mark all of the variables in the scope as ending there + for _, var in pairs(scope.VariableList) do + var.ScopeEndLocation = scope.EndLocation + end + + -- Move to the parent scope + currentScope = scope.ParentScope + + return scope + end + pushScope() -- push initial scope + + -- Add / reference variables + local function addLocalVar(name, setNameFunc, localInfo) + assert(localInfo, "Misisng localInfo") + assert(name, "Missing local var name") + local var = { + Type = 'Local'; + Name = name; + RenameList = {setNameFunc}; + AssignedTo = false; + Info = localInfo; + UseCount = 0; + Scope = currentScope; + BeginLocation = markLocation(); + EndLocation = markLocation(); + ReferenceLocationList = {markLocation()}; + } + function var:Rename(newName) + self.Name = newName + for _, renameFunc in pairs(self.RenameList) do + renameFunc(newName) + end + end + function var:Reference() + self.UseCount = self.UseCount + 1 + end + table.insert(currentScope.VariableList, var) + return var + end + local function getGlobalVar(name) + for _, var in pairs(globalVars) do + if var.Name == name then + return var + end + end + local var = { + Type = 'Global'; + Name = name; + RenameList = {}; + AssignedTo = false; + UseCount = 0; + Scope = nil; -- Globals have no scope + BeginLocation = markLocation(); + EndLocation = markLocation(); + ReferenceLocationList = {}; + } + function var:Rename(newName) + self.Name = newName + for _, renameFunc in pairs(self.RenameList) do + renameFunc(newName) + end + end + function var:Reference() + self.UseCount = self.UseCount + 1 + end + table.insert(globalVars, var) + return var + end + local function addGlobalReference(name, setNameFunc) + assert(name, "Missing var name") + local var = getGlobalVar(name) + table.insert(var.RenameList, setNameFunc) + return var + end + local function getLocalVar(scope, name) + -- First search this scope + -- Note: Reverse iterate here because Lua does allow shadowing a local + -- within the same scope, and the later defined variable should + -- be the one referenced. + for i = #scope.VariableList, 1, -1 do + if scope.VariableList[i].Name == name then + return scope.VariableList[i] + end + end + + -- Then search parent scope + if scope.ParentScope then + local var = getLocalVar(scope.ParentScope, name) + if var then + return var + end + end + + -- Then + return nil + end + local function referenceVariable(name, setNameFunc) + assert(name, "Missing var name") + local var = getLocalVar(currentScope, name) + if var then + table.insert(var.RenameList, setNameFunc) + else + var = addGlobalReference(name, setNameFunc) + end + -- Update the end location of where this variable is used, and + -- add this location to the list of references to this variable. + local curLocation = markLocation() + var.EndLocation = curLocation + table.insert(var.ReferenceLocationList, var.EndLocation) + return var + end + + local visitor = {} + visitor.FunctionLiteral = { + -- Function literal adds a new scope and adds the function literal arguments + -- as local variables in the scope. + Pre = function(expr) + pushScope() + for index, ident in pairs(expr.ArgList) do + local var = addLocalVar(ident.Source, function(name) + ident.Source = name + end, { + Type = 'Argument'; + Index = index; + }) + end + end; + Post = function(expr) + popScope() + end; + } + visitor.VariableExpr = function(expr) + -- Variable expression references from existing local varibales + -- in the current scope, annotating the variable usage with variable + -- information. + expr.Variable = referenceVariable(expr.Token.Source, function(newName) + expr.Token.Source = newName + end) + end + visitor.StatList = { + -- StatList adds a new scope + Pre = function(stat) + pushScope() + end; + Post = function(stat) + popScope() + end; + } + visitor.LocalVarStat = { + Post = function(stat) + -- Local var stat adds the local variables to the current scope as locals + -- We need to visit the subexpressions first, because these new locals + -- will not be in scope for the initialization value expressions. That is: + -- `local bar = bar + 1` + -- Is valid code + for varNum, ident in pairs(stat.VarList) do + addLocalVar(ident.Source, function(name) + stat.VarList[varNum].Source = name + end, { + Type = 'Local'; + }) + end + end; + } + visitor.LocalFunctionStat = { + Pre = function(stat) + -- Local function stat adds the function itself to the current scope as + -- a local variable, and creates a new scope with the function arguments + -- as local variables. + addLocalVar(stat.FunctionStat.NameChain[1].Source, function(name) + stat.FunctionStat.NameChain[1].Source = name + end, { + Type = 'LocalFunction'; + }) + pushScope() + for index, ident in pairs(stat.FunctionStat.ArgList) do + addLocalVar(ident.Source, function(name) + ident.Source = name + end, { + Type = 'Argument'; + Index = index; + }) + end + end; + Post = function() + popScope() + end; + } + visitor.FunctionStat = { + Pre = function(stat) + -- Function stat adds a new scope containing the function arguments + -- as local variables. + -- A function stat may also assign to a global variable if it is in + -- the form `function foo()` with no additional dots/colons in the + -- name chain. + local nameChain = stat.NameChain + local var; + if #nameChain == 1 then + -- If there is only one item in the name chain, then the first item + -- is a reference to a global variable. + var = addGlobalReference(nameChain[1].Source, function(name) + nameChain[1].Source = name + end) + else + var = referenceVariable(nameChain[1].Source, function(name) + nameChain[1].Source = name + end) + end + var.AssignedTo = true + pushScope() + for index, ident in pairs(stat.ArgList) do + addLocalVar(ident.Source, function(name) + ident.Source = name + end, { + Type = 'Argument'; + Index = index; + }) + end + end; + Post = function() + popScope() + end; + } + visitor.GenericForStat = { + Pre = function(stat) + -- Generic fors need an extra scope holding the range variables + -- Need a custom visitor so that the generator expressions can be + -- visited before we push a scope, but the body can be visited + -- after we push a scope. + for _, ex in pairs(stat.GeneratorList) do + VisitAst(ex, visitor) + end + pushScope() + for index, ident in pairs(stat.VarList) do + addLocalVar(ident.Source, function(name) + ident.Source = name + end, { + Type = 'ForRange'; + Index = index; + }) + end + VisitAst(stat.Body, visitor) + popScope() + return true -- Custom visit + end; + } + visitor.NumericForStat = { + Pre = function(stat) + -- Numeric fors need an extra scope holding the range variables + -- Need a custom visitor so that the generator expressions can be + -- visited before we push a scope, but the body can be visited + -- after we push a scope. + for _, ex in pairs(stat.RangeList) do + VisitAst(ex, visitor) + end + pushScope() + for index, ident in pairs(stat.VarList) do + addLocalVar(ident.Source, function(name) + ident.Source = name + end, { + Type = 'ForRange'; + Index = index; + }) + end + VisitAst(stat.Body, visitor) + popScope() + return true -- Custom visit + end; + } + visitor.AssignmentStat = { + Post = function(stat) + -- For an assignment statement we need to mark the + -- "assigned to" flag on variables. + for _, ex in pairs(stat.Lhs) do + if ex.Variable then + ex.Variable.AssignedTo = true + end + end + end; + } + + VisitAst(ast, visitor) + + return globalVars, popScope() +end + +-- Prints out an AST to a string +function PrintAst(ast) + + local printStat, printExpr; + + local function printt(tk) + if not tk.LeadingWhite or not tk.Source then + error("Bad token: "..FormatTable(tk)) + end + io.write(tk.LeadingWhite) + io.write(tk.Source) + end + + printExpr = function(expr) + if expr.Type == 'BinopExpr' then + printExpr(expr.Lhs) + printt(expr.Token_Op) + printExpr(expr.Rhs) + elseif expr.Type == 'UnopExpr' then + printt(expr.Token_Op) + printExpr(expr.Rhs) + elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or + expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or + expr.Type == 'VargLiteral' + then + -- Just print the token + printt(expr.Token) + elseif expr.Type == 'FieldExpr' then + printExpr(expr.Base) + printt(expr.Token_Dot) + printt(expr.Field) + elseif expr.Type == 'IndexExpr' then + printExpr(expr.Base) + printt(expr.Token_OpenBracket) + printExpr(expr.Index) + printt(expr.Token_CloseBracket) + elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then + printExpr(expr.Base) + if expr.Type == 'MethodExpr' then + printt(expr.Token_Colon) + printt(expr.Method) + end + if expr.FunctionArguments.CallType == 'StringCall' then + printt(expr.FunctionArguments.Token) + elseif expr.FunctionArguments.CallType == 'ArgCall' then + printt(expr.FunctionArguments.Token_OpenParen) + for index, argExpr in pairs(expr.FunctionArguments.ArgList) do + printExpr(argExpr) + local sep = expr.FunctionArguments.Token_CommaList[index] + if sep then + printt(sep) + end + end + printt(expr.FunctionArguments.Token_CloseParen) + elseif expr.FunctionArguments.CallType == 'TableCall' then + printExpr(expr.FunctionArguments.TableExpr) + end + elseif expr.Type == 'FunctionLiteral' then + printt(expr.Token_Function) + printt(expr.Token_OpenParen) + for index, arg in pairs(expr.ArgList) do + printt(arg) + local comma = expr.Token_ArgCommaList[index] + if comma then + printt(comma) + end + end + printt(expr.Token_CloseParen) + printStat(expr.Body) + printt(expr.Token_End) + elseif expr.Type == 'VariableExpr' then + printt(expr.Token) + elseif expr.Type == 'ParenExpr' then + printt(expr.Token_OpenParen) + printExpr(expr.Expression) + printt(expr.Token_CloseParen) + elseif expr.Type == 'TableLiteral' then + printt(expr.Token_OpenBrace) + for index, entry in pairs(expr.EntryList) do + if entry.EntryType == 'Field' then + printt(entry.Field) + printt(entry.Token_Equals) + printExpr(entry.Value) + elseif entry.EntryType == 'Index' then + printt(entry.Token_OpenBracket) + printExpr(entry.Index) + printt(entry.Token_CloseBracket) + printt(entry.Token_Equals) + printExpr(entry.Value) + elseif entry.EntryType == 'Value' then + printExpr(entry.Value) + else + assert(false, "unreachable") + end + local sep = expr.Token_SeparatorList[index] + if sep then + printt(sep) + end + end + printt(expr.Token_CloseBrace) + else + assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) + end + end + + printStat = function(stat) + if stat.Type == 'StatList' then + for index, ch in pairs(stat.StatementList) do + printStat(ch) + if stat.SemicolonList[index] then + printt(stat.SemicolonList[index]) + end + end + elseif stat.Type == 'BreakStat' then + printt(stat.Token_Break) + elseif stat.Type == 'ReturnStat' then + printt(stat.Token_Return) + for index, expr in pairs(stat.ExprList) do + printExpr(expr) + if stat.Token_CommaList[index] then + printt(stat.Token_CommaList[index]) + end + end + elseif stat.Type == 'LocalVarStat' then + printt(stat.Token_Local) + for index, var in pairs(stat.VarList) do + printt(var) + local comma = stat.Token_VarCommaList[index] + if comma then + printt(comma) + end + end + if stat.Token_Equals then + printt(stat.Token_Equals) + for index, expr in pairs(stat.ExprList) do + printExpr(expr) + local comma = stat.Token_ExprCommaList[index] + if comma then + printt(comma) + end + end + end + elseif stat.Type == 'LocalFunctionStat' then + printt(stat.Token_Local) + printt(stat.FunctionStat.Token_Function) + printt(stat.FunctionStat.NameChain[1]) + printt(stat.FunctionStat.Token_OpenParen) + for index, arg in pairs(stat.FunctionStat.ArgList) do + printt(arg) + local comma = stat.FunctionStat.Token_ArgCommaList[index] + if comma then + printt(comma) + end + end + printt(stat.FunctionStat.Token_CloseParen) + printStat(stat.FunctionStat.Body) + printt(stat.FunctionStat.Token_End) + elseif stat.Type == 'FunctionStat' then + printt(stat.Token_Function) + for index, part in pairs(stat.NameChain) do + printt(part) + local sep = stat.Token_NameChainSeparator[index] + if sep then + printt(sep) + end + end + printt(stat.Token_OpenParen) + for index, arg in pairs(stat.ArgList) do + printt(arg) + local comma = stat.Token_ArgCommaList[index] + if comma then + printt(comma) + end + end + printt(stat.Token_CloseParen) + printStat(stat.Body) + printt(stat.Token_End) + elseif stat.Type == 'RepeatStat' then + printt(stat.Token_Repeat) + printStat(stat.Body) + printt(stat.Token_Until) + printExpr(stat.Condition) + elseif stat.Type == 'GenericForStat' then + printt(stat.Token_For) + for index, var in pairs(stat.VarList) do + printt(var) + local sep = stat.Token_VarCommaList[index] + if sep then + printt(sep) + end + end + printt(stat.Token_In) + for index, expr in pairs(stat.GeneratorList) do + printExpr(expr) + local sep = stat.Token_GeneratorCommaList[index] + if sep then + printt(sep) + end + end + printt(stat.Token_Do) + printStat(stat.Body) + printt(stat.Token_End) + elseif stat.Type == 'NumericForStat' then + printt(stat.Token_For) + for index, var in pairs(stat.VarList) do + printt(var) + local sep = stat.Token_VarCommaList[index] + if sep then + printt(sep) + end + end + printt(stat.Token_Equals) + for index, expr in pairs(stat.RangeList) do + printExpr(expr) + local sep = stat.Token_RangeCommaList[index] + if sep then + printt(sep) + end + end + printt(stat.Token_Do) + printStat(stat.Body) + printt(stat.Token_End) + elseif stat.Type == 'WhileStat' then + printt(stat.Token_While) + printExpr(stat.Condition) + printt(stat.Token_Do) + printStat(stat.Body) + printt(stat.Token_End) + elseif stat.Type == 'DoStat' then + printt(stat.Token_Do) + printStat(stat.Body) + printt(stat.Token_End) + elseif stat.Type == 'IfStat' then + printt(stat.Token_If) + printExpr(stat.Condition) + printt(stat.Token_Then) + printStat(stat.Body) + for _, clause in pairs(stat.ElseClauseList) do + printt(clause.Token) + if clause.Condition then + printExpr(clause.Condition) + printt(clause.Token_Then) + end + printStat(clause.Body) + end + printt(stat.Token_End) + elseif stat.Type == 'CallExprStat' then + printExpr(stat.Expression) + elseif stat.Type == 'AssignmentStat' then + for index, ex in pairs(stat.Lhs) do + printExpr(ex) + local sep = stat.Token_LhsSeparatorList[index] + if sep then + printt(sep) + end + end + printt(stat.Token_Equals) + for index, ex in pairs(stat.Rhs) do + printExpr(ex) + local sep = stat.Token_RhsSeparatorList[index] + if sep then + printt(sep) + end + end + else + assert(false, "unreachable") + end + end + + printStat(ast) +end + +-- Adds / removes whitespace in an AST to put it into a "standard formatting" +local function FormatAst(ast) + local formatStat, formatExpr; + + local currentIndent = 0 + + local function applyIndent(token) + local indentString = '\n'..('\t'):rep(currentIndent) + if token.LeadingWhite == '' or (token.LeadingWhite:sub(-#indentString, -1) ~= indentString) then + -- Trim existing trailing whitespace on LeadingWhite + -- Trim trailing tabs and spaces, and up to one newline + token.LeadingWhite = token.LeadingWhite:gsub("\n?[\t ]*$", "") + token.LeadingWhite = token.LeadingWhite..indentString + end + end + + local function indent() + currentIndent = currentIndent + 1 + end + + local function undent() + currentIndent = currentIndent - 1 + assert(currentIndent >= 0, "Undented too far") + end + + local function leadingChar(tk) + if #tk.LeadingWhite > 0 then + return tk.LeadingWhite:sub(1,1) + else + return tk.Source:sub(1,1) + end + end + + local function padToken(tk) + if not WhiteChars[leadingChar(tk)] then + tk.LeadingWhite = ' '..tk.LeadingWhite + end + end + + local function padExpr(expr) + padToken(expr:GetFirstToken()) + end + + local function formatBody(openToken, bodyStat, closeToken) + indent() + formatStat(bodyStat) + undent() + applyIndent(closeToken) + end + + formatExpr = function(expr) + if expr.Type == 'BinopExpr' then + formatExpr(expr.Lhs) + formatExpr(expr.Rhs) + if expr.Token_Op.Source == '..' then + -- No padding on .. + else + padExpr(expr.Rhs) + padToken(expr.Token_Op) + end + elseif expr.Type == 'UnopExpr' then + formatExpr(expr.Rhs) + --(expr.Token_Op) + elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or + expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or + expr.Type == 'VargLiteral' + then + -- Nothing to do + --(expr.Token) + elseif expr.Type == 'FieldExpr' then + formatExpr(expr.Base) + --(expr.Token_Dot) + --(expr.Field) + elseif expr.Type == 'IndexExpr' then + formatExpr(expr.Base) + formatExpr(expr.Index) + --(expr.Token_OpenBracket) + --(expr.Token_CloseBracket) + elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then + formatExpr(expr.Base) + if expr.Type == 'MethodExpr' then + --(expr.Token_Colon) + --(expr.Method) + end + if expr.FunctionArguments.CallType == 'StringCall' then + --(expr.FunctionArguments.Token) + elseif expr.FunctionArguments.CallType == 'ArgCall' then + --(expr.FunctionArguments.Token_OpenParen) + for index, argExpr in pairs(expr.FunctionArguments.ArgList) do + formatExpr(argExpr) + if index > 1 then + padExpr(argExpr) + end + local sep = expr.FunctionArguments.Token_CommaList[index] + if sep then + --(sep) + end + end + --(expr.FunctionArguments.Token_CloseParen) + elseif expr.FunctionArguments.CallType == 'TableCall' then + formatExpr(expr.FunctionArguments.TableExpr) + end + elseif expr.Type == 'FunctionLiteral' then + --(expr.Token_Function) + --(expr.Token_OpenParen) + for index, arg in pairs(expr.ArgList) do + --(arg) + if index > 1 then + padToken(arg) + end + local comma = expr.Token_ArgCommaList[index] + if comma then + --(comma) + end + end + --(expr.Token_CloseParen) + formatBody(expr.Token_CloseParen, expr.Body, expr.Token_End) + elseif expr.Type == 'VariableExpr' then + --(expr.Token) + elseif expr.Type == 'ParenExpr' then + formatExpr(expr.Expression) + --(expr.Token_OpenParen) + --(expr.Token_CloseParen) + elseif expr.Type == 'TableLiteral' then + --(expr.Token_OpenBrace) + if #expr.EntryList == 0 then + -- Nothing to do + else + indent() + for index, entry in pairs(expr.EntryList) do + if entry.EntryType == 'Field' then + applyIndent(entry.Field) + padToken(entry.Token_Equals) + formatExpr(entry.Value) + padExpr(entry.Value) + elseif entry.EntryType == 'Index' then + applyIndent(entry.Token_OpenBracket) + formatExpr(entry.Index) + --(entry.Token_CloseBracket) + padToken(entry.Token_Equals) + formatExpr(entry.Value) + padExpr(entry.Value) + elseif entry.EntryType == 'Value' then + formatExpr(entry.Value) + applyIndent(entry.Value:GetFirstToken()) + else + assert(false, "unreachable") + end + local sep = expr.Token_SeparatorList[index] + if sep then + --(sep) + end + end + undent() + applyIndent(expr.Token_CloseBrace) + end + --(expr.Token_CloseBrace) + else + assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) + end + end + + formatStat = function(stat) + if stat.Type == 'StatList' then + for _, stat in pairs(stat.StatementList) do + formatStat(stat) + applyIndent(stat:GetFirstToken()) + end + + elseif stat.Type == 'BreakStat' then + --(stat.Token_Break) + + elseif stat.Type == 'ReturnStat' then + --(stat.Token_Return) + for index, expr in pairs(stat.ExprList) do + formatExpr(expr) + padExpr(expr) + if stat.Token_CommaList[index] then + --(stat.Token_CommaList[index]) + end + end + elseif stat.Type == 'LocalVarStat' then + --(stat.Token_Local) + for index, var in pairs(stat.VarList) do + padToken(var) + local comma = stat.Token_VarCommaList[index] + if comma then + --(comma) + end + end + if stat.Token_Equals then + padToken(stat.Token_Equals) + for index, expr in pairs(stat.ExprList) do + formatExpr(expr) + padExpr(expr) + local comma = stat.Token_ExprCommaList[index] + if comma then + --(comma) + end + end + end + elseif stat.Type == 'LocalFunctionStat' then + --(stat.Token_Local) + padToken(stat.FunctionStat.Token_Function) + padToken(stat.FunctionStat.NameChain[1]) + --(stat.FunctionStat.Token_OpenParen) + for index, arg in pairs(stat.FunctionStat.ArgList) do + if index > 1 then + padToken(arg) + end + local comma = stat.FunctionStat.Token_ArgCommaList[index] + if comma then + --(comma) + end + end + --(stat.FunctionStat.Token_CloseParen) + formatBody(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End) + elseif stat.Type == 'FunctionStat' then + --(stat.Token_Function) + for index, part in pairs(stat.NameChain) do + if index == 1 then + padToken(part) + end + local sep = stat.Token_NameChainSeparator[index] + if sep then + --(sep) + end + end + --(stat.Token_OpenParen) + for index, arg in pairs(stat.ArgList) do + if index > 1 then + padToken(arg) + end + local comma = stat.Token_ArgCommaList[index] + if comma then + --(comma) + end + end + --(stat.Token_CloseParen) + formatBody(stat.Token_CloseParen, stat.Body, stat.Token_End) + elseif stat.Type == 'RepeatStat' then + --(stat.Token_Repeat) + formatBody(stat.Token_Repeat, stat.Body, stat.Token_Until) + formatExpr(stat.Condition) + padExpr(stat.Condition) + elseif stat.Type == 'GenericForStat' then + --(stat.Token_For) + for index, var in pairs(stat.VarList) do + padToken(var) + local sep = stat.Token_VarCommaList[index] + if sep then + --(sep) + end + end + padToken(stat.Token_In) + for index, expr in pairs(stat.GeneratorList) do + formatExpr(expr) + padExpr(expr) + local sep = stat.Token_GeneratorCommaList[index] + if sep then + --(sep) + end + end + padToken(stat.Token_Do) + formatBody(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'NumericForStat' then + --(stat.Token_For) + for index, var in pairs(stat.VarList) do + padToken(var) + local sep = stat.Token_VarCommaList[index] + if sep then + --(sep) + end + end + padToken(stat.Token_Equals) + for index, expr in pairs(stat.RangeList) do + formatExpr(expr) + padExpr(expr) + local sep = stat.Token_RangeCommaList[index] + if sep then + --(sep) + end + end + padToken(stat.Token_Do) + formatBody(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'WhileStat' then + --(stat.Token_While) + formatExpr(stat.Condition) + padExpr(stat.Condition) + padToken(stat.Token_Do) + formatBody(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'DoStat' then + --(stat.Token_Do) + formatBody(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'IfStat' then + --(stat.Token_If) + formatExpr(stat.Condition) + padExpr(stat.Condition) + padToken(stat.Token_Then) + -- + local lastBodyOpen = stat.Token_Then + local lastBody = stat.Body + -- + for _, clause in pairs(stat.ElseClauseList) do + formatBody(lastBodyOpen, lastBody, clause.Token) + lastBodyOpen = clause.Token + -- + if clause.Condition then + formatExpr(clause.Condition) + padExpr(clause.Condition) + padToken(clause.Token_Then) + lastBodyOpen = clause.Token_Then + end + lastBody = clause.Body + end + -- + formatBody(lastBodyOpen, lastBody, stat.Token_End) + + elseif stat.Type == 'CallExprStat' then + formatExpr(stat.Expression) + elseif stat.Type == 'AssignmentStat' then + for index, ex in pairs(stat.Lhs) do + formatExpr(ex) + if index > 1 then + padExpr(ex) + end + local sep = stat.Token_LhsSeparatorList[index] + if sep then + --(sep) + end + end + padToken(stat.Token_Equals) + for index, ex in pairs(stat.Rhs) do + formatExpr(ex) + padExpr(ex) + local sep = stat.Token_RhsSeparatorList[index] + if sep then + --(sep) + end + end + else + assert(false, "unreachable") + end + end + + formatStat(ast) +end + +-- Strips as much whitespace off of tokens in an AST as possible without causing problems +local function StripAst(ast) + local stripStat, stripExpr; + + local function stript(token) + token.LeadingWhite = '' + end + + -- Make to adjacent tokens as close as possible + local function joint(tokenA, tokenB) + -- Strip the second token's whitespace + stript(tokenB) + + -- Get the trailing A <-> leading B character pair + local lastCh = tokenA.Source:sub(-1, -1) + local firstCh = tokenB.Source:sub(1, 1) + + -- Cases to consider: + -- Touching minus signs -> comment: `- -42` -> `--42' is invalid + -- Touching words: `a b` -> `ab` is invalid + -- Touching digits: `2 3`, can't occurr in the Lua syntax as number literals aren't a primary expression + -- Abiguous syntax: `f(x)\n(x)()` is already disallowed, we can't cause a problem by removing newlines + + -- Figure out what separation is needed + if + (lastCh == '-' and firstCh == '-') or + (AllIdentChars[lastCh] and AllIdentChars[firstCh]) + then + tokenB.LeadingWhite = ' ' -- Use a separator + else + tokenB.LeadingWhite = '' -- Don't use a separator + end + end + + -- Join up a statement body and it's opening / closing tokens + local function bodyjoint(open, body, close) + stripStat(body) + stript(close) + local bodyFirst = body:GetFirstToken() + local bodyLast = body:GetLastToken() + if bodyFirst then + -- Body is non-empty, join body to open / close + joint(open, bodyFirst) + joint(bodyLast, close) + else + -- Body is empty, just join open and close token together + joint(open, close) + end + end + + stripExpr = function(expr) + if expr.Type == 'BinopExpr' then + stripExpr(expr.Lhs) + stript(expr.Token_Op) + stripExpr(expr.Rhs) + -- Handle the `a - -b` -/-> `a--b` case which would otherwise incorrectly generate a comment + -- Also handles operators "or" / "and" which definitely need joining logic in a bunch of cases + joint(expr.Token_Op, expr.Rhs:GetFirstToken()) + joint(expr.Lhs:GetLastToken(), expr.Token_Op) + elseif expr.Type == 'UnopExpr' then + stript(expr.Token_Op) + stripExpr(expr.Rhs) + -- Handle the `- -b` -/-> `--b` case which would otherwise incorrectly generate a comment + joint(expr.Token_Op, expr.Rhs:GetFirstToken()) + elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or + expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or + expr.Type == 'VargLiteral' + then + -- Just print the token + stript(expr.Token) + elseif expr.Type == 'FieldExpr' then + stripExpr(expr.Base) + stript(expr.Token_Dot) + stript(expr.Field) + elseif expr.Type == 'IndexExpr' then + stripExpr(expr.Base) + stript(expr.Token_OpenBracket) + stripExpr(expr.Index) + stript(expr.Token_CloseBracket) + elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then + stripExpr(expr.Base) + if expr.Type == 'MethodExpr' then + stript(expr.Token_Colon) + stript(expr.Method) + end + if expr.FunctionArguments.CallType == 'StringCall' then + stript(expr.FunctionArguments.Token) + elseif expr.FunctionArguments.CallType == 'ArgCall' then + stript(expr.FunctionArguments.Token_OpenParen) + for index, argExpr in pairs(expr.FunctionArguments.ArgList) do + stripExpr(argExpr) + local sep = expr.FunctionArguments.Token_CommaList[index] + if sep then + stript(sep) + end + end + stript(expr.FunctionArguments.Token_CloseParen) + elseif expr.FunctionArguments.CallType == 'TableCall' then + stripExpr(expr.FunctionArguments.TableExpr) + end + elseif expr.Type == 'FunctionLiteral' then + stript(expr.Token_Function) + stript(expr.Token_OpenParen) + for index, arg in pairs(expr.ArgList) do + stript(arg) + local comma = expr.Token_ArgCommaList[index] + if comma then + stript(comma) + end + end + stript(expr.Token_CloseParen) + bodyjoint(expr.Token_CloseParen, expr.Body, expr.Token_End) + elseif expr.Type == 'VariableExpr' then + stript(expr.Token) + elseif expr.Type == 'ParenExpr' then + stript(expr.Token_OpenParen) + stripExpr(expr.Expression) + stript(expr.Token_CloseParen) + elseif expr.Type == 'TableLiteral' then + stript(expr.Token_OpenBrace) + for index, entry in pairs(expr.EntryList) do + if entry.EntryType == 'Field' then + stript(entry.Field) + stript(entry.Token_Equals) + stripExpr(entry.Value) + elseif entry.EntryType == 'Index' then + stript(entry.Token_OpenBracket) + stripExpr(entry.Index) + stript(entry.Token_CloseBracket) + stript(entry.Token_Equals) + stripExpr(entry.Value) + elseif entry.EntryType == 'Value' then + stripExpr(entry.Value) + else + assert(false, "unreachable") + end + local sep = expr.Token_SeparatorList[index] + if sep then + stript(sep) + end + end + stript(expr.Token_CloseBrace) + else + assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr)) + end + end + + stripStat = function(stat) + if stat.Type == 'StatList' then + -- Strip all surrounding whitespace on statement lists along with separating whitespace + for i = 1, #stat.StatementList do + local chStat = stat.StatementList[i] + + -- Strip the statement and it's whitespace + stripStat(chStat) + stript(chStat:GetFirstToken()) + + -- If there was a last statement, join them appropriately + local lastChStat = stat.StatementList[i-1] + if lastChStat then + -- See if we can remove a semi-colon, the only case where we can't is if + -- this and the last statement have a `);(` pair, where removing the semi-colon + -- would introduce ambiguous syntax. + if stat.SemicolonList[i-1] and + (lastChStat:GetLastToken().Source ~= ')' or chStat:GetFirstToken().Source ~= ')') + then + stat.SemicolonList[i-1] = nil + end + + -- If there isn't a semi-colon, we should safely join the two statements + -- (If there is one, then no whitespace leading chStat is always okay) + if not stat.SemicolonList[i-1] then + joint(lastChStat:GetLastToken(), chStat:GetFirstToken()) + end + end + end + + -- A semi-colon is never needed on the last stat in a statlist: + stat.SemicolonList[#stat.StatementList] = nil + + -- The leading whitespace on the statlist should be stripped + if #stat.StatementList > 0 then + stript(stat.StatementList[1]:GetFirstToken()) + end + + elseif stat.Type == 'BreakStat' then + stript(stat.Token_Break) + + elseif stat.Type == 'ReturnStat' then + stript(stat.Token_Return) + for index, expr in pairs(stat.ExprList) do + stripExpr(expr) + if stat.Token_CommaList[index] then + stript(stat.Token_CommaList[index]) + end + end + if #stat.ExprList > 0 then + joint(stat.Token_Return, stat.ExprList[1]:GetFirstToken()) + end + elseif stat.Type == 'LocalVarStat' then + stript(stat.Token_Local) + for index, var in pairs(stat.VarList) do + if index == 1 then + joint(stat.Token_Local, var) + else + stript(var) + end + local comma = stat.Token_VarCommaList[index] + if comma then + stript(comma) + end + end + if stat.Token_Equals then + stript(stat.Token_Equals) + for index, expr in pairs(stat.ExprList) do + stripExpr(expr) + local comma = stat.Token_ExprCommaList[index] + if comma then + stript(comma) + end + end + end + elseif stat.Type == 'LocalFunctionStat' then + stript(stat.Token_Local) + joint(stat.Token_Local, stat.FunctionStat.Token_Function) + joint(stat.FunctionStat.Token_Function, stat.FunctionStat.NameChain[1]) + joint(stat.FunctionStat.NameChain[1], stat.FunctionStat.Token_OpenParen) + for index, arg in pairs(stat.FunctionStat.ArgList) do + stript(arg) + local comma = stat.FunctionStat.Token_ArgCommaList[index] + if comma then + stript(comma) + end + end + stript(stat.FunctionStat.Token_CloseParen) + bodyjoint(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End) + elseif stat.Type == 'FunctionStat' then + stript(stat.Token_Function) + for index, part in pairs(stat.NameChain) do + if index == 1 then + joint(stat.Token_Function, part) + else + stript(part) + end + local sep = stat.Token_NameChainSeparator[index] + if sep then + stript(sep) + end + end + stript(stat.Token_OpenParen) + for index, arg in pairs(stat.ArgList) do + stript(arg) + local comma = stat.Token_ArgCommaList[index] + if comma then + stript(comma) + end + end + stript(stat.Token_CloseParen) + bodyjoint(stat.Token_CloseParen, stat.Body, stat.Token_End) + elseif stat.Type == 'RepeatStat' then + stript(stat.Token_Repeat) + bodyjoint(stat.Token_Repeat, stat.Body, stat.Token_Until) + stripExpr(stat.Condition) + joint(stat.Token_Until, stat.Condition:GetFirstToken()) + elseif stat.Type == 'GenericForStat' then + stript(stat.Token_For) + for index, var in pairs(stat.VarList) do + if index == 1 then + joint(stat.Token_For, var) + else + stript(var) + end + local sep = stat.Token_VarCommaList[index] + if sep then + stript(sep) + end + end + joint(stat.VarList[#stat.VarList], stat.Token_In) + for index, expr in pairs(stat.GeneratorList) do + stripExpr(expr) + if index == 1 then + joint(stat.Token_In, expr:GetFirstToken()) + end + local sep = stat.Token_GeneratorCommaList[index] + if sep then + stript(sep) + end + end + joint(stat.GeneratorList[#stat.GeneratorList]:GetLastToken(), stat.Token_Do) + bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'NumericForStat' then + stript(stat.Token_For) + for index, var in pairs(stat.VarList) do + if index == 1 then + joint(stat.Token_For, var) + else + stript(var) + end + local sep = stat.Token_VarCommaList[index] + if sep then + stript(sep) + end + end + joint(stat.VarList[#stat.VarList], stat.Token_Equals) + for index, expr in pairs(stat.RangeList) do + stripExpr(expr) + if index == 1 then + joint(stat.Token_Equals, expr:GetFirstToken()) + end + local sep = stat.Token_RangeCommaList[index] + if sep then + stript(sep) + end + end + joint(stat.RangeList[#stat.RangeList]:GetLastToken(), stat.Token_Do) + bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'WhileStat' then + stript(stat.Token_While) + stripExpr(stat.Condition) + stript(stat.Token_Do) + joint(stat.Token_While, stat.Condition:GetFirstToken()) + joint(stat.Condition:GetLastToken(), stat.Token_Do) + bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'DoStat' then + stript(stat.Token_Do) + stript(stat.Token_End) + bodyjoint(stat.Token_Do, stat.Body, stat.Token_End) + elseif stat.Type == 'IfStat' then + stript(stat.Token_If) + stripExpr(stat.Condition) + joint(stat.Token_If, stat.Condition:GetFirstToken()) + joint(stat.Condition:GetLastToken(), stat.Token_Then) + -- + local lastBodyOpen = stat.Token_Then + local lastBody = stat.Body + -- + for _, clause in pairs(stat.ElseClauseList) do + bodyjoint(lastBodyOpen, lastBody, clause.Token) + lastBodyOpen = clause.Token + -- + if clause.Condition then + stripExpr(clause.Condition) + joint(clause.Token, clause.Condition:GetFirstToken()) + joint(clause.Condition:GetLastToken(), clause.Token_Then) + lastBodyOpen = clause.Token_Then + end + stripStat(clause.Body) + lastBody = clause.Body + end + -- + bodyjoint(lastBodyOpen, lastBody, stat.Token_End) + + elseif stat.Type == 'CallExprStat' then + stripExpr(stat.Expression) + elseif stat.Type == 'AssignmentStat' then + for index, ex in pairs(stat.Lhs) do + stripExpr(ex) + local sep = stat.Token_LhsSeparatorList[index] + if sep then + stript(sep) + end + end + stript(stat.Token_Equals) + for index, ex in pairs(stat.Rhs) do + stripExpr(ex) + local sep = stat.Token_RhsSeparatorList[index] + if sep then + stript(sep) + end + end + else + assert(false, "unreachable") + end + end + + stripStat(ast) +end + +local idGen = 0 +local VarDigits = {} +for i = ('a'):byte(), ('z'):byte() do table.insert(VarDigits, string.char(i)) end +for i = ('A'):byte(), ('Z'):byte() do table.insert(VarDigits, string.char(i)) end +for i = ('0'):byte(), ('9'):byte() do table.insert(VarDigits, string.char(i)) end +table.insert(VarDigits, '_') +local VarStartDigits = {} +for i = ('a'):byte(), ('z'):byte() do table.insert(VarStartDigits, string.char(i)) end +for i = ('A'):byte(), ('Z'):byte() do table.insert(VarStartDigits, string.char(i)) end +local function indexToVarName(index) + local id = '' + local d = index % #VarStartDigits + index = (index - d) / #VarStartDigits + id = id..VarStartDigits[d+1] + while index > 0 do + local d = index % #VarDigits + index = (index - d) / #VarDigits + id = id..VarDigits[d+1] + end + return id +end +local function genNextVarName() + local varToUse = idGen + idGen = idGen + 1 + return indexToVarName(varToUse) +end +local function genVarName() + local varName = '' + repeat + varName = genNextVarName() + until not Keywords[varName] + return varName +end +local function MinifyVariables(globalScope, rootScope) + -- externalGlobals is a set of global variables that have not been assigned to, that is + -- global variables defined "externally to the script". We are not going to be renaming + -- those, and we have to make sure that we don't collide with them when renaming + -- things so we keep track of them in this set. + local externalGlobals = {} + + -- First we want to rename all of the variables to unique temoraries, so that we can + -- easily use the scope::GetVar function to check whether renames are valid. + local temporaryIndex = 0 + for _, var in pairs(globalScope) do + if var.AssignedTo then + var:Rename('_TMP_'..temporaryIndex..'_') + temporaryIndex = temporaryIndex + 1 + else + -- Not assigned to, external global + externalGlobals[var.Name] = true + end + end + local function temporaryRename(scope) + for _, var in pairs(scope.VariableList) do + var:Rename('_TMP_'..temporaryIndex..'_') + temporaryIndex = temporaryIndex + 1 + end + for _, childScope in pairs(scope.ChildScopeList) do + temporaryRename(childScope) + end + end + + -- Now we go through renaming, first do globals, we probably want them + -- to have shorter names in general. + -- TODO: Rename all vars based on frequency patterns, giving variables + -- used more shorter names. + local nextFreeNameIndex = 0 + for _, var in pairs(globalScope) do + if var.AssignedTo then + local varName = '' + repeat + varName = indexToVarName(nextFreeNameIndex) + nextFreeNameIndex = nextFreeNameIndex + 1 + until not Keywords[varName] and not externalGlobals[varName] + var:Rename(varName) + end + end + + -- Now rename all local vars + rootScope.FirstFreeName = nextFreeNameIndex + local function doRenameScope(scope) + for _, var in pairs(scope.VariableList) do + local varName = '' + repeat + varName = indexToVarName(scope.FirstFreeName) + scope.FirstFreeName = scope.FirstFreeName + 1 + until not Keywords[varName] and not externalGlobals[varName] + var:Rename(varName) + end + for _, childScope in pairs(scope.ChildScopeList) do + childScope.FirstFreeName = scope.FirstFreeName + doRenameScope(childScope) + end + end + doRenameScope(rootScope) +end + +local function MinifyVariables_2(globalScope, rootScope) + -- Variable names and other names that are fixed, that we cannot use + -- Either these are Lua keywords, or globals that are not assigned to, + -- that is environmental globals that are assigned elsewhere beyond our + -- control. + local globalUsedNames = {} + for kw, _ in pairs(Keywords) do + globalUsedNames[kw] = true + end + + -- Gather a list of all of the variables that we will rename + local allVariables = {} + local allLocalVariables = {} + do + -- Add applicable globals + for _, var in pairs(globalScope) do + if var.AssignedTo then + -- We can try to rename this global since it was assigned to + -- (and thus presumably initialized) in the script we are + -- minifying. + table.insert(allVariables, var) + else + -- We can't rename this global, mark it as an unusable name + -- and don't add it to the nename list + globalUsedNames[var.Name] = true + end + end + + -- Recursively add locals, we can rename all of those + local function addFrom(scope) + for _, var in pairs(scope.VariableList) do + table.insert(allVariables, var) + table.insert(allLocalVariables, var) + end + for _, childScope in pairs(scope.ChildScopeList) do + addFrom(childScope) + end + end + addFrom(rootScope) + end + + -- Add used name arrays to variables + for _, var in pairs(allVariables) do + var.UsedNameArray = {} + end + + -- Sort the least used variables first + table.sort(allVariables, function(a, b) + return #a.RenameList < #b.RenameList + end) + + -- Lazy generator for valid names to rename to + local nextValidNameIndex = 0 + local varNamesLazy = {} + local function varIndexToValidVarName(i) + local name = varNamesLazy[i] + if not name then + repeat + name = indexToVarName(nextValidNameIndex) + nextValidNameIndex = nextValidNameIndex + 1 + until not globalUsedNames[name] + varNamesLazy[i] = name + end + return name + end + + -- For each variable, go to rename it + for _, var in pairs(allVariables) do + -- Lazy... todo: Make theis pair a proper for-each-pair-like set of loops + -- rather than using a renamed flag. + var.Renamed = true + + -- Find the first unused name + local i = 1 + while var.UsedNameArray[i] do + i = i + 1 + end + + -- Rename the variable to that name + var:Rename(varIndexToValidVarName(i)) + + if var.Scope then + -- Now we need to mark the name as unusable by any variables: + -- 1) At the same depth that overlap lifetime with this one + -- 2) At a deeper level, which have a reference to this variable in their lifetimes + -- 3) At a shallower level, which are referenced during this variable's lifetime + for _, otherVar in pairs(allVariables) do + if not otherVar.Renamed then + if not otherVar.Scope or otherVar.Scope.Depth < var.Scope.Depth then + -- Check Global variable (Which is always at a shallower level) + -- or + -- Check case 3 + -- The other var is at a shallower depth, is there a reference to it + -- durring this variable's lifetime? + for _, refAt in pairs(otherVar.ReferenceLocationList) do + if refAt >= var.BeginLocation and refAt <= var.ScopeEndLocation then + -- Collide + otherVar.UsedNameArray[i] = true + break + end + end + + elseif otherVar.Scope.Depth > var.Scope.Depth then + -- Check Case 2 + -- The other var is at a greater depth, see if any of the references + -- to this variable are in the other var's lifetime. + for _, refAt in pairs(var.ReferenceLocationList) do + if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then + -- Collide + otherVar.UsedNameArray[i] = true + break + end + end + + else --otherVar.Scope.Depth must be equal to var.Scope.Depth + -- Check case 1 + -- The two locals are in the same scope + -- Just check if the usage lifetimes overlap within that scope. That is, we + -- can shadow a local variable within the same scope as long as the usages + -- of the two locals do not overlap. + if var.BeginLocation < otherVar.EndLocation and + var.EndLocation > otherVar.BeginLocation + then + otherVar.UsedNameArray[i] = true + end + end + end + end + else + -- This is a global var, all other globals can't collide with it, and + -- any local variable with a reference to this global in it's lifetime + -- can't collide with it. + for _, otherVar in pairs(allVariables) do + if not otherVar.Renamed then + if otherVar.Type == 'Global' then + otherVar.UsedNameArray[i] = true + elseif otherVar.Type == 'Local' then + -- Other var is a local, see if there is a reference to this global within + -- that local's lifetime. + for _, refAt in pairs(var.ReferenceLocationList) do + if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then + -- Collide + otherVar.UsedNameArray[i] = true + break + end + end + else + assert(false, "unreachable") + end + end + end + end + end + + + -- -- + -- print("Total Variables: "..#allVariables) + -- print("Total Range: "..rootScope.BeginLocation.."-"..rootScope.EndLocation) + -- print("") + -- for _, var in pairs(allVariables) do + -- io.write("`"..var.Name.."':\n\t#symbols: "..#var.RenameList.. + -- "\n\tassigned to: "..tostring(var.AssignedTo)) + -- if var.Type == 'Local' then + -- io.write("\n\trange: "..var.BeginLocation.."-"..var.EndLocation) + -- io.write("\n\tlocal type: "..var.Info.Type) + -- end + -- io.write("\n\n") + -- end + + -- -- First we want to rename all of the variables to unique temoraries, so that we can + -- -- easily use the scope::GetVar function to check whether renames are valid. + -- local temporaryIndex = 0 + -- for _, var in pairs(allVariables) do + -- var:Rename('_TMP_'..temporaryIndex..'_') + -- temporaryIndex = temporaryIndex + 1 + -- end + + -- For each variable, we need to build a list of names that collide with it + + -- + --error() +end + +local function BeautifyVariables(globalScope, rootScope) + local externalGlobals = {} + for _, var in pairs(globalScope) do + if not var.AssignedTo then + externalGlobals[var.Name] = true + end + end + + local localNumber = 1 + local globalNumber = 1 + + local function setVarName(var, name) + var.Name = name + for _, setter in pairs(var.RenameList) do + setter(name) + end + end + + for _, var in pairs(globalScope) do + if var.AssignedTo then + setVarName(var, 'G_'..globalNumber) + globalNumber = globalNumber + 1 + end + end + + local function modify(scope) + for _, var in pairs(scope.VariableList) do + local name = 'L_'..localNumber..'_' + if var.Info.Type == 'Argument' then + name = name..'arg'..var.Info.Index + elseif var.Info.Type == 'LocalFunction' then + name = name..'func' + elseif var.Info.Type == 'ForRange' then + name = name..'forvar'..var.Info.Index + end + setVarName(var, name) + localNumber = localNumber + 1 + end + for _, scope in pairs(scope.ChildScopeList) do + modify(scope) + end + end + modify(rootScope) +end + +local function usageError() + error( + "\nusage: minify or unminify \n" .. + " The modified code will be printed to the stdout, pipe it to a file, the\n" .. + " lua interpreter, or something else as desired EG:\n\n" .. + " lua minify.lua minify input.lua > output.lua\n\n" .. + " * minify will minify the code in the file.\n" .. + " * unminify will beautify the code and replace the variable names with easily\n" .. + " find-replacable ones to aide in reverse engineering minified code.\n", 0) +end + +local args = {...} +if #args ~= 2 then + usageError() +end + +local sourceFile = io.open(args[2], 'r') +if not sourceFile then + error("Could not open the input file `" .. args[2] .. "`", 0) +end + +local data = sourceFile:read('*all') +local ast = CreateLuaParser(data) +local global_scope, root_scope = AddVariableInfo(ast) + +local function minify(ast, global_scope, root_scope) + MinifyVariables(global_scope, root_scope) + StripAst(ast) + PrintAst(ast) +end + +local function beautify(ast, global_scope, root_scope) + BeautifyVariables(global_scope, root_scope) + FormatAst(ast) + PrintAst(ast) +end + +if args[1] == 'minify' then + minify(ast, global_scope, root_scope) +elseif args[1] == 'unminify' then + beautify(ast, global_scope, root_scope) +else + usageError() +end