Skip to content

Commit

Permalink
reconfigured tensor dispatch #50 #81
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravala committed Sep 17, 2020
1 parent 5228f28 commit 473b348
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Grassmann"
uuid = "4df31cd9-4c27-5bea-88d0-e6a7146666d8"
authors = ["Michael Reed"]
version = "0.7.1"
version = "0.7.2"

[deps]
AbstractTensors = "a8e43f4a-99b7-5565-8bf1-0165161caaea"
Expand Down
47 changes: 43 additions & 4 deletions src/algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ end

export , ,

(a::A,b::B) where {A<:TensorAlgebra,B<:TensorAlgebra} = ab
#⊗(a::A,b::B) where {A<:TensorAlgebra,B<:TensorAlgebra} = a∧b
(a::A,b::B) where {A<:TensorGraded,B<:TensorGraded} = Dyadic(a,b)
(a::A,b::B) where {A<:TensorGraded,B<:TensorGraded{V,0} where V} = a*b
(a::A,b::B) where {A<:TensorGraded{V,0} where V,B<:TensorGraded} = a*b

## regressive product: (L = grade(a) + grade(b); (-1)^(L*(L-mdims(V)))*⋆(⋆(a)∧⋆(b)))

Expand Down Expand Up @@ -202,20 +205,56 @@ outer(a::Leibniz.Derivation,b::Chain{V,1}) where V= outer(V(a),b)
outer(a::Chain{W},b::Leibniz.Derivation{T,1}) where {W,T} = outer(a,W(b))
outer(a::Chain{W},b::Chain{V,1}) where {W,V} = Chain{V,1}(a.*value(b))

contraction(a::Proj,b::TensorGraded) = a.v(a.vb)
contraction(a::Dyadic,b::TensorGraded) = a.x(a.yb)
contraction(a::TensorGraded,b::Dyadic) = (ab.x)b.y
contraction(a::TensorGraded,b::Proj) = (ab.v)b.v
contraction(a::Dyadic,b::Dyadic) = (a.x*(a.yb.x))b.y
contraction(a::Dyadic,b::Proj) = (a.x*(a.yb.v))b.v
contraction(a::Proj,b::Dyadic) = (a.v*(a.vb.x))b.y
contraction(a::Proj,b::Proj) = (a.v*(a.vb.v))b.v
contraction(a::Dyadic{V},b::TensorGraded{V,0}) where V = Dyadic{V}(a.x*b,a.y)
contraction(a::Proj{V},b::TensorGraded{V,0}) where V = valuetype(b)<:Complex ? Proj{V}(a.v*sqrt(b)) : Dyadic{V}(a.v*b,a.v)
contraction(a::Proj{V,<:Chain{V,1,<:TensorNested}},b::TensorGraded{V,0}) where V = Proj(Chain{V,1}(contraction.(value(a.v),b)))
contraction(a::Chain{W,1,<:Proj{V}},b::Chain{V,1}) where {W,V} = Chain{W,1}(value(a).⋅b)
contraction(a::Chain{W,1,<:Dyadic{V}},b::Chain{V,1}) where {W,V} = Chain{W,1}(value(a).⋅Ref(b))
contraction(a::Proj{W,<:Chain{W,1,<:TensorNested{V}}},b::Chain{V,1}) where {W,V} = a.v:b
contraction(a::Chain{W,G},b::Chain{V,1,<:Chain}) where {W,G,V} = Chain{V,1}(column(Ref(a).⋅value(b)))
contraction(a::Chain{W,G,<:Chain},b::Chain{V,1,<:Chain}) where {W,G,V} = Chain{V,1}(Ref(a).⋅value(b))
Base.:(:)(a::Chain{V,1,<:Chain},b::Chain{V,1,<:Chain}) where V = sum(value(a).⋅value(b))
Base.:(:)(a::Chain{W,1,<:Dyadic{V}},b::Chain{V,1}) where {W,V} = sum(value(a).⋅Ref(b))
Base.:(:)(a::Chain{W,1,<:Proj{V}},b::Chain{V,1}) where {W,V} = sum(broadcast(,value(a),Ref(b)))

+(a::Proj{V}...) where V = Proj(Chain(a...))
+(a::Dyadic{V}...) where V = Proj(Chain(a...))
+(a::TensorNested{V}...) where V = Proj(Chain(Dyadic.(a)...))
+(a::Proj{W,<:Chain{W,1,<:TensorNested{V}}} where W,b::TensorNested{V}) where V = +(value(a.v)...,b)
+(a::TensorNested{V},b::Proj{W,<:Chain{W,1,<:TensorNested{V}}} where W) where V = +(a,value(b.v)...)
+(a::Proj{M,<:Chain{M,1,<:TensorNested{V}}} where M,b::Proj{W,<:Chain{W,1,<:TensorNested{V}}} where W) where V = +(value(a.v)...,value(b.v)...)

