/*
 * Copyright (c) 2014-2015,2017,2019-2020,2024 LAAS/CNRS
 * All rights reserved.
 *
 * Redistribution  and  use  in  source  and binary  forms,  with  or  without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of  source  code must retain the  above copyright
 *      notice and this list of conditions.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice and  this list of  conditions in the  documentation and/or
 *      other materials provided with the distribution.
 *
 * THE SOFTWARE  IS PROVIDED "AS IS"  AND THE AUTHOR  DISCLAIMS ALL WARRANTIES
 * WITH  REGARD   TO  THIS  SOFTWARE  INCLUDING  ALL   IMPLIED  WARRANTIES  OF
 * MERCHANTABILITY AND  FITNESS.  IN NO EVENT  SHALL THE AUTHOR  BE LIABLE FOR
 * ANY  SPECIAL, DIRECT,  INDIRECT, OR  CONSEQUENTIAL DAMAGES  OR  ANY DAMAGES
 * WHATSOEVER  RESULTING FROM  LOSS OF  USE, DATA  OR PROFITS,  WHETHER  IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR  OTHER TORTIOUS ACTION, ARISING OUT OF OR
 * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *                                           Anthony Mallet on Thu Nov 20 2014
 */
#include "acheader.h"

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>

#include <avr/interrupt.h>
#include <avr/pgmspace.h>
#include <util/atomic.h>
#include <util/twi.h>

#include "tk3-mikrokopter.h"

uint8_t tk3_comm_errs;
uint8_t	tk3_comm_warns;

/* --- local data ---------------------------------------------------------- */

/* internal ring buffer for message reception/tramission. */
struct tk3_iobuffer {
  volatile uint8_t recvbuf[64];
  volatile uint8_t recv_r, recv_w;

  uint8_t tsflag;
  volatile tk3_time recv_ts;

  volatile uint8_t sendbuf[64];
  volatile uint8_t send_r, send_w;
};

static struct tk3_iobuffer uart[TK3_UART_N];
static struct tk3_iobuffer twi;

#ifdef TK3_TWI_MASTER
static uint8_t twi_nmotors;
#endif
#ifdef TK3_TWI_SLAVE
static uint8_t twi_addr;
#endif

static uint8_t	tk3_sndbuffer_full(const struct tk3_iobuffer *io);
static uint8_t	tk3_rcvbuffer_full(const struct tk3_iobuffer *io);
static uint8_t	tk3_io_recv(struct tk3_iobuffer *io, struct tk3_iorecv *iorecv);
static void	tk3_encodechar(char c, struct tk3_iobuffer *io);
static void	tk3_putchar(char c, struct tk3_iobuffer *io);
static void	tk3_getchar(char c, struct tk3_iobuffer *io);


/* --- tk3_uart_init ------------------------------------------------------- */

uint8_t
tk3_uart_init(uint8_t unit, uint32_t baud)
{
  uint16_t b;

  /* initialize buffers */
  uart[unit].recv_r = uart[unit].recv_w = 0;
  uart[unit].send_r = uart[unit].send_w = 0;

  /* compute baudrate (round to nearest integer */
  b = F_CPU / baud / 8;
  if (F_CPU % (baud * 8) <= baud * 4) b--;

  /* helper macros */
#define TK3_UART_INIT(n)                                        \
  do {                                                          \
    /* set baudrate */                                          \
    TK3_REG(UBRR, n,) = b;                                      \
                                                                \
    /* double clock speed */                                    \
    TK3_REG(UCSR, n, A) = (1 << U2X0);                          \
                                                                \
    /* enable UARTn for transmission and reception */           \
    TK3_REG(UCSR, n, B) =                                       \
      (1 << RXCIE0) | (1 << RXEN0) | (1 << TXEN0);              \
                                                                \
    /* 8/N/1 */                                                 \
    TK3_REG(UCSR, n, C) = (1 << UCSZ01) | (1 << UCSZ00);        \
  } while(0)

  /* initialize given unit */
  switch(unit) {
    case 0: TK3_UART_INIT(0); break;
#if TK3_UART_N > 1
    case 1: TK3_UART_INIT(1); break;
#endif
    default: return -1;
  }

  return 0;
}


