diff mbox series

[_Hashtable] Fix merge

Message ID f61f9026-94b3-497b-bbe5-807054399500@gmail.com
State New
Headers show
Series [_Hashtable] Fix merge | expand

Commit Message

François Dumont Oct. 19, 2023, 4:55 a.m. UTC
libstdc++: [_Hashtable] Do not reuse untrusted cached hash code

On merge reuse merged node cached hash code only if we are on the same 
type of
hash and this hash is stateless. Usage of function pointers or 
std::function as
hash functor will prevent this optimization.

libstdc++-v3/ChangeLog

     * include/bits/hashtable_policy.h
     (_Hash_code_base::_M_hash_code(const _Hash&, const 
_Hash_node_value<>&)): Remove.
     (_Hash_code_base::_M_hash_code<_H2>(const _H2&, const 
_Hash_node_value<>&)): Remove.
     * include/bits/hashtable.h
     (_M_src_hash_code<_H2>(const _H2&, const key_type&, const 
__node_value_type&)): New.
     (_M_merge_unique<>, _M_merge_multi<>): Use latter.
     * testsuite/23_containers/unordered_map/modifiers/merge.cc
     (test04, test05, test06): New test cases.

Tested under Linux x86_64, ok to commit ?

François

Comments

Jonathan Wakely Oct. 19, 2023, 8:05 a.m. UTC | #1
On Thursday, 19 October 2023, François Dumont <frs.dumont@gmail.com> wrote:
> libstdc++: [_Hashtable] Do not reuse untrusted cached hash code
>
> On merge reuse merged node cached hash code only if we are on the same
type of
> hash and this hash is stateless. Usage of function pointers or
std::function as
> hash functor will prevent this optimization.

I found this first sentence a little hard to parse. How about:

On merge, reuse a merged node's cached hash code only if we are on the same
type of
hash and this hash is stateless.


And for the second sentence, would it be clearer to say "will prevent
reusing cached hash codes" instead of "will prevent this optimization"?


And for the comment on the new function, I think this reads better:

"Only use the node's (possibly cached) hash code if its hash function _H2
matches _Hash. Otherwise recompute it using _Hash."

The code and tests look good, so if you're happy with the comment+changelog
suggestions, this is ok for trunk.

This seems like a bug fix that should be backported too, after some time on
trunk.


>
> libstdc++-v3/ChangeLog
>
>     * include/bits/hashtable_policy.h
>     (_Hash_code_base::_M_hash_code(const _Hash&, const
_Hash_node_value<>&)): Remove.
>     (_Hash_code_base::_M_hash_code<_H2>(const _H2&, const
_Hash_node_value<>&)): Remove.
>     * include/bits/hashtable.h
>     (_M_src_hash_code<_H2>(const _H2&, const key_type&, const
__node_value_type&)): New.
>     (_M_merge_unique<>, _M_merge_multi<>): Use latter.
>     * testsuite/23_containers/unordered_map/modifiers/merge.cc
>     (test04, test05, test06): New test cases.
>
> Tested under Linux x86_64, ok to commit ?
>
> François
>
>
François Dumont Oct. 19, 2023, 5:11 p.m. UTC | #2
Committed with the advised changes.

Ok, I'll backport next week.

Thanks

