require 'pl' utils.import 'pl.func' local ops = require 'pl.operator' local List = require 'pl.List' local append,concat = table.insert,table.concat local compare,find_if,compare_no_order,imap,reduce,count_map = tablex.compare,tablex.find_if,tablex.compare_no_order,tablex.imap,tablex.reduce,tablex.count_map local unpack = table.unpack function bindval (self,val) rawset(self,'value',val) end local optable = ops.optable function sexpr (e) if isPE(e) then if e.op ~= 'X' then local args = tablex.imap(sexpr,e) return '('..e.op..' '..table.concat(args,' ')..')' else return e.repr end else return tostring(e) end end psexpr = compose(print,sexpr) function equals (e1,e2) local p1,p2 = isPE(e1),isPE(e2) if p1 ~= p2 then return false end -- different kinds of animals! if p1 and p2 then -- both PEs -- operators must be the same if e1.op ~= e2.op then return false end -- PHs are equal if their representations are equal if e1.op == 'X' then return e1.repr == e2.repr -- commutative operators elseif e1.op == '+' or e1.op == '*' then return compare_no_order(e1,e2,equals) else -- arguments must be the same return compare(e1,e2,equals) end else -- fall back on simple equality for non PEs return e1 == e2 end end -- run down an unbalanced operator chain (like a+b+c) and return the arguments {a,b,c} function tcollect (op,e,ls) if isPE(e) and e.op == op then for i = 1,#e do tcollect(op,e[i],ls) end else ls:append(e) return end end function rcollect (e) local res = List() tcollect(e.op,e,res) return res end -- balance ensures that +/* chains are collected together, operates in-place. -- thus (+(+ a b) c) or (+ a (+ b c)) becomes (+ a b c), order immaterial function balance (e) if isPE(e) and e.op ~= 'X' then local op,args = e.op if op == '+' or op == '*' then args = rcollect(e) else args = imap(balance,e) end for i = 1,#args do e[i] = args[i] end end return e end -- fold constants in an expression function fold (e) if isPE(e) then if e.op == 'X' then -- there could be _bound values_! local val = rawget(e,'value') return val and val or e else local op = e.op local addmul = op == '*' or op == '+' -- first fold all arguments local args = imap(fold,e) if not addmul and not find_if(args,isPE) then -- no placeholders in these args, we can fold the expression. local opfn = optable[op] if opfn then return opfn(unpack(args)) else return '?' end elseif addmul then -- enforce a few rules for + and * -- split the args into two classes, PE args and non-PE args. local classes = List.partition(args,isPE) local pe,npe = classes[true],classes[false] if npe then -- there's at least one non PE argument -- so fold them if #npe == 1 then npe = npe[1] else npe = npe:reduce(optable[op]) end -- if the result is a constant, return it if not pe then return npe end -- either (* 1 x) => x or (* 1 x y ...) => (* x y ...) if op == '*' then if npe == 0 then return 0 elseif npe == 1 then -- identity if #pe == 1 then return pe[1] else npe = nil end end else -- special cases for + if npe == 0 then -- identity if #pe == 1 then return pe[1] else npe = nil end end end end -- build up the final arguments local res = {} if npe then append(res,npe) end for val,count in pairs(count_map(pe,equals)) do if count > 1 then if op == '*' then val = val ^ count else val = val * count end end append(res,val) end if #res == 1 then return res[1] end return PE{op=op,unpack(res)} elseif op == '^' then if args[2] == 1 then return args[1] end -- identity if args[2] == 0 then return 1 end end return PE{op=op,unpack(args)} end else return e end end function expand (e) if isPE(e) and e.op == '*' and isPE(e[2]) and e[2].op == '+' then local a,b = e[1],e[2] return expand(b[1]*a) + expand(b[2]*a) else return e end end function isnumber (x) return type(x) == 'number' end -- does this PE contain a reference to x? function references (e,x) if isPE(e) then if e.op == 'X' then return x.repr == e.repr else return find_if(e,references,x) end else return false end end local function muli (args) return PE{op='*',unpack(args)} end local function addi (args) return PE{op='+',unpack(args)} end function diff (e,x) if isPE(e) and references(e,x) then local op = e.op if op == 'X' then return 1 else local a,b = e[1],e[2] if op == '+' then -- differentiation is linear local args = imap(diff,e,x) return balance(addi(args)) elseif op == '*' then -- product rule local res,d,ee = {} for i = 1,#e do d = fold(diff(e[i],x)) if d ~= 0 then ee = {unpack(e)} ee[i] = d append(res,balance(muli(ee))) end end if #res > 1 then return addi(res) else return res[1] end elseif op == '^' and isnumber(b) then -- power rule return b*x^(b-1) end end else return 0 end end