import numpy as np
import serial
import time
import h5py
from datetime import datetime
from datetime import date
import cv2

def getUnixTimestamp():
    return np.datetime64(datetime.now()).astype(np.int64) / 1e6  # unix TS in secs and microsecs


def readPressure(ser):
    ser.reset_input_buffer()  # Remove the confirmation 'w' sent by the sensor
    ser.write('a'.encode('utf-8'))  # Request data from the sensor

    # Receive data
    w, h = 32, 32
    length = 2 * w * h
    input_string = ser.read(length)
    x = np.frombuffer(input_string, dtype=np.uint8).astype(np.uint16)
    if not len(input_string) == length:
        return None

    x = x[0::2] * 32 + x[1::2]
    x = x.reshape(h, w).transpose(1, 0)

    return x


def append_data(f, init, block_size, fc, ts, reading):
    if not init:
        sz = [block_size, ]
        maxShape = sz.copy()
        maxShape[0] = None
        f.create_dataset('frame_count', (1,), dtype=np.uint32)
        f.create_dataset('ts', tuple([block_size, ]), dtype=ts.dtype, chunks=True)
        f.create_dataset('pressure', tuple([block_size, 32, 32]), dtype=np.int32, chunks=True)

    # Check size
    oldSize = f['ts'].shape[0]
    if oldSize == fc:
        newSize = oldSize + block_size
        f['ts'].resize(newSize, axis=0)
        f['pressure'].resize(newSize, axis=0)

    f['frame_count'][0] = fc
    f['ts'][fc] = ts
    f['pressure'][fc] = reading.reshape(1, 32, 32)

    f.flush()

def viz_data(pressure, pressure_min, pressure_max, use_log):
    pressure = (pressure.astype(np.float32) - pressure_min) / (pressure_max - pressure_min)
    pressure = np.clip(pressure, 0, 1)
    if use_log:
        pressure = np.log(pressure + 1) / np.log(2.0)

    im = cv2.applyColorMap((np.clip(pressure, 0, 1) * 255).astype('uint8'), cv2.COLORMAP_JET)
    im = cv2.resize(im, (500, 500))
    return im


def main():
    path = './recordings' + str(time.time()) + '.hdf5'
    f = h5py.File(path, 'w')

    block_size = 1024
    fc = 0
    pressure_min = 500
    pressure_max = 700
    use_log = True
    viz = True

    port = 'COM3' #update serial port
    baudrate = 500000

    ser = serial.Serial(port, baudrate=baudrate, timeout=1.0)
    assert ser.is_open, 'Failed to open COM port!'

    init = False
    while ser.is_open:
        fc += 1
        ts = getUnixTimestamp()
        reading = readPressure(ser)

        if reading is not None:
            append_data(f, init, block_size, fc, ts, reading)
            init = True

            if viz:
                im = viz_data(reading, pressure_min, pressure_max, use_log)
                cv2.imshow('VizualizerTouch', im)
                if cv2.waitKey(1) & 0xff == 27:
                    break


if __name__ == "__main__":
    main()