// vi:set ft=cpp: -*- Mode: C++ -*-
/*
 * Copyright (C) 2025 Kernkonzept GmbH.
 * Author(s): Martin Decky <martin.decky@kernkonzept.com>
 *
 * License: see LICENSE.spdx (in this directory or the directories above)
 */

/**
 * \file
 * \brief Splay tree
 *
 * Splay tree is a dynamically self-balancing binary search tree suitable for
 * workloads that do not follow uniform random distribution. In practice, the
 * operations have a slightly higher overhead than in statically balanced
 * binary search trees, but the most frequently touched nodes stay close to the
 * root of the tree.
 */

#pragma once

#include "std_ops"
#include "pair"

#include "bits/bst.h"
#include "bits/bst_iter.h"

namespace cxx {

/**
 * Node of a splay tree.
 */
class Splay_tree_node : public Bits::Bst_node
{
private:
  template<typename Node, typename Get_key, typename Compare>
  friend class Splay_tree;

  /// Alias for Direction.
  typedef Bits::Direction Dir;

  // We are a final BST node, hide interior.
  /**@{*/
  using Bits::Bst_node::next;
  using Bits::Bst_node::next_p;
  using Bits::Bst_node::rotate;
  /**@}*/

protected:
  /// Create an uninitialized node, this is what you should do.
  Splay_tree_node() = default;

private:
  Splay_tree_node(Splay_tree_node const &) = delete;
  Splay_tree_node(Splay_tree_node &&) = delete;

  /// Default copy assignment for friend Splay_tree.
  Splay_tree_node &operator = (Splay_tree_node const &) = default;

  /// Default move assignment for friend Splay_tree.
  Splay_tree_node &operator = (Splay_tree_node &&) = default;

  /// Create an initialized node (for internal stuff).
  explicit Splay_tree_node(bool) : Bits::Bst_node(true) {}
};

/**
 * Generic splay tree.
 *
 * This implementation does not provide any memory management. It is the
 * responsibility of the caller to allocate nodes before inserting them and
 * to free them when they are removed or when the tree is destroyed.
 *
 * Conversely, the caller must also ensure that nodes are removed from the tree
 * before they are destroyed.
 *
 * \tparam Node    The data type of the nodes (must inherit from Splay_tree_node).
 * \tparam Get_key The meta function to get the key value from a node.
 *                 The implementation uses `Get_key::key_of(ptr_to_node)`. The
 *                 type of the key values must be defined in `Get_key::Key_type`.
 * \tparam Compare Binary relation to establish a total order for the
 *                 nodes of the tree. `Compare()(l, r)` must return true if
 *                 the key \a l is smaller than the key \a r.
 */
template<typename Node, typename Get_key,
         typename Compare = Lt_functor<typename Get_key::Key_type>>
class Splay_tree : public Bits::Bst<Node, Get_key, Compare>
{
private:
  typedef Bits::Bst<Node, Get_key, Compare> Bst;

  /// Hide this from possible descendants.
  using Bst::_head;

  /// Provide access to keys of nodes.
  using Bst::k;

  /// Alias type for Direction values.
  typedef typename Splay_tree_node::Dir Dir;

  Splay_tree(Splay_tree const &) = delete;
  Splay_tree &operator = (Splay_tree const &) = delete;

  Splay_tree(Splay_tree &&) = delete;
  Splay_tree &operator = (Splay_tree &&) = delete;

public:
  ///@{
  typedef Node Node_type;
  typedef typename Bst::Key_type Key_type;
  typedef typename Bst::Key_param_type Key_param_type;
  ///@}

  // Grab iterator types from Bst.
  ///@{
  /// Forward iterator for the tree.
  typedef typename Bst::Iterator Iterator;

  /// Constant forward iterator for the tree.
  typedef typename Bst::Const_iterator Const_iterator;

  /// Backward iterator for the tree.
  typedef typename Bst::Rev_iterator Rev_iterator;