/* --- tk3_twi_init -------------------------------------------------------- */

void
tk3_twi_init(uint8_t addr)
{
  /* initialize buffers */
  twi.recv_r = 0;
  twi.recv_w = 0;
  twi.send_r = 0;
  twi.send_w = 0;

  /* own address is programmed when data is to be sent */
#ifdef TK3_TWI_SLAVE
  twi_addr = addr;
#else
  twi_nmotors = addr;
#endif
  TK3_TW(AR) = (1 << TWGCE);

  /* 500kHz bus - max speed for 8MHz BLs */
  TK3_TW(SR) &= ~(( 1 << TWPS1) | (1 << TWPS0));
  TK3_TW(BR_) = ((F_CPU / 500000) - 16) / (2 * 1);

  /* enable */
  TK3_TW(CR) = (1 << TWEA) | (1 << TWEN) | (1 << TWIE);
#ifdef TK3_TWI_MASTER
  /* start continuously polling for slave messages */
  TK3_TW(CR) |= (1 << TWINT) | (1 << TWSTA);
#endif
}


/* --- tk3_buffer_len ------------------------------------------------------ */

static inline uint8_t
tk3_sndbuffer_full(const struct tk3_iobuffer *io)
{
  uint8_t l, r, w;

  r = io->send_r;
  w = io->send_w;

  l = (w < r) ? sizeof(io->sendbuf) : 0;
  l += w;
  l -= r;
  return l > sizeof(io->recvbuf) - 8;
}

static inline uint8_t
tk3_rcvbuffer_full(const struct tk3_iobuffer *io)
{
  uint8_t l, r, w;

  r = io->recv_r;
  w = io->recv_w;

  l = (w < r) ? sizeof(io->recvbuf) : 0;
  l += w;
  l -= r;
  return l > sizeof(io->recvbuf) - 8;
}


/* --- tk3_log ------------------------------------------------------------- */

void
tk3_log(enum tk3_channel channel, const char *fmt, ...)
{
  struct tk3_iobuffer *io;
  va_list ap;
  char c;

  switch(channel) {
    case TK3_UART0: io = &uart[0]; break;
#if TK3_UART_N > 1
    case TK3_UART1: io = &uart[1]; break;
#endif
    case TK3_TWI: io = &twi; break;
    default: return;
  }

  va_start(ap, fmt);
  tk3_putchar('^', io);
  while((c = pgm_read_byte(fmt++))) {
    switch(c) {
      case '%': {
        switch(pgm_read_byte(fmt++)) {
          case '1':
            tk3_encodechar(va_arg(ap, int), io);
            break;

          case '2': {
            uint16_t x = va_arg(ap, uint16_t);
            tk3_encodechar((x >> 8) & 0xff, io);
            tk3_encodechar(x & 0xff, io);
            break;
          }

          case '4': {
            uint32_t x = va_arg(ap, uint32_t);
            tk3_encodechar((x >> 24) & 0xff, io);
            tk3_encodechar((x >> 16) & 0xff, io);
            tk3_encodechar((x >> 8) & 0xff, io);
            tk3_encodechar(x & 0xff, io);
            break;
          }
        }
        break;
      }

      default:
        tk3_encodechar(c, io);
    }
  }
  tk3_putchar('$', io);
  va_end(ap);
}


/* --- tk3_log_buffer ------------------------------------------------------ */

void
tk3_log_buffer(enum tk3_channel channel, const uint8_t *buffer, uint8_t len)
{
  struct tk3_iobuffer *io;

  switch(channel) {
    case TK3_UART0: io = &uart[0]; break;
#if TK3_UART_N > 1
    case TK3_UART1: io = &uart[1]; break;
#endif
    case TK3_TWI: io = &twi; break;
    default: return;
  }

  tk3_putchar('^', io);
  while(len--)
    tk3_encodechar(*buffer++, io);
  tk3_putchar('$', io);
}


/* --- tk3_io_recv --------------------------------------------------------- */

