movescript/minify.lua

3228 lines
84 KiB
Lua

--[[
MIT License
Copyright (c) 2017 Mark Langen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
]]
function lookupify(tb)
for _, v in pairs(tb) do
tb[v] = true
end
return tb
end
function CountTable(tb)
local c = 0
for _ in pairs(tb) do c = c + 1 end
return c
end
function FormatTableInt(tb, atIndent, ignoreFunc)
if tb.Print then
return tb.Print()
end
atIndent = atIndent or 0
local useNewlines = (CountTable(tb) > 1)
local baseIndent = string.rep(' ', atIndent+1)
local out = "{"..(useNewlines and '\n' or '')
for k, v in pairs(tb) do
if type(v) ~= 'function' and not ignoreFunc(k) then
out = out..(useNewlines and baseIndent or '')
if type(k) == 'number' then
--nothing to do
elseif type(k) == 'string' and k:match("^[A-Za-z_][A-Za-z0-9_]*$") then
out = out..k.." = "
elseif type(k) == 'string' then
out = out.."[\""..k.."\"] = "
else
out = out.."["..tostring(k).."] = "
end
if type(v) == 'string' then
out = out.."\""..v.."\""
elseif type(v) == 'number' then
out = out..v
elseif type(v) == 'table' then
out = out..FormatTableInt(v, atIndent+(useNewlines and 1 or 0), ignoreFunc)
else
out = out..tostring(v)
end
if next(tb, k) then
out = out..","
end
if useNewlines then
out = out..'\n'
end
end
end
out = out..(useNewlines and string.rep(' ', atIndent) or '').."}"
return out
end
function FormatTable(tb, ignoreFunc)
ignoreFunc = ignoreFunc or function()
return false
end
return FormatTableInt(tb, 0, ignoreFunc)
end
local WhiteChars = lookupify{' ', '\n', '\t', '\r'}
local EscapeForCharacter = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'", ['\\'] = '\\'}
local CharacterForEscape = {['r'] = '\r', ['n'] = '\n', ['t'] = '\t', ['"'] = '"', ["'"] = "'", ['\\'] = '\\'}
local AllIdentStartChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_'}
local AllIdentChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
local HexDigits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'a', 'B', 'b', 'C', 'c', 'D', 'd', 'E', 'e', 'F', 'f'}
local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#', '.', ':'}
local EqualSymbols = lookupify{'~', '=', '>', '<'}
local Keywords = lookupify{
'and', 'break', 'do', 'else', 'elseif',
'end', 'false', 'for', 'function', 'goto', 'if',
'in', 'local', 'nil', 'not', 'or', 'repeat',
'return', 'then', 'true', 'until', 'while',
};
local BlockFollowKeyword = lookupify{'else', 'elseif', 'until', 'end'}
local UnopSet = lookupify{'-', 'not', '#'}
local BinopSet = lookupify{
'+', '-', '*', '/', '%', '^', '#',
'..', '.', ':',
'>', '<', '<=', '>=', '~=', '==',
'and', 'or'
}
local GlobalRenameIgnore = lookupify{
}
local BinaryPriority = {
['+'] = {6, 6};
['-'] = {6, 6};
['*'] = {7, 7};
['/'] = {7, 7};
['%'] = {7, 7};
['^'] = {10, 9};
['..'] = {5, 4};
['=='] = {3, 3};
['~='] = {3, 3};
['>'] = {3, 3};
['<'] = {3, 3};
['>='] = {3, 3};
['<='] = {3, 3};
['and'] = {2, 2};
['or'] = {1, 1};
};
local UnaryPriority = 8
-- Eof, Ident, Keyword, Number, String, Symbol
function CreateLuaTokenStream(text)
-- Tracking for the current position in the buffer, and
-- the current line / character we are on.
local p = 1
local length = #text
-- Output buffer for tokens
local tokenBuffer = {}
-- Get a character, or '' if at eof
local function look(n)
n = p + (n or 0)
if n <= length then
return text:sub(n, n)
else
return ''
end
end
local function get()
if p <= length then
local c = text:sub(p, p)
p = p + 1
return c
else
return ''
end
end
-- Error
local olderr = error
local function error(str)
local q = 1
local line = 1
local char = 1
while q <= p do
if text:sub(q, q) == '\n' then
line = line + 1
char = 1
else
char = char + 1
end
q = q + 1
end
for _, token in pairs(tokenBuffer) do
print(token.Type.."<"..token.Source..">")
end
olderr("file<"..line..":"..char..">: "..str)
end
-- Consume a long data with equals count of `eqcount'
local function longdata(eqcount)
while true do
local c = get()
if c == '' then
error("Unfinished long string.")
elseif c == ']' then
local done = true -- Until contested
for i = 1, eqcount do
if look() == '=' then
p = p + 1
else
done = false
break
end
end
if done and get() == ']' then
return
end
end
end
end
-- Get the opening part for a long data `[` `=`* `[`
-- Precondition: The first `[` has been consumed
-- Return: nil or the equals count
local function getopen()
local startp = p
while look() == '=' do
p = p + 1
end
if look() == '[' then
p = p + 1
return p - startp - 1
else
p = startp
return nil
end
end
-- Add token
local whiteStart = 1
local tokenStart = 1
local function token(type)
local tk = {
Type = type;
LeadingWhite = text:sub(whiteStart, tokenStart-1);
Source = text:sub(tokenStart, p-1);
}
table.insert(tokenBuffer, tk)
whiteStart = p
tokenStart = p
return tk
end
-- Parse tokens loop
while true do
-- Mark the whitespace start
whiteStart = p
-- Get the leading whitespace + comments
while true do
local c = look()
if c == '' then
break
elseif c == '-' then
if look(1) == '-' then
p = p + 2
-- Consume comment body
if look() == '[' then
p = p + 1
local eqcount = getopen()
if eqcount then
-- Long comment body
longdata(eqcount)
else
-- Normal comment body
while true do
local c2 = get()
if c2 == '' or c2 == '\n' then
break
end
end
end
else
-- Normal comment body
while true do
local c2 = get()
if c2 == '' or c2 == '\n' then
break
end
end
end
else
break
end
elseif WhiteChars[c] then
p = p + 1
else
break
end
end
local leadingWhite = text:sub(whiteStart, p-1)
-- Mark the token start
tokenStart = p
-- Switch on token type
local c1 = get()
if c1 == '' then
-- End of file
token('Eof')
break
elseif c1 == '\'' or c1 == '\"' then
-- String constant
while true do
local c2 = get()
if c2 == '\\' then
local c3 = get()
local esc = CharacterForEscape[c3]
if not esc then
error("Invalid Escape Sequence `"..c3.."`.")
end
elseif c2 == c1 then
break
end
end
token('String')
elseif AllIdentStartChars[c1] then
-- Ident or Keyword
while AllIdentChars[look()] do
p = p + 1
end
if Keywords[text:sub(tokenStart, p-1)] then
token('Keyword')
else
token('Ident')
end
elseif Digits[c1] or (c1 == '.' and Digits[look()]) then
-- Number
if c1 == '0' and look() == 'x' then
p = p + 1
-- Hex number
while HexDigits[look()] do
p = p + 1
end
else
-- Normal Number
while Digits[look()] do
p = p + 1
end
if look() == '.' then
-- With decimal point
p = p + 1
while Digits[look()] do
p = p + 1
end
end
if look() == 'e' or look() == 'E' then
-- With exponent
p = p + 1
if look() == '-' then
p = p + 1
end
while Digits[look()] do
p = p + 1
end
end
end
token('Number')
elseif c1 == '[' then
-- '[' Symbol or Long String
local eqCount = getopen()
if eqCount then
-- Long string
longdata(eqCount)
token('String')
else
-- Symbol
token('Symbol')
end
elseif c1 == '.' then
-- Greedily consume up to 3 `.` for . / .. / ... tokens
if look() == '.' then
get()
if look() == '.' then
get()
end
end
token('Symbol')
elseif EqualSymbols[c1] then
if look() == '=' then
p = p + 1
end
token('Symbol')
elseif Symbols[c1] then
token('Symbol')
else
error("Bad symbol `"..c1.."` in source.")
end
end
return tokenBuffer
end
function CreateLuaParser(text)
-- Token stream and pointer into it
local tokens = CreateLuaTokenStream(text)
-- for _, tok in pairs(tokens) do
-- print(tok.Type..": "..tok.Source)
-- end
local p = 1
local function get()
local tok = tokens[p]
if p < #tokens then
p = p + 1
end
return tok
end
local function peek(n)
n = p + (n or 0)
return tokens[n] or tokens[#tokens]
end
local function getTokenStartPosition(token)
local line = 1
local char = 0
local tkNum = 1
while true do
local tk = tokens[tkNum]
local text;
if tk == token then
text = tk.LeadingWhite
else
text = tk.LeadingWhite..tk.Source
end
for i = 1, #text do
local c = text:sub(i, i)
if c == '\n' then
line = line + 1
char = 0
else
char = char + 1
end
end
if tk == token then
break
end
tkNum = tkNum + 1
end
return line..":"..(char+1)
end
local function debugMark()
local tk = peek()
return "<"..tk.Type.." `"..tk.Source.."`> at: "..getTokenStartPosition(tk)
end
local function isBlockFollow()
local tok = peek()
return tok.Type == 'Eof' or (tok.Type == 'Keyword' and BlockFollowKeyword[tok.Source])
end
local function isUnop()
return UnopSet[peek().Source] or false
end
local function isBinop()
return BinopSet[peek().Source] or false
end
local function expect(type, source)
local tk = peek()
if tk.Type == type and (source == nil or tk.Source == source) then
return get()
else
for i = -3, 3 do
print("Tokens["..i.."] = `"..peek(i).Source.."`")
end
if source then
error(getTokenStartPosition(tk)..": `"..source.."` expected.")
else
error(getTokenStartPosition(tk)..": "..type.." expected.")
end
end
end
local function MkNode(node)
local getf = node.GetFirstToken
local getl = node.GetLastToken
function node:GetFirstToken()
local t = getf(self)
assert(t)
return t
end
function node:GetLastToken()
local t = getl(self)
assert(t)
return t
end
return node
end
-- Forward decls
local block;
local expr;
-- Expression list
local function exprlist()
local exprList = {}
local commaList = {}
table.insert(exprList, expr())
while peek().Source == ',' do
table.insert(commaList, get())
table.insert(exprList, expr())
end
return exprList, commaList
end
local function prefixexpr()
local tk = peek()
if tk.Source == '(' then
local oparenTk = get()
local inner = expr()
local cparenTk = expect('Symbol', ')')
return MkNode{
Type = 'ParenExpr';
Expression = inner;
Token_OpenParen = oparenTk;
Token_CloseParen = cparenTk;
GetFirstToken = function(self)
return self.Token_OpenParen
end;
GetLastToken = function(self)
return self.Token_CloseParen
end;
}
elseif tk.Type == 'Ident' then
return MkNode{
Type = 'VariableExpr';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
else
print(debugMark())
error(getTokenStartPosition(tk)..": Unexpected symbol")
end
end
function tableexpr()
local obrace = expect('Symbol', '{')
local entries = {}
local separators = {}
while peek().Source ~= '}' do
if peek().Source == '[' then
-- Index
local obrac = get()
local index = expr()
local cbrac = expect('Symbol', ']')
local eq = expect('Symbol', '=')
local value = expr()
table.insert(entries, {
EntryType = 'Index';
Index = index;
Value = value;
Token_OpenBracket = obrac;
Token_CloseBracket = cbrac;
Token_Equals = eq;
})
elseif peek().Type == 'Ident' and peek(1).Source == '=' then
-- Field
local field = get()
local eq = get()
local value = expr()
table.insert(entries, {
EntryType = 'Field';
Field = field;
Value = value;
Token_Equals = eq;
})
else
-- Value
local value = expr()
table.insert(entries, {
EntryType = 'Value';
Value = value;
})
end
-- Comma or Semicolon separator
if peek().Source == ',' or peek().Source == ';' then
table.insert(separators, get())
else
break
end
end
local cbrace = expect('Symbol', '}')
return MkNode{
Type = 'TableLiteral';
EntryList = entries;
Token_SeparatorList = separators;
Token_OpenBrace = obrace;
Token_CloseBrace = cbrace;
GetFirstToken = function(self)
return self.Token_OpenBrace
end;
GetLastToken = function(self)
return self.Token_CloseBrace
end;
}
end
-- List of identifiers
local function varlist()
local varList = {}
local commaList = {}
if peek().Type == 'Ident' then
table.insert(varList, get())
end
while peek().Source == ',' do
table.insert(commaList, get())
local id = expect('Ident')
table.insert(varList, id)
end
return varList, commaList
end
-- Body
local function blockbody(terminator)
local body = block()
local after = peek()
if after.Type == 'Keyword' and after.Source == terminator then
get()
return body, after
else
print(after.Type, after.Source)
error(getTokenStartPosition(after)..": "..terminator.." expected.")
end
end
-- Function declaration
local function funcdecl(isAnonymous)
local functionKw = get()
--
local nameChain;
local nameChainSeparator;
--
if not isAnonymous then
nameChain = {}
nameChainSeparator = {}
--
table.insert(nameChain, expect('Ident'))
--
while peek().Source == '.' do
table.insert(nameChainSeparator, get())
table.insert(nameChain, expect('Ident'))
end
if peek().Source == ':' then
table.insert(nameChainSeparator, get())
table.insert(nameChain, expect('Ident'))
end
end
--
local oparenTk = expect('Symbol', '(')
local argList, argCommaList = varlist()
local cparenTk = expect('Symbol', ')')
local fbody, enTk = blockbody('end')
--
return MkNode{
Type = (isAnonymous and 'FunctionLiteral' or 'FunctionStat');
NameChain = nameChain;
ArgList = argList;
Body = fbody;
--
Token_Function = functionKw;
Token_NameChainSeparator = nameChainSeparator;
Token_OpenParen = oparenTk;
Token_ArgCommaList = argCommaList;
Token_CloseParen = cparenTk;
Token_End = enTk;
GetFirstToken = function(self)
return self.Token_Function
end;
GetLastToken = function(self)
return self.Token_End;
end;
}
end
-- Argument list passed to a funciton
local function functionargs()
local tk = peek()
if tk.Source == '(' then
local oparenTk = get()
local argList = {}
local argCommaList = {}
while peek().Source ~= ')' do
table.insert(argList, expr())
if peek().Source == ',' then
table.insert(argCommaList, get())
else
break
end
end
local cparenTk = expect('Symbol', ')')
return MkNode{
CallType = 'ArgCall';
ArgList = argList;
--
Token_CommaList = argCommaList;
Token_OpenParen = oparenTk;
Token_CloseParen = cparenTk;
GetFirstToken = function(self)
return self.Token_OpenParen
end;
GetLastToken = function(self)
return self.Token_CloseParen
end;
}
elseif tk.Source == '{' then
return MkNode{
CallType = 'TableCall';
TableExpr = expr();
GetFirstToken = function(self)
return self.TableExpr:GetFirstToken()
end;
GetLastToken = function(self)
return self.TableExpr:GetLastToken()
end;
}
elseif tk.Type == 'String' then
return MkNode{
CallType = 'StringCall';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
else
error("Function arguments expected.")
end
end
local function primaryexpr()
local base = prefixexpr()
assert(base, "nil prefixexpr")
while true do
local tk = peek()
if tk.Source == '.' then
local dotTk = get()
local fieldName = expect('Ident')
base = MkNode{
Type = 'FieldExpr';
Base = base;
Field = fieldName;
Token_Dot = dotTk;
GetFirstToken = function(self)
return self.Base:GetFirstToken()
end;
GetLastToken = function(self)
return self.Field
end;
}
elseif tk.Source == ':' then
local colonTk = get()
local methodName = expect('Ident')
local fargs = functionargs()
base = MkNode{
Type = 'MethodExpr';
Base = base;
Method = methodName;
FunctionArguments = fargs;
Token_Colon = colonTk;
GetFirstToken = function(self)
return self.Base:GetFirstToken()
end;
GetLastToken = function(self)
return self.FunctionArguments:GetLastToken()
end;
}
elseif tk.Source == '[' then
local obrac = get()
local index = expr()
local cbrac = expect('Symbol', ']')
base = MkNode{
Type = 'IndexExpr';
Base = base;
Index = index;
Token_OpenBracket = obrac;
Token_CloseBracket = cbrac;
GetFirstToken = function(self)
return self.Base:GetFirstToken()
end;
GetLastToken = function(self)
return self.Token_CloseBracket
end;
}
elseif tk.Source == '{' then
base = MkNode{
Type = 'CallExpr';
Base = base;
FunctionArguments = functionargs();
GetFirstToken = function(self)
return self.Base:GetFirstToken()
end;
GetLastToken = function(self)
return self.FunctionArguments:GetLastToken()
end;
}
elseif tk.Source == '(' then
base = MkNode{
Type = 'CallExpr';
Base = base;
FunctionArguments = functionargs();
GetFirstToken = function(self)
return self.Base:GetFirstToken()
end;
GetLastToken = function(self)
return self.FunctionArguments:GetLastToken()
end;
}
else
return base
end
end
end
local function simpleexpr()
local tk = peek()
if tk.Type == 'Number' then
return MkNode{
Type = 'NumberLiteral';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
elseif tk.Type == 'String' then
return MkNode{
Type = 'StringLiteral';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
elseif tk.Source == 'nil' then
return MkNode{
Type = 'NilLiteral';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
elseif tk.Source == 'true' or tk.Source == 'false' then
return MkNode{
Type = 'BooleanLiteral';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
elseif tk.Source == '...' then
return MkNode{
Type = 'VargLiteral';
Token = get();
GetFirstToken = function(self)
return self.Token
end;
GetLastToken = function(self)
return self.Token
end;
}
elseif tk.Source == '{' then
return tableexpr()
elseif tk.Source == 'function' then
return funcdecl(true)
else
return primaryexpr()
end
end
local function subexpr(limit)
local curNode;
-- Initial Base Expression
if isUnop() then
local opTk = get()
local ex = subexpr(UnaryPriority)
curNode = MkNode{
Type = 'UnopExpr';
Token_Op = opTk;
Rhs = ex;
GetFirstToken = function(self)
return self.Token_Op
end;
GetLastToken = function(self)
return self.Rhs:GetLastToken()
end;
}
else
curNode = simpleexpr()
assert(curNode, "nil simpleexpr")
end
-- Apply Precedence Recursion Chain
while isBinop() and BinaryPriority[peek().Source][1] > limit do
local opTk = get()
local rhs = subexpr(BinaryPriority[opTk.Source][2])
assert(rhs, "RhsNeeded")
curNode = MkNode{
Type = 'BinopExpr';
Lhs = curNode;
Rhs = rhs;
Token_Op = opTk;
GetFirstToken = function(self)
return self.Lhs:GetFirstToken()
end;
GetLastToken = function(self)
return self.Rhs:GetLastToken()
end;
}
end
-- Return result
return curNode
end
-- Expression
expr = function()
return subexpr(0)
end
-- Expression statement
local function exprstat()
local ex = primaryexpr()
if ex.Type == 'MethodExpr' or ex.Type == 'CallExpr' then
-- all good, calls can be statements
return MkNode{
Type = 'CallExprStat';
Expression = ex;
GetFirstToken = function(self)
return self.Expression:GetFirstToken()
end;
GetLastToken = function(self)
return self.Expression:GetLastToken()
end;
}
else
-- Assignment expr
local lhs = {ex}
local lhsSeparator = {}
while peek().Source == ',' do
table.insert(lhsSeparator, get())
local lhsPart = primaryexpr()
if lhsPart.Type == 'MethodExpr' or lhsPart.Type == 'CallExpr' then
error("Bad left hand side of assignment")
end
table.insert(lhs, lhsPart)
end
local eq = expect('Symbol', '=')
local rhs = {expr()}
local rhsSeparator = {}
while peek().Source == ',' do
table.insert(rhsSeparator, get())
table.insert(rhs, expr())
end
return MkNode{
Type = 'AssignmentStat';
Rhs = rhs;
Lhs = lhs;
Token_Equals = eq;
Token_LhsSeparatorList = lhsSeparator;
Token_RhsSeparatorList = rhsSeparator;
GetFirstToken = function(self)
return self.Lhs[1]:GetFirstToken()
end;
GetLastToken = function(self)
return self.Rhs[#self.Rhs]:GetLastToken()
end;
}
end
end
-- If statement
local function ifstat()
local ifKw = get()
local condition = expr()
local thenKw = expect('Keyword', 'then')
local ifBody = block()
local elseClauses = {}
while peek().Source == 'elseif' or peek().Source == 'else' do
local elseifKw = get()
local elseifCondition, elseifThenKw;
if elseifKw.Source == 'elseif' then
elseifCondition = expr()
elseifThenKw = expect('Keyword', 'then')
end
local elseifBody = block()
table.insert(elseClauses, {
Condition = elseifCondition;
Body = elseifBody;
--
ClauseType = elseifKw.Source;
Token = elseifKw;
Token_Then = elseifThenKw;
})
if elseifKw.Source == 'else' then
break
end
end
local enKw = expect('Keyword', 'end')
return MkNode{
Type = 'IfStat';
Condition = condition;
Body = ifBody;
ElseClauseList = elseClauses;
--
Token_If = ifKw;
Token_Then = thenKw;
Token_End = enKw;
GetFirstToken = function(self)
return self.Token_If
end;
GetLastToken = function(self)
return self.Token_End
end;
}
end
-- Do statement
local function dostat()
local doKw = get()
local body, enKw = blockbody('end')
--
return MkNode{
Type = 'DoStat';
Body = body;
--
Token_Do = doKw;
Token_End = enKw;
GetFirstToken = function(self)
return self.Token_Do
end;
GetLastToken = function(self)
return self.Token_End
end;
}
end
-- While statement
local function whilestat()
local whileKw = get()
local condition = expr()
local doKw = expect('Keyword', 'do')
local body, enKw = blockbody('end')
--
return MkNode{
Type = 'WhileStat';
Condition = condition;
Body = body;
--
Token_While = whileKw;
Token_Do = doKw;
Token_End = enKw;
GetFirstToken = function(self)
return self.Token_While
end;
GetLastToken = function(self)
return self.Token_End
end;
}
end
-- For statement
local function forstat()
local forKw = get()
local loopVars, loopVarCommas = varlist()
local node = {}
if peek().Source == '=' then
local eqTk = get()
local exprList, exprCommaList = exprlist()
if #exprList < 2 or #exprList > 3 then
error("expected 2 or 3 values for range bounds")
end
local doTk = expect('Keyword', 'do')
local body, enTk = blockbody('end')
return MkNode{
Type = 'NumericForStat';
VarList = loopVars;
RangeList = exprList;
Body = body;
--
Token_For = forKw;
Token_VarCommaList = loopVarCommas;
Token_Equals = eqTk;
Token_RangeCommaList = exprCommaList;
Token_Do = doTk;
Token_End = enTk;
GetFirstToken = function(self)
return self.Token_For
end;
GetLastToken = function(self)
return self.Token_End
end;
}
elseif peek().Source == 'in' then
local inTk = get()
local exprList, exprCommaList = exprlist()
local doTk = expect('Keyword', 'do')
local body, enTk = blockbody('end')
return MkNode{
Type = 'GenericForStat';
VarList = loopVars;
GeneratorList = exprList;
Body = body;
--
Token_For = forKw;
Token_VarCommaList = loopVarCommas;
Token_In = inTk;
Token_GeneratorCommaList = exprCommaList;
Token_Do = doTk;
Token_End = enTk;
GetFirstToken = function(self)
return self.Token_For
end;
GetLastToken = function(self)
return self.Token_End
end;
}
else
error("`=` or in expected")
end
end
-- Repeat statement
local function repeatstat()
local repeatKw = get()
local body, untilTk = blockbody('until')
local condition = expr()
return MkNode{
Type = 'RepeatStat';
Body = body;
Condition = condition;
--
Token_Repeat = repeatKw;
Token_Until = untilTk;
GetFirstToken = function(self)
return self.Token_Repeat
end;
GetLastToken = function(self)
return self.Condition:GetLastToken()
end;
}
end
-- Local var declaration
local function localdecl()
local localKw = get()
if peek().Source == 'function' then
-- Local function def
local funcStat = funcdecl(false)
if #funcStat.NameChain > 1 then
error(getTokenStartPosition(funcStat.Token_NameChainSeparator[1])..": `(` expected.")
end
return MkNode{
Type = 'LocalFunctionStat';
FunctionStat = funcStat;
Token_Local = localKw;
GetFirstToken = function(self)
return self.Token_Local
end;
GetLastToken = function(self)
return self.FunctionStat:GetLastToken()
end;
}
elseif peek().Type == 'Ident' then
-- Local variable declaration
local varList, varCommaList = varlist()
local exprList, exprCommaList = {}, {}
local eqToken;
if peek().Source == '=' then
eqToken = get()
exprList, exprCommaList = exprlist()
end
return MkNode{
Type = 'LocalVarStat';
VarList = varList;
ExprList = exprList;
Token_Local = localKw;
Token_Equals = eqToken;
Token_VarCommaList = varCommaList;
Token_ExprCommaList = exprCommaList;
GetFirstToken = function(self)
return self.Token_Local
end;
GetLastToken = function(self)
if #self.ExprList > 0 then
return self.ExprList[#self.ExprList]:GetLastToken()
else
return self.VarList[#self.VarList]
end
end;
}
else
error("`function` or ident expected")
end
end
-- Return statement
local function retstat()
local returnKw = get()
local exprList;
local commaList;
if isBlockFollow() or peek().Source == ';' then
exprList = {}
commaList = {}
else
exprList, commaList = exprlist()
end
return {
Type = 'ReturnStat';
ExprList = exprList;
Token_Return = returnKw;
Token_CommaList = commaList;
GetFirstToken = function(self)
return self.Token_Return
end;
GetLastToken = function(self)
if #self.ExprList > 0 then
return self.ExprList[#self.ExprList]:GetLastToken()
else
return self.Token_Return
end
end;
}
end
-- Break statement
local function breakstat()
local breakKw = get()
return {
Type = 'BreakStat';
Token_Break = breakKw;
GetFirstToken = function(self)
return self.Token_Break
end;
GetLastToken = function(self)
return self.Token_Break
end;
}
end
-- Expression
local function statement()
local tok = peek()
if tok.Source == 'if' then
return false, ifstat()
elseif tok.Source == 'while' then
return false, whilestat()
elseif tok.Source == 'do' then
return false, dostat()
elseif tok.Source == 'for' then
return false, forstat()
elseif tok.Source == 'repeat' then
return false, repeatstat()
elseif tok.Source == 'function' then
return false, funcdecl(false)
elseif tok.Source == 'local' then
return false, localdecl()
elseif tok.Source == 'return' then
return true, retstat()
elseif tok.Source == 'break' then
return true, breakstat()
else
return false, exprstat()
end
end
-- Chunk
block = function()
local statements = {}
local semicolons = {}
local isLast = false
while not isLast and not isBlockFollow() do
-- Parse statement
local stat;
isLast, stat = statement()
table.insert(statements, stat)
local next = peek()
if next.Type == 'Symbol' and next.Source == ';' then
semicolons[#statements] = get()
end
end
return {
Type = 'StatList';
StatementList = statements;
SemicolonList = semicolons;
GetFirstToken = function(self)
if #self.StatementList == 0 then
return nil
else
return self.StatementList[1]:GetFirstToken()
end
end;
GetLastToken = function(self)
if #self.StatementList == 0 then
return nil
elseif self.SemicolonList[#self.StatementList] then
-- Last token may be one of the semicolon separators
return self.SemicolonList[#self.StatementList]
else
return self.StatementList[#self.StatementList]:GetLastToken()
end
end;
}
end
return block()
end
function VisitAst(ast, visitors)
local ExprType = lookupify{
'BinopExpr'; 'UnopExpr';
'NumberLiteral'; 'StringLiteral'; 'NilLiteral'; 'BooleanLiteral'; 'VargLiteral';
'FieldExpr'; 'IndexExpr';
'MethodExpr'; 'CallExpr';
'FunctionLiteral';
'VariableExpr';
'ParenExpr';
'TableLiteral';
}
local StatType = lookupify{
'StatList';
'BreakStat';
'ReturnStat';
'LocalVarStat';
'LocalFunctionStat';
'FunctionStat';
'RepeatStat';
'GenericForStat';
'NumericForStat';
'WhileStat';
'DoStat';
'IfStat';
'CallExprStat';
'AssignmentStat';
}
-- Check for typos in visitor construction
for visitorSubject, visitor in pairs(visitors) do
if not StatType[visitorSubject] and not ExprType[visitorSubject] then
error("Invalid visitor target: `"..visitorSubject.."`")
end
end
-- Helpers to call visitors on a node
local function preVisit(exprOrStat)
local visitor = visitors[exprOrStat.Type]
if type(visitor) == 'function' then
return visitor(exprOrStat)
elseif visitor and visitor.Pre then
return visitor.Pre(exprOrStat)
end
end
local function postVisit(exprOrStat)
local visitor = visitors[exprOrStat.Type]
if visitor and type(visitor) == 'table' and visitor.Post then
return visitor.Post(exprOrStat)
end
end
local visitExpr, visitStat;
visitExpr = function(expr)
if preVisit(expr) then
-- Handler did custom child iteration or blocked child iteration
return
end
if expr.Type == 'BinopExpr' then
visitExpr(expr.Lhs)
visitExpr(expr.Rhs)
elseif expr.Type == 'UnopExpr' then
visitExpr(expr.Rhs)
elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
expr.Type == 'VargLiteral'
then
-- No children to visit, single token literals
elseif expr.Type == 'FieldExpr' then
visitExpr(expr.Base)
elseif expr.Type == 'IndexExpr' then
visitExpr(expr.Base)
visitExpr(expr.Index)
elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
visitExpr(expr.Base)
if expr.FunctionArguments.CallType == 'ArgCall' then
for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
visitExpr(argExpr)
end
elseif expr.FunctionArguments.CallType == 'TableCall' then
visitExpr(expr.FunctionArguments.TableExpr)
end
elseif expr.Type == 'FunctionLiteral' then
visitStat(expr.Body)
elseif expr.Type == 'VariableExpr' then
-- No children to visit
elseif expr.Type == 'ParenExpr' then
visitExpr(expr.Expression)
elseif expr.Type == 'TableLiteral' then
for index, entry in pairs(expr.EntryList) do
if entry.EntryType == 'Field' then
visitExpr(entry.Value)
elseif entry.EntryType == 'Index' then
visitExpr(entry.Index)
visitExpr(entry.Value)
elseif entry.EntryType == 'Value' then
visitExpr(entry.Value)
else
assert(false, "unreachable")
end
end
else
assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
end
postVisit(expr)
end
visitStat = function(stat)
if preVisit(stat) then
-- Handler did custom child iteration or blocked child iteration
return
end
if stat.Type == 'StatList' then
for index, ch in pairs(stat.StatementList) do
visitStat(ch)
end
elseif stat.Type == 'BreakStat' then
-- No children to visit
elseif stat.Type == 'ReturnStat' then
for index, expr in pairs(stat.ExprList) do
visitExpr(expr)
end
elseif stat.Type == 'LocalVarStat' then
if stat.Token_Equals then
for index, expr in pairs(stat.ExprList) do
visitExpr(expr)
end
end
elseif stat.Type == 'LocalFunctionStat' then
visitStat(stat.FunctionStat.Body)
elseif stat.Type == 'FunctionStat' then
visitStat(stat.Body)
elseif stat.Type == 'RepeatStat' then
visitStat(stat.Body)
visitExpr(stat.Condition)
elseif stat.Type == 'GenericForStat' then
for index, expr in pairs(stat.GeneratorList) do
visitExpr(expr)
end
visitStat(stat.Body)
elseif stat.Type == 'NumericForStat' then
for index, expr in pairs(stat.RangeList) do
visitExpr(expr)
end
visitStat(stat.Body)
elseif stat.Type == 'WhileStat' then
visitExpr(stat.Condition)
visitStat(stat.Body)
elseif stat.Type == 'DoStat' then
visitStat(stat.Body)
elseif stat.Type == 'IfStat' then
visitExpr(stat.Condition)
visitStat(stat.Body)
for _, clause in pairs(stat.ElseClauseList) do
if clause.Condition then
visitExpr(clause.Condition)
end
visitStat(clause.Body)
end
elseif stat.Type == 'CallExprStat' then
visitExpr(stat.Expression)
elseif stat.Type == 'AssignmentStat' then
for index, ex in pairs(stat.Lhs) do
visitExpr(ex)
end
for index, ex in pairs(stat.Rhs) do
visitExpr(ex)
end
else
assert(false, "unreachable")
end
postVisit(stat)
end
if StatType[ast.Type] then
visitStat(ast)
else
visitExpr(ast)
end
end
function AddVariableInfo(ast)
local globalVars = {}
local currentScope = nil
-- Numbering generator for variable lifetimes
local locationGenerator = 0
local function markLocation()
locationGenerator = locationGenerator + 1
return locationGenerator
end
-- Scope management
local function pushScope()
currentScope = {
ParentScope = currentScope;
ChildScopeList = {};
VariableList = {};
BeginLocation = markLocation();
}
if currentScope.ParentScope then
currentScope.Depth = currentScope.ParentScope.Depth + 1
table.insert(currentScope.ParentScope.ChildScopeList, currentScope)
else
currentScope.Depth = 1
end
function currentScope:GetVar(varName)
for _, var in pairs(self.VariableList) do
if var.Name == varName then
return var
end
end
if self.ParentScope then
return self.ParentScope:GetVar(varName)
else
for _, var in pairs(globalVars) do
if var.Name == varName then
return var
end
end
end
end
end
local function popScope()
local scope = currentScope
-- Mark where this scope ends
scope.EndLocation = markLocation()
-- Mark all of the variables in the scope as ending there
for _, var in pairs(scope.VariableList) do
var.ScopeEndLocation = scope.EndLocation
end
-- Move to the parent scope
currentScope = scope.ParentScope
return scope
end
pushScope() -- push initial scope
-- Add / reference variables
local function addLocalVar(name, setNameFunc, localInfo)
assert(localInfo, "Misisng localInfo")
assert(name, "Missing local var name")
local var = {
Type = 'Local';
Name = name;
RenameList = {setNameFunc};
AssignedTo = false;
Info = localInfo;
UseCount = 0;
Scope = currentScope;
BeginLocation = markLocation();
EndLocation = markLocation();
ReferenceLocationList = {markLocation()};
}
function var:Rename(newName)
self.Name = newName
for _, renameFunc in pairs(self.RenameList) do
renameFunc(newName)
end
end
function var:Reference()
self.UseCount = self.UseCount + 1
end
table.insert(currentScope.VariableList, var)
return var
end
local function getGlobalVar(name)
for _, var in pairs(globalVars) do
if var.Name == name then
return var
end
end
local var = {
Type = 'Global';
Name = name;
RenameList = {};
AssignedTo = false;
UseCount = 0;
Scope = nil; -- Globals have no scope
BeginLocation = markLocation();
EndLocation = markLocation();
ReferenceLocationList = {};
}
function var:Rename(newName)
self.Name = newName
for _, renameFunc in pairs(self.RenameList) do
renameFunc(newName)
end
end
function var:Reference()
self.UseCount = self.UseCount + 1
end
table.insert(globalVars, var)
return var
end
local function addGlobalReference(name, setNameFunc)
assert(name, "Missing var name")
local var = getGlobalVar(name)
table.insert(var.RenameList, setNameFunc)
return var
end
local function getLocalVar(scope, name)
-- First search this scope
-- Note: Reverse iterate here because Lua does allow shadowing a local
-- within the same scope, and the later defined variable should
-- be the one referenced.
for i = #scope.VariableList, 1, -1 do
if scope.VariableList[i].Name == name then
return scope.VariableList[i]
end
end
-- Then search parent scope
if scope.ParentScope then
local var = getLocalVar(scope.ParentScope, name)
if var then
return var
end
end
-- Then
return nil
end
local function referenceVariable(name, setNameFunc)
assert(name, "Missing var name")
local var = getLocalVar(currentScope, name)
if var then
table.insert(var.RenameList, setNameFunc)
else
var = addGlobalReference(name, setNameFunc)
end
-- Update the end location of where this variable is used, and
-- add this location to the list of references to this variable.
local curLocation = markLocation()
var.EndLocation = curLocation
table.insert(var.ReferenceLocationList, var.EndLocation)
return var
end
local visitor = {}
visitor.FunctionLiteral = {
-- Function literal adds a new scope and adds the function literal arguments
-- as local variables in the scope.
Pre = function(expr)
pushScope()
for index, ident in pairs(expr.ArgList) do
local var = addLocalVar(ident.Source, function(name)
ident.Source = name
end, {
Type = 'Argument';
Index = index;
})
end
end;
Post = function(expr)
popScope()
end;
}
visitor.VariableExpr = function(expr)
-- Variable expression references from existing local varibales
-- in the current scope, annotating the variable usage with variable
-- information.
expr.Variable = referenceVariable(expr.Token.Source, function(newName)
expr.Token.Source = newName
end)
end
visitor.StatList = {
-- StatList adds a new scope
Pre = function(stat)
pushScope()
end;
Post = function(stat)
popScope()
end;
}
visitor.LocalVarStat = {
Post = function(stat)
-- Local var stat adds the local variables to the current scope as locals
-- We need to visit the subexpressions first, because these new locals
-- will not be in scope for the initialization value expressions. That is:
-- `local bar = bar + 1`
-- Is valid code
for varNum, ident in pairs(stat.VarList) do
addLocalVar(ident.Source, function(name)
stat.VarList[varNum].Source = name
end, {
Type = 'Local';
})
end
end;
}
visitor.LocalFunctionStat = {
Pre = function(stat)
-- Local function stat adds the function itself to the current scope as
-- a local variable, and creates a new scope with the function arguments
-- as local variables.
addLocalVar(stat.FunctionStat.NameChain[1].Source, function(name)
stat.FunctionStat.NameChain[1].Source = name
end, {
Type = 'LocalFunction';
})
pushScope()
for index, ident in pairs(stat.FunctionStat.ArgList) do
addLocalVar(ident.Source, function(name)
ident.Source = name
end, {
Type = 'Argument';
Index = index;
})
end
end;
Post = function()
popScope()
end;
}
visitor.FunctionStat = {
Pre = function(stat)
-- Function stat adds a new scope containing the function arguments
-- as local variables.
-- A function stat may also assign to a global variable if it is in
-- the form `function foo()` with no additional dots/colons in the
-- name chain.
local nameChain = stat.NameChain
local var;
if #nameChain == 1 then
-- If there is only one item in the name chain, then the first item
-- is a reference to a global variable.
var = addGlobalReference(nameChain[1].Source, function(name)
nameChain[1].Source = name
end)
else
var = referenceVariable(nameChain[1].Source, function(name)
nameChain[1].Source = name
end)
end
var.AssignedTo = true
pushScope()
for index, ident in pairs(stat.ArgList) do
addLocalVar(ident.Source, function(name)
ident.Source = name
end, {
Type = 'Argument';
Index = index;
})
end
end;
Post = function()
popScope()
end;
}
visitor.GenericForStat = {
Pre = function(stat)
-- Generic fors need an extra scope holding the range variables
-- Need a custom visitor so that the generator expressions can be
-- visited before we push a scope, but the body can be visited
-- after we push a scope.
for _, ex in pairs(stat.GeneratorList) do
VisitAst(ex, visitor)
end
pushScope()
for index, ident in pairs(stat.VarList) do
addLocalVar(ident.Source, function(name)
ident.Source = name
end, {
Type = 'ForRange';
Index = index;
})
end
VisitAst(stat.Body, visitor)
popScope()
return true -- Custom visit
end;
}
visitor.NumericForStat = {
Pre = function(stat)
-- Numeric fors need an extra scope holding the range variables
-- Need a custom visitor so that the generator expressions can be
-- visited before we push a scope, but the body can be visited
-- after we push a scope.
for _, ex in pairs(stat.RangeList) do
VisitAst(ex, visitor)
end
pushScope()
for index, ident in pairs(stat.VarList) do
addLocalVar(ident.Source, function(name)
ident.Source = name
end, {
Type = 'ForRange';
Index = index;
})
end
VisitAst(stat.Body, visitor)
popScope()
return true -- Custom visit
end;
}
visitor.AssignmentStat = {
Post = function(stat)
-- For an assignment statement we need to mark the
-- "assigned to" flag on variables.
for _, ex in pairs(stat.Lhs) do
if ex.Variable then
ex.Variable.AssignedTo = true
end
end
end;
}
VisitAst(ast, visitor)
return globalVars, popScope()
end
-- Prints out an AST to a string
function PrintAst(ast)
local printStat, printExpr;
local function printt(tk)
if not tk.LeadingWhite or not tk.Source then
error("Bad token: "..FormatTable(tk))
end
io.write(tk.LeadingWhite)
io.write(tk.Source)
end
printExpr = function(expr)
if expr.Type == 'BinopExpr' then
printExpr(expr.Lhs)
printt(expr.Token_Op)
printExpr(expr.Rhs)
elseif expr.Type == 'UnopExpr' then
printt(expr.Token_Op)
printExpr(expr.Rhs)
elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
expr.Type == 'VargLiteral'
then
-- Just print the token
printt(expr.Token)
elseif expr.Type == 'FieldExpr' then
printExpr(expr.Base)
printt(expr.Token_Dot)
printt(expr.Field)
elseif expr.Type == 'IndexExpr' then
printExpr(expr.Base)
printt(expr.Token_OpenBracket)
printExpr(expr.Index)
printt(expr.Token_CloseBracket)
elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
printExpr(expr.Base)
if expr.Type == 'MethodExpr' then
printt(expr.Token_Colon)
printt(expr.Method)
end
if expr.FunctionArguments.CallType == 'StringCall' then
printt(expr.FunctionArguments.Token)
elseif expr.FunctionArguments.CallType == 'ArgCall' then
printt(expr.FunctionArguments.Token_OpenParen)
for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
printExpr(argExpr)
local sep = expr.FunctionArguments.Token_CommaList[index]
if sep then
printt(sep)
end
end
printt(expr.FunctionArguments.Token_CloseParen)
elseif expr.FunctionArguments.CallType == 'TableCall' then
printExpr(expr.FunctionArguments.TableExpr)
end
elseif expr.Type == 'FunctionLiteral' then
printt(expr.Token_Function)
printt(expr.Token_OpenParen)
for index, arg in pairs(expr.ArgList) do
printt(arg)
local comma = expr.Token_ArgCommaList[index]
if comma then
printt(comma)
end
end
printt(expr.Token_CloseParen)
printStat(expr.Body)
printt(expr.Token_End)
elseif expr.Type == 'VariableExpr' then
printt(expr.Token)
elseif expr.Type == 'ParenExpr' then
printt(expr.Token_OpenParen)
printExpr(expr.Expression)
printt(expr.Token_CloseParen)
elseif expr.Type == 'TableLiteral' then
printt(expr.Token_OpenBrace)
for index, entry in pairs(expr.EntryList) do
if entry.EntryType == 'Field' then
printt(entry.Field)
printt(entry.Token_Equals)
printExpr(entry.Value)
elseif entry.EntryType == 'Index' then
printt(entry.Token_OpenBracket)
printExpr(entry.Index)
printt(entry.Token_CloseBracket)
printt(entry.Token_Equals)
printExpr(entry.Value)
elseif entry.EntryType == 'Value' then
printExpr(entry.Value)
else
assert(false, "unreachable")
end
local sep = expr.Token_SeparatorList[index]
if sep then
printt(sep)
end
end
printt(expr.Token_CloseBrace)
else
assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
end
end
printStat = function(stat)
if stat.Type == 'StatList' then
for index, ch in pairs(stat.StatementList) do
printStat(ch)
if stat.SemicolonList[index] then
printt(stat.SemicolonList[index])
end
end
elseif stat.Type == 'BreakStat' then
printt(stat.Token_Break)
elseif stat.Type == 'ReturnStat' then
printt(stat.Token_Return)
for index, expr in pairs(stat.ExprList) do
printExpr(expr)
if stat.Token_CommaList[index] then
printt(stat.Token_CommaList[index])
end
end
elseif stat.Type == 'LocalVarStat' then
printt(stat.Token_Local)
for index, var in pairs(stat.VarList) do
printt(var)
local comma = stat.Token_VarCommaList[index]
if comma then
printt(comma)
end
end
if stat.Token_Equals then
printt(stat.Token_Equals)
for index, expr in pairs(stat.ExprList) do
printExpr(expr)
local comma = stat.Token_ExprCommaList[index]
if comma then
printt(comma)
end
end
end
elseif stat.Type == 'LocalFunctionStat' then
printt(stat.Token_Local)
printt(stat.FunctionStat.Token_Function)
printt(stat.FunctionStat.NameChain[1])
printt(stat.FunctionStat.Token_OpenParen)
for index, arg in pairs(stat.FunctionStat.ArgList) do
printt(arg)
local comma = stat.FunctionStat.Token_ArgCommaList[index]
if comma then
printt(comma)
end
end
printt(stat.FunctionStat.Token_CloseParen)
printStat(stat.FunctionStat.Body)
printt(stat.FunctionStat.Token_End)
elseif stat.Type == 'FunctionStat' then
printt(stat.Token_Function)
for index, part in pairs(stat.NameChain) do
printt(part)
local sep = stat.Token_NameChainSeparator[index]
if sep then
printt(sep)
end
end
printt(stat.Token_OpenParen)
for index, arg in pairs(stat.ArgList) do
printt(arg)
local comma = stat.Token_ArgCommaList[index]
if comma then
printt(comma)
end
end
printt(stat.Token_CloseParen)
printStat(stat.Body)
printt(stat.Token_End)
elseif stat.Type == 'RepeatStat' then
printt(stat.Token_Repeat)
printStat(stat.Body)
printt(stat.Token_Until)
printExpr(stat.Condition)
elseif stat.Type == 'GenericForStat' then
printt(stat.Token_For)
for index, var in pairs(stat.VarList) do
printt(var)
local sep = stat.Token_VarCommaList[index]
if sep then
printt(sep)
end
end
printt(stat.Token_In)
for index, expr in pairs(stat.GeneratorList) do
printExpr(expr)
local sep = stat.Token_GeneratorCommaList[index]
if sep then
printt(sep)
end
end
printt(stat.Token_Do)
printStat(stat.Body)
printt(stat.Token_End)
elseif stat.Type == 'NumericForStat' then
printt(stat.Token_For)
for index, var in pairs(stat.VarList) do
printt(var)
local sep = stat.Token_VarCommaList[index]
if sep then
printt(sep)
end
end
printt(stat.Token_Equals)
for index, expr in pairs(stat.RangeList) do
printExpr(expr)
local sep = stat.Token_RangeCommaList[index]
if sep then
printt(sep)
end
end
printt(stat.Token_Do)
printStat(stat.Body)
printt(stat.Token_End)
elseif stat.Type == 'WhileStat' then
printt(stat.Token_While)
printExpr(stat.Condition)
printt(stat.Token_Do)
printStat(stat.Body)
printt(stat.Token_End)
elseif stat.Type == 'DoStat' then
printt(stat.Token_Do)
printStat(stat.Body)
printt(stat.Token_End)
elseif stat.Type == 'IfStat' then
printt(stat.Token_If)
printExpr(stat.Condition)
printt(stat.Token_Then)
printStat(stat.Body)
for _, clause in pairs(stat.ElseClauseList) do
printt(clause.Token)
if clause.Condition then
printExpr(clause.Condition)
printt(clause.Token_Then)
end
printStat(clause.Body)
end
printt(stat.Token_End)
elseif stat.Type == 'CallExprStat' then
printExpr(stat.Expression)
elseif stat.Type == 'AssignmentStat' then
for index, ex in pairs(stat.Lhs) do
printExpr(ex)
local sep = stat.Token_LhsSeparatorList[index]
if sep then
printt(sep)
end
end
printt(stat.Token_Equals)
for index, ex in pairs(stat.Rhs) do
printExpr(ex)
local sep = stat.Token_RhsSeparatorList[index]
if sep then
printt(sep)
end
end
else
assert(false, "unreachable")
end
end
printStat(ast)
end
-- Adds / removes whitespace in an AST to put it into a "standard formatting"
local function FormatAst(ast)
local formatStat, formatExpr;
local currentIndent = 0
local function applyIndent(token)
local indentString = '\n'..('\t'):rep(currentIndent)
if token.LeadingWhite == '' or (token.LeadingWhite:sub(-#indentString, -1) ~= indentString) then
-- Trim existing trailing whitespace on LeadingWhite
-- Trim trailing tabs and spaces, and up to one newline
token.LeadingWhite = token.LeadingWhite:gsub("\n?[\t ]*$", "")
token.LeadingWhite = token.LeadingWhite..indentString
end
end
local function indent()
currentIndent = currentIndent + 1
end
local function undent()
currentIndent = currentIndent - 1
assert(currentIndent >= 0, "Undented too far")
end
local function leadingChar(tk)
if #tk.LeadingWhite > 0 then
return tk.LeadingWhite:sub(1,1)
else
return tk.Source:sub(1,1)
end
end
local function padToken(tk)
if not WhiteChars[leadingChar(tk)] then
tk.LeadingWhite = ' '..tk.LeadingWhite
end
end
local function padExpr(expr)
padToken(expr:GetFirstToken())
end
local function formatBody(openToken, bodyStat, closeToken)
indent()
formatStat(bodyStat)
undent()
applyIndent(closeToken)
end
formatExpr = function(expr)
if expr.Type == 'BinopExpr' then
formatExpr(expr.Lhs)
formatExpr(expr.Rhs)
if expr.Token_Op.Source == '..' then
-- No padding on ..
else
padExpr(expr.Rhs)
padToken(expr.Token_Op)
end
elseif expr.Type == 'UnopExpr' then
formatExpr(expr.Rhs)
--(expr.Token_Op)
elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
expr.Type == 'VargLiteral'
then
-- Nothing to do
--(expr.Token)
elseif expr.Type == 'FieldExpr' then
formatExpr(expr.Base)
--(expr.Token_Dot)
--(expr.Field)
elseif expr.Type == 'IndexExpr' then
formatExpr(expr.Base)
formatExpr(expr.Index)
--(expr.Token_OpenBracket)
--(expr.Token_CloseBracket)
elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
formatExpr(expr.Base)
if expr.Type == 'MethodExpr' then
--(expr.Token_Colon)
--(expr.Method)
end
if expr.FunctionArguments.CallType == 'StringCall' then
--(expr.FunctionArguments.Token)
elseif expr.FunctionArguments.CallType == 'ArgCall' then
--(expr.FunctionArguments.Token_OpenParen)
for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
formatExpr(argExpr)
if index > 1 then
padExpr(argExpr)
end
local sep = expr.FunctionArguments.Token_CommaList[index]
if sep then
--(sep)
end
end
--(expr.FunctionArguments.Token_CloseParen)
elseif expr.FunctionArguments.CallType == 'TableCall' then
formatExpr(expr.FunctionArguments.TableExpr)
end
elseif expr.Type == 'FunctionLiteral' then
--(expr.Token_Function)
--(expr.Token_OpenParen)
for index, arg in pairs(expr.ArgList) do
--(arg)
if index > 1 then
padToken(arg)
end
local comma = expr.Token_ArgCommaList[index]
if comma then
--(comma)
end
end
--(expr.Token_CloseParen)
formatBody(expr.Token_CloseParen, expr.Body, expr.Token_End)
elseif expr.Type == 'VariableExpr' then
--(expr.Token)
elseif expr.Type == 'ParenExpr' then
formatExpr(expr.Expression)
--(expr.Token_OpenParen)
--(expr.Token_CloseParen)
elseif expr.Type == 'TableLiteral' then
--(expr.Token_OpenBrace)
if #expr.EntryList == 0 then
-- Nothing to do
else
indent()
for index, entry in pairs(expr.EntryList) do
if entry.EntryType == 'Field' then
applyIndent(entry.Field)
padToken(entry.Token_Equals)
formatExpr(entry.Value)
padExpr(entry.Value)
elseif entry.EntryType == 'Index' then
applyIndent(entry.Token_OpenBracket)
formatExpr(entry.Index)
--(entry.Token_CloseBracket)
padToken(entry.Token_Equals)
formatExpr(entry.Value)
padExpr(entry.Value)
elseif entry.EntryType == 'Value' then
formatExpr(entry.Value)
applyIndent(entry.Value:GetFirstToken())
else
assert(false, "unreachable")
end
local sep = expr.Token_SeparatorList[index]
if sep then
--(sep)
end
end
undent()
applyIndent(expr.Token_CloseBrace)
end
--(expr.Token_CloseBrace)
else
assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
end
end
formatStat = function(stat)
if stat.Type == 'StatList' then
for _, stat in pairs(stat.StatementList) do
formatStat(stat)
applyIndent(stat:GetFirstToken())
end
elseif stat.Type == 'BreakStat' then
--(stat.Token_Break)
elseif stat.Type == 'ReturnStat' then
--(stat.Token_Return)
for index, expr in pairs(stat.ExprList) do
formatExpr(expr)
padExpr(expr)
if stat.Token_CommaList[index] then
--(stat.Token_CommaList[index])
end
end
elseif stat.Type == 'LocalVarStat' then
--(stat.Token_Local)
for index, var in pairs(stat.VarList) do
padToken(var)
local comma = stat.Token_VarCommaList[index]
if comma then
--(comma)
end
end
if stat.Token_Equals then
padToken(stat.Token_Equals)
for index, expr in pairs(stat.ExprList) do
formatExpr(expr)
padExpr(expr)
local comma = stat.Token_ExprCommaList[index]
if comma then
--(comma)
end
end
end
elseif stat.Type == 'LocalFunctionStat' then
--(stat.Token_Local)
padToken(stat.FunctionStat.Token_Function)
padToken(stat.FunctionStat.NameChain[1])
--(stat.FunctionStat.Token_OpenParen)
for index, arg in pairs(stat.FunctionStat.ArgList) do
if index > 1 then
padToken(arg)
end
local comma = stat.FunctionStat.Token_ArgCommaList[index]
if comma then
--(comma)
end
end
--(stat.FunctionStat.Token_CloseParen)
formatBody(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
elseif stat.Type == 'FunctionStat' then
--(stat.Token_Function)
for index, part in pairs(stat.NameChain) do
if index == 1 then
padToken(part)
end
local sep = stat.Token_NameChainSeparator[index]
if sep then
--(sep)
end
end
--(stat.Token_OpenParen)
for index, arg in pairs(stat.ArgList) do
if index > 1 then
padToken(arg)
end
local comma = stat.Token_ArgCommaList[index]
if comma then
--(comma)
end
end
--(stat.Token_CloseParen)
formatBody(stat.Token_CloseParen, stat.Body, stat.Token_End)
elseif stat.Type == 'RepeatStat' then
--(stat.Token_Repeat)
formatBody(stat.Token_Repeat, stat.Body, stat.Token_Until)
formatExpr(stat.Condition)
padExpr(stat.Condition)
elseif stat.Type == 'GenericForStat' then
--(stat.Token_For)
for index, var in pairs(stat.VarList) do
padToken(var)
local sep = stat.Token_VarCommaList[index]
if sep then
--(sep)
end
end
padToken(stat.Token_In)
for index, expr in pairs(stat.GeneratorList) do
formatExpr(expr)
padExpr(expr)
local sep = stat.Token_GeneratorCommaList[index]
if sep then
--(sep)
end
end
padToken(stat.Token_Do)
formatBody(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'NumericForStat' then
--(stat.Token_For)
for index, var in pairs(stat.VarList) do
padToken(var)
local sep = stat.Token_VarCommaList[index]
if sep then
--(sep)
end
end
padToken(stat.Token_Equals)
for index, expr in pairs(stat.RangeList) do
formatExpr(expr)
padExpr(expr)
local sep = stat.Token_RangeCommaList[index]
if sep then
--(sep)
end
end
padToken(stat.Token_Do)
formatBody(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'WhileStat' then
--(stat.Token_While)
formatExpr(stat.Condition)
padExpr(stat.Condition)
padToken(stat.Token_Do)
formatBody(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'DoStat' then
--(stat.Token_Do)
formatBody(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'IfStat' then
--(stat.Token_If)
formatExpr(stat.Condition)
padExpr(stat.Condition)
padToken(stat.Token_Then)
--
local lastBodyOpen = stat.Token_Then
local lastBody = stat.Body
--
for _, clause in pairs(stat.ElseClauseList) do
formatBody(lastBodyOpen, lastBody, clause.Token)
lastBodyOpen = clause.Token
--
if clause.Condition then
formatExpr(clause.Condition)
padExpr(clause.Condition)
padToken(clause.Token_Then)
lastBodyOpen = clause.Token_Then
end
lastBody = clause.Body
end
--
formatBody(lastBodyOpen, lastBody, stat.Token_End)
elseif stat.Type == 'CallExprStat' then
formatExpr(stat.Expression)
elseif stat.Type == 'AssignmentStat' then
for index, ex in pairs(stat.Lhs) do
formatExpr(ex)
if index > 1 then
padExpr(ex)
end
local sep = stat.Token_LhsSeparatorList[index]
if sep then
--(sep)
end
end
padToken(stat.Token_Equals)
for index, ex in pairs(stat.Rhs) do
formatExpr(ex)
padExpr(ex)
local sep = stat.Token_RhsSeparatorList[index]
if sep then
--(sep)
end
end
else
assert(false, "unreachable")
end
end
formatStat(ast)
end
-- Strips as much whitespace off of tokens in an AST as possible without causing problems
local function StripAst(ast)
local stripStat, stripExpr;
local function stript(token)
token.LeadingWhite = ''
end
-- Make to adjacent tokens as close as possible
local function joint(tokenA, tokenB)
-- Strip the second token's whitespace
stript(tokenB)
-- Get the trailing A <-> leading B character pair
local lastCh = tokenA.Source:sub(-1, -1)
local firstCh = tokenB.Source:sub(1, 1)
-- Cases to consider:
-- Touching minus signs -> comment: `- -42` -> `--42' is invalid
-- Touching words: `a b` -> `ab` is invalid
-- Touching digits: `2 3`, can't occurr in the Lua syntax as number literals aren't a primary expression
-- Abiguous syntax: `f(x)\n(x)()` is already disallowed, we can't cause a problem by removing newlines
-- Figure out what separation is needed
if
(lastCh == '-' and firstCh == '-') or
(AllIdentChars[lastCh] and AllIdentChars[firstCh])
then
tokenB.LeadingWhite = ' ' -- Use a separator
else
tokenB.LeadingWhite = '' -- Don't use a separator
end
end
-- Join up a statement body and it's opening / closing tokens
local function bodyjoint(open, body, close)
stripStat(body)
stript(close)
local bodyFirst = body:GetFirstToken()
local bodyLast = body:GetLastToken()
if bodyFirst then
-- Body is non-empty, join body to open / close
joint(open, bodyFirst)
joint(bodyLast, close)
else
-- Body is empty, just join open and close token together
joint(open, close)
end
end
stripExpr = function(expr)
if expr.Type == 'BinopExpr' then
stripExpr(expr.Lhs)
stript(expr.Token_Op)
stripExpr(expr.Rhs)
-- Handle the `a - -b` -/-> `a--b` case which would otherwise incorrectly generate a comment
-- Also handles operators "or" / "and" which definitely need joining logic in a bunch of cases
joint(expr.Token_Op, expr.Rhs:GetFirstToken())
joint(expr.Lhs:GetLastToken(), expr.Token_Op)
elseif expr.Type == 'UnopExpr' then
stript(expr.Token_Op)
stripExpr(expr.Rhs)
-- Handle the `- -b` -/-> `--b` case which would otherwise incorrectly generate a comment
joint(expr.Token_Op, expr.Rhs:GetFirstToken())
elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
expr.Type == 'VargLiteral'
then
-- Just print the token
stript(expr.Token)
elseif expr.Type == 'FieldExpr' then
stripExpr(expr.Base)
stript(expr.Token_Dot)
stript(expr.Field)
elseif expr.Type == 'IndexExpr' then
stripExpr(expr.Base)
stript(expr.Token_OpenBracket)
stripExpr(expr.Index)
stript(expr.Token_CloseBracket)
elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
stripExpr(expr.Base)
if expr.Type == 'MethodExpr' then
stript(expr.Token_Colon)
stript(expr.Method)
end
if expr.FunctionArguments.CallType == 'StringCall' then
stript(expr.FunctionArguments.Token)
elseif expr.FunctionArguments.CallType == 'ArgCall' then
stript(expr.FunctionArguments.Token_OpenParen)
for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
stripExpr(argExpr)
local sep = expr.FunctionArguments.Token_CommaList[index]
if sep then
stript(sep)
end
end
stript(expr.FunctionArguments.Token_CloseParen)
elseif expr.FunctionArguments.CallType == 'TableCall' then
stripExpr(expr.FunctionArguments.TableExpr)
end
elseif expr.Type == 'FunctionLiteral' then
stript(expr.Token_Function)
stript(expr.Token_OpenParen)
for index, arg in pairs(expr.ArgList) do
stript(arg)
local comma = expr.Token_ArgCommaList[index]
if comma then
stript(comma)
end
end
stript(expr.Token_CloseParen)
bodyjoint(expr.Token_CloseParen, expr.Body, expr.Token_End)
elseif expr.Type == 'VariableExpr' then
stript(expr.Token)
elseif expr.Type == 'ParenExpr' then
stript(expr.Token_OpenParen)
stripExpr(expr.Expression)
stript(expr.Token_CloseParen)
elseif expr.Type == 'TableLiteral' then
stript(expr.Token_OpenBrace)
for index, entry in pairs(expr.EntryList) do
if entry.EntryType == 'Field' then
stript(entry.Field)
stript(entry.Token_Equals)
stripExpr(entry.Value)
elseif entry.EntryType == 'Index' then
stript(entry.Token_OpenBracket)
stripExpr(entry.Index)
stript(entry.Token_CloseBracket)
stript(entry.Token_Equals)
stripExpr(entry.Value)
elseif entry.EntryType == 'Value' then
stripExpr(entry.Value)
else
assert(false, "unreachable")
end
local sep = expr.Token_SeparatorList[index]
if sep then
stript(sep)
end
end
stript(expr.Token_CloseBrace)
else
assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
end
end
stripStat = function(stat)
if stat.Type == 'StatList' then
-- Strip all surrounding whitespace on statement lists along with separating whitespace
for i = 1, #stat.StatementList do
local chStat = stat.StatementList[i]
-- Strip the statement and it's whitespace
stripStat(chStat)
stript(chStat:GetFirstToken())
-- If there was a last statement, join them appropriately
local lastChStat = stat.StatementList[i-1]
if lastChStat then
-- See if we can remove a semi-colon, the only case where we can't is if
-- this and the last statement have a `);(` pair, where removing the semi-colon
-- would introduce ambiguous syntax.
if stat.SemicolonList[i-1] and
(lastChStat:GetLastToken().Source ~= ')' or chStat:GetFirstToken().Source ~= ')')
then
stat.SemicolonList[i-1] = nil
end
-- If there isn't a semi-colon, we should safely join the two statements
-- (If there is one, then no whitespace leading chStat is always okay)
if not stat.SemicolonList[i-1] then
joint(lastChStat:GetLastToken(), chStat:GetFirstToken())
end
end
end
-- A semi-colon is never needed on the last stat in a statlist:
stat.SemicolonList[#stat.StatementList] = nil
-- The leading whitespace on the statlist should be stripped
if #stat.StatementList > 0 then
stript(stat.StatementList[1]:GetFirstToken())
end
elseif stat.Type == 'BreakStat' then
stript(stat.Token_Break)
elseif stat.Type == 'ReturnStat' then
stript(stat.Token_Return)
for index, expr in pairs(stat.ExprList) do
stripExpr(expr)
if stat.Token_CommaList[index] then
stript(stat.Token_CommaList[index])
end
end
if #stat.ExprList > 0 then
joint(stat.Token_Return, stat.ExprList[1]:GetFirstToken())
end
elseif stat.Type == 'LocalVarStat' then
stript(stat.Token_Local)
for index, var in pairs(stat.VarList) do
if index == 1 then
joint(stat.Token_Local, var)
else
stript(var)
end
local comma = stat.Token_VarCommaList[index]
if comma then
stript(comma)
end
end
if stat.Token_Equals then
stript(stat.Token_Equals)
for index, expr in pairs(stat.ExprList) do
stripExpr(expr)
local comma = stat.Token_ExprCommaList[index]
if comma then
stript(comma)
end
end
end
elseif stat.Type == 'LocalFunctionStat' then
stript(stat.Token_Local)
joint(stat.Token_Local, stat.FunctionStat.Token_Function)
joint(stat.FunctionStat.Token_Function, stat.FunctionStat.NameChain[1])
joint(stat.FunctionStat.NameChain[1], stat.FunctionStat.Token_OpenParen)
for index, arg in pairs(stat.FunctionStat.ArgList) do
stript(arg)
local comma = stat.FunctionStat.Token_ArgCommaList[index]
if comma then
stript(comma)
end
end
stript(stat.FunctionStat.Token_CloseParen)
bodyjoint(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
elseif stat.Type == 'FunctionStat' then
stript(stat.Token_Function)
for index, part in pairs(stat.NameChain) do
if index == 1 then
joint(stat.Token_Function, part)
else
stript(part)
end
local sep = stat.Token_NameChainSeparator[index]
if sep then
stript(sep)
end
end
stript(stat.Token_OpenParen)
for index, arg in pairs(stat.ArgList) do
stript(arg)
local comma = stat.Token_ArgCommaList[index]
if comma then
stript(comma)
end
end
stript(stat.Token_CloseParen)
bodyjoint(stat.Token_CloseParen, stat.Body, stat.Token_End)
elseif stat.Type == 'RepeatStat' then
stript(stat.Token_Repeat)
bodyjoint(stat.Token_Repeat, stat.Body, stat.Token_Until)
stripExpr(stat.Condition)
joint(stat.Token_Until, stat.Condition:GetFirstToken())
elseif stat.Type == 'GenericForStat' then
stript(stat.Token_For)
for index, var in pairs(stat.VarList) do
if index == 1 then
joint(stat.Token_For, var)
else
stript(var)
end
local sep = stat.Token_VarCommaList[index]
if sep then
stript(sep)
end
end
joint(stat.VarList[#stat.VarList], stat.Token_In)
for index, expr in pairs(stat.GeneratorList) do
stripExpr(expr)
if index == 1 then
joint(stat.Token_In, expr:GetFirstToken())
end
local sep = stat.Token_GeneratorCommaList[index]
if sep then
stript(sep)
end
end
joint(stat.GeneratorList[#stat.GeneratorList]:GetLastToken(), stat.Token_Do)
bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'NumericForStat' then
stript(stat.Token_For)
for index, var in pairs(stat.VarList) do
if index == 1 then
joint(stat.Token_For, var)
else
stript(var)
end
local sep = stat.Token_VarCommaList[index]
if sep then
stript(sep)
end
end
joint(stat.VarList[#stat.VarList], stat.Token_Equals)
for index, expr in pairs(stat.RangeList) do
stripExpr(expr)
if index == 1 then
joint(stat.Token_Equals, expr:GetFirstToken())
end
local sep = stat.Token_RangeCommaList[index]
if sep then
stript(sep)
end
end
joint(stat.RangeList[#stat.RangeList]:GetLastToken(), stat.Token_Do)
bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'WhileStat' then
stript(stat.Token_While)
stripExpr(stat.Condition)
stript(stat.Token_Do)
joint(stat.Token_While, stat.Condition:GetFirstToken())
joint(stat.Condition:GetLastToken(), stat.Token_Do)
bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'DoStat' then
stript(stat.Token_Do)
stript(stat.Token_End)
bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
elseif stat.Type == 'IfStat' then
stript(stat.Token_If)
stripExpr(stat.Condition)
joint(stat.Token_If, stat.Condition:GetFirstToken())
joint(stat.Condition:GetLastToken(), stat.Token_Then)
--
local lastBodyOpen = stat.Token_Then
local lastBody = stat.Body
--
for _, clause in pairs(stat.ElseClauseList) do
bodyjoint(lastBodyOpen, lastBody, clause.Token)
lastBodyOpen = clause.Token
--
if clause.Condition then
stripExpr(clause.Condition)
joint(clause.Token, clause.Condition:GetFirstToken())
joint(clause.Condition:GetLastToken(), clause.Token_Then)
lastBodyOpen = clause.Token_Then
end
stripStat(clause.Body)
lastBody = clause.Body
end
--
bodyjoint(lastBodyOpen, lastBody, stat.Token_End)
elseif stat.Type == 'CallExprStat' then
stripExpr(stat.Expression)
elseif stat.Type == 'AssignmentStat' then
for index, ex in pairs(stat.Lhs) do
stripExpr(ex)
local sep = stat.Token_LhsSeparatorList[index]
if sep then
stript(sep)
end
end
stript(stat.Token_Equals)
for index, ex in pairs(stat.Rhs) do
stripExpr(ex)
local sep = stat.Token_RhsSeparatorList[index]
if sep then
stript(sep)
end
end
else
assert(false, "unreachable")
end
end
stripStat(ast)
end
local idGen = 0
local VarDigits = {}
for i = ('a'):byte(), ('z'):byte() do table.insert(VarDigits, string.char(i)) end
for i = ('A'):byte(), ('Z'):byte() do table.insert(VarDigits, string.char(i)) end
for i = ('0'):byte(), ('9'):byte() do table.insert(VarDigits, string.char(i)) end
table.insert(VarDigits, '_')
local VarStartDigits = {}
for i = ('a'):byte(), ('z'):byte() do table.insert(VarStartDigits, string.char(i)) end
for i = ('A'):byte(), ('Z'):byte() do table.insert(VarStartDigits, string.char(i)) end
local function indexToVarName(index)
local id = ''
local d = index % #VarStartDigits
index = (index - d) / #VarStartDigits
id = id..VarStartDigits[d+1]
while index > 0 do
local d = index % #VarDigits
index = (index - d) / #VarDigits
id = id..VarDigits[d+1]
end
return id
end
local function genNextVarName()
local varToUse = idGen
idGen = idGen + 1
return indexToVarName(varToUse)
end
local function genVarName()
local varName = ''
repeat
varName = genNextVarName()
until not Keywords[varName]
return varName
end
local function MinifyVariables(globalScope, rootScope)
-- externalGlobals is a set of global variables that have not been assigned to, that is
-- global variables defined "externally to the script". We are not going to be renaming
-- those, and we have to make sure that we don't collide with them when renaming
-- things so we keep track of them in this set.
local externalGlobals = {}
-- First we want to rename all of the variables to unique temoraries, so that we can
-- easily use the scope::GetVar function to check whether renames are valid.
local temporaryIndex = 0
for _, var in pairs(globalScope) do
if var.AssignedTo then
var:Rename('_TMP_'..temporaryIndex..'_')
temporaryIndex = temporaryIndex + 1
else
-- Not assigned to, external global
externalGlobals[var.Name] = true
end
end
local function temporaryRename(scope)
for _, var in pairs(scope.VariableList) do
var:Rename('_TMP_'..temporaryIndex..'_')
temporaryIndex = temporaryIndex + 1
end
for _, childScope in pairs(scope.ChildScopeList) do
temporaryRename(childScope)
end
end
-- Now we go through renaming, first do globals, we probably want them
-- to have shorter names in general.
-- TODO: Rename all vars based on frequency patterns, giving variables
-- used more shorter names.
local nextFreeNameIndex = 0
for _, var in pairs(globalScope) do
if var.AssignedTo then
local varName = ''
repeat
varName = indexToVarName(nextFreeNameIndex)
nextFreeNameIndex = nextFreeNameIndex + 1
until not Keywords[varName] and not externalGlobals[varName]
var:Rename(varName)
end
end
-- Now rename all local vars
rootScope.FirstFreeName = nextFreeNameIndex
local function doRenameScope(scope)
for _, var in pairs(scope.VariableList) do
local varName = ''
repeat
varName = indexToVarName(scope.FirstFreeName)
scope.FirstFreeName = scope.FirstFreeName + 1
until not Keywords[varName] and not externalGlobals[varName]
var:Rename(varName)
end
for _, childScope in pairs(scope.ChildScopeList) do
childScope.FirstFreeName = scope.FirstFreeName
doRenameScope(childScope)
end
end
doRenameScope(rootScope)
end
local function MinifyVariables_2(globalScope, rootScope)
-- Variable names and other names that are fixed, that we cannot use
-- Either these are Lua keywords, or globals that are not assigned to,
-- that is environmental globals that are assigned elsewhere beyond our
-- control.
local globalUsedNames = {}
for kw, _ in pairs(Keywords) do
globalUsedNames[kw] = true
end
-- Gather a list of all of the variables that we will rename
local allVariables = {}
local allLocalVariables = {}
do
-- Add applicable globals
for _, var in pairs(globalScope) do
if var.AssignedTo then
-- We can try to rename this global since it was assigned to
-- (and thus presumably initialized) in the script we are
-- minifying.
table.insert(allVariables, var)
else
-- We can't rename this global, mark it as an unusable name
-- and don't add it to the nename list
globalUsedNames[var.Name] = true
end
end
-- Recursively add locals, we can rename all of those
local function addFrom(scope)
for _, var in pairs(scope.VariableList) do
table.insert(allVariables, var)
table.insert(allLocalVariables, var)
end
for _, childScope in pairs(scope.ChildScopeList) do
addFrom(childScope)
end
end
addFrom(rootScope)
end
-- Add used name arrays to variables
for _, var in pairs(allVariables) do
var.UsedNameArray = {}
end
-- Sort the least used variables first
table.sort(allVariables, function(a, b)
return #a.RenameList < #b.RenameList
end)
-- Lazy generator for valid names to rename to
local nextValidNameIndex = 0
local varNamesLazy = {}
local function varIndexToValidVarName(i)
local name = varNamesLazy[i]
if not name then
repeat
name = indexToVarName(nextValidNameIndex)
nextValidNameIndex = nextValidNameIndex + 1
until not globalUsedNames[name]
varNamesLazy[i] = name
end
return name
end
-- For each variable, go to rename it
for _, var in pairs(allVariables) do
-- Lazy... todo: Make theis pair a proper for-each-pair-like set of loops
-- rather than using a renamed flag.
var.Renamed = true
-- Find the first unused name
local i = 1
while var.UsedNameArray[i] do
i = i + 1
end
-- Rename the variable to that name
var:Rename(varIndexToValidVarName(i))
if var.Scope then
-- Now we need to mark the name as unusable by any variables:
-- 1) At the same depth that overlap lifetime with this one
-- 2) At a deeper level, which have a reference to this variable in their lifetimes
-- 3) At a shallower level, which are referenced during this variable's lifetime
for _, otherVar in pairs(allVariables) do
if not otherVar.Renamed then
if not otherVar.Scope or otherVar.Scope.Depth < var.Scope.Depth then
-- Check Global variable (Which is always at a shallower level)
-- or
-- Check case 3
-- The other var is at a shallower depth, is there a reference to it
-- durring this variable's lifetime?
for _, refAt in pairs(otherVar.ReferenceLocationList) do
if refAt >= var.BeginLocation and refAt <= var.ScopeEndLocation then
-- Collide
otherVar.UsedNameArray[i] = true
break
end
end
elseif otherVar.Scope.Depth > var.Scope.Depth then
-- Check Case 2
-- The other var is at a greater depth, see if any of the references
-- to this variable are in the other var's lifetime.
for _, refAt in pairs(var.ReferenceLocationList) do
if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
-- Collide
otherVar.UsedNameArray[i] = true
break
end
end
else --otherVar.Scope.Depth must be equal to var.Scope.Depth
-- Check case 1
-- The two locals are in the same scope
-- Just check if the usage lifetimes overlap within that scope. That is, we
-- can shadow a local variable within the same scope as long as the usages
-- of the two locals do not overlap.
if var.BeginLocation < otherVar.EndLocation and
var.EndLocation > otherVar.BeginLocation
then
otherVar.UsedNameArray[i] = true
end
end
end
end
else
-- This is a global var, all other globals can't collide with it, and
-- any local variable with a reference to this global in it's lifetime
-- can't collide with it.
for _, otherVar in pairs(allVariables) do
if not otherVar.Renamed then
if otherVar.Type == 'Global' then
otherVar.UsedNameArray[i] = true
elseif otherVar.Type == 'Local' then
-- Other var is a local, see if there is a reference to this global within
-- that local's lifetime.
for _, refAt in pairs(var.ReferenceLocationList) do
if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
-- Collide
otherVar.UsedNameArray[i] = true
break
end
end
else
assert(false, "unreachable")
end
end
end
end
end
-- --
-- print("Total Variables: "..#allVariables)
-- print("Total Range: "..rootScope.BeginLocation.."-"..rootScope.EndLocation)
-- print("")
-- for _, var in pairs(allVariables) do
-- io.write("`"..var.Name.."':\n\t#symbols: "..#var.RenameList..
-- "\n\tassigned to: "..tostring(var.AssignedTo))
-- if var.Type == 'Local' then
-- io.write("\n\trange: "..var.BeginLocation.."-"..var.EndLocation)
-- io.write("\n\tlocal type: "..var.Info.Type)
-- end
-- io.write("\n\n")
-- end
-- -- First we want to rename all of the variables to unique temoraries, so that we can
-- -- easily use the scope::GetVar function to check whether renames are valid.
-- local temporaryIndex = 0
-- for _, var in pairs(allVariables) do
-- var:Rename('_TMP_'..temporaryIndex..'_')
-- temporaryIndex = temporaryIndex + 1
-- end
-- For each variable, we need to build a list of names that collide with it
--
--error()
end
local function BeautifyVariables(globalScope, rootScope)
local externalGlobals = {}
for _, var in pairs(globalScope) do
if not var.AssignedTo then
externalGlobals[var.Name] = true
end
end
local localNumber = 1
local globalNumber = 1
local function setVarName(var, name)
var.Name = name
for _, setter in pairs(var.RenameList) do
setter(name)
end
end
for _, var in pairs(globalScope) do
if var.AssignedTo then
setVarName(var, 'G_'..globalNumber)
globalNumber = globalNumber + 1
end
end
local function modify(scope)
for _, var in pairs(scope.VariableList) do
local name = 'L_'..localNumber..'_'
if var.Info.Type == 'Argument' then
name = name..'arg'..var.Info.Index
elseif var.Info.Type == 'LocalFunction' then
name = name..'func'
elseif var.Info.Type == 'ForRange' then
name = name..'forvar'..var.Info.Index
end
setVarName(var, name)
localNumber = localNumber + 1
end
for _, scope in pairs(scope.ChildScopeList) do
modify(scope)
end
end
modify(rootScope)
end
local function usageError()
error(
"\nusage: minify <file> or unminify <file>\n" ..
" The modified code will be printed to the stdout, pipe it to a file, the\n" ..
" lua interpreter, or something else as desired EG:\n\n" ..
" lua minify.lua minify input.lua > output.lua\n\n" ..
" * minify will minify the code in the file.\n" ..
" * unminify will beautify the code and replace the variable names with easily\n" ..
" find-replacable ones to aide in reverse engineering minified code.\n", 0)
end
local args = {...}
if #args ~= 2 then
usageError()
end
local sourceFile = io.open(args[2], 'r')
if not sourceFile then
error("Could not open the input file `" .. args[2] .. "`", 0)
end
local data = sourceFile:read('*all')
local ast = CreateLuaParser(data)
local global_scope, root_scope = AddVariableInfo(ast)
local function minify(ast, global_scope, root_scope)
MinifyVariables(global_scope, root_scope)
StripAst(ast)
PrintAst(ast)
end
local function beautify(ast, global_scope, root_scope)
BeautifyVariables(global_scope, root_scope)
FormatAst(ast)
PrintAst(ast)
end
if args[1] == 'minify' then
minify(ast, global_scope, root_scope)
elseif args[1] == 'unminify' then
beautify(ast, global_scope, root_scope)
else
usageError()
end