  /// Constant backward iterator for the tree.
  typedef typename Bst::Const_rev_iterator Const_rev_iterator;
  ///@}

  /**
   * Find the node with \a key.
   *
   * \note Due to the nature of splay trees, searching for an element rebuilds
   *       (splays) the tree. Therefore this is not a const method.
   *
   * \param key  Key of the element to search.
   *
   * \return Pointer to the node with the given \a key, or nullptr if \a key
   *         was not found.
   */
  Node *find_node(Key_param_type key);

  /**
   * Find node with \a key (constant variant).
   *
   * Since a splay tree is a binary search tree, we can have a const method
   * variant of the find operation that uses the generic BST find algorithm
   * for cases where it might be useful (a read-only object instance, a
   * likelihood of a negative search result which should not splay the tree,
   * etc.).
   *
   * \note Splaying the tree on lookup is actually essential for achieving
   *       optimal amortized time complexity of the data structure. Thus
   *       this const method should be used only when strictly necessary.
   *
   * \param key  Key of the element to search.
   *
   * \return Pointer to the node with the given \a key, or nullptr if \a key
   *         was not found.
   */
   Node *find_node_const(Key_param_type key) const
   { Bst::find_node(key); }

  /**
   * Insert new node into the splay tree.
   *
   * \param new_node  Pointer to the new node.
   *
   * \return Pair, with \a second set to `true` and \a first pointing to
   *         \a new_node, on success. If there is already a node with the same
   *         key, then \a first points to this node and \a second is 'false'.
   */
  Pair<Node *, bool> insert(Node *new_node);

  /**
   * Remove the node with \a key from the splay tree.
   *
   * \param key  Key to the node to remove.
   *
   * \return Pointer to the removed node on success, or nullptr if no node with
   *         the \a key exists.
   */
  Node *remove(Key_param_type key);

  /**
   * Alias for remove().
   */
  Node *erase(Key_param_type key) { return remove(key); }

  /// Create an empty splay tree.
  Splay_tree() = default;

