/*
 * Copyright (c) 2019-2020,2024 LAAS/CNRS
 * All rights reserved.
 *
 * Redistribution  and  use  in  source  and binary  forms,  with  or  without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of  source  code must retain the  above copyright
 *      notice and this list of conditions.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice and  this list of  conditions in the  documentation and/or
 *      other materials provided with the distribution.
 *
 * THE SOFTWARE  IS PROVIDED "AS IS"  AND THE AUTHOR  DISCLAIMS ALL WARRANTIES
 * WITH  REGARD   TO  THIS  SOFTWARE  INCLUDING  ALL   IMPLIED  WARRANTIES  OF
 * MERCHANTABILITY AND  FITNESS.  IN NO EVENT  SHALL THE AUTHOR  BE LIABLE FOR
 * ANY  SPECIAL, DIRECT,  INDIRECT, OR  CONSEQUENTIAL DAMAGES  OR  ANY DAMAGES
 * WHATSOEVER  RESULTING FROM  LOSS OF  USE, DATA  OR PROFITS,  WHETHER  IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR  OTHER TORTIOUS ACTION, ARISING OUT OF OR
 * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *                                           Anthony Mallet on Wed Dec  4 2019
 */
#include "ac_tk3_flash.h"

#include <sys/ioctl.h>
#include <err.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <termios.h>
#include <unistd.h>

#include "flash.h"

static const struct stm32_dev {
  uint16_t id;
  const char *dev;
  uintptr_t flash;
  const uint32_t *pgsz;
} stm32_devs[] = {
  { .id = 0x451, .dev = "STM32F76xxx/77xxx", .flash = 0x8000000,
    (const uint32_t []){ 4, 32<<10, 1, 128<<10, 7, 256<<10, 0 } },
  { .id = 0x450, .dev = "STM32H742xxx/743xxx/753xxx/750xxx", .flash = 0x8000000,
    (const uint32_t []){ 16, 128<<10, 0 } },
  { 0 }
};

static char *			stm32_read_settings(int fd,
                                        const struct stm32_dev *dev,
                                        uintptr_t *paddr, size_t *sz);
static const struct stm32_dev *	stm32_connect(int fd);
static int			stm32_erase(int fd,
                                        const struct stm32_dev *dev,
                                        uintptr_t paddr, size_t size);
static int			stm32_write(int fd,
                                        const struct stm32_dev *dev,
                                        char *text, uintptr_t paddr,
                                        size_t size);
static int			stm32_read(int fd, char *text,
                                        uintptr_t paddr, size_t size);


/* --- stm32_flash --------------------------------------------------------- */