-(a::TensorNested) where V = -1a
-(a::TensorNested,b::TensorNested) where V = a+(-b)
*(a::Number,b::TensorNested{V}) where V = (a*one(V))*b
*(a::TensorNested{V},b::Number) where V = a*(b*one(V))
@inline *(a::TensorGraded{V,0},b::TensorNested{V}) where V = ba
@inline *(a::TensorNested{V},b::TensorGraded{V,0}) where V = ab
@inline *(a::TensorGraded{V,0},b::Proj{V,<:Chain{V,1,<:TensorNested}}) where V = Proj{V}(a*b.v)
@inline *(a::Proj{V,<:Chain{V,1,<:TensorNested}},b::TensorGraded{V,0}) where V = Proj{V}(a.v*b)
Base.:(a::A,b::B) where {A<:TensorAlgebra,B<:TensorAlgebra} = ab

# dyadic identity element

Base.:+(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling{Bool}) where V = t+g
Base.:+(t::LinearAlgebra.UniformScaling,g::TensorNested) = t+DyadicChain(g)
Base.:+(g::TensorNested,t::LinearAlgebra.UniformScaling) = DyadicChain(g)+t
Base.:+(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling) where V = t+g
Base.:-(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling{Bool}) where V = t+g
Base.:-(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling) where V = t+g
Base.:-(t::LinearAlgebra.UniformScaling,g::TensorNested) = t-DyadicChain(g)
Base.:-(g::TensorNested,t::LinearAlgebra.UniformScaling) = DyadicChain(g)-t
@generated Base.:+(t::LinearAlgebra.UniformScaling{Bool},g::Chain{V,1,<:Chain{V,1}}) where V = :(Chain{V,1}($(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)]).+value(g)))
@generated Base.:+(t::LinearAlgebra.UniformScaling,g::Chain{V,1,<:Chain{V,1}}) where V = :(Chain{V,1}(t.λ*$(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)]).+value(g)))
@generated Base.:-(t::LinearAlgebra.UniformScaling{Bool},g::Chain{V,1,<:Chain{V,1}}) where V = :(Chain{V,1}($(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)]).-value(g)))
@generated Base.:-(t::LinearAlgebra.UniformScaling,g::Chain{V,1,<:Chain{V,1}}) where V = :(Chain{V,1}(t.λ*$(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)]).-value(g)))
@generated Base.:-(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling{Bool}) where V = :(Chain{V,1}(value(g).-$(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)])))
@generated Base.:-(g::Chain{V,1,<:Chain{V,1}},t::LinearAlgebra.UniformScaling) where V = :(Chain{V,1}(value(g).-t.λ*$(getalgebra(V).b[Grassmann.list(2,mdims(V)+1)])))

## cross product

