/*
 * Copyright (c) 2019-2021,2024 LAAS/CNRS
 * All rights reserved.
 *
 * Redistribution  and  use  in  source  and binary  forms,  with  or  without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of  source  code must retain the  above copyright
 *      notice and this list of conditions.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice and  this list of  conditions in the  documentation and/or
 *      other materials provided with the distribution.
 *
 * THE SOFTWARE  IS PROVIDED "AS IS"  AND THE AUTHOR  DISCLAIMS ALL WARRANTIES
 * WITH  REGARD   TO  THIS  SOFTWARE  INCLUDING  ALL   IMPLIED  WARRANTIES  OF
 * MERCHANTABILITY AND  FITNESS.  IN NO EVENT  SHALL THE AUTHOR  BE LIABLE FOR
 * ANY  SPECIAL, DIRECT,  INDIRECT, OR  CONSEQUENTIAL DAMAGES  OR  ANY DAMAGES
 * WHATSOEVER  RESULTING FROM  LOSS OF  USE, DATA  OR PROFITS,  WHETHER  IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR  OTHER TORTIOUS ACTION, ARISING OUT OF OR
 * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *                                           Anthony Mallet on Sat Nov 23 2019
 */
#include "autoconf.h"

#include <stddef.h>
#include <string.h>

#include "tk3-paparazzi.h"
#include "tk3-boot.h"

/* USART protocol used in the STM32 bootloader - here over USB CDC
 * http://www.st.com/web/en/resource/technical/document/
 *   application_note/CD00264342.pdf */

/* bootloader commands */
#define STBL_MAGIC	0x7f

#define STBL_ACK	0x79
#define STBL_NACK	0x1f

#define STBL_GET	0x00	/* get version and supported commands */
#define STBL_GV		0x01	/* get version and read protection status */
#define STBL_GID	0x02	/* get ID */
#define STBL_RM		0x11	/* read memory */
#define STBL_GO		0x21	/* go */
#define STBL_WM		0x31	/* write memory */
#define STBL_EE		0x44	/* extended erase */

#define STBL_CMD(c)	(((uint8_t)(c) << 8) | (uint8_t)~(c))


/* private data */

static ptrdiff_t	ioread(struct tk3iob *io, void *buffer, size_t count,
                                uint32_t timeout);


/* --- bootloader ---------------------------------------------------------- */

static const uint8_t dfumagic[] = "dfu";

