/*
 * Copyright (c) 2021 LAAS/CNRS
 * All rights reserved.
 *
 * Redistribution  and  use  in  source  and binary  forms,  with  or  without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of  source  code must retain the  above copyright
 *      notice and this list of conditions.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice and  this list of  conditions in the  documentation and/or
 *      other materials provided with the distribution.
 *
 * THE SOFTWARE  IS PROVIDED "AS IS"  AND THE AUTHOR  DISCLAIMS ALL WARRANTIES
 * WITH  REGARD   TO  THIS  SOFTWARE  INCLUDING  ALL   IMPLIED  WARRANTIES  OF
 * MERCHANTABILITY AND  FITNESS.  IN NO EVENT  SHALL THE AUTHOR  BE LIABLE FOR
 * ANY  SPECIAL, DIRECT,  INDIRECT, OR  CONSEQUENTIAL DAMAGES  OR  ANY DAMAGES
 * WHATSOEVER  RESULTING FROM  LOSS OF  USE, DATA  OR PROFITS,  WHETHER  IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR  OTHER TORTIOUS ACTION, ARISING OUT OF OR
 * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *                                           Anthony Mallet on Wed Sep 29 2021
 */
#include "autoconf.h"

#include <sys/stat.h>
#include <sys/types.h>
#include <sys/uio.h>

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <termios.h>
#include <unistd.h>

#ifdef __linux__
# include <libudev.h>
#endif

#include "tk3-mux.h"



/* --- tk_open_tty --------------------------------------------------------- */

static const char *	usb_serial_to_tty(const char *serial);

struct tk_chan_s *
tk_open_tty(const char *device)
{
  struct tk_chan_s *chan;
  const char *path;
  struct termios t;
  int fd;

  /* try to match a serial id first */
  path = usb_serial_to_tty(device);
  if (path) device = path;

  /* open */
  fd = open(device, O_RDWR | O_NOCTTY);
  if (fd < 0) return NULL;
  if (!isatty(fd)) { errno = ENOTTY; goto err; }

  /* configure line discipline */
  if (tcgetattr(fd, &t)) goto err;

  t.c_iflag = IGNBRK;
  t.c_oflag = 0;
  t.c_lflag = 0;
  t.c_cflag &= ~(CSIZE | CSTOPB | PARENB | PARODD);
  t.c_cflag |= CS8 | CREAD | CLOCAL;
  t.c_cc[VMIN] = 0;
  t.c_cc[VTIME] = 0;

  if (tcsetattr(fd, TCSANOW, &t)) goto err;

  /* discard any pending data */
  tcflush(fd, TCIOFLUSH);

  /* return handle */
  chan = malloc(sizeof(*chan));
  if (!chan) goto err;

  *chan = (struct tk_chan_s){ .fd = fd };
  snprintf(chan->path, sizeof(chan->path), "%s", device);

  if (path) free((void *)path);
  return chan;

err:
  if (path) free((void *)path);
  close(fd);
  return NULL;
}


/* --- tk_close_tty -------------------------------------------------------- */

void
tk_close_tty(struct tk_chan_s **chan)
{
  if (!*chan) return;

  if ((*chan)->fd >= 0) {
    tcdrain((*chan)->fd);
    close((*chan)->fd);
  }
  free(*chan);
  *chan = NULL;
}


/* --- usb_serial_to_tty --------------------------------------------------- */

/* Return a tty device matching the "serial" string */

static const char *
usb_serial_to_tty(const char *serial)
{
#ifdef __linux__
  struct udev *udev;
  struct udev_enumerate *scan = NULL;
  struct udev_list_entry *ttys, *tty;
  struct udev_device *dev = NULL, *p;
  const char *path = NULL;
  const char *s;

  udev = udev_new();
  if (!udev) return NULL;

  /* iterate over tty devices */
  scan = udev_enumerate_new(udev);
  if (udev_enumerate_add_match_subsystem(scan, "tty")) goto done;
  if (udev_enumerate_scan_devices(scan)) goto done;

  ttys = udev_enumerate_get_list_entry(scan);
  udev_list_entry_foreach(tty, ttys) {
    const char *sysfs, *pserialstr, *pifacestr, *pnifacestr;
    int suf, piface, pniface;

    /* get sysfs entry for the device and create a corresponding udev_device */
    if (dev) udev_device_unref(dev);
    sysfs = udev_list_entry_get_name(tty);
    dev = udev_device_new_from_syspath(udev, sysfs);

    /* iterate over parents and look for serial or bInterfaceNumber for usb */
    pserialstr = pifacestr = pnifacestr = NULL;
    for(p = udev_device_get_parent(dev); p; p = udev_device_get_parent(p)) {
      if (!pserialstr)
        pserialstr = udev_device_get_sysattr_value(p, "serial");
      if (!pifacestr)
        pifacestr = udev_device_get_sysattr_value(p, "bInterfaceNumber");
      if (!pnifacestr)
        pnifacestr = udev_device_get_sysattr_value(p, "bNumInterfaces");
    }
    piface = pifacestr ? atoi(pifacestr) : 0;
    pniface = pnifacestr ? atoi(pnifacestr) : 0;

    /* check match */
    if (!pserialstr) continue; /* no device serial */
    if (strstr(serial, pserialstr) != serial) continue; /* no prefix match */
    s = serial + strlen(pserialstr); /* advance after prefix */
    if (pniface > 1 && *s == '\0') { /* ambiguous match */
      warnx(
        "serial %s matches %s.%d but is ambiguous", serial, pserialstr, piface);
      continue;
    }
    if (sscanf(s, ".%d", &suf) != 1) suf = -1;
    if (pniface < 2 && *s && suf != piface) continue; /* no suffix match */
    if (pniface > 1 && suf != piface) continue; /* no suffix match */

    /* got a match, return the tty path */
    path = strdup(udev_device_get_devnode(dev));
  }
  if (dev) udev_device_unref(dev);

done:
  if (scan) udev_enumerate_unref(scan);
  if (udev) udev_unref(udev);
  return path;

#else
  return NULL; /* if needed, implement this for other OSes */
#endif
}