Expand Down
31 changes: 12 additions & 19 deletions src/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
function Base.exp(t::T) where T<:TensorGraded
S,B = T<:SubManifold,T<:TensorTerm
if B && isnull(t)
return one(V)
return one(Manifold(t))
elseif isR301(Manifold(t)) && grade(t)==2 # && abs(t[0])<1e-9 && !options.over
u = sqrt(abs(abs2(t)[1]))
u<1e-5 && (return 1+t)
Expand Down Expand Up @@ -294,12 +294,12 @@ for (logfast,expf) ∈ ((:log_fast,:exp),(:logh_fast,:exph))
@eval function $logfast(t::T) where T<:TensorAlgebra
V = Manifold(t)
term = zero(V)
norm = FixedVector{2}(0.,0.)
nrm = FixedVector{2}(0.,0.)
while true
en = $expf(term)
term -= 2(en-t)/(en+t)
@inbounds norm .= (norm[2],norm(term))
@inbounds norm[1] norm[2] && break
@inbounds nrm .= (nrm[2],norm(term))
@inbounds nrm[1] nrm[2] && break
end
return term
end
Expand Down Expand Up @@ -678,39 +678,32 @@ Base.rand(::AbstractRNG,::SamplerType{MultiVector}) = rand(MultiVector{rand(Mani
Base.rand(::AbstractRNG,::SamplerType{MultiVector{V}}) where V = MultiVector{V}(DirectSum.orand(svec(mdims(V),Float64)))
Base.rand(::AbstractRNG,::SamplerType{MultiVector{V,T}}) where {V,T} = MultiVector{V}(rand(svec(mdims(V),T)))

export Orthotope, Orthogrid
export Orthogrid

struct Orthotope{V,T}
min::Chain{V,1,T}
max::Chain{V,1,T}
end

(::Base.Colon)(min::Chain{V,1,T},max::Chain{V,1,T}) where {V,T} = Orthotope{V,T}(min,max)

struct Orthogrid{V,T} # <: TensorGraded{V,1} mess up collect?
x::Orthotope{V,T}
@computed struct Orthogrid{V,T} # <: TensorGraded{V,1} mess up collect?
v::Dyadic{V,Chain{V,1,T,mdims(V)},Chain{V,1,T,mdims(V)}}
n::Chain{V,1,Int}
s::Chain{V,1,Float64}
end

Orthogrid{V,T}(x,n) where {V,T} = Orthogrid{V,T}(x,n,Chain{V,1}(value(x.max-x.min)./(value(n)-1)))
Orthogrid{V,T}(v,n) where {V,T} = Orthogrid{V,T}(v,n,Chain{V,1}(value(v.x-v.y)./(value(n)-1)))

Base.show(io::IO,t::Orthogrid) = println('(',t.x.min,"):(",t.s,"):(",t.x.max,')')
Base.show(io::IO,t::Orthogrid) = println('(',t.v.x,"):(",t.s,"):(",t.v.y,')')

zeroinf(f) = iszero(f) ? Inf : f

(::Base.Colon)(min::Chain{V,1,T},step::Chain{V,1,T},max::Chain{V,1,T}) where {V,T} = Orthogrid{V,T}(min:max,Chain{V,1}(Int.(round.(value(max-min)./zeroinf.(value(step))))+1),step)
(::Base.Colon)(min::Chain{V,1,T},step::Chain{V,1,T},max::Chain{V,1,T}) where {V,T} = Orthogrid{V,T}(minmax,Chain{V,1}(Int.(round.(value(max-min)./zeroinf.(value(step)))).+1),step)

Base.iterate(t::Orthogrid) = (getindex(t,1),1)
Base.iterate(t::Orthogrid,state) = (s=state+1; slength(t) ? (getindex(t,s),s) : nothing)
@pure Base.eltype(::Type{Orthogrid{V,T}}) where {V,T} = Chain{V,1,T,mdims(V)}
@pure Base.step(t::Orthogrid) = value(t.s)
@pure Base.size(t::Orthogrid) = value(t.n).data
@pure Base.size(t::Orthogrid) = value(t.n).v
@pure Base.length(t::Orthogrid) = prod(size(t))
@pure Base.lastindex(t::Orthogrid) = length(t)
@pure Base.lastindex(t::Orthogrid,i::Int) = size(t)[i]
@pure Base.getindex(t::Orthogrid,i::CartesianIndex) = getindex(t,i.I...)
@pure Base.getindex(t::Orthogrid{V},i::Vararg{Int}) where V = Chain{V,1}(value(t.x.min)+(Values(i)-1).*step(t))
@pure Base.getindex(t::Orthogrid{V},i::Vararg{Int}) where V = Chain{V,1}(value(t.v.x)+(Values(i).-1).*step(t))

Base.IndexStyle(::Orthogrid) = IndexCartesian()
function Base.getindex(A::Orthogrid, I::Int)
Expand Down
50 changes: 29 additions & 21 deletions src/forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,27 @@
W = Manifold(w)
if isbasis(W)
if Q == V
if G == M == 1
x = bits(W)
X = isdyadic(V) ? x>>Int(mdims(V)/2) : x
Y = 0X ? X : x
out = :(@inbounds b.v[bladeindex($(mdims(V)),Y)])
return :(Simplex{V}(V[intlog(Y)+1] ? -($out) : $out,SubManifold{V}()))
elseif G == 1 && M == 2
(!isdyadic(V)) && :(throw(error("wrong basis")))
ib,(m1,m2) = indexbasis(N,1),DirectSum.eval_shift(W)
:(@inbounds $(V[m2] ? :(-(b.v[m2])) : :(b.v[m2]))*getbasis(V,ib[m1]))
if isdyadic(V)
if G == M == 1
x = bits(W)
X = isdyadic(V) ? x>>Int(mdims(V)/2) : x
Y = 0X ? X : x
out = :(@inbounds b.v[bladeindex($(mdims(V)),Y)])
return :(Simplex{V}(V[intlog(Y)+1] ? -($out) : $out,SubManifold{V}()))
elseif G == 1 && M == 2
ib,(m1,m2) = indexbasis(N,1),DirectSum.eval_shift(W)
:(@inbounds $(V[m2] ? :(-(b.v[m2])) : :(b.v[m2]))*getbasis(V,ib[m1]))
else
:(throw(error("not yet possible")))
end
else
:(throw(error("not yet possible")))
return :(contraction(w,b))
end
else
:(interform(w,b))
return :(interform(w,b))
end
elseif V==W
return :b
return V===W ? :b : :(Chain{w,G,T}(value(b)))
elseif WV
if G == 1
ind = Values{mdims(W),Int}(indices(bits(W),mdims(V)))
Expand Down Expand Up @@ -85,9 +88,10 @@ end
(W::Signature)(b::MultiVector{V,T}) where {V,T} = SubManifold(W)(b)
function (W::SubManifold{Q,M,S})(m::MultiVector{V,T}) where {Q,M,V,S,T}
if isbasis(W)
throw(error("MultiVector forms not yet supported"))
isdyadic(V) && throw(error("MultiVector forms not yet supported"))
return V==W ? contraction(W,m) : interform(W,m)
elseif V==W
return m
return V===W ? m : MultiVector{W,T}(value(m))
elseif WV
out,N = zeros(choicevec(M,valuetype(m))),mdims(V)
bs = binomsum_set(N)
Expand Down Expand Up @@ -185,9 +189,10 @@ end
## Chain forms

(a::Chain)(b::T) where {T<:TensorAlgebra} = interform(a,b)
(a::Chain{V,1,<:Manifold} where V)(b::T) where {T<:TensorAlgebra} = contraction(a,b)
@eval begin
function (a::Chain{V,2})(b::Chain{V,1}) where V
(!isdyadic(V)) && throw(error("wrong basis"))
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:N,:t,:df,:di))...)
out = zero(mvec(N,1,t))
for Q 1:Int(N/2)
Expand All @@ -198,7 +203,7 @@ end
end
return Chain{V,1}(out)
end
function Chain{V,T}(b::Matrix{T}) where {V,T}
function Chain{V,T}(b::AbstractMatrix{T}) where {V,T}
(!isdyadic(V)) && throw(error("$V does not support this conversion"))
$(insert_expr((:N,:M))...)
size(b) (M,M) && throw(error("dimension mismatch"))
Expand All @@ -216,6 +221,7 @@ end
# more forms

