NeuroNetworksBook/Include/realization/lossfunction.mqh

125 lines
12 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:12:34 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| lossfunction.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#include "defines.mqh"
#include "bufferdouble.mqh"
//+------------------------------------------------------------------+
//| Class CLossFunction |
//| 07=0G5=85: 07>2K9 :;0AA 4;O @01>BK A DC=:F8O<8 ?>B5@L |
//+------------------------------------------------------------------+
class CLossFunction : public CObject
{
protected:
//--- 70I8IQ==K9 :>=AB@C:B>@ =5 ?>72>;8B A>74020BL >1J5:B CLossFunction, >B 40==>3> :;0AA0 <>6=> B>;L:> =0A;54>20BLAO
CLossFunction(void) { }
CLossFunction(const CLossFunction &loss)=delete; // :>?8@>20=85 70?@5I5=>
public:
virtual ENUM_LOSS_FUNCTION LossFunction(void) const=0;
virtual double Calculate(const CBufferDouble *calculated,const CBufferDouble *target)=0;
//+------------------------------------------------------------------+
//| A?><>30B5;L=0O DC=:F8O 4;O ?@>25@:8 ?0@0<5B@>2 2 Calculate |
//+------------------------------------------------------------------+
bool CheckParameters(const CBufferDouble *calculated,const CBufferDouble *target) const
{
return(calculated && target && calculated.Rows()==target.Rows() && calculated.Cols()==target.Cols());
}
};
//+------------------------------------------------------------------+
//| Mean squared error (MSE) |
//+------------------------------------------------------------------+
class CLoss_MSE final : public CLossFunction
{
public:
//+------------------------------------------------------------------+
//| >;CG5=85 B8?0 DC=:F88 |
//+------------------------------------------------------------------+
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
{
return LOSS_MSE;
}
//+------------------------------------------------------------------+
//| 0AGQB >H81:8 |
//+------------------------------------------------------------------+
virtual double Calculate(const CBufferDouble *calculated,const CBufferDouble *target) override
{
//--- ?@>25@8< ?0@0<5B@K
if(!CheckParameters(calculated,target))
return DBL_MAX;
//---
MATRIX m = calculated.m_mMatrix - target.m_mMatrix;
m = m * m;
//---
return m.Sum() / (m.Rows()*m.Cols());
}
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
class CLoss_MAD final : public CLossFunction
{
public:
//+------------------------------------------------------------------+
//| >;CG5=85 B8?0 DC=:F88 |
//+------------------------------------------------------------------+
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
{
return LOSS_MAE;
}
//+------------------------------------------------------------------+
//| 0AGQB >H81:8 |
//+------------------------------------------------------------------+
virtual double Calculate(const CBufferDouble *calculated,const CBufferDouble *target) override
{
//--- ?@>25@8< ?0@0<5B@K
if(!CheckParameters(calculated,target))
return DBL_MAX;
//---
MATRIX m=calculated.m_mMatrix - target.m_mMatrix;
for(ulong r=0; r<m.Rows(); r++)
for(ulong c=0; c<m.Cols(); c++)
m[r, c] = MathAbs(m[r, c]);
//---
return m.Sum() / (m.Rows() * m.Cols());
}
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
class CLoss_BCE final : public CLossFunction
{
public:
//+------------------------------------------------------------------+
//| >;CG5=85 B8?0 DC=:F88 |
//+------------------------------------------------------------------+
virtual ENUM_LOSS_FUNCTION LossFunction(void) const override
{
return LOSS_BCE;
}
//+------------------------------------------------------------------+
//| 0AGQB >H81:8 |
//+------------------------------------------------------------------+
virtual double Calculate(const CBufferDouble *calculated,const CBufferDouble *target) override
{
//--- ?@>25@8< ?0@0<5B@K
if(!CheckParameters(calculated,target))
return DBL_MAX;
//---
double result = 0;
for(ulong r = 0; r < calculated.m_mMatrix.Rows(); r++)
for(ulong c = 0; c < calculated.m_mMatrix.Cols(); c++)
result -= target.m_mMatrix[r, c] * MathLog(calculated.m_mMatrix[r, c]);
//---
return result;
}
};
//+------------------------------------------------------------------+