// vi:set ft=cpp: -*- Mode: C++ -*-
/*
 * (c) 2014 Alexander Warg <alexander.warg@kernkonzept.com>
 *
 * License: see LICENSE.spdx (in this directory or the directories above)
 */
#pragma once
#pragma GCC system_header

#include <l4/sys/cxx/ipc_basics>
#include <l4/sys/cxx/ipc_iface>
#include <l4/sys/__typeinfo.h>
#include <stddef.h>

namespace L4 {
namespace Ipc {
namespace Msg {
namespace Detail {

template<typename T> struct Sizeof { enum { size = sizeof(T) }; };
template<> struct Sizeof<void> { enum { size = 0 }; };

/**
 * Argument data structure for server-function arguments.
 */
template<typename ...> struct Arg_pack
{
  template<typename DIR>
  unsigned get(char *, unsigned offset, unsigned)
  { return offset; }

  template<typename DIR>
  unsigned set(char *, unsigned offset, unsigned, long)
  { return offset; }

  template<typename F, typename ...ARGS>
  long call(F f, ARGS ...args)
  { return f(args...); }

  template<typename O, typename FUNC, typename ...ARGS>
  long obj_call(O *o, ARGS ...args)
  {
    typedef typename FUNC::template fwd<O> Fwd;
    return Fwd(o).template call<ARGS...>(args...);
    //return o->op_dispatch(args...);
  }
};

/**
 * Data member for server-function argument T.
 */
template<typename T, typename SVR_TYPE, typename ...M>
struct Svr_arg : Svr_xmit<T>, Arg_pack<M...>
{
  typedef Arg_pack<M...> Base;

  typedef SVR_TYPE svr_type;
  typedef typename _Elem<T>::svr_arg_type svr_arg_type;

  svr_type v;

  template<typename DIR>
  int get(char *msg, unsigned offset, unsigned limit)
  {
    typedef Svr_xmit<T> ct;
    int r = ct::to_svr(msg, offset, limit, this->v,
                       typename DIR::dir(), typename DIR::cls());
    if (L4_LIKELY(r >= 0))
      return Base::template get<DIR>(msg, r, limit);

    if (_Elem<T>::Is_optional)
      {
        v = svr_type();
        return Base::template get<DIR>(msg, offset, limit);
      }
    return r;
  }

  template<typename DIR>
  int set(char *msg, unsigned offset, unsigned limit, long ret)
  {
    typedef Svr_xmit<T> ct;
    int r = ct::from_svr(msg, offset, limit, ret, this->v,
                         typename DIR::dir(), typename DIR::cls());
    if (L4_UNLIKELY(r < 0))
      return r;
    return Base::template set<DIR>(msg, r, limit, ret);
  }

  template<typename F, typename ...ARGS>
  long call(F f, ARGS ...args)
  {
    //As_arg<value_type> check;
    return Base::template
      call<F, ARGS..., svr_arg_type>(f, args..., this->v);
  }

  template<typename O, typename FUNC, typename ...ARGS>
  long obj_call(O *o, ARGS ...args)
  {
    //As_arg<value_type> check;
    return Base::template
      obj_call<O,FUNC,  ARGS..., svr_arg_type>(o, args..., this->v);
  }
};

template<typename T, typename ...M>
struct Svr_arg<T, void, M...> : Arg_pack<M...>
{
  typedef Arg_pack<M...> Base;

  template<typename DIR>
  int get(char *msg, unsigned offset, unsigned limit)
  { return Base::template get<DIR>(msg, offset, limit); }

  template<typename DIR>
  int set(char *msg, unsigned offset, unsigned limit, long ret)
  { return Base::template set<DIR>(msg, offset, limit, ret); }

  template<typename F, typename ...ARGS>
  long call(F f, ARGS ...args)
  {
    return Base::template call<F, ARGS...>(f, args...);
  }