int
stm32_flash(const char *exe, int fd, const char *serial, enum settype init_p)
{
  const struct stm32_dev *dev;
  uintptr_t paddr, tk3paddr;
  size_t size, tk3size;
  char *data, *tk3;
  Elf *elf;

  elf = elf_init(exe, EM_ARM);
  if (!elf) exit(2);

  dev = stm32_connect(fd);
  if (!dev) exit(2);

  /* get flash & settings data */
  printf("Retrieving load data from %s\n", exe);
  paddr = dev->flash;
  size = 0xffffff;
  data = elf_loadable_data(elf, &paddr, &size, LD_DATA);
  if (!data) exit(2);

  tk3paddr = dev->flash;
  tk3size = 0xffffff;
  tk3 = elf_loadable_data(elf, &tk3paddr, &tk3size, LD_SETTINGS);
  if (!tk3) {
    if (init_p != SET_KEEP) exit(2);
  }

  /* check old settings */
  switch (init_p) {
    case SET_KEEP:
      if (!tk3) break;

      printf("Preserving settings in 0x%lx-0x%lx\n",
             tk3paddr, tk3paddr + tk3size - 1);
      if (stm32_read(fd, tk3, tk3paddr, tk3size)) return -1;

      if (strcmp(tk3, TK3_MAGIC)) {
        warnx("invalid settings data, please flash with option -i");
        exit(2);
      }
      if (strcmp(tk3 + tk3size - sizeof(TK3_MAGIC), TK3_MAGIC)) {
        warnx("outdated settings data, please flash with option -i or -u");
        exit(2);
      }
      break;

    case SET_MERGE: {
      uintptr_t opaddr;
      char *oset;
      size_t osz;

      /* read old settings */
      oset = stm32_read_settings(fd, dev, &opaddr, &osz);
      if (!oset) return -1;

      /* merge old and new settings */
      printf("Updating settings in 0x%lx-0x%lx\n",
             tk3paddr, tk3paddr + tk3size - 1);
      merge_settings(tk3, oset, 1/*drop*/);
      free(oset);
      break;
    }

    case SET_RESET:
      break;
  }

  /* write flash data */
  if (stm32_erase(fd, dev, paddr, size)) /* write protection? */ exit(2);
  if (stm32_write(fd, dev, data, paddr, size)) return -1;

  /* write settings data */
  if (init_p != SET_KEEP && tk3 &&
      (tk3paddr < paddr || tk3paddr + tk3size > paddr + size)) {
    if (stm32_erase(fd, dev, tk3paddr, tk3size)) exit(2);
    if (stm32_write(fd, dev, tk3, tk3paddr, tk3size)) return -1;
  }

  if (tk3) {
    print_settings(serial, tk3 + sizeof(TK3_MAGIC));
    free(tk3);
  }
  free(data);
  return 0;
}


/* --- stm32_params --------------------------------------------------------- */

int
stm32_params(int fd, const char *serial, char **argv)
{
  const struct stm32_dev *dev;
  uintptr_t tk3paddr;
  size_t tk3sz;
  char *tk3;
  int i;

  /* connect to the uc */
  dev = stm32_connect(fd);
  if (!dev) exit(2);

  /* read old settings */
  tk3 = stm32_read_settings(fd, dev, &tk3paddr, &tk3sz);
  if (!tk3) return -1;

  /* merge old and new settings */
  for(i = 0; argv[i] && argv[i+1]; i+=2) {
    if (set_char_settings(tk3, argv[i], argv[i+1])) exit(2);
  }

  /* re-flash */
  if (i) {
    if (stm32_erase(fd, dev, tk3paddr, tk3sz)) exit(2);
    if (stm32_write(fd, dev, tk3, tk3paddr, tk3sz)) return -1;
  }

  print_settings(serial, tk3 + sizeof(TK3_MAGIC));
  return 0;
}


/* --- stm32_read_settings ------------------------------------------------- */

/* read existing settings */

static char *
stm32_read_settings(int fd, const struct stm32_dev *dev,
                    uintptr_t *paddr, size_t *sz)
{
  char data[sizeof(TK3_MAGIC)];
  const uint32_t *page;
  char *set, *p;
  uint32_t i;

  printf("Searching for existing settings\n");

  /* read flash to find TK3_MAGIC at the beginning of a page */
  *sz = 0;
  for (page = dev->pgsz, *paddr = dev->flash; *page && !*sz; page+=2)
    for (i = *page; i; *paddr += page[1], i--)  {

      if (stm32_read(fd, data, *paddr, sizeof(TK3_MAGIC))) return NULL;
      if (!strcmp(data, TK3_MAGIC)) {
        *sz = page[1];
        break;
      }
    }
  if (!*sz) {
    warnx("settings data not found, please flash with option -i");
    exit(2);
  }

  /* read old settings */
  set = malloc(*sz);
  if (!set) { warnx("out of memory"); exit(2); }
  if (stm32_read(fd, set, *paddr, *sz)) return NULL;

  /* get actual size */
  for(p = set + sizeof(TK3_MAGIC); *p;) {
    if (p > set + *sz - sizeof(TK3_MAGIC) - 1) break;
    p += 1 + strlen(p);
    p += 4;
    p += 1 + strlen(p);
  }
  p++;
  p = set + ((p - set + 3) & ~3); /* align 4 */
  if (p - set > *sz - sizeof(TK3_MAGIC) || strcmp(p, TK3_MAGIC)) {
    warnx("invalid settings data, please flash with option -i");
    exit(2);
  }
  *sz = p - set + sizeof(TK3_MAGIC);

  printf("Found settings data at 0x%lx-0x%lx, %zu bytes\n",
         *paddr, *paddr + *sz - 1, *sz);
  return set;
}


/* --- stm32_connect ------------------------------------------------------- */

/* connect to the bootloader */

static const struct stm32_dev *
stm32_connect(int fd)
{
  const struct stm32_dev *dev;
  uint32_t supported;
  char data[32];
  ssize_t i, l;
  uint8_t v;

  /* drain any remaining input */
  while (read_serial(fd, data, sizeof(data), 100) > 0);

  /* (re)synchronize */
  for (i = l = 0; !l || *data != 0x1f; i++) {
    if (i > 2) { warnx("bootloader not responding"); return NULL; }
    write_serial(fd, "\x7f", 1);
    l = read_serial(fd, data, 1, 100);
  }

  /* get version and supported commands */
  write_serial(fd, "\0\xff", 2);
  l = read_serial(fd, data, 2, 1000);
  if (l < 2 || data[0] != 0x79) {
    warnx("cannot read bootloader version");
    return NULL;
  }
  l = read_serial(fd, &data[2], 2 + data[1], 1000);
  if (l != 2 + data[1] || data[l+1] != 0x79) {
    warnx("cannot read bootloader version");
    return NULL;
  }

  /* check supported commands */
  v = data[2];
  supported = 0;
  for (i = 3; i < 3 + data[1]; i++)
    switch(data[i]) {
      case 0x11: /* read memory */	supported |= 0x01; break;
      case 0x21: /* go */		supported |= 0x02; break;
      case 0x31: /* write memory */	supported |= 0x04; break;
      case 0x44: /* extended erase */	supported |= 0x08; break;
    }
  if (supported != 0xf) {
    warnx("bootloader does not support all required commands");
    return NULL;
  }

  /* get pid */
  write_serial(fd, "\2\xfd", 2);
  l = read_serial(fd, data, sizeof(data), 300);
  if (l != 5 || data[0] != 0x79 || data[1] != 1 || data[4] != 0x79) {
    warnx("cannot read product id");
    return NULL;
  }

  for(dev = stm32_devs; dev->id; dev++)
    if (dev->id == (data[2] << 8 | data[3])) {
      printf(
        "%s bootloader version %u.%u\n", dev->dev, (v >> 4) & 0xf, v & 0xf);
      return dev;
    }
  warnx("unsupported product id 0x%x", data[2] << 8 | data[3]);
  return NULL;
}


/* --- stm32_erase --------------------------------------------------------- */

static int
stm32_erase(int fd, const struct stm32_dev *dev, uintptr_t paddr, size_t size)
{
  const uint32_t *page;
  uintptr_t addr;
  char data[32];
  uint32_t i, s;
  ssize_t l;
  int erase;

  erase = 0;
  for (page = dev->pgsz, addr = dev->flash, s = 0; *page; page+=2)
    for (i = *page; i; addr+=page[1], i--, s++)  {
      if (addr + page[1] <= paddr || addr >= paddr + size) continue;

      if (!erase) {
        printf("Erasing flash page %d", s);
        erase = 1;
      } else
        printf(", %d", s);
      fflush(stdout);

      write_serial(fd, "\x44\xbb", 2);
      l = read_serial(fd, data, 1, 1000);
      if (l < 1 || *data != 0x79) {
        printf("...\n"); warnx("cannot erase flash"); return -1;
      }

      data[0] = 0;
      data[1] = 0;
      data[2] = (s >> 8) & 0xff;
      data[3] = s & 0xff;
      data[4] = data[2] ^ data[3];
      write_serial(fd, data, 5);
      l = read_serial(fd, data, 1, 1000);
      if (l < 1 || *data != 0x79) {
        printf("...\n"); warnx("cannot erase flash"); return -1;
      }
    }

  if (erase) printf(".\n");
  return 0;
}


