## Copyright (C) 2006 Michel D. Schmid
##
##
## This program is free software; you can redistribute it and/or modify it
## under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 2, or (at your option)
## any later version.
##
## This program is distributed in the hope that it will be useful, but
## WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
## General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; see the file COPYING. If not, see
## .
## -*- texinfo -*-
## @deftypefn {Function File} address@hidden = train (@var{MLPnet},@var{mInputN},@var{mOutput},@var{[]},@var{[]},@var{VV})
## A neural feed-forward network will be trained with @code{train}
##
## @example
## [net,tr,out,E] = train(MLPnet,mInputN,mOutput,[],[],VV);
## @end example
## @noindent
##
## @example
## left side arguments:
## net: the trained network of the net structure @code{MLPnet}
## @end example
## @noindent
##
## @example
## right side arguments:
## MLPnet : the untrained network, created with @code{newff}
## mInputN: normalized input matrix
## mOutput: output matrix (normalized or not)
## [] : unused parameter
## [] : unused parameter
## VV : validize structure
## @end example
## @end deftypefn
## @seealso{newff,prestd,trastd}
## Author: Michel D. Schmid
## Comments: see in "A neural network toolbox for Octave User's Guide" [4]
## for variable naming... there have inputs or targets only one letter,
## e.g. for inputs is P written. To write a program, this is stupid, you can't
## search for 1 letter variable... that's why it is written here like Pp, or Tt
## instead only P or T.
function [net] = train(net,Pp,Tt,notUsed1,notUsed2,VV)
## check range of input arguments
error(nargchk(3,6,nargin))
## set defaults
doValidation = 0;
if nargin==6
# doValidation=1;
## check if VV is in MATLAB(TM) notation
[VV, doValidation] = checkVV(VV);
endif
## check input args
checkInputArgs(net,Pp,Tt)
## nargin ...
switch(nargin)
case 3
[Pp,Tt] = trainArgs(net,Pp,Tt);
VV = [];
case 6
[Pp,Tt] = trainArgs(net,Pp,Tt);
if isempty(VV)
VV = [];
else
if !isfield(VV,"Pp")
error("VV.Pp must be defined or VV must be [].")
endif
if !isfield(VV,"Tt")
error("VV.Tt must be defined or VV must be [].")
endif
[VV.Pp,VV.Tt] = trainArgs(net,VV.Pp,VV.Tt);
endif
otherwise
error("train: impossible code execution in switch(nargin)")
endswitch
## so now, let's start training the network
##===========================================
## first let's check if a train function is defined ...
if isempty(net.trainFcn)
error("train:net.trainFcn not defined")
endif
## calculate input matrix Im
[nRowsInputs, nColumnsInputs] = size(Pp);
Im = ones(nRowsInputs,nColumnsInputs).*Pp{1,1};
if (doValidation)
[nRowsVal, nColumnsVal] = size(VV.Pp);
VV.Im = ones(nRowsVal,nColumnsVal).*VV.Pp{1,1};
endif
## make it MATLAB(TM) compatible
nLayers = net.numLayers;
Tt{nLayers,1} = Tt{1,1};
Tt{1,1} = [];
if (!isempty(VV))
VV.Tt{nLayers,1} = VV.Tt{1,1};
VV.Tt{1,1} = [];
endif
## which training algorithm should be used
switch(net.trainFcn)
case "trainlm"
if !strcmp(net.performFcn,"mse")
error("Levenberg-Marquardt algorithm is defined with the MSE performance function, so please set MSE in NEWFF!")
endif
net = __trainlm(net,Im,Pp,Tt,VV);
otherwise
error("train algorithm argument is not valid!")
endswitch
# =======================================================
#
# additional check functions...
#
# =======================================================
function checkInputArgs(net,Pp,Tt)
## check "net", must be a net structure
if !__checknetstruct(net)
error("Structure doesn't seem to be a neural network!")
endif
## check Pp (inputs)
nInputSize = net.inputs{1}.size; #only one exists
[nRowsPp, nColumnsPp] = size(Pp);
if ( (nColumnsPp>0) )
if ( nInputSize==nRowsPp )
# seems to be everything i.o.
else
error("Number of rows must be the same, like in net.inputs.size defined!")
endif
else
error("At least one column must exist")
endif
## check Tt (targets)
[nRowsTt, nColumnsTt] = size(Tt);
if ( (nRowsTt | nColumnsTt)==0 )
error("No targets defined!")
elseif ( nColumnsTt!=nColumnsPp )
error("Inputs and targets must have the same number of data sets (columns).")
elseif ( net.layers{net.numLayers}.size!=nRowsTt )
error("Defined number of output neurons are not identically to targets data sets!")
endif
endfunction
# -------------------------------------------------------
function [Pp,Tt] = trainArgs(net,Pp,Tt);
## check number of arguments
error(nargchk(3,3,nargin));
[PpRows, PpColumns] = size(Pp);
Pp = mat2cell(Pp,PpRows,PpColumns); # mat2cell is the reason
# why octave-2.9.5 doesn't work
# octave-2.9.x with x>=6 should be
# ok
[TtRows, TtColumns] = size(Tt);
Tt = mat2cell(Tt,TtRows,TtColumns);
endfunction
# -------------------------------------------------------
function [VV, doValidation] = checkVV(VV)
## check number of arguments
error(nargchk(1,1,nargin));
if (isempty(VV))
doValidation = 0;
else
doValidation = 1;
## check if MATLAB(TM) naming convention is used
if isfield(VV,"P")
VV.Pp = VV.P;
VV.P = [];
elseif !isfield(VV,"Pp")
error("VV is defined but inside exist no VV.P or VV.Pp")
endif
if isfield(VV,"T")
VV.Tt = VV.T;
VV.T = [];
elseif !isfield(VV,"Tt")
error("VV is defined but inside exist no VV.TP or VV.Tt")
endif
endif
endfunction
# ============================================================
endfunction