summaryrefslogtreecommitdiff
path: root/Data/Libraries/Penlight/examples/symbols.lua
diff options
context:
space:
mode:
Diffstat (limited to 'Data/Libraries/Penlight/examples/symbols.lua')
-rw-r--r--Data/Libraries/Penlight/examples/symbols.lua223
1 files changed, 223 insertions, 0 deletions
diff --git a/Data/Libraries/Penlight/examples/symbols.lua b/Data/Libraries/Penlight/examples/symbols.lua
new file mode 100644
index 0000000..e73c4ba
--- /dev/null
+++ b/Data/Libraries/Penlight/examples/symbols.lua
@@ -0,0 +1,223 @@
+require 'pl'
+utils.import 'pl.func'
+local ops = require 'pl.operator'
+local List = require 'pl.List'
+local append,concat = table.insert,table.concat
+local compare,find_if,compare_no_order,imap,reduce,count_map = tablex.compare,tablex.find_if,tablex.compare_no_order,tablex.imap,tablex.reduce,tablex.count_map
+local unpack = table.unpack
+
+function bindval (self,val)
+ rawset(self,'value',val)
+end
+
+local optable = ops.optable
+
+function sexpr (e)
+ if isPE(e) then
+ if e.op ~= 'X' then
+ local args = tablex.imap(sexpr,e)
+ return '('..e.op..' '..table.concat(args,' ')..')'
+ else
+ return e.repr
+ end
+ else
+ return tostring(e)
+ end
+end
+
+
+psexpr = compose(print,sexpr)
+
+
+
+function equals (e1,e2)
+ local p1,p2 = isPE(e1),isPE(e2)
+ if p1 ~= p2 then return false end -- different kinds of animals!
+ if p1 and p2 then -- both PEs
+ -- operators must be the same
+ if e1.op ~= e2.op then return false end
+ -- PHs are equal if their representations are equal
+ if e1.op == 'X' then return e1.repr == e2.repr
+ -- commutative operators
+ elseif e1.op == '+' or e1.op == '*' then
+ return compare_no_order(e1,e2,equals)
+ else
+ -- arguments must be the same
+ return compare(e1,e2,equals)
+ end
+ else -- fall back on simple equality for non PEs
+ return e1 == e2
+ end
+end
+
+-- run down an unbalanced operator chain (like a+b+c) and return the arguments {a,b,c}
+function tcollect (op,e,ls)
+ if isPE(e) and e.op == op then
+ for i = 1,#e do
+ tcollect(op,e[i],ls)
+ end
+ else
+ ls:append(e)
+ return
+ end
+end
+
+function rcollect (e)
+ local res = List()
+ tcollect(e.op,e,res)
+ return res
+end
+
+
+-- balance ensures that +/* chains are collected together, operates in-place.
+-- thus (+(+ a b) c) or (+ a (+ b c)) becomes (+ a b c), order immaterial
+function balance (e)
+ if isPE(e) and e.op ~= 'X' then
+ local op,args = e.op
+ if op == '+' or op == '*' then
+ args = rcollect(e)
+ else
+ args = imap(balance,e)
+ end
+ for i = 1,#args do
+ e[i] = args[i]
+ end
+ end
+ return e
+end
+
+-- fold constants in an expression
+function fold (e)
+ if isPE(e) then
+ if e.op == 'X' then
+ -- there could be _bound values_!
+ local val = rawget(e,'value')
+ return val and val or e
+ else
+ local op = e.op
+ local addmul = op == '*' or op == '+'
+ -- first fold all arguments
+ local args = imap(fold,e)
+ if not addmul and not find_if(args,isPE) then
+ -- no placeholders in these args, we can fold the expression.
+ local opfn = optable[op]
+ if opfn then
+ return opfn(unpack(args))
+ else
+ return '?'
+ end
+ elseif addmul then
+ -- enforce a few rules for + and *
+ -- split the args into two classes, PE args and non-PE args.
+ local classes = List.partition(args,isPE)
+ local pe,npe = classes[true],classes[false]
+ if npe then -- there's at least one non PE argument
+ -- so fold them
+ if #npe == 1 then npe = npe[1]
+ else npe = npe:reduce(optable[op])
+ end
+ -- if the result is a constant, return it
+ if not pe then return npe end
+
+ -- either (* 1 x) => x or (* 1 x y ...) => (* x y ...)
+ if op == '*' then
+ if npe == 0 then return 0
+ elseif npe == 1 then -- identity
+ if #pe == 1 then return pe[1] else npe = nil end
+ end
+ else -- special cases for +
+ if npe == 0 then -- identity
+ if #pe == 1 then return pe[1] else npe = nil end
+ end
+ end
+ end
+ -- build up the final arguments
+ local res = {}
+ if npe then append(res,npe) end
+ for val,count in pairs(count_map(pe,equals)) do
+ if count > 1 then
+ if op == '*' then val = val ^ count
+ else val = val * count
+ end
+ end
+ append(res,val)
+ end
+ if #res == 1 then return res[1] end
+ return PE{op=op,unpack(res)}
+ elseif op == '^' then
+ if args[2] == 1 then return args[1] end -- identity
+ if args[2] == 0 then return 1 end
+ end
+ return PE{op=op,unpack(args)}
+ end
+ else
+ return e
+ end
+end
+
+function expand (e)
+ if isPE(e) and e.op == '*' and isPE(e[2]) and e[2].op == '+' then
+ local a,b = e[1],e[2]
+ return expand(b[1]*a) + expand(b[2]*a)
+ else
+ return e
+ end
+end
+
+function isnumber (x)
+ return type(x) == 'number'
+end
+
+-- does this PE contain a reference to x?
+function references (e,x)
+ if isPE(e) then
+ if e.op == 'X' then return x.repr == e.repr
+ else
+ return find_if(e,references,x)
+ end
+ else
+ return false
+ end
+end
+
+local function muli (args)
+ return PE{op='*',unpack(args)}
+end
+
+local function addi (args)
+ return PE{op='+',unpack(args)}
+end
+
+function diff (e,x)
+ if isPE(e) and references(e,x) then
+ local op = e.op
+ if op == 'X' then
+ return 1
+ else
+ local a,b = e[1],e[2]
+ if op == '+' then -- differentiation is linear
+ local args = imap(diff,e,x)
+ return balance(addi(args))
+ elseif op == '*' then -- product rule
+ local res,d,ee = {}
+ for i = 1,#e do
+ d = fold(diff(e[i],x))
+ if d ~= 0 then
+ ee = {unpack(e)}
+ ee[i] = d
+ append(res,balance(muli(ee)))
+ end
+ end
+ if #res > 1 then return addi(res)
+ else return res[1] end
+ elseif op == '^' and isnumber(b) then -- power rule
+ return b*x^(b-1)
+ end
+ end
+ else
+ return 0
+ end
+end
+
+
+