11module ParametricExpressionModule
22
33using DispatchDoctor: @stable , @unstable
4- using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
4+ using ChainRulesCore: ChainRulesCore as CRC , NoTangent, @thunk
55
66using .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77using .. NodeModule: AbstractExpressionNode, Node, tree_mapreduce
@@ -20,6 +20,7 @@ import ..StringsModule: string_tree
2020import .. EvaluateModule: eval_tree_array
2121import .. EvaluateDerivativeModule: eval_grad_tree_array
2222import .. EvaluationHelpersModule: _grad_evaluator
23+ import .. ChainRulesModule: extract_gradient
2324import .. ExpressionModule:
2425 get_contents,
2526 get_metadata,
@@ -207,7 +208,7 @@ has_constants(ex::ParametricExpression) = _interface_error()
207208has_operators (ex:: ParametricExpression ) = has_operators (get_tree (ex))
208209function get_constants (ex:: ParametricExpression{T} ) where {T}
209210 constants, constant_refs = get_constants (get_tree (ex))
210- parameters = ex . metadata . parameters
211+ parameters = get_metadata (ex) . parameters
211212 flat_parameters = parameters[:]
212213 num_constants = length (constants)
213214 num_parameters = length (flat_parameters)
@@ -218,9 +219,27 @@ function set_constants!(ex::ParametricExpression{T}, x, refs) where {T}
218219 # First, set the usual constants
219220 set_constants! (get_tree (ex), @view (x[1 : (refs. num_constants)]), refs. constant_refs)
220221 # Then, copy in the parameters
221- ex . metadata . parameters[:] .= @view (x[(refs. num_constants + 1 ): end ])
222+ get_metadata (ex) . parameters[:] .= @view (x[(refs. num_constants + 1 ): end ])
222223 return ex
223224end
225+ function extract_gradient (
226+ gradient: :@NamedTuple {
227+ tree:: NT ,
228+ metadata: :@NamedTuple {
229+ _data: :@NamedTuple {
230+ operators:: Nothing ,
231+ variable_names:: Nothing ,
232+ parameters:: PARAM ,
233+ parameter_names:: Nothing ,
234+ }
235+ }
236+ },
237+ ex:: ParametricExpression{T,N} ,
238+ ) where {T,N<: ParametricNode{T} ,NT<: NodeTangent{T,N} ,PARAM<: AbstractMatrix{T} }
239+ d_constants = extract_gradient (gradient. tree, get_tree (ex))
240+ d_params = gradient. metadata. _data. parameters[:]
241+ return vcat (d_constants, d_params) # Same shape as `get_constants`
242+ end
224243
225244function Base. convert (:: Type{Node} , ex:: ParametricExpression{T} ) where {T}
226245 num_params = UInt16 (size (ex. metadata. parameters, 1 ))
@@ -238,9 +257,7 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
238257 Node{T},
239258 )
240259end
241- function ChainRulesCore. rrule (
242- :: typeof (convert), :: Type{Node} , ex:: ParametricExpression{T}
243- ) where {T}
260+ function CRC. rrule (:: typeof (convert), :: Type{Node} , ex:: ParametricExpression{T} ) where {T}
244261 tree = get_contents (ex)
245262 primal = convert (Node, ex)
246263 pullback = let tree = tree
0 commit comments