struct tk3_iorecv *
tk3_recv(void)
{
  static struct tk3_iorecv io_twi = { .state.channel = TK3_TWI };
  static struct tk3_iorecv io_uart0 = { .state.channel = TK3_UART0 };
#if TK3_UART_N > 1
  static struct tk3_iorecv io_uart1 = { .state.channel = TK3_UART1 };
#endif

  if (tk3_comm_warns) tk3_comm_warns--;
  if (tk3_comm_errs) tk3_comm_errs--;

  if (tk3_rcvbuffer_full(&twi)) tk3_comm_warns = 0xff;
  if (tk3_io_recv(&twi, &io_twi)) return &io_twi;

  if (tk3_rcvbuffer_full(&uart[0])) tk3_comm_warns = 0xff;
  if (tk3_io_recv(&uart[0], &io_uart0)) return &io_uart0;

#if TK3_UART_N > 1
  if (tk3_rcvbuffer_full(&uart[1])) tk3_comm_warns = 0xff;
  if (tk3_io_recv(&uart[1], &io_uart1)) return &io_uart1;
#endif
  return NULL;
}

static uint8_t
tk3_io_recv(struct tk3_iobuffer *io,  struct tk3_iorecv *iorecv)
{
  uint8_t r;
  char c;

  r = io->recv_r;
  while(r != io->recv_w) {
    c = io->recvbuf[r];
    r = (r + 1) % sizeof(io->recvbuf);
    io->recv_r = r;

    switch (c) {
      case '^':
        iorecv->state.start = 1;
        iorecv->state.escape = 0;
        iorecv->len = 0;
        break;

      case '$':
        if (!iorecv->state.start) break;
        iorecv->state.start = 0;
        return 1;

      case '!':
        iorecv->state.start = 0;
        break;

      case '\\':
        iorecv->state.escape = 1;
        break;

      default:
        if (!iorecv->state.start) break;
        if (iorecv->len >= sizeof(iorecv->data)) {
          iorecv->state.start = 0;
          break;
        }
        if (iorecv->state.escape) {
          iorecv->state.escape = 0;
          c = ~c;
        }

        iorecv->data[iorecv->len++] = c;
        break;
    }
  }

  return 0;
}


/* --- tk3_send ------------------------------------------------------------ */

