Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

ctorchvfunction.h

Go to the documentation of this file.
00001 // Copyright (C) 2003
00002 // Gerhard Neumann (gneumann@gmx.net)
00003 // Stephan Neumann (sneumann@gmx.net) 
00004 //                
00005 // This file is part of RL Toolbox.
00006 // http://www.igi.tugraz.at/ril_toolbox
00007 //
00008 // All rights reserved.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions
00012 // are met:
00013 // 1. Redistributions of source code must retain the above copyright
00014 //    notice, this list of conditions and the following disclaimer.
00015 // 2. Redistributions in binary form must reproduce the above copyright
00016 //    notice, this list of conditions and the following disclaimer in the
00017 //    documentation and/or other materials provided with the distribution.
00018 // 3. The name of the author may not be used to endorse or promote products
00019 //    derived from this software without specific prior written permission.
00020 // 
00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00024 // IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00025 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00026 // NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00027 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00028 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00030 // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 
00032 #ifndef CVTORCHFUNCTION_H
00033 #define CVTORCHFUNCTION_H
00034 
00035 #include "ConnectedMachine.h"
00036 #include "MLP.h"
00037 #include "cvfunction.h"
00038 #include "ccontinuousactions.h"
00039 
00040 
00041 using Torch::Sequence;
00042 using Torch::Parameters;
00043 using Torch::Machine;
00044 using Torch::GradientMachine;
00045 
00047 
00056 class CTorchFunction
00057 {
00058 protected:
00060         Machine *machine;
00061                 
00063         Sequence *input;
00064 
00065 public:
00067         CTorchFunction(Machine *machine);
00068         virtual ~CTorchFunction();
00069 
00070         
00071 
00073         virtual Machine *getMachine();
00074 
00075         virtual double getValueFromMachine(Sequence *state);
00076 };
00077 
00079 
00085 class CTorchGradientFunction : public CTorchFunction, public CGradientFunction
00086 {
00087 protected:
00088         Sequence *alpha;
00090         GradientMachine *gradientMachine;
00091 
00092         CAdaptiveEtaCalculator *localEtaCalc;
00093 public:
00095         CTorchGradientFunction(int numInputs, int numOutputs);
00096         CTorchGradientFunction(GradientMachine *machine);
00097         virtual ~CTorchGradientFunction();
00098 
00100         virtual void resetData();
00101 
00102         virtual void updateWeights(CFeatureList *gradientFeatures);
00103 
00104         virtual int getNumWeights();
00105 
00106         virtual void getInputDerivationPre(ColumnVector *input, Matrix *targetVector);
00107         virtual void getFunctionValuePre(ColumnVector *input, ColumnVector *output);
00108 
00109 
00110         virtual void getWeights(double *parameters);
00111         virtual void setWeights(double *parameters);
00112 
00113         virtual void getGradientPre(ColumnVector *input, ColumnVector *outputErrors, CFeatureList *gradientFeatures);
00114 
00115         void setGradientMachine(GradientMachine *gradientMachine);
00116         GradientMachine *getGradientMachine();
00117 };
00118 
00119 class CTorchGradientEtaCalculator : public CIndividualEtaCalculator
00120 {
00121 public:
00122         CTorchGradientEtaCalculator(GradientMachine *gradientMachine);
00123 };
00124 
00125 class CTorchVFunction :  public CAbstractVFunction
00126 {
00127 protected:
00129 
00131         void getInputSequence(CState *state, Sequence *input);
00132 
00133         CTorchFunction *torchFunction;
00134 
00135         Sequence *input;
00136 
00137 public:
00139         CTorchVFunction(CTorchFunction *torchFunction, CStateProperties *properties);
00140         virtual ~CTorchVFunction();
00141 
00143         virtual double getValue(CState *state);
00144 
00145 };
00146 
00147 
00149 
00155 class CVFunctionFromGradientFunction : public CGradientVFunction, public CVFunctionInputDerivationCalculator
00156 {
00157 protected:
00159         CGradientFunction *gradientFunction;
00160 
00161         ColumnVector *input;
00162         ColumnVector *outputError;
00163         Matrix *inputDerivation;
00164 
00165         virtual void updateWeights(CFeatureList *gradientFeatures);
00166         void getInputSequence(CState *state, ColumnVector *sequence);
00167 
00168 public:
00170         CVFunctionFromGradientFunction(CGradientFunction *gradientFunction, CStateProperties *properties);
00171         virtual ~CVFunctionFromGradientFunction();
00172 
00174 
00175         virtual void setValue(CState *state, double value);
00177 
00178 
00180         virtual void resetData();
00181 
00183         virtual double getValue(CState *state);
00184 
00185         //virtual CStateProperties *getGradientCalculator();
00186 
00187         virtual void getGradient(CStateCollection *originalState, CFeatureList *modifiedState);
00188 
00189         virtual int getNumWeights();
00190 
00191         virtual CAbstractVETraces *getStandardETraces();
00192         
00193         void getInputDerivation(CStateCollection *originalState, ColumnVector *targetVector);
00194 
00195         virtual void getWeights(double *parameters);
00196         virtual void setWeights(double *parameters); 
00197 
00198 };
00199 
00200 class CQFunctionFromGradientFunction : public CContinuousActionQFunction, CStateObject
00201 {
00202 protected:
00203         CGradientFunction *gradientFunction;
00204         ColumnVector *input;
00205         ColumnVector *outputError;
00206 
00207         CActionSet *staticActions;
00208 
00209         void getInputSequence(ColumnVector *input, CState *state, CContinuousActionData *data);
00210         virtual void updateWeights(CFeatureList *gradientFeatures);
00211 
00212 public:
00213         CQFunctionFromGradientFunction(CContinuousAction *contAction, CGradientFunction *torchGradientFunction, CActionSet *actions, CStateProperties *properties);
00214         virtual ~CQFunctionFromGradientFunction();
00215 
00216         virtual void getBestContinuousAction(CStateCollection *state, CContinuousActionData *actionData);
00217 
00218         virtual void updateCAValue(CStateCollection *state, CContinuousActionData *data, double td);
00219         virtual void setCAValue(CStateCollection *state, CContinuousActionData *data, double qValue); 
00220         virtual double getCAValue(CStateCollection *state, CContinuousActionData *data);
00221 
00222 
00223         virtual void getCAGradient(CStateCollection *state, CContinuousActionData *data, CFeatureList *gradient);
00224         virtual int getNumWeights();
00225 
00226         virtual void getWeights(double *parameters);
00227         virtual void setWeights(double *parameters);
00228 
00229         virtual void resetData();
00230 };
00231 
00232 
00233 
00234 #endif