-
Jason Merlo authoredJason Merlo authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
data_mgr.py 10.21 KiB
# -*- coding: utf-8 -*-
"""
Data Manager Class.
Controls data source selection.
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import h5py # Used for hdf5 database
from pyratk.acquisition.mux_buffer import MuxBuffer
from pyratk.acquisition.virtual_daq import VirtualDAQ
# COMPRESSION OPTIONS
COMPRESSION = "gzip"
COMPRESSION_OPTS = 9
class DataManager(MuxBuffer):
"""
DataManager class; extends MuxBuffer class.
Handles multiple data input sources and aggrigates them into one data
"mux" which can select from the various input sources added.
"""
def __init__(self, db="default.hdf5", daq=None):
"""
Initialize DataManager Class.
Arguments:
db (optional)
database to save/load from
"""
self.db = None
self.samples = None # Pointer to the samples dataset
self.labels = None # Pointer to the labels dataset
# MuxBuffer attribute initialization
self.source_list = []
# Open virtual daq for playback
self.virt_daq = VirtualDAQ()
self.add_source(self.virt_daq)
self.daq = daq
# DEVEL/DEBUG
try:
self.open_database(db)
except Exception as e:
print("(DataManager) Error opening debug database:", e)
# === DATABASE ============================================================
def open_database(self, db_path, mode='a'):
"""
Accept file path of a database to load and open it.
db_path: string
Path to file to open
mode: string
File access mode (default: 'a')
"""
# If a database is open, close it then open a new one
if self.db is not None:
self.db.close()
# Attempt to open or create a new file with the specified name
try:
self.db = h5py.File(db_path, mode)
except OSError as e:
print("(DataManager) Error opening HDF5 File:", e)
# Attempt to open 'samples', 'labels', 'subjects' datasets
# Create new datsets on keyerror
try:
self.samples = self.db["samples"]
except KeyError:
print("No 'samples' group found. Creating 'samples'...")
self.samples = self.db.create_group("samples")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.labels = self.db["labels"]
except KeyError:
print("No 'labels' group found. Creating 'labels'")
self.labels = self.db.create_group("labels")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.subjects = self.db["subjects"]
except KeyError:
print("No 'subjects' group found. Creating 'subjects'")
self.subjects = self.db.create_group("subjects")
except Exception as e:
# Handle all other exceptions
print(e)
# === DATASET CONTROL =====================================================
def load_dataset(self, ds):
"""Load dataset into virtualDAQ and set to virtualDAQ source."""
self.virt_daq.load_dataset(ds)
self.set_source(self.virt_daq)
def get_datasets(self):
"""Return list of all dataset objects in 'samples' dataset."""
keys = []
if self.db:
for key in self.samples:
keys.append(self.samples[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def delete_dataset(self, ds):
"""Remove dataset from database."""
try:
del self.db[ds.name]
except Exception as e:
print("Error deleting dataset: ", e)
def save_buffer(self, name, labels, subject, notes):
"""
Write buffer contents to dataset with specified 'name'.
If no name is provided, 'sample_n' will be used, where n is the index
of the sample relative to 0.
"""
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
# Does not exist, create new entry
try:
# Save buffer data
ds = self.samples.create_dataset(
name, data=self.source.ts_buffer,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS)
except Exception as e:
print("(DataManager) Error saving dataset: ", e)
# Create hard links to class label group and subject group
if labels is None:
labels_str = ''
else:
labels_str = [] # used for saving labels to attribute
for label in labels:
# Get type of label
if type(label) is str:
label_name = label
else:
label_name = label.name
# print('Checking if',
# "labels/{:}/{:}".format(label_name, name), 'exists')
if "{:}/{:}".format(label_name, name) not in self.labels:
self.labels.create_group(label_name)
self.labels[label_name][name] = ds
labels_str.append(label_name.split('/')[-1])
labels_str = ','.join(labels_str)
if subject:
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
if "{:}/{:}".format(subject_name, name) not in self.subjects:
self.subjects.create_group(subject_name)
self.subjects[subject_name][name] = ds
if notes is None:
notes = ''
try:
# Save attribute data
attrs = self.samples[name].attrs
attrs.create("sample_rate", self.source.sample_rate)
attrs.create("sample_size", self.source.sample_size)
attrs.create("daq_type", self.source.daq_type.encode('utf8'))
attrs.create("num_channels", self.source.num_channels)
attrs.create("label", labels_str.encode('utf8'))
if subject:
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
attrs.create("subject",
subject_name.split('/')[-1].encode('utf8'))
attrs.create("notes", notes.encode('utf8'))
except Exception as e:
print("(DataManager) Error saving attributes: ", e)
def remove_attributes(self, name, labels, subject):
"""
Remove hard-linked created in attribute folders.
Currently the only hard-linked attribues are: labels, subject
"""
# Attempt to open dataset
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
print("(DataManager) Error attempting to remove attributes.")
print("(DataManager) Dataset does not exist.")
# Attempt to delete sample from label dataset
for label in labels:
try:
self.labels[label].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in label dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# Attempt to delete sample from subjects dataset
if subject:
try:
self.subjects[subject].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in subject dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# --- labels --- #
def get_labels(self):
"""Return list of all label objects in 'labels' dataset."""
keys = []
if self.db:
for key in self.labels:
keys.append(self.labels[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_label(self, label):
"""Add a label to dataset."""
try:
self.labels.create_group(label)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_label(self):
"""TODO: implement remove_label."""
pass
# --- subjects --- #
def get_subjects(self):
"""Return list of all subject objects in dataset."""
keys = []
if self.db:
for key in self.subjects:
keys.append(self.subjects[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_subject(self, subject):
"""Add a subject to dataset."""
try:
self.subjects.create_group(subject)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_subject():
# TODO: implement remove_subject
pass
# === DATA CONTROL ========================================================
def reset(self):
"""Reset DAQ manager, clear all data and graphs."""
self.source.reset()
def pause_toggle(self):
"""Pauses the DAQ manager."""
# Virtual DAQ needs a dataset loaded before running
if self.source is not self.virt_daq or self.virt_daq.ds is not None:
if self.paused is True:
# self.reset() # used for arbitrary dt
self.paused = False
else:
self.paused = True
def close(self):
"""Close the selected object in the DAQ manager."""
for source in self.source_list:
source.close()
# === PROPERTIES ====================================================
@property
def reset_flag(self):
return self.source.reset_flag
@reset_flag.setter
def reset_flag(self, a):
self.source.reset_flag = a
@property
def ts_buffer(self):
return self.source.ts_buffer