  template<typename O, typename FUNC, typename ...ARGS>
  long obj_call(O *o, ARGS ...args)
  {
    return Base::template obj_call<O, FUNC, ARGS...>(o, args...);
  }
};

template<typename A, typename ...M>
struct Arg_pack<A, M...> : Svr_arg<A, typename _Elem<A>::svr_type, M...>
{};

} // namespace Detail

//---------------------------------------------------------------------
/**
 * Server-side RPC arguments data structure used to provide arguments
 * to the server-side implementation of an RPC function.
 */
template<typename IPC_TYPE> struct Svr_arg_pack;

template<typename R, typename ...ARGS>
struct Svr_arg_pack<R (ARGS...)> : Detail::Arg_pack<ARGS...>
{
  typedef Detail::Arg_pack<ARGS...> Base;
  template<typename DIR>
  int get(void *msg, unsigned offset, unsigned limit)
  {
    char *buf = static_cast<char *>(msg);
    return Base::template get<DIR>(buf, offset, limit);
  }

  template<typename DIR>
  int set(void *msg, unsigned offset, unsigned limit, long ret)
  {
    char *buf = static_cast<char *>(msg);
    return Base::template set<DIR>(buf, offset, limit, ret);
  }
};

/**
 * Handle an incoming RPC call and forward it to o->op_dispatch().
 */
template<typename IPC_TYPE, typename O, typename ...ARGS>
static l4_msgtag_t
handle_svr_obj_call(O *o, l4_utcb_t *utcb, l4_msgtag_t tag, ARGS ...args)
{
  typedef Svr_arg_pack<typename IPC_TYPE::rpc::ipc_type> Pack;
  enum
  {
    Do_reply  = IPC_TYPE::rpc::flags_type::Is_call,
    Short_err = Do_reply ? -L4_EMSGTOOSHORT : -L4_ENOREPLY,
  };

  // XXX: send a reply or just do not reply in case of a cheating client
  if (L4_UNLIKELY(tag.words() + tag.items() * Item_words > Mr_words))
    return l4_msgtag(Short_err, 0, 0, 0);

  // our whole arguments data structure
  Pack pack;
  l4_msg_regs_t *mrs = l4_utcb_mr_u(utcb);

  int in_pos = Detail::Sizeof<typename IPC_TYPE::opcode_type>::size;

  unsigned const in_bytes = tag.words() * Word_bytes;

  in_pos = pack.template get<Do_in_data>(&mrs->mr[0], in_pos, in_bytes);

  if (L4_UNLIKELY(in_pos < 0))
    return l4_msgtag(Short_err, 0, 0, 0);

  if (L4_UNLIKELY(pack.template get<Do_out_data>(mrs->mr, 0, Mr_bytes) < 0))
    return l4_msgtag(Short_err, 0, 0, 0);


  in_pos = pack.template get<Do_in_items>(&mrs->mr[tag.words()], 0,
                                          tag.items() * Item_bytes);

  if (L4_UNLIKELY(in_pos < 0))
    return l4_msgtag(Short_err, 0, 0, 0);

  asm volatile ("" : "=m" (mrs->mr));

  // call the server function
  long ret = pack.template obj_call<O, typename IPC_TYPE::rpc, ARGS...>(o, args...);

  if (!Do_reply)
    return l4_msgtag(-L4_ENOREPLY, 0, 0, 0);

  // our convention says that negative return value means no
  // reply data
  if (L4_UNLIKELY(ret < 0))
    return l4_msgtag(ret, 0, 0, 0);

  // reply with the reply data from the server function
  int bytes = pack.template set<Do_out_data>(mrs->mr, 0, Mr_bytes, ret);
  if (L4_UNLIKELY(bytes < 0))
    return l4_msgtag(-L4_EMSGTOOLONG, 0, 0, 0);

  unsigned words = (bytes + Word_bytes - 1) / Word_bytes;
  bytes = pack.template set<Do_out_items>(&mrs->mr[words], 0,
                                          Mr_bytes - words * Word_bytes,
                                          ret);
  if (L4_UNLIKELY(bytes < 0))
    return l4_msgtag(-L4_EMSGTOOLONG, 0, 0, 0);

  unsigned const items = bytes / Item_bytes;
  return l4_msgtag(ret, words, items, 0);
}

//-------------------------------------------------------------------------

template<typename RPCS, typename OPCODE_TYPE>
struct Dispatch_call;

template<typename CLASS>
struct Dispatch_call<L4::Typeid::Raw_ipc<CLASS>, void>
{
  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  call(OBJ *o, l4_utcb_t *utcb, l4_msgtag_t tag, ARGS ...a)
  {
    return o->op_dispatch(utcb, tag, a...);
  }
};

template<typename RPCS>
struct Dispatch_call<RPCS, void>
{
  constexpr static unsigned rmask()
  { return RPCS::rpc::flags_type::Rights & 3UL; }

  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  call(OBJ *o, l4_utcb_t *utcb, l4_msgtag_t tag, unsigned rights, ARGS ...a)
  {
    if ((rights & rmask()) != rmask())
      return l4_msgtag(-L4_EPERM, 0, 0, 0);

    typedef L4::Typeid::Rights<typename RPCS::rpc::class_type> Rights;
    return handle_svr_obj_call<RPCS>(o, utcb, tag,
                                     Rights(rights), a...);

  }
};

template<typename RPCS, typename OPCODE_TYPE>
struct Dispatch_call
{
  constexpr static unsigned rmask()
  { return RPCS::rpc::flags_type::Rights & 3UL; }

  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  _call(OBJ *o, l4_utcb_t *utcb, l4_msgtag_t tag, unsigned rights, OPCODE_TYPE op, ARGS ...a)
  {
    if (L4::Types::Same<typename RPCS::opcode_type, void>::value
        || RPCS::Opcode == op)
      {
        if ((rights & rmask()) != rmask())
          return l4_msgtag(-L4_EPERM, 0, 0, 0);

        typedef L4::Typeid::Rights<typename RPCS::rpc::class_type> Rights;
        return handle_svr_obj_call<RPCS>(o, utcb, tag,
                                         Rights(rights), a...);
      }
    return Dispatch_call<typename RPCS::next, OPCODE_TYPE>::template
      _call<OBJ, ARGS...>(o, utcb, tag, rights, op, a...);
  }

  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  call(OBJ *o, l4_utcb_t *utcb, l4_msgtag_t tag, unsigned rights, ARGS ...a)
  {
    OPCODE_TYPE op;
    unsigned limit = tag.words() * Word_bytes;
    typedef Svr_xmit<OPCODE_TYPE> S;
    int err = S::to_svr(reinterpret_cast<char *>(l4_utcb_mr_u(utcb)->mr), 0,
                        limit, op, Dir_in(), Cls_data());
    if (L4_UNLIKELY(err < 0))
      return l4_msgtag(-L4_EMSGTOOSHORT, 0, 0, 0);

    return _call<OBJ, ARGS...>(o, utcb, tag, rights, op, a...);
  }
};

template<>
struct Dispatch_call<Typeid::Detail::Rpcs_end, void>
{
  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  _call(OBJ *, l4_utcb_t *, l4_msgtag_t, unsigned, int, ARGS ...)
  { return l4_msgtag(-L4_ENOSYS, 0, 0, 0); }

  template<typename OBJ, typename ...ARGS>
  static l4_msgtag_t
  call(OBJ *, l4_utcb_t *, l4_msgtag_t, unsigned, ARGS ...)
  { return l4_msgtag(-L4_ENOSYS, 0, 0, 0); }
};

template<typename OPCODE_TYPE>
struct Dispatch_call<Typeid::Detail::Rpcs_end, OPCODE_TYPE> :
  Dispatch_call<Typeid::Detail::Rpcs_end, void> {};

template<typename RPCS, typename OBJ, typename ...ARGS>
static l4_msgtag_t
dispatch_call(OBJ *o, l4_utcb_t *utcb, l4_msgtag_t tag, unsigned rights, ARGS ...a)
{
  return Dispatch_call<typename RPCS::type, typename RPCS::opcode_type>::template
    call<OBJ, ARGS...>(o, utcb, tag, rights, a...);
}

} // namespace Msg
} // namesapce Ipc
} // namespace L4
