120 lines
9.6 KiB
MQL5
120 lines
9.6 KiB
MQL5
//+------------------------------------------------------------------+
|
|
//| lossfunction.mqh |
|
|
//| Copyright 2021, MetaQuotes Ltd. |
|
|
//| https://www.mql5.com |
|
|
//+------------------------------------------------------------------+
|
|
#property copyright "Copyright 2021, MetaQuotes Ltd."
|
|
#property link "https://www.mql5.com"
|
|
//+------------------------------------------------------------------+
|
|
//| Подключаем библиотеки |
|
|
//+------------------------------------------------------------------+
|
|
#include "defines.mqh"
|
|
#include "bufferdouble.mqh"
|
|
//+------------------------------------------------------------------+
|
|
//| Class CLossFunction |
|
|
//| Назначение: Класс работы с функциями потерь |
|
|
//+------------------------------------------------------------------+
|
|
class CLossFunction : public CObject
|
|
{
|
|
protected:
|
|
ENUM_LOSS_FUNCTION eFunction;
|
|
|
|
virtual double MAD(CArrayDouble *calculated, CArrayDouble *target);
|
|
virtual double MSE(CArrayDouble *calculated, CArrayDouble *target);
|
|
virtual double LogLoss(CArrayDouble *calculated, CArrayDouble *target);
|
|
|
|
public:
|
|
CLossFunction(void);
|
|
~CLossFunction(void) {};
|
|
//---
|
|
void LossFunction(ENUM_LOSS_FUNCTION function) { eFunction = function; }
|
|
ENUM_LOSS_FUNCTION LossFunction(void) { return(eFunction); }
|
|
//---
|
|
virtual double CaclFunction(CArrayDouble *calculated, CArrayDouble *target);
|
|
};
|
|
//+------------------------------------------------------------------+
|
|
//| Конструктор |
|
|
//+------------------------------------------------------------------+
|
|
CLossFunction::CLossFunction(void) : eFunction(ENUM_LOSS_LogLoss)
|
|
{}
|
|
//+------------------------------------------------------------------+
|
|
//| Диспетчерский метод определения значения функции потерь |
|
|
//+------------------------------------------------------------------+
|
|
double CLossFunction::CaclFunction(CArrayDouble *calculated, CArrayDouble *target)
|
|
{
|
|
double result = DBL_MAX;
|
|
//---
|
|
switch(eFunction)
|
|
{
|
|
case ENUM_LOSS_MAD:
|
|
result = MAD(calculated, target);
|
|
break;
|
|
case ENUM_LOSS_MSE:
|
|
result = MSE(calculated, target);
|
|
break;
|
|
case ENUM_LOSS_LogLoss:
|
|
result = LogLoss(calculated, target);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
//---
|
|
return result;
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Метод определения абсолютного среднего отклонения |
|
|
//+------------------------------------------------------------------+
|
|
double CLossFunction::MAD(CArrayDouble *calculated, CArrayDouble *target)
|
|
{
|
|
double result = DBL_MAX;
|
|
//---
|
|
if(CheckPointer(calculated) == POINTER_INVALID || CheckPointer(target) == POINTER_INVALID ||
|
|
calculated.Total() > target.Total())
|
|
return result;
|
|
//---
|
|
result = 0;
|
|
int total = calculated.Total();
|
|
for(int i = 0; i < total; i++)
|
|
result += MathAbs(calculated.At(i) - target.At(i));
|
|
result /= total;
|
|
//---
|
|
return result;
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Метод определения среднеквадратического отклонения |
|
|
//+------------------------------------------------------------------+
|
|
double CLossFunction::MSE(CArrayDouble *calculated, CArrayDouble *target)
|
|
{
|
|
double result = DBL_MAX;
|
|
//---
|
|
if(CheckPointer(calculated) == POINTER_INVALID || CheckPointer(target) == POINTER_INVALID ||
|
|
calculated.Total() > target.Total())
|
|
return result;
|
|
//---
|
|
result = 0;
|
|
int total = calculated.Total();
|
|
for(int i = 0; i < total; i++)
|
|
result += MathPow(calculated.At(i) - target.At(i), 2);
|
|
result /= total;
|
|
//---
|
|
return result;
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Кросс-энтропия |
|
|
//+------------------------------------------------------------------+
|
|
double CLossFunction::LogLoss(CArrayDouble *calculated, CArrayDouble *target)
|
|
{
|
|
double result = DBL_MAX;
|
|
//---
|
|
if(CheckPointer(calculated) == POINTER_INVALID || CheckPointer(target) == POINTER_INVALID ||
|
|
calculated.Total() > target.Total())
|
|
return result;
|
|
//---
|
|
result = 0;
|
|
int total = target.Total();
|
|
for(int i = 0; i < total; i++)
|
|
result -= target.At(i) * MathLog(calculated.At(i));
|
|
//---
|
|
return result;
|
|
}
|
|
//+------------------------------------------------------------------+
|