  /// Destroy the tree.
  ~Splay_tree() noexcept
  { this->remove_all([](Node *){}); }

private:
  /**
   * Splay tree around the specified key.
   *
   * Top-down splay algorithm that rebuilds the tree around the specified key.
   * In case the tree contains the node with the given key, it will become the
   * new root of the tree.
   *
   * \param[in,out] root  Reference to the root node of the tree to splay. Note
   *                      that this is an output argument because the tree is
   *                      rebuilt. The root node is assumed to be non-null.
   * \param[in]     key   Key to splay the tree around.
   *
   * \retval true   The tree contains the node with the given key and it has
   *                become the root of the tree.
   * \retval false  The tree does not contain the node with the given key. The
   *                tree has been rebuilt regardless.
   */
  bool splay(Bits::Bst_node *&root, Key_param_type key);
};

//----------------------------------------------------------------------------
/* Implementation of splay tree */

template<typename Node, typename Get_key, class Compare>
bool
Splay_tree<Node, Get_key, Compare>::splay(Bits::Bst_node *&root,
                                          Key_param_type key)
{
  /*
   * Temporary trees for splaying the original tree. The left tree contains
   * nodes with keys smaller than the given key. The right tree contains nodes
   * with keys larger than the given key.
   */
  Bits::Bst_node *left_root = nullptr;
  Bits::Bst_node *right_root = nullptr;

  /*
   * Tail pointers for appending nodes to the left/right tree: rightmost node
   * of the left tree and leftmost node of the right tree.
   */
  Bits::Bst_node *left_tail = nullptr;
  Bits::Bst_node *right_tail = nullptr;

  // Flag for indicating exact match.
  bool found = false;

  /*
   * We search for the key and splay the tree at the same time. The root always
   * points to the top-most node of the subtree we are currently inspecting.
   */
  while (true)
    {
      // Compare the current root node with the key.
      auto dir = Bst::dir(key, root);
      if (dir == Dir::N)
        {
          // We have found an exact match.
          found = true;
          break;
        }

     if (dir == Dir::L)
       {
         // Inspect the left subtree (key < root).

         auto left = Splay_tree_node::next(root, Dir::L);
         if (!left)
           {
             /*
              * There is no left subtree, thus the root is the closest node to
              * the key.
              */
             break;
           }

         auto next_dir = Bst::dir(key, left);
         if (next_dir == Dir::L)
           {
             /*
              * Zig-Zig left-left rotation (key < root && key < left).
              *
              *      root                  left
              *      /  \                  /  \
              *  left    B       -->      A    root
              *  /  \                          /  \
              * A    left_right      left_right    B
              *
              */

             auto left_right = Splay_tree_node::next(left, Dir::R);
             Splay_tree_node::next(root, Dir::L, left_right);
             Splay_tree_node::next(left, Dir::R, root);
             root = left;

             left = Splay_tree_node::next(root, Dir::L);
             if (!left)
               {
                 /*
                  * There is no left subtree after the rotation, thus the new
                  * root is the closest node to the key.
                  */
                 break;
               }
           }

         // Zig step. We link the current root into the right tree.

         if (!right_root)
           {
             // First node in the right tree.
             right_root = root;
             right_tail = root;
           }
         else
           {
             // Link the current root as the leftmost node of the right tree.
             Splay_tree_node::next(right_tail, Dir::L, root);
             right_tail = root;
           }

         // Move to the left subtree.
         root = Splay_tree_node::next(root, Dir::L);
       }
     else
       {
         // Inspect the right subtree (key > root).

         auto right = Splay_tree_node::next(root, Dir::R);
         if (!right)
           {
             /*
              * There is no right subtree, thus the root is the closest node to
              * the key.
              */
             break;
           }

         auto next_dir = Bst::dir(key, right);
         if (next_dir == Dir::R)
           {
             /*
              * Zig-Zig right-right rotation (key > root && key > right).
              *
              *       root                      right
              *       /  \                      /   \
              *      B    right       -->   root     A
              *           /   \             /  \
              * right_left     A           B    right_left
              *
              */

             auto right_left = Splay_tree_node::next(right, Dir::L);
             Splay_tree_node::next(root, Dir::R, right_left);
             Splay_tree_node::next(right, Dir::L, root);
             root = right;

             right = Splay_tree_node::next(root, Dir::R);
             if (!right)
               {
                 /*
                  * There is no right subtree after the rotation, thus the new
                  * root is the closest node to the key.
                  */
                 break;
               }
           }

         // Zig step. We link the current root into the left tree.

         if (!left_root)
           {
             // First node in the left tree.
             left_root = root;
             left_tail = root;
           }
         else
           {
             // Link the current root as the rightmost node of the left tree.
             Splay_tree_node::next(left_tail, Dir::R, root);
             left_tail = root;
           }

         // Move to the right subtree.
         root = Splay_tree_node::next(root, Dir::R);
       }
    }

  /*
   * The root points to the last node visited on the search path (possibly
   * the exact match). We reassemble it with the left tree and right tree.
   *
   *   root                 root
   *   /  \   -->           /  \
   *  A    B       left_root    right_root
   *               /       \    /        \
   *                       .    .
   *                       .    .
   *               left_tail    right_tail
   *                       \    /
   *                        A  B
   *
   */

  if (left_root)
    {
      auto left = Splay_tree_node::next(root, Dir::L);
      Splay_tree_node::next(left_tail, Dir::R, left);
      Splay_tree_node::next(root, Dir::L, left_root);
    }

  if (right_root)
    {
      auto right = Splay_tree_node::next(root, Dir::R);
      Splay_tree_node::next(right_tail, Dir::L, right);
      Splay_tree_node::next(root, Dir::R, right_root);
    }

  return found;
}

template<typename Node, typename Get_key, class Compare>
inline
Node *Splay_tree<Node, Get_key, Compare>::find_node(Key_param_type key)
{
  // Do not splay an empty tree.
  if (!_head)
    return nullptr;

  /*
   * Splay the tree around the key. If we have found an exact match, the node
   * has been splayed to the root.
   */
  if (splay(_head, key))
    return Bst::head();

  // Exact match not found.
  return nullptr;
}

template<typename Node, typename Get_key, class Compare>
Pair<Node *, bool>
Splay_tree<Node, Get_key, Compare>::insert(Node *new_node)
{
  // Initialize the common splay tree node parts of the node.
  *static_cast<Splay_tree_node *>(new_node) = Splay_tree_node(true);

  if (!_head)
    {
      // Empty tree. Just insert the new node as the root node.
      _head = new_node;
      return pair(new_node, true);
    }

  /*
   * Splay the tree around the key of the new node. If we have found an exact
   * match, the node has been splayed to the root and we do not insert the new
   * node (it would be a duplicate).
   */
  auto new_key = Get_key::key_of(new_node);
  if (splay(_head, new_key))
    return pair(Bst::head(), false);

  auto dir = Bst::dir(new_key, _head);
  if (dir == Dir::L)
    {
      /*
       * We know that new_node < _head and that all keys in the left subtree
       * are < new_node (due to the splay semantics).
       *
       *     _head            new_node
       *     /   \            /      \
       * left     B  -->  left        _head
       *                                  \
       *                                   B
       *
       */

      auto left = Splay_tree_node::next(_head, Dir::L);
      Splay_tree_node::next(new_node, Dir::L, left);
      Splay_tree_node::next(new_node, Dir::R, _head);
      Splay_tree_node::next(_head, Dir::L, nullptr);
      _head = new_node;
    }
  else
    {
      /*
       * We know that new_node > _head and that all keys in the right subtree
       * are > new_node (due to the splay semantics).
       *
       *   _head                  new_node
       *   /   \                  /      \
       *  B     right  -->   _head        right
       *                     /
       *                    B
       *
       */

      auto right = Splay_tree_node::next(_head, Dir::R);
      Splay_tree_node::next(new_node, Dir::L, _head);
      Splay_tree_node::next(new_node, Dir::R, right);
      Splay_tree_node::next(_head, Dir::R, nullptr);
      _head = new_node;
    }

  return pair(new_node, true);
}

template<typename Node, typename Get_key, class Compare>
inline
Node *Splay_tree<Node, Get_key, Compare>::remove(Key_param_type key)
{
  // Do not splay an empty tree.
  if (!_head)
    return nullptr;

  /*
   * Splay the tree around the key. If there is no exact match, then there is
   * also no node to remove.
   */
  if (!splay(_head, key))
    return nullptr;

  // Extract the matching node to return it to the caller.
  auto node = Bst::head();

  // We delete the root node of the tree.
  auto left = Splay_tree_node::next(_head, Dir::L);
  auto right = Splay_tree_node::next(_head, Dir::R);

  if (!left)
    {
      /*
       * No left subtree. Just replace the root with the right subtree.
       *
       * _head            right
       *     \       -->    .
       *      right         .
       *        .
       *        .
       *
       */
      _head = right;
    }
  else
    {
      /*
       * Splay the left subtree around key (greater than any node in the left
       * subtree) which moves the maximum node of the left subtree to the root
       * of the left subtree. The right subtree can then be attached again to
       * form a new tree.
       *
       *     _head                     max_left
       *     /   \       -->           /      \
       * left     right       rest_left        right
       *  .         .             .              .
       *  .         .             .              .
       *
       */
      splay(left, key);
      Splay_tree_node::next(left, Dir::R, right);
      _head = left;
    }

  return node;
}

}
