diff options
Diffstat (limited to 'Data/Libraries/Penlight/examples/symbols.lua')
-rw-r--r-- | Data/Libraries/Penlight/examples/symbols.lua | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/Data/Libraries/Penlight/examples/symbols.lua b/Data/Libraries/Penlight/examples/symbols.lua new file mode 100644 index 0000000..e73c4ba --- /dev/null +++ b/Data/Libraries/Penlight/examples/symbols.lua @@ -0,0 +1,223 @@ +require 'pl' +utils.import 'pl.func' +local ops = require 'pl.operator' +local List = require 'pl.List' +local append,concat = table.insert,table.concat +local compare,find_if,compare_no_order,imap,reduce,count_map = tablex.compare,tablex.find_if,tablex.compare_no_order,tablex.imap,tablex.reduce,tablex.count_map +local unpack = table.unpack + +function bindval (self,val) + rawset(self,'value',val) +end + +local optable = ops.optable + +function sexpr (e) + if isPE(e) then + if e.op ~= 'X' then + local args = tablex.imap(sexpr,e) + return '('..e.op..' '..table.concat(args,' ')..')' + else + return e.repr + end + else + return tostring(e) + end +end + + +psexpr = compose(print,sexpr) + + + +function equals (e1,e2) + local p1,p2 = isPE(e1),isPE(e2) + if p1 ~= p2 then return false end -- different kinds of animals! + if p1 and p2 then -- both PEs + -- operators must be the same + if e1.op ~= e2.op then return false end + -- PHs are equal if their representations are equal + if e1.op == 'X' then return e1.repr == e2.repr + -- commutative operators + elseif e1.op == '+' or e1.op == '*' then + return compare_no_order(e1,e2,equals) + else + -- arguments must be the same + return compare(e1,e2,equals) + end + else -- fall back on simple equality for non PEs + return e1 == e2 + end +end + +-- run down an unbalanced operator chain (like a+b+c) and return the arguments {a,b,c} +function tcollect (op,e,ls) + if isPE(e) and e.op == op then + for i = 1,#e do + tcollect(op,e[i],ls) + end + else + ls:append(e) + return + end +end + +function rcollect (e) + local res = List() + tcollect(e.op,e,res) + return res +end + + +-- balance ensures that +/* chains are collected together, operates in-place. +-- thus (+(+ a b) c) or (+ a (+ b c)) becomes (+ a b c), order immaterial +function balance (e) + if isPE(e) and e.op ~= 'X' then + local op,args = e.op + if op == '+' or op == '*' then + args = rcollect(e) + else + args = imap(balance,e) + end + for i = 1,#args do + e[i] = args[i] + end + end + return e +end + +-- fold constants in an expression +function fold (e) + if isPE(e) then + if e.op == 'X' then + -- there could be _bound values_! + local val = rawget(e,'value') + return val and val or e + else + local op = e.op + local addmul = op == '*' or op == '+' + -- first fold all arguments + local args = imap(fold,e) + if not addmul and not find_if(args,isPE) then + -- no placeholders in these args, we can fold the expression. + local opfn = optable[op] + if opfn then + return opfn(unpack(args)) + else + return '?' + end + elseif addmul then + -- enforce a few rules for + and * + -- split the args into two classes, PE args and non-PE args. + local classes = List.partition(args,isPE) + local pe,npe = classes[true],classes[false] + if npe then -- there's at least one non PE argument + -- so fold them + if #npe == 1 then npe = npe[1] + else npe = npe:reduce(optable[op]) + end + -- if the result is a constant, return it + if not pe then return npe end + + -- either (* 1 x) => x or (* 1 x y ...) => (* x y ...) + if op == '*' then + if npe == 0 then return 0 + elseif npe == 1 then -- identity + if #pe == 1 then return pe[1] else npe = nil end + end + else -- special cases for + + if npe == 0 then -- identity + if #pe == 1 then return pe[1] else npe = nil end + end + end + end + -- build up the final arguments + local res = {} + if npe then append(res,npe) end + for val,count in pairs(count_map(pe,equals)) do + if count > 1 then + if op == '*' then val = val ^ count + else val = val * count + end + end + append(res,val) + end + if #res == 1 then return res[1] end + return PE{op=op,unpack(res)} + elseif op == '^' then + if args[2] == 1 then return args[1] end -- identity + if args[2] == 0 then return 1 end + end + return PE{op=op,unpack(args)} + end + else + return e + end +end + +function expand (e) + if isPE(e) and e.op == '*' and isPE(e[2]) and e[2].op == '+' then + local a,b = e[1],e[2] + return expand(b[1]*a) + expand(b[2]*a) + else + return e + end +end + +function isnumber (x) + return type(x) == 'number' +end + +-- does this PE contain a reference to x? +function references (e,x) + if isPE(e) then + if e.op == 'X' then return x.repr == e.repr + else + return find_if(e,references,x) + end + else + return false + end +end + +local function muli (args) + return PE{op='*',unpack(args)} +end + +local function addi (args) + return PE{op='+',unpack(args)} +end + +function diff (e,x) + if isPE(e) and references(e,x) then + local op = e.op + if op == 'X' then + return 1 + else + local a,b = e[1],e[2] + if op == '+' then -- differentiation is linear + local args = imap(diff,e,x) + return balance(addi(args)) + elseif op == '*' then -- product rule + local res,d,ee = {} + for i = 1,#e do + d = fold(diff(e[i],x)) + if d ~= 0 then + ee = {unpack(e)} + ee[i] = d + append(res,balance(muli(ee))) + end + end + if #res > 1 then return addi(res) + else return res[1] end + elseif op == '^' and isnumber(b) then -- power rule + return b*x^(b-1) + end + end + else + return 0 + end +end + + + |