Skip to content


Folders and files

Last commit message
Last commit date

Latest commit



36 Commits

Repository files navigation

mercury-ad: Mercury module for automatic differentiation


A Mercury module for automatic differentiation that includes forward and backward differentiation.

The module adapts the approach presented in Interestingly, most of the functional approaches for backward differentiation described in uses references to update a tape in-place, whereas this library implements a more pure functional approach for the fan-out and reverse phases. [For the technical details, I added an integer tag to each tape and then collect the sensitivity values to extract the gradients.]

The module also includes some simple optimisers and implements one of the extended examples (mlp_R) given in

The general type is ad_number, which includes base(float) for numerical values. Full documentation is given below.

Test examples

Some test code is here:

:- module test_ad.

:- interface.
:- import_module io.
:- pred main(io::di, io::uo) is det.

:- implementation.
:- use_module math.
:- import_module ad.
:- import_module ad.v.
:- import_module list.
:- import_module float.

main(!IO) :-
    derivative_F(func(X) = exp(base(2.0)*X), base(1.0), GradF),
    print_line("Expected: ", !IO), print_line(base(math.exp(2.0)*2.0), !IO),
    print_line(GradF, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=exp(base(2.0)*A)+B else Y = base(0.0)),
		   from_list([1.0,3.0]), Grad0),
    print_line("Expected: ", !IO), print_line([base(math.exp(2.0)*2.0),base(1.0)], !IO),
    print_line(Grad0, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=B+A*A*A else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad),
    print_line("Expected: ", !IO), print_line([base(3.0*1.1*1.1),base(1.0)], !IO),
    print_line(Grad, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad2),
    print_line("Expected: ", !IO),
		base(math.exp(2.3+1.1*1.1*1.1))], !IO),
    print_line(Grad2, !IO),
    gradient_F(func(List) = Y :-
		   (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad3),
    print_line("Expected: ", !IO),
		base(math.exp(2.3+1.1*1.1*1.1))], !IO),
    print_line(Grad3, !IO),
    multivariate_argmin_F(func(AB) = Y :-
			      if AB = [A,B]
				      then Y = A*A+(B-base(1.0))*(B-base(1.0))
								  else Y=base(0.0),
    print_line("Expected: ", !IO),
    print_line([base(0.0),base(1.0)], !IO),
    multivariate_argmin_R(func(AB) = Y :-
			      if AB = [A,B]
				      then Y = A*A+(B-base(1.0))*(B-base(1.0))
								  else Y=base(0.0),
    print_line("Expected: ", !IO),
    print_line([base(0.0),base(1.0)], !IO),
    print_line("Expected: ", !IO),
    print_line([base(1.0),base(1.0)], !IO),
    print_line("Expected: ", !IO),
    print_line([base(1.0),base(1.0)], !IO),
    gradient_F(func(List) = Y :-
		   (if List=[A,B] then Y=atan2(A,B) else Y = base(0.0)),
		   [base(0.5),base(0.6)], Y8),
    print_line("Expected: ", !IO),
    Y8_Y = base(fdiff(func(Yi)=math.atan2(Yi,0.6),0.5)),
    Y8_X = base(fdiff(func(Xi)=math.atan2(0.5,Xi),0.6)),
    print_line([Y8_Y, Y8_X], !IO),

:- func rosenbrock(v_ad_number) = ad_number.
rosenbrock(In) = Result :-
    In = [X,Y] ->
    A = base(1.0),
    B = base(100.0),
    Result = (A-X)*(A-X)+B*(Y-X*X)*(Y-X*X)
    Result = base(0.0).

Running the code and getting output:

mmc --make test_ad libad && ./test_ad
[base(14.7781121978613), base(1.0)]
[base(14.7781121978613), base(1.0)]
[base(3.630000000000001), base(1.0)]
[base(3.630000000000001), base(1.0)]
[base(137.0344903162743), base(37.75054829649429)]
[base(137.0344903162743), base(37.75054829649429)]
[base(137.0344903162743), base(37.75054829649429)]
[base(137.0344903162743), base(37.75054829649429)]
[base(0.0), base(1.0)]
[base(-4.2936212903462384e-09), base(0.9999999957063787)]
[base(0.0), base(1.0)]
[base(-4.2936212903462384e-09), base(0.9999999957063787)]
[base(1.0), base(1.0)]
[base(0.9999914554400818), base(0.9999828755368568)]
[base(1.0), base(1.0)]
[base(0.9999914554355536), base(0.9999828755391909)]
[base(0.9836065574087004), base(-0.8196721312081489)]
[base(0.9836065573770492), base(-0.819672131147541)]


% Copyright (C) 2023 Mark Clements.
% This file is distributed under the terms specified in LICENSE.
% File: ad.m.
% Authors: mclements
% Stability: low.
% This module defines backward and forward automatic
% differentiation

:- module ad.
:- interface.
:- import_module list.
:- import_module float.

    %% main representation type
:- type ad_number --->
   dual_number(int,       % epsilon (used for order of derivative)
	       ad_number, % value
	       ad_number) % derivative
   tape(int,              % variable order (new)
	int,              % epsilon (used for order of derivative)
	ad_number,        % value
	list(ad_number),  % factors
	list(ad_number),  % tape
	int,              % fanout 
	ad_number)        % sensitivity

    %% vector of ad_numbers
:- type v_ad_number == list(ad_number).
    %% matrix of ad_numbers
:- type m_ad_number == list(list(ad_number)).
    %% vector of floats
:- type v_float == list(float).
    %% matrix of floats
:- type m_float == list(list(float)).

    %% make_dual(Tag, Value, Derivative) constructs a dual_number
:- func make_dual_number(int,ad_number,ad_number) = ad_number.
    %% make_dual(Tag, Epsilon, Value, Factors, Tapes) constructs a tape
:- func make_tape(int, int, ad_number, v_ad_number,
		  v_ad_number) = ad_number.

%% defined functions and predicates for differentiation
:- func (ad_number::in) + (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) - (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) * (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) / (ad_number::in) = (ad_number::out) is det.
:- func pow(ad_number, ad_number) = ad_number.
:- pred (ad_number::in) < (ad_number::in) is semidet.
:- pred (ad_number::in) =< (ad_number::in) is semidet.
:- pred (ad_number::in) > (ad_number::in) is semidet.
:- pred (ad_number::in) >= (ad_number::in) is semidet.
:- pred (ad_number::in) == (ad_number::in) is semidet. % equality
:- func exp(ad_number) = ad_number is det.
:- func ln(ad_number) = ad_number is det.
:- func log2(ad_number) = ad_number is det.
:- func log10(ad_number) = ad_number is det.
:- func log(ad_number,ad_number) = ad_number is det.
:- func sqrt(ad_number) = ad_number is det.
:- func sin(ad_number) = ad_number is det.
:- func cos(ad_number) = ad_number is det.
:- func tan(ad_number) = ad_number is det.
:- func asin(ad_number) = ad_number is det.
:- func acos(ad_number) = ad_number is det.
:- func atan(ad_number) = ad_number is det.
:- func atan2(ad_number,ad_number) = ad_number is det.
:- func sinh(ad_number) = ad_number is det.
:- func cosh(ad_number) = ad_number is det.
:- func tanh(ad_number) = ad_number is det.
:- func abs(ad_number) = ad_number is det.
%% TODO: add further functions and operators

    %% derivative_F(F,Theta,Derivative,!Epsilon) takes a function F and initial values Theta,
    %% and returns the Derivarive, with input and output for Epsilon (accounting on the derivatives).
    %% Uses forward differentiation.
:- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out,
		     int::in, int::out) is det.
    %% derivative_F(F,Theta,Derivative) takes a function F and initial values Theta,
    %% and returns the Derivative, assuming the default derivative count.
    %% Uses forward differentiation.
:- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out) is det.

    %% gradient_F(F,Theta,Gradient,!Epsilon) takes a function F and initial values Theta,
    %% and returns the Gradient, with input and output for Epsilon (accounting on the derivatives)
    %% Uses forward differentiation.