void
tk3_send()
{
  /* twi */
  if (tk3_sndbuffer_full(&twi)) tk3_comm_warns = 0xff;
#ifdef TK3_TWI_SLAVE
  if (twi.send_r != twi.send_w) {
    /* listen to our address to reply to master polls */
    if ((TK3_TW(AR) >> 1) != twi_addr)
      TK3_TW(AR) = (twi_addr << 1) | (1 << TWGCE);
  }
#endif

  /* uart (synchronous) */
  if (tk3_sndbuffer_full(&uart[0])) tk3_comm_warns = 0xff;
#if TK3_UART_N > 1
  if (tk3_sndbuffer_full(&uart[1])) tk3_comm_warns = 0xff;
#endif
  /* helper macro */
#define TK3_UART_SEND(n)                                        \
  do {                                                          \
    uint8_t r = uart[n].send_r;                                 \
    uint8_t c;                                                  \
                                                                \
    while (r != uart[n].send_w) {                               \
      c = uart[n].sendbuf[r];                                   \
      r = (r + 1) % sizeof(uart[n].sendbuf);                    \
      uart[n].send_r = r;                                       \
                                                                \
      while (!(UCSR ## n ## A & (1 << UDRE ## n)));             \
      UDR ## n = c;                                             \
    }                                                           \
  } while(0)

  /* scan all units */
  TK3_UART_SEND(0);
#if TK3_UART_N > 1
  TK3_UART_SEND(1);
#endif
}


/* --- tk3_encodechar ------------------------------------------------------ */

static inline void
tk3_encodechar(char c, struct tk3_iobuffer *io)
{
  switch (c) {
    case '^': case '$':  case '!': case '\\':
      tk3_putchar('\\', io);
      c = ~c;
      break;
  }
  tk3_putchar(c, io);
}


/* --- tk3_putchar --------------------------------------------------------- */

static inline void
tk3_putchar(char c, struct tk3_iobuffer *io)
{
  uint8_t w, wp;

  w = io->send_w;
  wp = (w + 1) % sizeof(io->sendbuf);
  if (wp == io->send_r) {
    /* discard byte in case of buffer overflow */
    tk3_comm_errs = 0xff;
    return;
  }
  io->sendbuf[w] = c;
  io->send_w = wp;
}


/* --- tk3_getchar --------------------------------------------------------- */

static inline __attribute__((always_inline)) void
tk3_getchar(char c, struct tk3_iobuffer *io)
{
  uint8_t w, wp;

  /* push byte */
  w = io->recv_w;
  wp = (w + 1) % sizeof(io->recvbuf);
  if (wp == io->recv_r) {
    /* discard byte in case of buffer overflow */
    tk3_comm_errs = 0xff;
    return;
  }
  io->recvbuf[w] = c;
  io->recv_w = wp;
}


/* --- USARTn_RX_vect ------------------------------------------------------ */

static inline __attribute__((always_inline)) void
usart_rx_vect(const uint8_t unit,
              volatile uint8_t * const ucsr, volatile uint8_t * const udr)
{
  do {
    uint8_t status = *ucsr;
    uint8_t c = *udr;

    if (status & ((1 << FE0) | (1 << DOR0))) {
      tk3_comm_errs = 0xff;
      c = '!';
    }

    tk3_getchar(c, &uart[unit]);
  } while(0 /* *ucsr & (1 << RXC0) */);
}

#ifdef USART_RX_vect
ISR(USART_RX_vect) { usart_rx_vect(0, &UCSR0A, &UDR0); }
#else
ISR(USART0_RX_vect) { usart_rx_vect(0, &UCSR0A, &UDR0); }
#endif

#if TK3_UART_N > 1
ISR(USART1_RX_vect) { usart_rx_vect(1, &UCSR1A, &UDR1); }
#endif


/* --- TWI_vect ------------------------------------------------------------ */

#define TK3_TW_vect_(n)	TWI ## n ## _vect
#define TK3_TW_vect(n)	TK3_TW_vect_(n)

ISR(TK3_TW_vect(TK3_TWI_BUS))
{
#define TWCR_CLRINT ((1 << TWINT) | (1 << TWEN) | (1 << TWIE))

  switch (TK3_TW(SR) & TW_STATUS_MASK) {
#ifdef TK3_TWI_MASTER
    /* bus start */

    case TW_START:
    case TW_REP_START: {
      /* start condition transmitted */
      /* or repeated start condition transmitted */

      if (twi.send_r != twi.send_w) {
        /* highest priority is to broadcast data to the BLs */
        TK3_TW(DR) = 0 | TW_WRITE; /* general call */
      } else {
        /* otherwise, do round robin polling of the BLs */
        static uint8_t sladdr = 1;

        TK3_TW(DR) =
          (uint8_t)((TK3TWI_BASEADDR + sladdr) << 1) | TW_READ;
        if (!--sladdr) sladdr = twi_nmotors;
      }
      TK3_TW(CR) = TWCR_CLRINT;
      break;
    }

      /* master transmitter mode */

    case TW_MT_SLA_NACK:
    case TW_MT_ARB_LOST:
      /* SLA+W transmitted, NACK received */
      /* arbitration lost in SLA+W or data */
      TK3_TW(CR) = (1 << TWINT); /* reset TWI internal state */
      while(TK3_TW(CR) & (1 << TWINT));

      TK3_TW(CR) = TWCR_CLRINT | (1 << TWSTA); /* restart */
      break;

    case TW_MT_SLA_ACK:
    case TW_MT_DATA_ACK:
    case TW_MT_DATA_NACK: {
      /* SLA+W transmitted, ACK received */
      /* data transmitted, ACK received */
      /* data transmitted, NACK received */
      uint8_t r = twi.send_r;
      if (r != twi.send_w) {
        TK3_TW(DR) = twi.sendbuf[r];
        TK3_TW(CR) = TWCR_CLRINT;
        twi.send_r = (r + 1) % sizeof(twi.sendbuf);
      } else
        TK3_TW(CR) = TWCR_CLRINT | (1 << TWSTA); /* repeated start */
      break;
    }

      /* master receiver mode */

    case TW_MR_SLA_NACK:
      /* SLA+R transmitted, NACK received */
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWSTA); /* repeated start */
      break;

    case TW_MR_SLA_ACK:
      /* SLA+R transmitted, ACK received */
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA); /* ack next byte */
      break;

    case TW_MR_DATA_ACK: {
      /* data received, ACK returned */

      /* the strategy is to nack as soon as an end-of-message marker is
       * seen, because there is only one reception buffer and incomplete
       * messages from different BLs would not interleave correctly.
       * This makes an overhead of one byte,but any other strategy
       * (e.g. sending the buffer length ahead) will also consume extra bytes.
       * The BL will send fake 0xff after the '$' (normally only once). */
      uint8_t c = TK3_TW(DR);
      if (c == '$')
        TK3_TW(CR) = TWCR_CLRINT; /* nack next byte */
      else
        TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA); /* ack next byte */

      tk3_getchar(c, &twi);
      break;
    }

    case TW_MR_DATA_NACK:
      /* data received, NACK returned */

      /* this extra byte is the fake 0xff from the slave, sent after the
       * end-of-message. It can be ignored. */
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWSTA); /* repeated start */
      break;

