00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef C_VPOLICYFUNCTIONLEARNER__H
00033 #define C_VPOLICYFUNCTIONLEARNER__H
00034
00035
00036 #include "cagentlistener.h"
00037 #include "newmat/newmat.h"
00038
00039
00040
00041
00042 class CFeatureList;
00043 class CAbstractVFunction;
00044 class CVFunctionInputDerivationCalculator;
00045 class CContinuousActionGradientPolicy;
00046 class CCAGradientPolicyInputDerivationCalculator;
00047 class CContinuousActionData;
00048 class CStateGradient;
00049 class CStateReward;
00050 class CTransitionFunction;
00051 class CTransitionFunctionInputDerivationCalculator;
00052 class CStateCollectionImpl;
00053 class CFeatureList;
00054 class CStateModifier;
00055 class CStateCollection;
00056
00057
00058 class CVPolicyLearner : public CSemiMDPRewardListener
00059 {
00060 protected:
00061 typedef std::list<CFeatureList *> CStateGradient;
00062
00063 CAbstractVFunction *vFunction;
00064 CVFunctionInputDerivationCalculator *vFunctionInputDerivation;
00065
00066 CContinuousActionGradientPolicy *gradientPolicy;
00067 CCAGradientPolicyInputDerivationCalculator *policydInput;
00068
00069 ColumnVector *dReward;
00070 ColumnVector *dVFunction;
00071 Matrix *dPolicy;
00072 Matrix *dModelInput;
00073
00074 CContinuousActionData *data;
00075
00076 std::list<CStateGradient *> *stateGradients;
00077
00078 CStateGradient *stateGradient1;
00079 CStateGradient *stateGradient2;
00080 CStateGradient *dModelGradient;
00081
00082 CStateReward *rewardFunction;
00083 CTransitionFunction *dynModel;
00084 CTransitionFunctionInputDerivationCalculator *dynModeldInput;
00085
00086
00087 CStateCollectionImpl *tempStateCol;
00088
00089 CFeatureList *policyGradient;
00090
00091 void getDNextState(CStateGradient *stateGradient1, CStateGradient *stateGradient2, CStateCollection *currentState, CContinuousActionData *data);
00092 void multMatrixFeatureList(Matrix *matrix, CFeatureList *features, int index, std::list<CFeatureList *> *newFeatures);
00093
00094 std::list<CStateCollectionImpl *> *pastStates;
00095 std::list<ColumnVector *> *pastDRewards;
00096 std::list<CContinuousActionData *> *pastActions;
00097
00098 std::list<CStateCollectionImpl *> *statesResource;
00099 std::list<ColumnVector *> *rewardsResource;
00100 std::list<CContinuousActionData *> *actionsResource;
00101
00102 std::list<CStateModifier *> *stateModifiers;
00103
00104
00105 public:
00106 CVPolicyLearner(CStateReward *rewardFunction, CTransitionFunction *dynModel, CTransitionFunctionInputDerivationCalculator *dynModeldInput,CAbstractVFunction *vFunction, CVFunctionInputDerivationCalculator *vFunctionInputDerivation, CContinuousActionGradientPolicy *gradientPolicy, CCAGradientPolicyInputDerivationCalculator *policydInput, std::list<CStateModifier *> *stateModifiers, int nForwardView);
00107 virtual ~CVPolicyLearner();
00108
00109 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00110
00111 virtual void newEpisode();
00112
00113 void calculateGradient(std::list<CStateCollectionImpl *> *states, std::list<ColumnVector *> *Drewards, std::list<CContinuousActionData *> *actionDatas, CFeatureList *policyGradient);
00114
00115 };
00116
00117 #endif
00118