Skip to content
Snippets Groups Projects
virtual_daq.py 6.3 KiB
Newer Older
Jason Merlo's avatar
Jason Merlo committed
# -*- coding: utf-8 -*-
"""
Virtual DAQ Class.

Description:
DAQ emulation class to playback recorded DAQ datasets in a transparent manner.

Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
Jason Merlo's avatar
Jason Merlo committed
"""
from pyratk.acquisition import daq   # Extention of DAQ object
Jason Merlo's avatar
Jason Merlo committed
import threading                     # Used for creating thread and sync events
Jason Merlo's avatar
Jason Merlo committed
import time
import h5py
Jason Merlo's avatar
Jason Merlo committed
from pyratk.datatypes.ts_data import TimeSeries
Jason Merlo's avatar
Jason Merlo committed
from pyratk.datatypes.motion import StateMatrix
from pyqtgraph import QtCore
class VirtualDAQ(daq.DAQ):
Jason Merlo's avatar
Jason Merlo committed
    """Emulate DAQ using HDF5 dataset data."""

    def __init__(self):
        """Create virtual DAQ object to play back recording (hdf5 dataset)."""
        super().__init__()
Jason Merlo's avatar
Jason Merlo committed

Merlo, Jason's avatar
Merlo, Jason committed
        # Create temp attributes
        self.sample_rate = 1
        self.sample_chunk_size = 1
        self.daq_type = "None"
        self.num_channels = 1
        self.sample_period = self.sample_chunk_size / self.sample_rate

Jason Merlo's avatar
Jason Merlo committed
    def load_dataset(self, ds):
        """Select dataset to read from and loads attributes."""
        if isinstance(ds, h5py._hl.dataset.Dataset):
            self.ds = ds
Jason Merlo's avatar
Jason Merlo committed
            # self.reset()
Jason Merlo's avatar
Jason Merlo committed
        else:
            raise(TypeError,
                  "load_dataset expects a h5py dataset type, got", type(ds))

Jason Merlo's avatar
Jason Merlo committed
        # Load attributes
        self.sample_rate = ds.attrs["sample_rate"]
        self.sample_chunk_size = ds.attrs["sample_size"]
Jason Merlo's avatar
Jason Merlo committed
        self.daq_type = ds.attrs["daq_type"].decode('utf-8')
        self.num_channels = ds.attrs["num_channels"]
        self.sample_period = self.sample_chunk_size / self.sample_rate
Jason Merlo's avatar
Jason Merlo committed

        # Create data buffers
        length = 4096
        shape = (self.num_channels, self.sample_chunk_size)
Jason Merlo's avatar
Jason Merlo committed
        self.ts_buffer = TimeSeries(length, shape)

Merlo, Jason's avatar
Merlo, Jason committed
        print('(VirtualDAQ) Loaded dataset:', ds.name)
        print('(VirtualDAQ) Sample period:', self.sample_period)

Jason Merlo's avatar
Jason Merlo committed
    def load_trajectory(self, ts):
        """Load trajectory dataset."""
        # Trajectory dataset
        self.ts = ts

        # Create data buffers
        length = 4096
        shape = (3, 3)  # State matrix shape
        self.ts_trajectory = TimeSeries(length, shape)


    def get_samples(self, stride=1, loop=-1, playback_speed=1.0):
Jason Merlo's avatar
Jason Merlo committed
        """Read sample from dataset at sampled speed, or one-by-one."""
Jason Merlo's avatar
Jason Merlo committed
        if self.ds:
Jason Merlo's avatar
Jason Merlo committed
            # Read in samples from dataset
            try:
                self.data = self.ds[self.sample_index]
            except IndexError:
                print("Invalid sample index:", self.sample_index)

Merlo, Jason's avatar
Merlo, Jason committed
            if hasattr(self, 'ts'):
Jason Merlo's avatar
Jason Merlo committed
                self._append_trajectory(self.sample_index)
Jason Merlo's avatar
Jason Merlo committed

Jason Merlo's avatar
Jason Merlo committed
            # Delay by sample period
            if loop == -1 or loop == 1:
                time.sleep(self.sample_period * playback_speed)
            elif loop == 0:
                print('Stepped:', stride)