#endif /* TK3_TWI_MASTER */

#ifdef TK3_TWI_SLAVE

      /* slave transmitter mode */

    case TW_ST_SLA_ACK:
    case TW_ST_DATA_ACK: {
      /* SLA+R received, ACK returned */
      /* data transmitted, ACK received */
      uint8_t r, c;

      /* send a fake '$' whenever the send buffer is empty. This should not
       * happen. */
      r = twi.send_r;
      if (r != twi.send_w) {
        c = twi.sendbuf[r];
        twi.send_r = (r + 1) % sizeof(twi.sendbuf);
      } else
        c = '$';

      TK3_TW(DR) = c;
      if (c == '$')
        TK3_TW(CR) = TWCR_CLRINT; /* stop sending after first message */
      else
        TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA);
      break;
    }

    case TW_ST_DATA_NACK:
    case TW_ST_LAST_DATA:
      /* data transmitted, NACK received */
      /* last data byte transmitted, ACK received */

      if (twi.send_r == twi.send_w) {
        /* disconnect from the bus by resetting our address, so that the master
         * gets a nack immediately when polling us. It is still required to
         * listen to general calls, so TWEA cannot just be disabled */
        TK3_TW(AR) = 0 | (1 << TWGCE);
      }
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA);
      break;


      /* slave receiver mode */

    case TW_SR_GCALL_ACK:
      /* general call received, ACK returned */
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA);
      break;

    case TW_SR_GCALL_DATA_ACK:
    case TW_SR_GCALL_DATA_NACK:
      /* or general call data received, ACK returned */
      /* or general call data received, NACK returned */
      tk3_getchar(TK3_TW(DR), &twi);
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA);
      break;

    case TW_SR_STOP:
      /* stop or repeated start condition received while selected */
      TK3_TW(CR) = TWCR_CLRINT | (1 << TWEA);
      break;
#endif /* TK3_TWI_SLAVE */

      /* error conditions */

    case TW_NO_INFO:
      /* no state information available */
      break;

    case TW_BUS_ERROR:
      /* illegal start or stop condition */
      TK3_TW(CR) = (1 << TWINT) | (1 << TWEN) | (1 << TWSTO);
      while(TK3_TW(CR) & (1 << TWSTO));

      /* restart */
#ifdef TK3_TWI_SLAVE
      TK3_TW(CR) = (1 << TWEN) | (1 << TWIE) | (1 << TWEA);
#else
      TK3_TW(CR) = (1 << TWEN) | (1 << TWIE) | (1 << TWSTA);
#endif
      break;

    default:
      /* this must not happen by construction */
      while(1);
  }
}