:- pred gradient_F((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out) is det.
    %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses forward differentiation.
:- pred gradient_F((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out,
		  int::in, int::out) is det.

    %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses backward differentiation.
:- pred gradient_R((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out,
		   int::in, int::out) is det.
    %% gradient_R(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses backward differentiation.
:- pred gradient_R((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out) is det.

    %% gradient_ascent_F(F,Theta,Iterations,Eta,{Final,Objective,Derivatives})
    %% takes a function F, initial values Theta, number of Iterations and change Epsilon,
    %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives.
    %% Uses forward differentiation.
:- pred gradient_ascent_F((func(v_ad_number) = ad_number)::in,
			   {v_ad_number, ad_number, v_ad_number}::out) is det.
    %% gradient_ascent_R(F,Theta,Iterations,Eta,{Final,Objective,Derivatives})
    %% takes a function F, initial values Theta, number of Iterations and change Epsilon,
    %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives.
    %% Uses backward differentiation.
:- pred gradient_ascent_R((func(v_ad_number) = ad_number)::in,
			   {v_ad_number, ad_number, v_ad_number}::out) is det.

    %% multivariate_argmin_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *minimum*.
    %% Uses forward differentiation.
:- pred multivariate_argmin_F((func(v_ad_number) = ad_number)::in,
			      v_ad_number::out) is det.
    %% multivariate_argmin_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *minimum*.
    %% Uses backward differentiation.
:- pred multivariate_argmin_R((func(v_ad_number) = ad_number)::in,
			      v_ad_number::out) is det.

    %% multivariate_argmax_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *maximum*.
    %% Uses forward differentiation.
:- pred multivariate_argmax_F((func(v_ad_number) = ad_number)::in,
			      v_ad_number::out) is det.
    %% multivariate_argmax_R(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *maximum*.
    %% Uses backward differentiation.
:- pred multivariate_argmax_R((func(v_ad_number) = ad_number)::in,
			      v_ad_number::out) is det.

    %% multivariate_max_F(F,Theta,Value})
    %% takes a function F and initial values Theta
    %% and calculates the *maximum* Value.
    %% Uses forward differentiation.
:- pred multivariate_max_F((func(v_ad_number) = ad_number)::in,
			   ad_number::out) is det.
    %% multivariate_max_R(F,Theta,Value})
    %% takes a function F and initial values Theta
    %% and calculates the *maximum* Value.
    %% Uses backward differentiation.
:- pred multivariate_max_R((func(v_ad_number) = ad_number)::in,
			   ad_number::out) is det.

%% Some common utilities
    %% sqr(X) = X*X
:- func sqr(ad_number) = ad_number.
    %% map_n(F,N) =, 1..N).
:- func map_n(func(int) = ad_number, int) = v_ad_number.
    %% vplus(X,Y) = X + Y
:- func vplus(v_ad_number, v_ad_number) = v_ad_number.
    %% vminus(X,Y) = X - Y
:- func vminus(v_ad_number, v_ad_number) = v_ad_number.
    %% ktimesv(K,V) = K*V
:- func ktimesv(ad_number, v_ad_number) = v_ad_number.
    %% magnitude_squared(V) = sum_i(V[i]*V[i])
:- func magnitude_squared(v_ad_number) = ad_number.
    %% magnitude(V) = sqrt(sum_i(V[i]*V[i]))
:- func magnitude(v_ad_number) = ad_number.
    %% distance_squared(X,Y) = magnitude_sqrt(X-Y)
:- func distance_squared(v_ad_number,v_ad_number) = ad_number.
    %% distance(X,Y) = magnitude(X-Y)
:- func distance(v_ad_number,v_ad_number) = ad_number.
    %% fdiff(F,X,Eps) gives the first derivative for F at X using a
    %% symmetric finite difference using Eps
:- func fdiff(func(float)=float,float,float) = float.
    %% fdiff(F,X) = fdiff(F,X,1.0e-5)
:- func fdiff(func(float)=float,float) = float.

%% submodule for operations and functions on v_ad_number
:- module ad.v.
:- interface.
    %% Addition
:- func (v_ad_number::in) + (v_ad_number::in) = (v_ad_number::out) is det.
    %% Subtraction
:- func (v_ad_number::in) - (v_ad_number::in) = (v_ad_number::out) is det.
    %% multiplication by a scalar
:- func (ad_number::in) * (v_ad_number::in) = (v_ad_number::out) is det.
    %% convert from a vector of floats
:- func from_list(v_float) = v_ad_number.
    %% convert of a vector of floats
:- func to_list(v_ad_number) = v_float is det.
:- end_module ad.v.

%% submodule for operations and functions on m_ad_number
:- module ad.m.
:- interface.
    %% Addition
:- func (m_ad_number::in) + (m_ad_number::in) = (m_ad_number::out) is det.
    %% Subtraction
:- func (m_ad_number::in) - (m_ad_number::in) = (m_ad_number::out) is det.
    %% convert from a matrix of floats
:- func from_lists(m_float) = m_ad_number.
    %% convert of a matrix of floats
:- func to_lists(m_ad_number) = m_float is det.
:- end_module ad.m.

    %% fanout(Tape) is the fanout operation for backward differentiation 
:- func determine_fanout(ad_number) = ad_number.
    %% reverse_phase(Sensitivity,Tape) is the reverse pahse for backward differentiation
:- func reverse_phase(ad_number, ad_number) = ad_number.
    %% extract_gradients(Tape) extracts the gradients as a vector
:- func extract_gradients(ad_number) = v_ad_number.
    %% to_float(Ad_number) return a float representation
:- func to_float(ad_number) = float.


Mercury library for automatic differentiation








No releases published


No packages published