On 19/10/2023 10:05, Jonathan Wakely wrote:
>
>
> On Thursday, 19 October 2023, François Dumont <frs.dumont@gmail.com> 
> wrote:
> > libstdc++: [_Hashtable] Do not reuse untrusted cached hash code
> >
> > On merge reuse merged node cached hash code only if we are on the 
> same type of
> > hash and this hash is stateless. Usage of function pointers or 
> std::function as
> > hash functor will prevent this optimization.
>
> I found this first sentence a little hard to parse. How about:
>
> On merge, reuse a merged node's cached hash code only if we are on the 
> same
> type of
> hash and this hash is stateless.
>
>
> And for the second sentence, would it be clearer to say "will prevent 
> reusing cached hash codes" instead of "will prevent this optimization"?
>
>
> And for the comment on the new function, I think this reads better:
>
> "Only use the node's (possibly cached) hash code if its hash function 
> _H2 matches _Hash. Otherwise recompute it using _Hash."
>
> The code and tests look good, so if you're happy with the 
> comment+changelog suggestions, this is ok for trunk.
>
> This seems like a bug fix that should be backported too, after some 
> time on trunk.
>
>
> >
> > libstdc++-v3/ChangeLog
> >
> >     * include/bits/hashtable_policy.h
> >     (_Hash_code_base::_M_hash_code(const _Hash&, const 
> _Hash_node_value<>&)): Remove.
> >     (_Hash_code_base::_M_hash_code<_H2>(const _H2&, const 
> _Hash_node_value<>&)): Remove.
> >     * include/bits/hashtable.h
> >     (_M_src_hash_code<_H2>(const _H2&, const key_type&, const 
> __node_value_type&)): New.
> >     (_M_merge_unique<>, _M_merge_multi<>): Use latter.
> >     * testsuite/23_containers/unordered_map/modifiers/merge.cc
> >     (test04, test05, test06): New test cases.
> >
> > Tested under Linux x86_64, ok to commit ?
> >
> > François
> >
> >
diff mbox series

Patch

diff --git a/libstdc++-v3/include/bits/hashtable.h b/libstdc++-v3/include/bits/hashtable.h
index 4c12dc895b2..f69acfe5213 100644
--- a/libstdc++-v3/include/bits/hashtable.h
+++ b/libstdc++-v3/include/bits/hashtable.h
@@ -1109,6 +1109,20 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	return { __n, this->_M_node_allocator() };
       }
 
+      // Check and if needed compute hash code using _Hash as __n _M_hash_code,
+      // if present, was computed using _H2.
+      template<typename _H2>
+	__hash_code
+	_M_src_hash_code(const _H2&, const key_type& __k,
+			 const __node_value_type& __src_n) const
+	{
+	  if constexpr (std::is_same_v<_H2, _Hash>)
+	    if constexpr (std::is_empty_v<_Hash>)
+	      return this->_M_hash_code(__src_n);
+
+	  return this->_M_hash_code(__k);
+	}
+
     public:
       // Extract a node.
       node_type
@@ -1146,7 +1160,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	      auto __pos = __i++;
 	      const key_type& __k = _ExtractKey{}(*__pos);
 	      __hash_code __code