/* --- tk_recv_msg --------------------------------------------------------- */

/* returns: 0: no message, -1: error, 1: complete msg */

int
tk_recv_msg(struct tk_chan_s *chan, bool recv)
{
  struct iovec iov[2];
  ssize_t s;
  uint8_t c;

  do {
    /* process pending messages */
    while(chan->rb.r != chan->rb.w) {
      c = chan->rb.buf[chan->rb.r];
      chan->rb.r = (chan->rb.r + 1) % sizeof(chan->rb.buf);

      switch(c) {
        case '^':
          chan->start = true;
          chan->escape = false;
          chan->len = 0;
          break;

        case '$':
          if (!chan->start) break;
          chan->start = false;

          switch(chan->msg[0]) {
            case 'N': /* info messages */
              warnx("hardware info: %.*s", chan->len-1, &chan->msg[1]);
              break;
            case 'A': /* warning messages */
              warnx("hardware warning: %.*s", chan->len-1, &chan->msg[1]);
              break;
            case 'E': /* error messages */
              warnx("hardware error: %.*s", chan->len-1, &chan->msg[1]);
              break;

            default:
              chan->revent = true;
              return 1;
          }
          break;

        case '!':
          chan->start = false;
          break;

        case '\\':
          chan->escape = true;
          break;

        default:
          if (!chan->start) break;
          if (chan->len >= sizeof(chan->msg)) {
            chan->start = false; break;
          }

          if (chan->escape) {
            c = ~c; chan->escape = false;
          }
          chan->msg[chan->len++] = c;
          break;
      }
    }
    if (!recv) return 0;

    /* feed the ring  buffer */
    iov[0].iov_base = chan->rb.buf + chan->rb.w;
    iov[1].iov_base = chan->rb.buf;

    if (chan->rb.r > chan->rb.w) {
      iov[0].iov_len = chan->rb.r - chan->rb.w - 1;
      iov[1].iov_len = 0;
    } else if (chan->rb.r > 0) {
      iov[0].iov_len = sizeof(chan->rb.buf) - chan->rb.w;
      iov[1].iov_len = chan->rb.r - 1;
    } else {
      iov[0].iov_len = sizeof(chan->rb.buf) - chan->rb.w - 1;
      iov[1].iov_len = 0;
    }

    if (iov[0].iov_len || iov[1].iov_len) {
      do {
        s = readv(chan->fd, iov, 2);
      } while(s < 0 && errno == EINTR);

      if (s < 0)
        return errno == EAGAIN ? 0 : -1;
      else if (s == 0)
        return 0;
      else
        chan->rb.w = (chan->rb.w + s) % sizeof(chan->rb.buf);
    }
  } while(1);

  return 0;
}


/* --- tk_send_msg --------------------------------------------------------- */

static int	tk_encode(struct tk_chan_s *chan, char c);

int
tk_send_msg(struct tk_chan_s *chan, const char *fmt, ...)
{
  va_list ap;
  char c;

  if (tk_putc(chan, '^')) return -1;

  va_start(ap, fmt);
  while((c = *fmt++)) {
    switch(c) {
      default: tk_encode(chan, c); break;

      case '%': {
        switch(*fmt++) {
          case 'c': {
            uint8_t x = va_arg(ap, unsigned int);
            tk_encode(chan, x);
            break;
          }

          case 'b': {
            uint8_t *x = va_arg(ap, uint8_t *);
            size_t s = va_arg(ap, size_t);;
            while(s--) tk_encode(chan, *(x++));
            break;
          }
        }
        break;
      }
    }
  }
  va_end(ap);

  return tk_putc(chan, '$');
}

static int
tk_encode(struct tk_chan_s *chan, char c)
{
  switch (c) {
    case '^': case '$': case '\\': case '!':
      if (tk_putc(chan, '\\')) return -1;
      c = ~c;
  }
  return tk_putc(chan, c);
}


/* --- tk_putc ------------------------------------------------------------- */

int
tk_putc(struct tk_chan_s *chan, char c)
{
  struct iovec iov[2];
  int s;

  /* store char */
  chan->wb.buf[chan->wb.w] = c;
  chan->wb.w = (chan->wb.w + 1) % sizeof(chan->wb.buf);

  if ((chan->wb.w + 1) % sizeof(chan->wb.buf) != chan->wb.r)
    if (c != '$') return 0;

  /* flush buffer */
  while(chan->wb.w != chan->wb.r) {
    iov[0].iov_base = chan->wb.buf + chan->wb.r;
    iov[1].iov_base = chan->wb.buf;

    if (chan->wb.w > chan->wb.r) {
      iov[0].iov_len = chan->wb.w - chan->wb.r;
      iov[1].iov_len = 0;
    } else {
      iov[0].iov_len = sizeof(chan->wb.buf) - chan->wb.r;
      iov[1].iov_len = chan->wb.w;
    }

    s = writev(chan->fd, iov, 2);
    if (s < 0) {
      if (errno == EINTR) continue;
      return -1;
    }

    chan->wb.r = (chan->wb.r + s) % sizeof(chan->wb.buf);
  }

  return 0;
}
