diff mbox series

skip Cholesky decomposition in is>>n_mv_dist

Message ID orzhkiixqj.fsf@lxoliva.fsfla.org
State New
Headers show
Series skip Cholesky decomposition in is>>n_mv_dist | expand

Commit Message

Alexandre Oliva Aug. 9, 2019, 7:50 a.m. UTC
normal_mv_distribution maintains the variance-covariance matrix param
in Cholesky-decomposed form.  Existing param_type constructors, when
taking a full or lower-triangle varcov matrix, perform Cholesky
decomposition to convert it to the internal representation.  This
internal representation is visible both in the varcov() result, and in
the streamed-out representation of a normal_mv_distribution object.

The problem is that when that representation is streamed back in, the
read-back decomposed varcov matrix is used as a lower-triangle
non-decomposed varcov matrix, and it undergoes Cholesky decomposition
again.  So, each cycle of stream-out/stream-in changes the varcov
matrix to its "square root", instead of restoring the original
params.

This patch includes Corentin's changes that introduce verification in
testsuite/ext/random/normal_mv_distribution/operators/serialize.cc and
other similar tests that the object read back in compares equal to the
written-out object: the modified tests pass only if (u == v).

This patch also fixes the error exposed by his change, introducing an
alternate private constructor for param_type, used only by operator>>.

Tested on x86_64-linux-gnu.  Ok to install?


for  libstdc++-v3/ChangeLog

	* include/ext/random
	(normal_mv_distribution::param_type::param_type): New private
	ctor taking a decomposed varcov matrix, for use by...
	(operator>>): ... this, befriended.
	* include/ext/random.tcc (operator>>): Use it.
	(normal_mv_distribution::param_type::_M_init_lower): Adjust
	member function name in exception message.

for  libstdc++-v3/ChangeLog
from  Corentin Gay  <gay@adacore.com>

	* testsuite/ext/random/beta_distribution/operators/serialize.cc,
	testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc,
	testsuite/ext/random/normal_mv_distribution/operators/serialize.cc,
	testsuite/ext/random/triangular_distribution/operators/serialize.cc,
	testsuite/ext/random/von_mises_distribution/operators/serialize.cc:
	Add call to `VERIFY`.
---
 libstdc++-v3/include/ext/random                    |   15 +++++++++++++++
 libstdc++-v3/include/ext/random.tcc                |    8 +++++---
 .../beta_distribution/operators/serialize.cc       |    2 ++
 .../operators/serialize.cc                         |    1 +
 .../normal_mv_distribution/operators/serialize.cc  |    2 ++
 .../triangular_distribution/operators/serialize.cc |    2 ++
 .../von_mises_distribution/operators/serialize.cc  |    2 ++
 7 files changed, 29 insertions(+), 3 deletions(-)

Comments

Ulrich Drepper Aug. 9, 2019, 8:20 a.m. UTC | #1
On Fri, Aug 9, 2019 at 9:50 AM Alexandre Oliva <oliva@adacore.com> wrote:

> normal_mv_distribution maintains the variance-covariance matrix param
> in Cholesky-decomposed form.  Existing param_type constructors, when
> taking a full or lower-triangle varcov matrix, perform Cholesky
> decomposition to convert it to the internal representation.  This
> internal representation is visible both in the varcov() result, and in
> the streamed-out representation of a normal_mv_distribution object.
>
> […]
>


> Tested on x86_64-linux-gnu.  Ok to install?
>

Yes.  Thanks.
Jonathan Wakely Aug. 9, 2019, 10:26 a.m. UTC | #2
On 09/08/19 10:20 +0200, Ulrich Drepper wrote:
>On Fri, Aug 9, 2019 at 9:50 AM Alexandre Oliva <oliva@adacore.com> wrote:
>
>> normal_mv_distribution maintains the variance-covariance matrix param
>> in Cholesky-decomposed form.  Existing param_type constructors, when
>> taking a full or lower-triangle varcov matrix, perform Cholesky
>> decomposition to convert it to the internal representation.  This
>> internal representation is visible both in the varcov() result, and in
>> the streamed-out representation of a normal_mv_distribution object.
>>
>> […]
>>
>
>
>> Tested on x86_64-linux-gnu.  Ok to install?
>>
>
>Yes.  Thanks.

If the operator>> is a friend it can just write straight to the array
members of the param_type object:

diff --git a/libstdc++-v3/include/ext/random.tcc b/libstdc++-v3/include/ext/random.tcc
index 31dc33a2555..77abdd9a1de 100644
--- a/libstdc++-v3/include/ext/random.tcc
+++ b/libstdc++-v3/include/ext/random.tcc
@@ -700,18 +700,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       const typename __ios_base::fmtflags __flags = __is.flags();
       __is.flags(__ios_base::dec | __ios_base::skipws);

-      std::array<_RealType, _Dimen> __mean;
-      for (auto& __it : __mean)
+      typename normal_mv_distribution<_Dimen, _RealType>::param_type __param;
+      for (auto& __it : __param._M_mean)
        __is >> __it;
