summaryrefslogtreecommitdiff
path: root/Data/Libraries/Penlight/examples/symbols.lua
blob: e73c4badd359f3323abf60a28ea2d9005209699d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
require 'pl'
utils.import 'pl.func'
local ops = require 'pl.operator'
local List = require 'pl.List'
local append,concat = table.insert,table.concat
local compare,find_if,compare_no_order,imap,reduce,count_map = tablex.compare,tablex.find_if,tablex.compare_no_order,tablex.imap,tablex.reduce,tablex.count_map
local unpack = table.unpack

function bindval (self,val)
    rawset(self,'value',val)
end

local optable = ops.optable

function sexpr (e)
	if isPE(e) then
		if e.op ~= 'X' then
			local args = tablex.imap(sexpr,e)
			return '('..e.op..' '..table.concat(args,' ')..')'
		else
			return e.repr
		end
	else
		return tostring(e)
	end
end


psexpr = compose(print,sexpr)



function equals (e1,e2)
    local p1,p2 = isPE(e1),isPE(e2)
    if p1 ~= p2 then return false end  -- different kinds of animals!
    if p1 and p2 then -- both PEs
        -- operators must be the same
        if e1.op ~= e2.op then return false end
        -- PHs are equal if their representations are equal
        if e1.op == 'X' then return e1.repr == e2.repr
        -- commutative operators
        elseif e1.op == '+' or e1.op == '*' then
            return compare_no_order(e1,e2,equals)
        else
            -- arguments must be the same
            return compare(e1,e2,equals)
        end
    else -- fall back on simple equality for non PEs
        return e1 == e2
    end
end

-- run down an unbalanced operator chain (like a+b+c) and return the arguments {a,b,c}
function tcollect (op,e,ls)
    if isPE(e) and e.op == op then
        for i = 1,#e do
            tcollect(op,e[i],ls)
        end
    else
        ls:append(e)
        return
    end
end

function rcollect (e)
    local res = List()
    tcollect(e.op,e,res)
    return res
end


-- balance ensures that +/* chains are collected together, operates in-place.
-- thus (+(+ a b) c) or (+ a (+ b c)) becomes (+ a b c), order immaterial
function balance (e)
    if isPE(e) and e.op ~= 'X' then
        local op,args = e.op
        if op == '+' or op == '*' then
            args = rcollect(e)
        else
            args = imap(balance,e)
        end
        for i = 1,#args do
            e[i] = args[i]
        end
    end
    return e
end

-- fold constants in an expression
function fold (e)
    if isPE(e) then
        if e.op == 'X' then
            -- there could be _bound values_!
            local val = rawget(e,'value')
            return val and val or e
        else
            local op = e.op
            local addmul = op == '*' or op == '+'
            -- first fold all arguments
            local args = imap(fold,e)
            if not addmul and not find_if(args,isPE) then
                -- no placeholders in these args, we can fold the expression.
                local opfn = optable[op]
                if opfn then
                    return opfn(unpack(args))
                else
                    return '?'
                end
            elseif addmul then
                -- enforce a few rules for + and *
                -- split the args into two classes, PE args and non-PE args.
                local classes = List.partition(args,isPE)
                local pe,npe = classes[true],classes[false]
                if npe then -- there's at least one non PE argument
                    -- so fold them
                    if #npe == 1 then npe = npe[1]
                    else npe = npe:reduce(optable[op])
                    end
                    -- if the result is a constant, return it
                    if not pe then return npe end

                    -- either (* 1 x) => x or (* 1 x y ...) => (* x y ...)
                    if op == '*' then
                        if npe == 0 then return 0
                        elseif npe == 1 then -- identity
                            if #pe == 1 then return pe[1] else npe = nil end
                        end
                    else -- special cases for +
                        if npe == 0 then -- identity
                            if #pe == 1 then return pe[1] else npe = nil end
                        end
                    end
                end
                -- build up the final arguments
                local res = {}
                if npe then append(res,npe) end
                for val,count in pairs(count_map(pe,equals)) do
                    if count > 1 then
                        if op == '*' then val = val ^ count
                        else val = val * count
                        end
                    end
                    append(res,val)
                end
                if #res == 1 then return res[1] end
                return PE{op=op,unpack(res)}
            elseif op == '^' then
                if args[2] == 1 then return args[1] end -- identity
                if args[2] == 0 then return 1 end
            end
            return PE{op=op,unpack(args)}
        end
    else
        return e
    end
end

function expand (e)
    if isPE(e) and e.op == '*' and isPE(e[2]) and e[2].op == '+' then
        local a,b = e[1],e[2]
        return expand(b[1]*a) + expand(b[2]*a)
    else
        return e
    end
end

function isnumber (x)
    return type(x) == 'number'
end

-- does this PE contain a reference to x?
function references (e,x)
    if isPE(e) then
        if e.op == 'X' then return x.repr == e.repr
        else
            return find_if(e,references,x)
        end
    else
        return false
    end
end

local function muli (args)
    return PE{op='*',unpack(args)}
end

local function addi (args)
    return PE{op='+',unpack(args)}
end

function diff (e,x)
    if isPE(e) and references(e,x) then
        local op = e.op
        if op == 'X' then
            return 1
        else
            local a,b = e[1],e[2]
            if op == '+' then -- differentiation is linear
                local args = imap(diff,e,x)
                return balance(addi(args))
            elseif op == '*' then -- product rule
                local res,d,ee = {}
                for i = 1,#e do
                    d = fold(diff(e[i],x))
                    if d ~= 0 then
                        ee = {unpack(e)}
                        ee[i] = d
                        append(res,balance(muli(ee)))
                    end
                end
                if #res > 1 then return addi(res)
                else return res[1] end
            elseif op == '^' and isnumber(b) then -- power rule
                return b*x^(b-1)
            end
        end
    else
        return 0
    end
end