## TensorRules.jl

Macros to define custom adjoints for TensorOperations.jl
Author ho-oto
Popularity
15 Stars
Updated Last
1 Year Ago
Started In
October 2020

# TensorRules.jl

TensorRules.jl provides a macro @∇ (you can type ∇ by \nabla<tab>), which enable us to use automatic differentiation (AD) libraries (e.g., Zygote.jl, Diffractor.jl) with @tensor and @tensoropt macros in TensorOperations.jl.

TensorRules.jl uses ChainRulesCore.jl to define custom adjoints. So, you can use any AD libraries which supports ChainRulesCore.jl.

## How to use

julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
# d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of foo by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> gradient(foo, a, b, c); # it works!

## How it works

The strategy of TensorRules.jl are very similar to TensorGrad.jl.

@∇ converts functions which contains @tensor or @tensoropt macro. First, @∇ detects @tensor or @tensoropt expressions in function definition and convert them to inlined functions. Then, @∇ define custom adjoint rules for the generated functions.

For example, the following definition

@∇ function foo(a, b, c, d, e, f)
@tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
x = x + f
@tensor x[A, B] += a[A, C] * (a * a)[C, B]
return x
end

will be converted to a code equivalent to

function foo(a, b, c, d, e, f)
x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
x = x + f
x += _foo_2(a, a * a)
return x
end

@inline _foo_1(x1, x2, x3, x4, x5) =
@tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]

@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]

function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
f = _foo_1(x1, x2, x3, x4, x5)
Px1, Px2, Px3, Px4, Px5 = ProjectTo(x1), ProjectTo(x2), ProjectTo(x3), ProjectTo(x4), ProjectTo(x5)
function _foo_1_pullback(Δf)
fnΔx1(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx2(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx3(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx4(Δf, x1, x2, x3, x4, x5) = first(@tensoropt !C _[] := conj(conj(Δf[A, B]) * x5[A, B]))
fnΔx5(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, B] := conj(x4 * conj(Δf[A, B]))
fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, B] += conj(x4 * conj(Δf[A, B]))
Δx1 = InplaceableThunk(
Thunk(() -> Px1(fnΔx1(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx2 = InplaceableThunk(
Thunk(() -> Px2(fnΔx2(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx3 = InplaceableThunk(
Thunk(() -> Px3(fnΔx3(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx4 = Thunk(() -> fnΔx4(Δf, x1, x2, x3, x4, x5))
Δx5 = InplaceableThunk(
Thunk(() -> Px5(fnΔx5(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5)
)
return (NoTangent(), Δx1, Δx2, Δx3, Δx4, Δx5)
end
return f, _foo_1_pullback
end

function rrule(::typeof(_foo_2), x1, x2)
...
end

By using Thunk and InplaceableThunk properly, adjoints will be evaluated only if they are needed.

## unsupported features

• @∇ uses @capture macro defined in MacroTools.jl to parse Expr. Because of the limitation of @capture macro, index notations based on :typed_vcat and :typed_hcat (A[a; b], A[a b]) are unsupported. Please use A[a, b] style.
• Designations of contraction order based on ord=(...) or NCON style are unsupported. Please use @tensoropt and specify costs of each bonds.
• Since Zygote.jl does not support inplace operations, we cannot use @tensor A[] = ... in the expression. Please use :=, += and -= instead.

## TODO

• support @tensor block (@tensor begin ... end)
• support higher order differentiation (by applying @∇ to rrule and frule recursively)
• better support of InplaceableThunk

### Used By Packages

No packages found.