Recursive call signature keeps changing
I'm going to implement a program that uses recursion quite a bit. So before I started getting exceptions, I figured it would be nice to implement a trampoline and use tricks when needed.
The first attempt I made was with a factorial. Here's the code:callable(f) = !isempty(methods(f))
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
function factorial(n, continuation)
if n == 1
continuation(1)
else
(() -> factorial(n-1, (z -> (() -> continuation(n*z)))))
end
end
function cont(x)
x
end
Also, I applied a naive factorial algorithm to make sure that, in essence, I would have prevented a stack overflow:
function factorial_overflow(n)
if n == 1
1
else
n*factorial_overflow(n-1)
end
end
Results:
julia> factorial_overflow(140000)
ERROR: StackOverflowError:
#JITing with a small input
julia> trampoline(factorial, 10, cont)
3628800
#Testing
julia> trampoline(factorial, 140000, cont)
0
So yes, I avoid StacksOverflows. And yes, I know the result is nonsense since I am getting integers, but here I was just taking care of the stack. The production version will of course be fixed.
(Also, I know there is a built-in for factorial, I wouldn't use any of them, I made them to test my trampoline).
The batumper uses a lot of time the first time it starts up and then quickly turns out ... when calculating the same or lower values. If I did trampoline(factorial, 150000, cont)
, I would have compile times again.
It seems to me (educated guess) that I am JITing many different signatures for the factorial: one for each volume created.
My question is, can I avoid this?
source to share
I think the problem is that each closure is its own type, which specializes in captured variables. To avoid this specialization, you can instead use functors that are not fully specialized:
struct L1
f
n::Int
z::Int
end
(o::L1)() = o.f(o.n*o.z)
struct L2
f
n::Int
end
(o::L2)(z) = L1(o.f, o.n, z)
struct Factorial
f
c
n::Int
end
(o::Factorial)() = o.f(o.n-1, L2(o.c, o.n))
callable(f) = false
callable(f::Union{Factorial, L1, L2}) = true
function myfactorial(n, continuation)
if n == 1
continuation(1)
else
Factorial(myfactorial, continuation, n)
end
end
function cont(x)
x
end
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
Note that function fields are untyped. The function now runs much faster on first run:
julia> @time trampoline(myfactorial, 10, cont)
0.020673 seconds (4.24 k allocations: 264.427 KiB)
3628800
julia> @time trampoline(myfactorial, 10, cont)
0.000009 seconds (37 allocations: 1.094 KiB)
3628800
julia> @time trampoline(myfactorial, 14000, cont)
0.001277 seconds (55.55 k allocations: 1.489 MiB)
0
julia> @time trampoline(myfactorial, 14000, cont)
0.001197 seconds (55.55 k allocations: 1.489 MiB)
0
I just translated each closure of the code into a corresponding functor. It might not be necessary and there are probably better solutions, but it works and hopefully demonstrates the approach.
Edit:
To clarify the reason for the slowdown, you can use:
function factorial(n, continuation)
if n == 1
continuation(1)
else
tmp = (z -> (() -> continuation(n*z)))
@show typeof(tmp)
(() -> factorial(n-1, tmp))
end
end
Output:
julia> trampoline(factorial, 10, cont)
typeof(tmp) = ##31#34{Int64,#cont}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,#cont}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}}
3628800
tmp
is a closure. Its auto-generated type ##31#34
is like
struct Tmp{T,F}
n::T
continuation::F
end
F
Field type specialization continuation
causes long compilation times.
By using L2
instead that is not specialized in the corresponding field F
, the argument continuation
for is factorial
always of type L2
and the problem can be avoided.
source to share