85 lines
8.4 KiB
MQL5
85 lines
8.4 KiB
MQL5
//+------------------------------------------------------------------+
|
|
//| lossfunction.mqh |
|
|
//| Copyright 2021, MetaQuotes Ltd. |
|
|
//| https://www.mql5.com |
|
|
//+------------------------------------------------------------------+
|
|
#property copyright "Copyright 2021, MetaQuotes Ltd."
|
|
#property link "https://www.mql5.com"
|
|
//+------------------------------------------------------------------+
|
|
//| Connect libraries |
|
|
//+------------------------------------------------------------------+
|
|
#include "defines.mqh"
|
|
#include "buffer.mqh"
|
|
//+------------------------------------------------------------------+
|
|
//| Class CLossFunction |
|
|
//| Purpose: Base class for working with loss functions |
|
|
//+------------------------------------------------------------------+
|
|
class CLossFunction : public CObject
|
|
{
|
|
protected:
|
|
//--- protected constructor will not allow creating a CLossFunction object; only inheritance from this class is allowed
|
|
CLossFunction(void) { }
|
|
CLossFunction(const CLossFunction &loss) = delete; // copying is prohibited
|
|
public:
|
|
virtual ENUM_LOSS_FUNCTION LossFunction(void) const = 0;
|
|
//+------------------------------------------------------------------+
|
|
//| Helper function for checking parameters in Calculate |
|
|
//+------------------------------------------------------------------+
|
|
bool CheckParameters(const CBufferType *calculated, const CBufferType *target) const
|
|
{
|
|
return(calculated && target && calculated.Rows() == target.Rows() && calculated.Cols() == target.Cols());
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Calculate error |
|
|
//+------------------------------------------------------------------+
|
|
virtual TYPE Calculate(const CBufferType *calculated, const CBufferType *target)
|
|
{
|
|
//--- check parameters
|
|
if(!CheckParameters(calculated, target))
|
|
return FLT_MAX;
|
|
return calculated.m_mMatrix.Loss(target.m_mMatrix,LossFunction());
|
|
}
|
|
};
|
|
//+------------------------------------------------------------------+
|
|
//| Mean squared error (MSE) |
|
|
//+------------------------------------------------------------------+
|
|
class CLoss_MSE final : public CLossFunction
|
|
{
|
|
public:
|
|
//+------------------------------------------------------------------+
|
|
//| Get the type of the function |
|
|
//+------------------------------------------------------------------+
|
|
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
|
|
{
|
|
return LOSS_MSE;
|
|
}
|
|
};
|
|
//+------------------------------------------------------------------+
|
|
//| Mean average error (MAE) |
|
|
//+------------------------------------------------------------------+
|
|
class CLoss_MAE final : public CLossFunction
|
|
{
|
|
public:
|
|
//+------------------------------------------------------------------+
|
|
//| Get the type of the function |
|
|
//+------------------------------------------------------------------+
|
|
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
|
|
{
|
|
return LOSS_MAE;
|
|
}
|
|
};
|
|
//+------------------------------------------------------------------+
|
|
//| Binary Crossentropy error (BCE) |
|
|
//+------------------------------------------------------------------+
|
|
class CLoss_BCE final : public CLossFunction
|
|
{
|
|
public:
|
|
//+------------------------------------------------------------------+
|
|
//| Get the type of the function |
|
|
//+------------------------------------------------------------------+
|
|
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
|
|
{
|
|
return LOSS_BCE;
|
|
}
|
|
};
|
|
//+------------------------------------------------------------------+
|