-		= this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+		= _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
 	      size_type __bkt = _M_bucket_index(__code);
 	      if (_M_find_node(__bkt, __k, __code) == nullptr)
 		{
@@ -1174,8 +1188,9 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	  for (auto __i = __src.cbegin(), __end = __src.cend(); __i != __end;)
 	    {
 	      auto __pos = __i++;
+	      const key_type& __k = _ExtractKey{}(*__pos);
 	      __hash_code __code
-		= this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+		= _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
 	      auto __nh = __src.extract(__pos);
 	      __hint = _M_insert_multi_node(__hint, __code, __nh._M_ptr)._M_cur;
 	      __nh._M_ptr = nullptr;
diff --git a/libstdc++-v3/include/bits/hashtable_policy.h b/libstdc++-v3/include/bits/hashtable_policy.h
index 86b32fb15f2..5d162463dc3 100644
--- a/libstdc++-v3/include/bits/hashtable_policy.h
+++ b/libstdc++-v3/include/bits/hashtable_policy.h
@@ -1319,19 +1319,6 @@  namespace __detail
 	  return _M_hash()(__k);
 	}
 
-      __hash_code
-      _M_hash_code(const _Hash&,
-		   const _Hash_node_value<_Value, true>& __n) const
-      { return __n._M_hash_code; }
-
-      // Compute hash code using _Hash as __n _M_hash_code, if present, was
-      // computed using _H2.
-      template<typename _H2>
-	__hash_code
-	_M_hash_code(const _H2&,
-		const _Hash_node_value<_Value, __cache_hash_code>& __n) const
-	{ return _M_hash_code(_ExtractKey{}(__n._M_v())); }
-
       __hash_code
       _M_hash_code(const _Hash_node_value<_Value, false>& __n) const
       { return _M_hash_code(_ExtractKey{}(__n._M_v())); }
diff --git a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
index b140ce452aa..c051b58137a 100644
--- a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
+++ b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
@@ -17,15 +17,29 @@ 
 
 // { dg-do run { target c++17 } }
 
+#include <string>
+#include <functional>
 #include <unordered_map>
 #include <algorithm>
 #include <testsuite_hooks.h>
 
 using test_type = std::unordered_map<int, int>;
 
-struct hash {
-  auto operator()(int i) const noexcept { return ~std::hash<int>()(i); }
-};
+template<typename T>
+  struct xhash
+  {
+    auto operator()(const T& i) const noexcept
+    { return ~std::hash<T>()(i); }
+  };
+
+
+namespace std
+{
+  template<typename T>
+    struct __is_fast_hash<xhash<T>> : __is_fast_hash<std::hash<T>>
+    { };
+}
+
 struct equal : std::equal_to<> { };
 
 template<typename C1, typename C2>
@@ -64,7 +78,7 @@  test02()
 {
   const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
   test_type c1 = c0;
-  std::unordered_map<int, int, hash, equal> c2( c0.begin(), c0.end() );
+  std::unordered_map<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
 
   c1.merge(c2);
   VERIFY( c1 == c0 );
@@ -89,7 +103,7 @@  test03()
 {
   const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
   test_type c1 = c0;
-  std::unordered_multimap<int, int, hash, equal> c2( c0.begin(), c0.end() );
+  std::unordered_multimap<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
   c1.merge(c2);
   VERIFY( c1 == c0 );
   VERIFY( equal_elements(c2, c0) );
@@ -125,10 +139,164 @@  test03()
   VERIFY( c2.empty() );
 }
 
+void
+test04()
+{
+  const std::unordered_map<std::string, int> c0
+    { {"one", 10}, {"two", 20}, {"three", 30} };
+
+  std::unordered_map<std::string, int> c1 = c0;
+  std::unordered_multimap<std::string, int> c2( c0.begin(), c0.end() );
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
+void
+test05()
+{
+  const std::unordered_map<std::string, int> c0
+    { {"one", 10}, {"two", 20}, {"three", 30} };
+
+  std::unordered_map<std::string, int> c1 = c0;
+  std::unordered_multimap<std::string, int, xhash<std::string>, equal> c2( c0.begin(), c0.end() );
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
+template<typename T>
+  using hash_f =
+    std::function<std::size_t(const T&)>;
+
+std::size_t
+hash_func(const std::string& str)
+{ return std::hash<std::string>{}(str);  }
+
+std::size_t
+xhash_func(const std::string& str)
+{ return xhash<std::string>{}(str); }
+
+namespace std
+{
+  template<typename T>
+    struct __is_fast_hash<hash_f<T>> : __is_fast_hash<std::hash<T>>
+    { };
+}
+
+void
+test06()
+{
+  const std::unordered_map<std::string, int, hash_f<std::string>, equal>
+    c0({ {"one", 10}, {"two", 20}, {"three", 30} }, 3, &hash_func);
+
+  std::unordered_map<std::string, int, hash_f<std::string>, equal>
+    c1(3, &hash_func);
+  c1 = c0;
+  std::unordered_multimap<std::string, int, hash_f<std::string>, equal>
+    c2(c0.begin(), c0.end(), 3, &xhash_func);
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
 int
 main()
 {
   test01();
   test02();
   test03();
+  test04();
+  test05();
+  test06();
 }