import time
import numpy as np
import logging
from itertools import cycle
from rich.progress import Progress
import rich
import threading
import _thread as thread
import queue
from typing import Optional, Union, Dict, Any
import psutil
from .utils.tui import rich_information
from . import config
from .utils.config import readconfig, undefaultify
from .utils.sound import parse_table, load_sounds, build_playlist
from .services.ThuAZeroService import THUA
from .services.DAQZeroService import DAQ
from .services.GCMZeroService import GCM
from .services.NICounterZeroService import NIC
def timed(fn, s, *args, **kwargs):
quit_fun = thread.interrupt_main # raises KeyboardInterrupt
timer = threading.Timer(s, quit_fun)
timer.start()
try:
result = fn(*args, **kwargs)
except: # catch KeyboardInterrupt for silent timeouts
result = 1
finally:
timer.cancel()
return result
def kill_child_processes():
try:
parent = psutil.Process()
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for child in children:
child.terminate() # friendly termination
_, still_alive = psutil.wait_procs(children, timeout=3)
for child in still_alive:
child.kill() # unfriendly termination
[docs]def client(
protocolfile: str,
playlistfile: Optional[str] = None,
*,
host: str = "localhost",
save_prefix: Optional[str] = None,
show_progress: bool = True,
debug: bool = False,
preview: bool = False,
_stop_event: Optional[threading.Event] = None,
_done_event: Optional[threading.Event] = None,
_queue: Optional[queue.Queue] = None,
):
"""Starts an experiment.
Args:
host (str): _description_
protocolfile (str): _description_
playlistfile (Optional[str]): _description_.
save_prefix (Optional[str]): _description_.
show_progress (bool): _description_.
debug (bool): _description_.
preview (bool): _description_.
_stop_event (threading.Event, optional): Used to stop the task from an outside thread. Defaults to None.
_done_event (threading.Event, optional): Set to signal that the task is done/stopped to an outside thread. Defaults to None.
_queue (queue.Queue, optional): Signal the expected duration of the task to outside funs. Defaults to None.
"""
# load config/protocols
prot = readconfig(protocolfile)
logging.debug(prot)
defaults = config
defaults["host"] = host
if defaults["python_exe"] is None:
defaults["python_exe"] = "python"
if defaults["serializer"] is None:
defaults["serializer"] = "pickle"
rich.print(defaults)
# unique file name for video and node-local logs
if save_prefix is None:
save_prefix = f"{defaults['host']}-{time.strftime('%Y%m%d_%H%M%S')}"
logging.info(f"Saving as {save_prefix}.")
new_console = debug
services = {}
if "THUA" in prot["use_services"] and not preview:
this = defaults.copy()
# update `this`` with service specific host params
if "host" in prot["THUA"]:
this.update(prot["THUA"]["host"])
thua = THUA.make(
this["serializer"],
this["user"],
this["host"],
this["python_exe"],
)
thua.setup(prot["THUA"]["port"], prot["THUA"]["interval"], prot["maxduration"] + 10)
thua.init_local_logger("{0}/{1}/{1}_thu.log".format(this["savefolder"], save_prefix))
services["THUA"] = thua
gcm_keys = [key for key in prot["use_services"] if "GCM" in key]
for gcm_cnt, gcm_key in enumerate(gcm_keys):
# if gcm_key in prot["use_services"] and gcm_key in prot:
this = defaults.copy()
this.update(prot[gcm_key])
host_is_remote = "host" in prot[gcm_key]
if "port" not in prot[gcm_key]:
prot[gcm_key]["port"] = GCM.SERVICE_PORT + gcm_cnt
gcm = GCM.make(
this["serializer"],
this["user"],
this["host"],
this["python_exe"],
host_is_remote=host_is_remote,
new_console=new_console,
port=prot[gcm_key]["port"],
)
cam_params = undefaultify(prot[gcm_key])
if not preview:
maxduration = prot["maxduration"] + 10
else:
maxduration = 1_000_000
if preview:
cam_params["callbacks"] = {"disp_fast": None}
save_suffix = f"_{gcm_cnt+1}" if gcm_cnt > 0 else ""
gcm.setup(
f"{this['savefolder']}/{save_prefix}/{save_prefix}{save_suffix}",
maxduration,
cam_params,
)
if not preview:
gcm.init_local_logger(f"{this['savefolder']}/{save_prefix}/{save_prefix}{save_suffix}_gcm.log")
services[gcm_key] = gcm
daq_keys = [key for key in prot["use_services"] if "DAQ" in key]
daq_keys = [] if preview else daq_keys
for daq_cnt, daq_key in enumerate(daq_keys):
this = defaults.copy()
this.update(prot[daq_key])
if "device" not in prot[daq_key]:
prot[daq_key]["device"] = "Dev1"
if "port" not in prot[daq_key]:
prot[daq_key]["port"] = DAQ.SERVICE_PORT + daq_cnt
if this["host"] in config["ATTENUATION"]: # use node specific attenuation data
attenuation = config["ATTENUATION"][this["host"]]
logging.info(f"Using attenuation data specific to {this['host']}.")
else:
attenuation = config["ATTENUATION"]
# Load/generate all stimuli specified in playlist
fs = prot[daq_key]["samplingrate"]
playlist = parse_table(playlistfile)
sounds = load_sounds(
playlist,
fs,
attenuation=attenuation,
LEDamp=prot[daq_key]["ledamp"],
stimfolder=config["stimfolder"],
)
sounds = [sound.astype(np.float64) for sound in sounds]
# Generate stimulus sequence (shuffle, loop playlist)
playlist_items, totallen = build_playlist(sounds, prot["maxduration"], fs, shuffle=prot[daq_key]["shuffle"])
if prot["maxduration"] == -1:
logging.info(f"Setting maxduration from playlist to {totallen}.")
prot["maxduration"] = totallen
playlist_items = cycle(playlist_items) # iter(playlist_items)
else:
playlist_items = cycle(playlist_items)
# split analog and digital outputs
# TODO: catch errors if channel numbers are inconsistent - sounds[ii].shape[-1] should be nb_analog+nb_digital
if prot[daq_key]["digital_chans_out"] is not None:
nb_digital_chans_out = len(prot[daq_key]["digital_chans_out"])
digital_data = [snd[:, -nb_digital_chans_out:].astype(np.uint8) for snd in sounds]
analog_data = [snd[:, :-nb_digital_chans_out] for snd in sounds] # remove digital traces from stimset
else:
digital_data = None
analog_data = sounds
daq = DAQ.make(
this["serializer"],
this["user"],
this["host"],
this["python_exe"],
new_console=new_console,
port=prot[daq_key]["port"],
)
save_suffix = f"_{daq_cnt+1}" if daq_cnt > 0 else ""
daq.setup(
f"{this['savefolder']}/{save_prefix}/{save_prefix}{save_suffix}",
playlist_items,
playlist,
prot["maxduration"],
fs,
dev_name=prot[daq_key]["device"],
clock_source=prot[daq_key]["clock_source"],
nb_inputsamples_per_cycle=prot[daq_key]["nb_inputsamples_per_cycle"],
analog_chans_in=prot[daq_key]["analog_chans_in"],
analog_chans_in_limits=None,
analog_chans_in_terminals=None,
analog_chans_out=prot[daq_key]["analog_chans_out"],
analog_chans_out_limits=None,
analog_data_out=analog_data,
digital_chans_out=prot[daq_key]["digital_chans_out"],
digital_data_out=digital_data,
metadata={
"analog_chans_in_info": prot[daq_key]["analog_chans_in_info"],
"analog_chans_out_info": prot[daq_key]["analog_chans_out_info"],
"digitial_chans_out_info": prot[daq_key]["digitial_chans_out_info"],
},
params=undefaultify(prot[daq_key]),
)
daq.init_local_logger(f"{this['savefolder']}/{save_prefix}/{save_prefix}{save_suffix}_daq.log")
services[daq_key] = daq
if "NIC" in prot["use_services"]:
this = defaults.copy()
this.update(prot["NIC"])
# update `this`` with service specific host params
if "host" in prot["NIC"]:
this.update(prot["NIC"]["host"])
nic = NIC.make(
this["serializer"],
this["user"],
this["host"],
this["python_exe"],
new_console=new_console,
port=prot["NIC"]["port"],
)
nic_params = undefaultify(prot["NIC"])
nic.setup(
nic_params["output_channel"],
prot["maxduration"] + 10,
nic_params["frequency"],
nic_params["duty_cycle"],
nic_params,
)
nic.init_local_logger(f"{this['savefolder']}/{save_prefix}/{save_prefix}{save_suffix}_daq.log")
# display config info
for key, s in services.items():
rich_information(s.information(), prefix=key)
logging.info("Starting services")
# First, start video services - this will start acquisition or, if external triggering is enabled, arm the cameras to wait for the triggers
time_last_cam_started = time.time() + 5 # in case no cam was initialized
for service_name, service in services.items():
if "GCM" in service_name or "THUA" in service_name:
logging.info(f" {service_name}.")
service.start()
time_last_cam_started = time.time()
time.sleep(0.5)
# start the counter task for triggering frames
if "NIC" in prot["use_services"]:
logging.info(" NI Counter service.")
nic.start()
time_last_cam_started = time.time()
# Wait 5 seconds for cams to run
if daq_keys:
while time.time() - time_last_cam_started < 5:
time.sleep(0.1)
# Start DAQ services
for service_name, service in services.items():
if "DAQ" in service_name:
logging.info(f" {service_name}.")
service.start()
logging.info("All services started.")
if show_progress:
total = 0
for service_name, service in services.items():
total = max(total, service.progress()["total"])
if _queue is not None:
_queue.put(total)
cli_progress(services, save_prefix, _stop_event, _done_event)
else:
return services
[docs]def cli_progress(
services: Dict[str, Any],
save_prefix: str,
stop_event: Optional[threading.Event] = None,
done_event: Optional[threading.Event] = None,
):
"""_summary_
Args:
services (_type_): Dictionary of intialized services.
save_prefix (_type_): Name of the expt.
stop_event (_type_, optional): Used to stop the task from an outside thread. Defaults to None.
done_event (_type_, optional): Set to signal that the task is done/stopped to an outside thread. Defaults to None.
"""
with Progress() as progress:
tasks = {}
for service_name, service in services.items():
tasks[service_name] = progress.add_task(f"[red]{service_name}", total=service.progress()["total"])
RUN = True
STOPPED_PREMATURELY = False
while RUN and not progress.finished:
for task_name, task_id in tasks.items():
if stop_event is not None and stop_event.is_set():
break
if progress._tasks[task_id].finished:
continue
try:
p = timed(services[task_name].progress, 5)
description = None
if "framenumber" in p:
description = f"{task_name} {p['framenumber_delta'] / p['elapsed_delta']: 7.2f} fps"
progress.update(task_id, completed=p["elapsed"], description=description)
except: # if call times out, stop progress display - this will stop the display whenever a task times out - not necessarily when a task is done
progress.stop_task(task_id)
time.sleep(1)
if stop_event is not None and stop_event.is_set():
logging.info("Received STOP signal. Cancelling jobs:")
for task_name, task_id in tasks.items():
progress.stop_task(task_id)
RUN = False
STOPPED_PREMATURELY = True
time.sleep(1)
if STOPPED_PREMATURELY:
logging.info("Finishing jobs.")
for service_name, service in services.items():
logging.info(f" {service_name}")
if service_name == "THUA":
continue
# if service_name == 'GCM' and service.finished:
# continue
try:
service.finish()
except Exception as e:
logging.warning(" Failed.")
print(e)
logging.info(" done.")
time.sleep(4)
if stop_event is not None and not stop_event.is_set() and done_event is not None:
done_event.set()
logging.info("Cleaning up jobs.")
kill_child_processes()
logging.info(f"Done with experiment {save_prefix}.")