uintptr_t
bootloader(struct tk3iob *in, struct tk3iob *out)
{
  int32_t tk3set_bltout;
  const uint8_t *dfu;
  uint8_t c[16];

  /* wait for magic byte or timeout */
  dfu = dfumagic;
  tk3set_bltout = tk3set_get(TK3SET_BLTOUT, 5);
  tk3fb_on(TK3FB_ALL);

  while(1) {
    if (ioread(in, c, 1, tk3set_bltout * 1000) == 0) return -1;
    if (*c == STBL_MAGIC) break;
    if (*c == *dfu) { if (!*++dfu) return dfuaddr; } else dfu = dfumagic;
  }
  tk3iob_putc(out, STBL_ACK);

  tk3fb_off(TK3FB_ALL);

  /* wait for commands */
  tk3fb_on(TK3FB_BL);
  while(1) {
    if (!ioread(in, c, 2, 5000)) continue;

    switch (c[0] << 8 | c[1]) {
      case STBL_CMD(STBL_GET): {
        static const uint8_t reply[] = {
          STBL_ACK, 7, 0x93,
          STBL_GET, STBL_GV, STBL_GID, STBL_RM, STBL_GO, STBL_WM, STBL_EE,
          STBL_ACK
        };

        tk3iob_write(out, reply, sizeof(reply));
        break;
      }

      case STBL_CMD(STBL_GV):
        tk3iob_write(out, (uint8_t []){
            STBL_ACK,
            0x93, 0, 0,
            STBL_ACK
            }, 5);
        break;

      case STBL_CMD(STBL_GID):		/* --- get ID ---------------------- */
        tk3iob_write(out, (uint8_t []){
            STBL_ACK,
            1, (DBGMCU->IDCODE & 0xfff) >> 8, DBGMCU->IDCODE & 0xff,
            STBL_ACK
            }, 5);
        break;

      case STBL_CMD(STBL_RM): {		/* --- read memory ----------------- */
        const struct memregion_s *m;
        uintptr_t addr;
        size_t n, len;

        tk3iob_putc(out, STBL_ACK);

        /* address */
        if (ioread(in, c, 5, 100) != 5) goto failed;
        if (c[0] ^ c[1] ^ c[2] ^ c[3] ^ c[4]) goto failed;

        addr = c[0] << 24 | c[1] << 16 | c[2] << 8 | c[3];
        for(m = memregions; m->end; m++)
          if (addr >= m->start && addr < m->end) break;
        if (!m->end) goto failed;

        tk3iob_putc(out, STBL_ACK);

        /* length */
        if (ioread(in, c, 2, 100) != 2) goto failed;
        if (c[0] != (uint8_t)~c[1]) goto failed;

        len = c[0] + 1;
        while (addr + len > m->end) {
          if (!m[1].end) goto failed;
          if (m->end + 1 > m[1].start) goto failed;
          m++;
        }

        tk3iob_putc(out, STBL_ACK);

        /* send data */
        do {
          n = tk3iob_write(out, (void *)addr, len);
          if (!n) __WFI();

          addr += n;
          len -= n;
        } while(len);

        break;
      }

      case STBL_CMD(STBL_GO): {		/* --- go -------------------------- */
        const struct memregion_s *m;
        uintptr_t addr;

        tk3iob_putc(out, STBL_ACK);

        /* address */
        if (ioread(in, c, 5, 100) != 5) goto failed;
        if (c[0] ^ c[1] ^ c[2] ^ c[3] ^ c[4]) goto failed;

        addr = c[0] << 24 | c[1] << 16 | c[2] << 8 | c[3];
        for(m = memregions; m->end; m++)
          if (addr >= m->start && addr < m->end) break;
        if (!m->end) goto failed;

        tk3iob_putc(out, STBL_ACK);
        return addr;
      }

      case STBL_CMD(STBL_WM): {		/* --- write memory ---------------- */
        const struct memregion_s *m;
        static uint8_t data[256];
        uintptr_t addr;
        size_t i;
        int s;

        tk3iob_putc(out, STBL_ACK);

        /* address */
        if (ioread(in, c, 5, 100) != 5) goto failed;
        if (c[0] ^ c[1] ^ c[2] ^ c[3] ^ c[4]) goto failed;

        addr = c[0] << 24 | c[1] << 16 | c[2] << 8 | c[3];
        for(m = memregions; m->end; m++)
          if (addr >= m->start && addr < m->end) break;
        if (!m->end) goto failed;

        tk3iob_putc(out, STBL_ACK);

        /* length, data, checksum  */
        if (ioread(in, c, 1, 100) != 1) goto failed;
        if (ioread(in, data, c[0] + 1, 100) != *c+1) goto failed;
        if (ioread(in, &c[1], 1, 100) != 1) goto failed;

        for (i = 0; i < *c + 1; i++) c[1] ^= data[i];
        if (c[1] ^ c[0]) goto failed;

        /* write */
        unlock_flash();
        s = write_flash(addr, data, c[0] + 1);
        lock_flash();

        tk3iob_putc(out, s ? STBL_NACK : STBL_ACK);
        break;
      }

      case STBL_CMD(STBL_EE): {		/* --- extended erase -------------- */
        uint64_t sectors = 0;
        uint16_t sector;
        size_t n;
        int s;

        tk3iob_putc(out, STBL_ACK);

        /* sectors */
        if (ioread(in, c, 2, 100) != 2) goto failed;
        n = c[0] << 8 | c[1];

        if ((n & 0xfff0) == 0xfff0) { /* mass erase */

          if (ioread(in, &c[2], 1, 100) != 1) goto failed;
          if (c[0] ^ c[1] ^ c[2]) goto failed;
          switch (n) {
            case 0xffff:
              for(const struct memregion_s *m = memregions; m->end; m++)
                if (m->kind & MEM_E) sectors |= 1 << m->sector;
              break;

            default: goto failed;
          }

        } else { /* sector erase */

          /* sectors list */
          c[3] = 0;
          do {
            if (ioread(in, c, 2, 100) != 2) goto failed;
            c[3] ^= c[0] ^ c[1];
            sector = c[0] << 8 | c[1];
            if (sector < 8 * sizeof(sectors)) sectors |= 1 << sector;
          } while(n--);

          if (ioread(in, c, 1, 100) != 1) goto failed;
          if (c[3] ^ c[0]) goto failed;
        }

        /* erase */
        unlock_flash();
        for (n = 0, s = 0; n < 8 * sizeof(sectors); n++)
          if (sectors & (1 << n)) s |= erase_flash(n);
        lock_flash();

        tk3iob_putc(out, s ? STBL_NACK : STBL_ACK);
        break;
      }

      failed:
      default:
        tk3iob_putc(out, STBL_NACK);
        break;
    }
  }

  return -1;
}


/* --- ioread -------------------------------------------------------------- */

static ptrdiff_t
ioread(struct tk3iob *io, void *buffer, size_t count, uint32_t timeout)
{
  systime_t deadline =
    osalOsGetSystemTimeX() + (timeout * OSAL_ST_FREQUENCY + 999)/1000;
  size_t s, n = 0;

  while((osalOsGetSystemTimeX() < deadline || !timeout) && count > 0) {
    __WFI();

    s = tk3iob_read(io, buffer, count);
    if (s) {
      n += s;
      buffer += s;
      count -= s;
    }
  }

  return n;
}