function (a::Chain{V,1})(b::SubManifold{V,1}) where V
(!isdyadic(V)) && (return contraction(a,b))
x = bits(b)
X = isdyadic(V) ? x<<Int(mdims(V)/2) : x
Y = X>2^mdims(V) ? x : X
Expand All @@ -224,7 +230,7 @@ function (a::Chain{V,1})(b::SubManifold{V,1}) where V
end
@eval begin
function (a::Chain{V,2,T})(b::SubManifold{V,1}) where {V,T}
(!isdyadic(V)) && throw(error("wrong basis"))
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:N,:df,:di))...)
Q = bladeindex(N,bits(b))
@inbounds m,val = df[Q][1],df[Q][2]*value(b)
Expand All @@ -235,6 +241,7 @@ end
return Chain{V,1}(out)
end
function (a::Chain{V,1})(b::Simplex{V,1}) where V
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:t,))...)
x = bits(b)
X = isdyadic(V) ? x<<Int(mdims(V)/2) : x
Expand All @@ -243,6 +250,7 @@ end
Simplex{V}((V[intlog(x)+1]*out*b.v)::t,SubManifold{V}())
end
function (a::Simplex{V,1})(b::Chain{V,1}) where V
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:t,))...)
x = bits(a)
X = isdyadic(V) ? x>>Int(mdims(V)/2) : x
Expand All @@ -251,13 +259,13 @@ end
Simplex{V}((a.v*V[intlog(Y)+1]*out)::t,SubManifold{V}())
end
function (a::Simplex{V,2})(b::Chain{V,1}) where V
(!isdyadic(V)) && throw(error("wrong basis"))
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:N,:t))...)
ib,(m1,m2) = indexbasis(N,1),DirectSum.eval_shift(a)
@inbounds ((V[m2]*a.v*b.v[m2])::t)*getbasis(V,ib[m1])
end
function (a::Chain{V,2})(b::Simplex{V,1}) where V
(!isdyadic(V)) && throw(error("wrong basis"))
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:N,:t,:df,:di))...)
Q = bladeindex(N,bits(b))
out = zero(mvec(N,1,T))
Expand All @@ -268,6 +276,7 @@ end
return Chain{V,1}(out)
end
function (a::Chain{V,1})(b::Chain{V,1}) where V
(!isdyadic(V)) && (return contraction(a,b))
$(insert_expr((:N,:M,:t,:df))...)
out = zero(t)
for Q 1:M
Expand All @@ -276,4 +285,3 @@ end
return Simplex{V}(out::t,SubManifold{V}())
end
end

0 comments on commit 473b348

Please sign in to comment.