#!/usr/bin/env python3
#
# SPDX-License-Identifier: GPL-2.0+
#
# Copyright (c) 2022 Linaro.
# Copyright (c) 2022 Ying-Chun Liu (PaulLiu) <paul.liu@linaro.org>
#
# Author: Ying-Chun Liu (PaulLiu) <paul.liu@linaro.org>
#

import fcntl
import struct
import os
import argparse
import sys

def getPartitionSize(device_path):
    """ Get Partition Size

    Parameters:
    device_path (string): path to the device

    Returns:
    int: the size of the device in bytes

    """

    req = 0x80081272 # BLKGETSIZE64, result is bytes as unsigned 64-bit integer (uint64)
    buf = b' ' * 8
    fmt = 'L'

    with open(device_path) as dev:
        buf = fcntl.ioctl(dev.fileno(), req, buf)
    bytes = struct.unpack('L', buf)[0]

    return bytes

def readData(device_path, offset, length):
    """ Read Data from device

    Parameters:
    device_path (string): path to the device
    offset (int): offset
    length (int): length

    Returns:
    bytesarray: data

    """
    fd=os.open(device_path, os.O_RDONLY)

    try:
        os.lseek(fd, offset, os.SEEK_SET)
    except:
        return
    
    buf = os.read(fd, length)

    os.close(fd)

    return buf

def writeData(device_path, offset, data):
    """ Write data to the device

    Parameters:
    device_path (string): path to the device
    offset (int): offset
    data (bytesarray): data

    Returns:
    int: error code

    """
    fd=os.open(device_path, os.O_RDWR)

    try:
        os.lseek(fd, offset, os.SEEK_SET)
    except:
        return
    
    r = os.write(fd, data)

    os.close(fd)

    return r

def outputReadData(buf, offset=0):
    """ dump the data

    The format of the output is similar to U-boot's "bcb dump" command.

    Parameters:
    buf (bytesarray): the data
    offset (int): offset

    """
    n = len(buf)
    for i in range(0, n, 16):
        print("%08x: "%(i+offset), end="")
        for j in range(16):
            if (i+j < n):
                print("%02x "%(buf[i+j]), end="")
            else:
                print("   ", end="")
        print(" ", end="")
        j = 0
        while (i + j < n and j < 16):
            char1 = buf[i+j]
            if (32 <= char1 and char1 <= 126):
                print("%c"%(char1), end="")
            else:
                print(".", end="")
            j = j + 1
        print("")

fields = {
    "command": {
        "offset": 0,
        "length": 32
    },
    "status": {
        "offset": 32,
        "length": 32
    },
    "recovery": {
        "offset": 64,
        "length": 768
    },
    "stage": {
        "offset": 832,
        "length": 32
    },
    "reserved": {
        "offset": 864,
        "length": 1184
    }
}

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("device_path",
                        help="Device path, should be misc partition.")
    subparsers = parser.add_subparsers(title="subcommands",
                                       dest="subcommand",
                                       help="sub-commands help")
    read_parser = subparsers.add_parser("read")
    read_parser.add_argument("field", help="command")
    write_parser = subparsers.add_parser("write")
    write_parser.add_argument("field", help="command")
    write_parser.add_argument("data", help="data to be written")
    clear_parser = subparsers.add_parser("clear")
    clear_parser.add_argument("field", help="command")
    dump_parser = subparsers.add_parser("dump")
    dump_parser.add_argument("field", help="command")

    args = parser.parse_args()

    device_path = args.device_path

    partitionSize = getPartitionSize(device_path)
    if (partitionSize < 2048):
        print ("Partition size %d too small."%(partitionSize))
        sys.exit(1)

    if (args.subcommand == "read"):
        buf = readData(device_path,
                       fields[args.field]["offset"],
                       fields[args.field]["length"])
        # strip ending zeros
        i = len(buf)-1
        while (i>=0 and buf[i]==0):
            i = i-1
        data1 = buf[0:i+1].decode("utf-8")
        print(data1)
    elif (args.subcommand == "write"):
        data_orig = args.data.encode("utf-8")
        if (len(data_orig) > fields[args.field]["length"]):
            print("Error: data length too long")
            sys.exit(2)

        # init data1 with zeros
        data1 = [0] * fields[args.field]["length"]
        # replace : to line feeds as a separator
        for i in range(len(data_orig)):
            if (data_orig[i] == b':'[0]):
                data1[i] = 10
            else:
                data1[i] = data_orig[i]
        data2 = bytes(data1)

        writeData(device_path,
                  fields[args.field]["offset"],
                  data2)
    elif (args.subcommand == "clear"):
        data1 = b'\0' * fields[args.field]["length"]
        writeData(device_path,
                  fields[args.field]["offset"],
                  data1)
    elif (args.subcommand == "dump"):
        buf = readData(device_path,
                       fields[args.field]["offset"],
                       fields[args.field]["length"])
        outputReadData(buf, offset=fields[args.field]["offset"])
    else:
        pass