/* --- stm32_write ---------------------------------------------------------- */

static int
stm32_write(int fd, const struct stm32_dev *dev,
            char *text, uintptr_t paddr, size_t size)
{
  uint8_t data[32];
  uintptr_t addr;
  uint32_t i, s;
  size_t percent;
  ssize_t l;

  printf("Writing data to ...");
  fflush(stdout);
  percent = 0;

  for(addr = paddr; addr < paddr + size; addr += 256, text += 256) {
    s = (addr + 256 > paddr + size) ? paddr + size - addr : 256;
    printf("\rWriting data to 0x%lx %d%%",
           addr, (int)((double)percent/size*100));
    fflush(stdout);

    /* write command */
    write_serial(fd, "\x31\xce", 2);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot write flash"); return -1;
    }

    data[0] = (addr >> 24) & 0xff;
    data[1] = (addr >> 16) & 0xff;
    data[2] = (addr >> 8) & 0xff;
    data[3] = addr & 0xff;
    data[4] = data[0] ^ data[1] ^ data[2] ^ data[3];
    write_serial(fd, data, 5);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot write flash"); return -1;
    }

    data[0] = s - 1;
    write_serial(fd, data, 1);
    write_serial(fd, text, s);
    for (i = 0; i < s; i++) data[0] ^= text[i];
    write_serial(fd, data, 1);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot write flash"); return -1;
    }

    percent += s;
  }
  printf("\rWriting data to 0x%lx-0x%lx, %zu bytes\n",
         paddr, paddr + size - 1, size);

  return 0;
}


/* --- stm32_read ---------------------------------------------------------- */

static int
stm32_read(int fd, char *text, uintptr_t paddr, size_t size)
{
  uint8_t data[32];
  uintptr_t addr;
  uint32_t s;
  size_t percent;
  ssize_t l;

  printf("Reading data from ...");
  fflush(stdout);
  percent = 0;

  for(addr = paddr; addr < paddr + size; addr += 256, text += 256) {
    s = (addr + 256 > paddr + size) ? paddr + size - addr : 256;
    printf("\rReading data from 0x%lx %d%%",
           addr, (int)((double)percent/size*100));
    fflush(stdout);

    /* read command */
    write_serial(fd, "\x11\xee", 2);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot read flash"); return -1;
    }

    data[0] = (addr >> 24) & 0xff;
    data[1] = (addr >> 16) & 0xff;
    data[2] = (addr >> 8) & 0xff;
    data[3] = addr & 0xff;
    data[4] = data[0] ^ data[1] ^ data[2] ^ data[3];
    write_serial(fd, data, 5);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot read flash"); return -1;
    }

    data[0] = s - 1;
    data[1] = (uint8_t)~(s - 1);
    write_serial(fd, data, 2);
    l = read_serial(fd, data, 1, 1000);
    if (l < 1 || *data != 0x79) {
      printf("\n"); warnx("cannot read flash"); return -1;
    }

    l = read_serial(fd, text, s, 1000);
    if (l != s) { printf("\n"); warnx("cannot read flash"); return -1; }

    percent += s;
  }
  printf("\rReading data from 0x%lx-0x%lx, %zu bytes\n",
         paddr, paddr + size - 1, size);

  return 0;
}


/* --- stm32_reset --------------------------------------------------------- */

int
stm32_reset(int fd)
{
  uint8_t data[5];
  ssize_t l;

  /* go command */
  write_serial(fd, "\x21\xde", 2);
  l = read_serial(fd, data, 1, 1000);
  if (l < 1 || *data != 0x79) { warnx("cannot exit bootloader"); return -1; }

  /* reboot at 0x8000000 (tk3 bootloader) */
  write_serial(fd, "\x8\0\0\0\x8", 5);

  /* don't check reply */
  l = read_serial(fd, data, 1, 1000);

  return 0;
}