Jason Merlo's avatar
Jason Merlo committed
            else:
                raise ValueError("Value must be -1, 0, or 1.")
Jason Merlo's avatar
Jason Merlo committed

Jason Merlo's avatar
Jason Merlo committed
            # Append tarjectory before emitting new data signal
Merlo, Jason's avatar
Merlo, Jason committed
            if hasattr(self, 'ts'):
Jason Merlo's avatar
Jason Merlo committed
                self.ts_trajectory.append(self.trajectory_data)

            new_data = (self.data, self.sample_index)
Jason Merlo's avatar
Jason Merlo committed
            # Set the update event to True once data is read in
            self.data_available_signal.emit(new_data)
            self.ts_buffer.append(self.data)
Jason Merlo's avatar
Jason Merlo committed

            # Incriment time index and loop around at end of dataset
            next_index = self.sample_index + stride
            if next_index < self.ds.shape[0]:
                self.sample_index = next_index
            else:
                self.sample_index = 0
Merlo, Jason's avatar
Merlo, Jason committed
                if hasattr(self, 'ts'):
Jason Merlo's avatar
Jason Merlo committed
                    self.ts_trajectory.clear()
Jason Merlo's avatar
Jason Merlo committed
                    self._append_trajectory(self.sample_index)
Jason Merlo's avatar
Jason Merlo committed
                self.reset_signal.emit()

            # Return True if more data
            return (self.sample_index + stride) % self.ds.shape[0] / stride < 1.0
Jason Merlo's avatar
Jason Merlo committed
        else:
            raise RuntimeError(
                "(VirtualDAQ) Dataset source must be set to get samples")

Jason Merlo's avatar
Jason Merlo committed
    def _append_trajectory(self, index):
        coordinate_type = self.ts.attrs['coordinate_type'].decode('utf-8')

        try:
            data = StateMatrix(self.ts[..., self.sample_index * self.sample_chunk_size],
                               coordinate_type=coordinate_type)
        except IndexError:
            print("Invalid trajectory sample index:", self.sample_index)

        self.trajectory_data = data.get_state().q
        self.ts_trajectory.append(self.trajectory_data)

Jason Merlo's avatar
Jason Merlo committed
    def reset(self):
        """Reset all data to beginning of data file and begin playing."""
        self.close()
Merlo, Jason's avatar
Merlo, Jason committed
        if hasattr(self, 'ts'):
Jason Merlo's avatar
Jason Merlo committed
            self.ts_trajectory.clear()
Jason Merlo's avatar
Jason Merlo committed
            self._append_trajectory(self.sample_index)
Jason Merlo's avatar
Jason Merlo committed
        self.sample_index = 0
        self.run()
Merlo, Jason's avatar
Merlo, Jason committed

    # === SAMPLING ======================================================
    def sample_loop(self):
        """Call get_samples forever."""

        while self.running:
            if self.paused:
                # warning('(daq.py) daq paused...')
                time.sleep(0.1)  # sleep 100 ms
            else:
                self.get_samples()

                new_data = (self.data, self.sample_num)

                # Set the update event to True once data is read in
                self.data_available_signal.emit(new_data)
                self.ts_buffer.append(self.data)

                # Incriment sample number
                self.sample_num += 1

        print("Sampling thread stopped.")


    def run(self):
        if self.running == False:
            # Spawn sampling thread
            self.running = True
            self.t_sampling = threading.Thread(target=self.sample_loop)

            try:
                if not self.t_sampling.is_alive():
                    print('Staring sampling thread')
                    self.t_sampling.start()
                self.paused = False
            except RuntimeError as e:
                print('Error starting sampling thread: ', e)
        else:
            print('Warning: Not starting new sampling thread; sampling thread already running!')

    def start(self):
        self.run()

    def close(self):
        if hasattr(self, 't_sampling') and self.t_sampling.is_alive():
            print("Stopping sampling thread...")
            try:
                self.t_sampling.join()
            except Exception as e:
                print("Error closing sampling thread: ", e)

        super().close()