WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Optimisers.AccumGrad
Optimisers.ClipGrad
Optimisers.ClipNorm
Optimisers.MixedPrecision
Optimisers.add_mixed_precision
Optimisers.OptimiserChain
Optimisers.SignDecay
Optimisers.WeightDecay
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad, MixedPrecision
AccumGrad, MixedPrecision, add_mixed_precision

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

Expand Down
43 changes: 43 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,46 @@ end

adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta))
adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...))


"""
add_mixed_precision([T], tree, model) -> new_tree

Add mixed precision to an existing optimisers state `tree` for `model`.
If `T` is not provided, `Float32` is used.

Each leaf of the new returned tree will contain a `MixedPrecision` rule wrapping the original rule,
and the states will be preserved and converted to type `T`.
"""
add_mixed_precision(tree, model) = add_mixed_precision(Float32, tree, model)

function add_mixed_precision(T, tree, model)
cache = IdDict()
tree = _add_mixed_precision(T, tree, model; cache)
isempty(cache) && @warn "setup found no trainable parameters in this model"
return tree
end

function _add_mixed_precision(T, tree, x; cache)
ch, re = functor(tree)
return mapvalue((ti, xi) -> _add_mixed_precision(T, ti, xi; cache), ch, _trainable(x))
end

function _add_mixed_precision(T, tree::Optimisers.Leaf, x; cache)
haskey(cache, tree) && return cache[tree]
fT(z) = z isa AbstractFloat || isnumeric(z) ? T.(z) : z
if !(tree.rule isa MixedPrecision{T})
if tree.rule isa MixedPrecision # different type
rulenew = MixedPrecision(T, tree.rule.rule)
statenew = fmap(fT, tree.state)
else
rulenew = MixedPrecision(T, tree.rule)
statenew = (T.(x), fmap(fT, tree.state))
end
treenew = Leaf(rulenew, statenew, tree.frozen)
else
treenew = tree
end
cache[tree] = treenew
return treenew
end
28 changes: 28 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,31 @@ end

@test_throws ArgumentError OptimiserChain(MixedPrecision(Adam()))
end

@testset "add_mixed_precision" begin
d = rand(Float16, 2,2)
d2 = rand(Float16, 2)
model = Foo(Foo(d, d2), d)
opt_state = Optimisers.setup(AdamW(), model)
@test opt_state.x.x === opt_state.y
@test opt_state.x.y.state[1] isa Vector{Float16}
@test opt_state.x.y.state[2] isa Vector{Float16}
@test opt_state.x.y.state[3] isa Tuple{Float16, Float16}

opt_state_new = add_mixed_precision(opt_state, model)

@test opt_state_new.x.x.rule isa MixedPrecision{Float32}
@test opt_state_new.x.x === opt_state_new.y
@test opt_state_new.x.x.state[1] isa Matrix{Float32}
@test opt_state_new.x.x.state[1] ≈ model.x.x
@test opt_state_new.x.y.state[2][1] isa Vector{Float32}
@test opt_state_new.x.y.state[2][2] isa Vector{Float32}
@test opt_state_new.x.y.state[2][3] isa Tuple{Float32, Float32}

opt_state_new2 = add_mixed_precision(Float64, opt_state_new, model)

@test opt_state_new2.x.x.rule isa MixedPrecision{Float64} # MixedPrecision{Float32} replaced
@test opt_state_new2.x.x.rule.rule isa AdamW # no nesting of MixedPrecision
@test opt_state_new2.x.x.state[1] isa Matrix{Float64}
@test opt_state_new2.x.x.state[2][1] isa Matrix{Float64}
end
Loading