-      std::array<_RealType, _Dimen * (_Dimen + 1) / 2> __varcov;
-      for (auto& __it : __varcov)
+      for (auto& __it : __param._M_t)
        __is >> __it;

       __is >> __x._M_nd;

-      __x.param(typename normal_mv_distribution<_Dimen, _RealType>::
-               param_type(__mean.begin(), __mean.end(),
-                          __varcov.begin(), __varcov.end()));
+      __x.param(__param);

       __is.flags(__flags);
       return __is;


The default constructor for param_type() will pointlessly fill the
arrays that are about to be overwritten though, so maybe this isn't an
improvement.
diff mbox series

Patch

diff --git a/libstdc++-v3/include/ext/random b/libstdc++-v3/include/ext/random
index 41a2962c8f6e5..d5574e02ba02c 100644
--- a/libstdc++-v3/include/ext/random
+++ b/libstdc++-v3/include/ext/random
@@ -752,6 +752,21 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 				_InputIterator2 __varbegin,
 				_InputIterator2 __varend);
 
+	// param_type constructors apply Cholesky decomposition to the
+	// varcov matrix in _M_init_full and _M_init_lower, but the
+	// varcov matrix output ot a stream is already decomposed, so
+	// we need means to restore it as-is when reading it back in.
+	template<size_t _Dimen1, typename _RealType1,
+		 typename _CharT, typename _Traits>
+	friend std::basic_istream<_CharT, _Traits>&
+	operator>>(std::basic_istream<_CharT, _Traits>& __is,
+		   __gnu_cxx::normal_mv_distribution<_Dimen1, _RealType1>&
+		   __x);
+	param_type(std::array<_RealType, _Dimen> const &__mean,
+		   std::array<_RealType, _M_t_size> const &__varcov)
+	  : _M_mean (__mean), _M_t (__varcov)
+	{}
+
 	std::array<_RealType, _Dimen> _M_mean;
 	std::array<_RealType, _M_t_size> _M_t;
       };
diff --git a/libstdc++-v3/include/ext/random.tcc b/libstdc++-v3/include/ext/random.tcc
index 31dc33a2555ed..a8a49a3a9fa6a 100644
--- a/libstdc++-v3/include/ext/random.tcc
+++ b/libstdc++-v3/include/ext/random.tcc
@@ -581,7 +581,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    __sum = *__varcovbegin++ - __sum;
 	    if (__builtin_expect(__sum <= _RealType(0), 0))
 	      std::__throw_runtime_error(__N("normal_mv_distribution::"
-					     "param_type::_M_init_full"));
+					     "param_type::_M_init_lower"));
 	    *__w++ = std::sqrt(__sum);
 	  }
       }
@@ -709,9 +709,11 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
       __is >> __x._M_nd;
 
+      // The param_type temporary is built with a private constructor,
+      // to skip the Cholesky decomposition that would be performed
+      // otherwise.
       __x.param(typename normal_mv_distribution<_Dimen, _RealType>::
-		param_type(__mean.begin(), __mean.end(),
-			   __varcov.begin(), __varcov.end()));
+		param_type(__mean, __varcov));
 
       __is.flags(__flags);
       return __is;
diff --git a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc
index b05417156d191..a4925fc1c41be 100644
--- a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc
+++ b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc
@@ -23,6 +23,7 @@ 
 
 #include <ext/random>
 #include <sstream>
+#include <testsuite_hooks.h>
 
 void
 test01()
@@ -35,6 +36,7 @@  test01()
   str << u;
 
   str >> v;
+  VERIFY( u == v );
 }
 
 int main()
diff --git a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc
index 9c2cc46ac1ce0..e9077b2c58d65 100644
--- a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc
+++ b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc
@@ -38,6 +38,7 @@  test01()
   str << u;
 
   str >> v;
+  VERIFY( u == v );
 }
 
 int
diff --git a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc
index 8d83f9e6966d2..f5fbc42a686f0 100644
--- a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc
+++ b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc
@@ -23,6 +23,7 @@ 
 
 #include <ext/random>
 #include <sstream>
+#include <testsuite_hooks.h>
 
 void
 test01()
@@ -35,6 +36,7 @@  test01()
   str << u;
 
   str >> v;
+  VERIFY( u == v );
 }
 
 int main()
diff --git a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc
index cf17fea8b03ff..75e16cf0437cb 100644
--- a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc
+++ b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc
@@ -23,6 +23,7 @@ 
 
 #include <ext/random>
 #include <sstream>
+#include <testsuite_hooks.h>
 
 void
 test01()
@@ -35,6 +36,7 @@  test01()
   str << u;
 
   str >> v;
+  VERIFY( u == v );
 }
 
 int main()
diff --git a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc
index f3d7912e314ba..b32a31dee6421 100644
--- a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc
+++ b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc
@@ -23,6 +23,7 @@ 
 
 #include <ext/random>
 #include <sstream>
+#include <testsuite_hooks.h>
 
 void
 test01()
@@ -35,6 +36,7 @@  test01()
   str << u;
 
   str >> v;
+  VERIFY( u == v );
 }
 
 int main()