Skip to content
Snippets Groups Projects
Commit 42936187 authored by Merlo, Jason's avatar Merlo, Jason
Browse files

Fixed CSV export, added polar tracker widget

parent 7841fbb3
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,7 @@ from pyqtgraph import QtCore ...@@ -15,6 +15,7 @@ from pyqtgraph import QtCore
from pyratk.formatting import warning from pyratk.formatting import warning
import numpy as np import numpy as np
import datetime import datetime
import os
# COMPRESSION OPTIONS # COMPRESSION OPTIONS
...@@ -296,9 +297,18 @@ class DataManager(MuxBuffer): ...@@ -296,9 +297,18 @@ class DataManager(MuxBuffer):
Will save trajectory_samples as ground truth if it exists. Will save trajectory_samples as ground truth if it exists.
""" """
print('source: ', type(self.source))
print('vdaq: ', type(VirtualDAQ))
if type(self.source) == type(self.virt_daq):
print('(data_mgr) Pre-loading buffer...')
self.source.load_buffer()
# Reshape array to hold contiguous samples, not chunks # Reshape array to hold contiguous samples, not chunks
data = self.source.ts_buffer.data data = self.source.ts_buffer.data
print('data.shape:', data.shape)
chunk_size = data.shape[2] chunk_size = data.shape[2]
num_chunks = data.shape[0] num_chunks = data.shape[0]
num_channels = data.shape[1] num_channels = data.shape[1]
...@@ -316,8 +326,8 @@ class DataManager(MuxBuffer): ...@@ -316,8 +326,8 @@ class DataManager(MuxBuffer):
# Add timestamp data per sample (to match NI readout) # Add timestamp data per sample (to match NI readout)
start = 0 start = 0
step = self.source.sample_interval step = self.source.sample_period
stop = new_data.shape[0] * step stop = num_chunks * step
sample_times = np.array([np.linspace(start, stop, new_data.shape[0])]).transpose() sample_times = np.array([np.linspace(start, stop, new_data.shape[0])]).transpose()
print('sample_times.shape:', sample_times.shape) print('sample_times.shape:', sample_times.shape)
...@@ -335,7 +345,7 @@ class DataManager(MuxBuffer): ...@@ -335,7 +345,7 @@ class DataManager(MuxBuffer):
header_1 = ('Timestamp', date_str) * num_channels header_1 = ('Timestamp', date_str) * num_channels
header.append(','.join(header_1)) header.append(','.join(header_1))
header_2 = ('Interval', str(self.source.sample_interval))\ header_2 = ('Interval', str(self.source.sample_period))\
* num_channels * num_channels
header.append(','.join(header_2)) header.append(','.join(header_2))
...@@ -364,6 +374,9 @@ class DataManager(MuxBuffer): ...@@ -364,6 +374,9 @@ class DataManager(MuxBuffer):
# Does not exist, create new entry # Does not exist, create new entry
try: try:
samples_name = '{}.csv'.format(name) samples_name = '{}.csv'.format(name)
base_name = '/'.join(name.split('/')[:-1])
if not os.path.exists(base_name):
os.makedirs(base_name)
np.savetxt(samples_name, new_data, np.savetxt(samples_name, new_data,
header=header, fmt='%g', delimiter=',', comments='') header=header, fmt='%g', delimiter=',', comments='')
except Exception as e: except Exception as e:
......
...@@ -46,8 +46,11 @@ class VirtualDAQ(daq.DAQ): ...@@ -46,8 +46,11 @@ class VirtualDAQ(daq.DAQ):
self.sample_chunk_size = ds.attrs["sample_size"] self.sample_chunk_size = ds.attrs["sample_size"]
self.daq_type = ds.attrs["daq_type"].decode('utf-8') self.daq_type = ds.attrs["daq_type"].decode('utf-8')
self.num_channels = ds.attrs["num_channels"] self.num_channels = ds.attrs["num_channels"]
# Sample period (time between sample chunks being returned from DAQ)
self.sample_period = self.sample_chunk_size / self.sample_rate self.sample_period = self.sample_chunk_size / self.sample_rate
# Create data buffers # Create data buffers
length = 4096 length = 4096
shape = (self.num_channels, self.sample_chunk_size) shape = (self.num_channels, self.sample_chunk_size)
...@@ -66,6 +69,36 @@ class VirtualDAQ(daq.DAQ): ...@@ -66,6 +69,36 @@ class VirtualDAQ(daq.DAQ):
shape = (3, 3) # State matrix shape shape = (3, 3) # State matrix shape
self.ts_trajectory = TimeSeries(length, shape) self.ts_trajectory = TimeSeries(length, shape)
def load_buffer(self):
"""
Pre-load all samples into buffer.
Required for exporting dataset to a CSV. May also improve performance.
"""
print('Preloading buffer...')
if self.ds:
next_index = 0
while next_index < self.ds.shape[0]:
# Read in samples from dataset
try:
self.data = self.ds[self.sample_index]
except IndexError:
print("Invalid sample index:", self.sample_index)
if hasattr(self, 'ts'):
self._append_trajectory(self.sample_index)
self.ts_trajectory.append(self.trajectory_data)
self.ts_buffer.append(self.data)
next_index += 1
if next_index % 100 == 0:
progress = 100 * next_index / self.ds.shape[0]
print('Progress: {:.2f} %'.format(progress))
else:
raise Exception('No dataset loaded.')
def get_samples(self, stride=1, loop=-1, playback_speed=30): def get_samples(self, stride=1, loop=-1, playback_speed=30):
"""Read sample from dataset at sampled speed, or one-by-one.""" """Read sample from dataset at sampled speed, or one-by-one."""
......
# -*- coding: utf-8 -*-
"""
Radar datatypes
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
from collections import namedtuple
from dataclasses import dataclass
from pyratk.datatypes.geometry import Point
TransmitterTuple = namedtuple('Transmitter', ['location', 'pulses'])
ReceiverTuple = namedtuple('Receiver', ['daq_index', 'location'])
# Detection = namedtuple('Detection', ['location', 'power', 'velocity', 'doppler'])
@dataclass
class Detection:
'''Class for keeping track of an item in inventory.'''
location: Point = Point()
power: float = 0
velocity: Point = Point()
doppler: float = 0
...@@ -94,7 +94,7 @@ class Receiver(object): ...@@ -94,7 +94,7 @@ class Receiver(object):
self.data = None self.data = None
self.fast_fft_len=int(round(self.daq.sample_rate * self.transmitter.pulses[0].delay)) self.fast_fft_len=int(round(self.daq.sample_rate * self.transmitter.pulses[0].delay))
self.mti_window = np.transpose(np.tile(np.fft.fftshift(signal.windows.chebwin(self.slow_fft_len,at=60)),self.fast_fft_size).reshape((-1,self.slow_fft_len))) self.mti_window = np.transpose(np.tile(np.fft.fftshift(signal.windows.chebwin(self.slow_fft_len,at=60)),self.fast_fft_size).reshape((-1,self.slow_fft_len)))
def connect_signals(self): def connect_signals(self):
# self.daq.reset_signal.connect(self.reset) # self.daq.reset_signal.connect(self.reset)
...@@ -192,14 +192,14 @@ class Receiver(object): ...@@ -192,14 +192,14 @@ class Receiver(object):
# print('dc.shape',dc.shape) # print('dc.shape',dc.shape)
self.fft_mat = self.compute_fft2(self.datacube[-1], (self.slow_fft_size, self.fast_fft_size)) self.fft_mat = self.compute_fft2(self.datacube[-1], (self.slow_fft_size, self.fast_fft_size))
#self.fft_mat=np.multiply(self.fft_mat,self.mti_window) self.fft_mat=np.multiply(self.fft_mat,self.mti_window)
# print('fft_mat.shape', self.fft_mat.shape) # print('fft_mat.shape', self.fft_mat.shape)
if self.datacube[-1].shape == self.datacube[-2].shape: # if self.datacube[-1].shape == self.datacube[-2].shape:
if hasattr(self, 'zero_fft_mat'): # if hasattr(self, 'zero_fft_mat'):
self.fft_mat -= self.zero_fft_mat # self.fft_mat -= self.zero_fft_mat
else: # else:
self.zero_fft_mat = self.fft_mat # self.zero_fft_mat = self.fft_mat
# Power Thresholding # Power Thresholding
# if self.cfft_data[vmax_bin] < POWER_THRESHOLD: # if self.cfft_data[vmax_bin] < POWER_THRESHOLD:
......
# -*- coding: utf-8 -*-
"""
AP-S Tracker Class.
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import numpy as np # Storing data
from pyratk.datatypes.ts_data import TimeSeries # storing data
from pyratk.datatypes.motion import StateMatrix
from pyratk.datatypes.geometry import Point
from pyratk.datatypes.radar import Detection
class ApsTracker(object):
"""Class to track detections using 4 doppler measurements."""
# === INITIALIZATION METHODS ============================================= #
def __init__(self, daq, receiver_array):
"""
Initialize tracker class.
"""
self.valid_constraints = {1: ['x', 'y', 'z'],
2: ['xy', 'xz', 'yz'],
3: []}
# copy arguments into attributes
self.daq = daq
self.receiver_array = receiver_array
self.detections = []
# Configure control signals
self.connect_control_signals()
def connect_control_signals(self):
"""Initialize control signals."""
self.receiver_array.data_available_signal.connect(self.update)
self.daq.reset_signal.connect(self.reset)
# ====== CONTROL METHODS ================================================= #
def update(self):
"""
Update position of track based on new data.
Called by data_available_signal signal in DAQ.
"""
self.detections.clear()
# Add new Detection objects to detections list
# loc is cylindrical (R, theta, Z), but Z is ignored by plot
R = np.random.rand() * 15
theta = np.random.rand() * np.pi
loc = Point(R, theta, 0.0)
new_detection = Detection(loc)
self.detections.append(new_detection)
def reset(self):
"""Reset all temporal elements."""
print("(tracker.py) Resetting tracker...")
self.detections.clear()
# class TrackerEvaluator(Object):
# def __init__():
# -*- coding: utf-8 -*-
"""
PolarTracker Widget Class.
Contains parametric graph capable of plotting a tracked object's path in the
polar coordinate system.
Author: Jason Merlo, Stavros Vakalis
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import pyqtgraph as pg
from pyqtgraph.Qt import QtGui, QtCore
import numpy as np # Used for numerical operations
import platform # Get OS for DPI scaling
from pyratk.datatypes.geometry import Point, Circle
class PolarTrackerWidget(pg.GraphicsLayoutWidget):
def __init__(self, tracker, max_range=20):
super().__init__()
"""
Initialize polar tracker widget.
tracker - Tracker object
Note: Tracker requires list of namedtouples named `detections`.
The namedtouple must contain:
- location (Point): location of detection
- power (float): power of detection
- doppler (float): Doppler velocity of detection
- velocity (Point): Velocity vector (if tracked)
"""
# Copy arguments to member variables
self.tracker = tracker
self.max_range = max_range
# Add plots to layout
self.plot = self.addPlot()
# Add polar grid lines
self.plot.addLine(x=0, pen=0.2)
self.plot.addLine(y=0, pen=0.2)
for r in range(2, self.max_range*2, 2):
circle = pg.QtGui.QGraphicsEllipseItem(-r, -r, r * 2, r * 2)
circle.setPen(pg.mkPen(0.2))
self.plot.addItem(circle)
# Add radar location marker plot
# self.radar_loc_plot = pg.ScatterPlotItem(
# size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 255, 0, 255))
# for radar in self.tracker.radar_array:
# loc = (radar.loc.x, radar.loc.y)
# self.radar_loc_plot.addPoints(pos=[loc])
# self.plot.addItem(self.radar_loc_plot)
# Add radar detection marker plot
self.det_loc_plot = pg.ScatterPlotItem(
size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 255, 0, 255))
self.plot.addItem(self.det_loc_plot)
# Set up plot
self.plot.setLimits(yMin=0)
self.plot.setRange(yRange=[0, self.max_range], xRange=[-self.max_range, self.max_range])
self.plot.setAspectLocked(True)
# xMin=-self.max_range, xMax=self.max_range)
self.plot.setLabel('left', text='Downrange', units='m')
self.plot.setLabel('bottom', text='Crossrange', units='m')
self.plot.setTitle('Polar Tracker')
# Remove extra margins around plot
self.ci.layout.setContentsMargins(0, 0, 0, 0)
def update(self):
'''
Draw detections on graph.
'''
self.det_loc_plot.clear()
for det in self.tracker.detections:
R = det.location.p[0]
theta = det.location.p[1]
x = R * np.cos(theta)
y = R * np.sin(theta)
self.det_loc_plot.addPoints(pos=[Point(x, y)])
def reset(self):
# self.tracker.reset()
self.update()
# === UTILITY FUNCTIONS ===================================================
def draw_circle(self, curve, cir, num_pts=100, color="AAFFFF16"):
'''
adds a Circle, c, to the plot
'''
x_list = []
y_list = []
for i in range(num_pts):
ang = 2 * np.pi * (i / num_pts)
x = (np.cos(ang) * cir.r) + cir.c.x
y = (np.sin(ang) * cir.r) + cir.c.y
x_list.append(x)
y_list.append(y)
# append first point to end to 'close' circle
x_list.append(x_list[0])
y_list.append(y_list[0])
x_pts = np.array(x_list)
y_pts = np.array(y_list)
curve.setData(x=x_pts, y=y_pts, pen=pg.mkPen(
{'color': color, 'width': 3}))
def draw_triangle(self, curve, pts, color="AAFFFF16"):
"""Create triangle object from points."""
curve.clear()
for pt in pts:
curve.append(pt)
curve.append(pts[0])
def ppm(self):
'''
pixels per meter
'''
os = platform.system().lower()
if os == 'windows':
pixels = self.frameGeometry().width() - 55.75
meters = self.plot.vb.viewRange(
)[0][1] - self.plot.vb.viewRange()[0][0]
elif os == 'darwin':
pixels = self.frameGeometry().width() - 55.75
meters = self.plot.vb.viewRange(
)[0][1] - self.plot.vb.viewRange()[0][0]
return pixels / meters
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment