Skip to content

Commit

Permalink
Re-organize queries
Browse files Browse the repository at this point in the history
With the goal to further simplify the code, this commit applies the
single responsibility principle to the query method. This is done by
removing the validation extra-functionality into a read_response method
which is now called after the query. The read_response method accepts
the expected response as optional argument so that the response can be
validated in a uniform way.

Additionally resolved several issues reported by static code analysis.
  • Loading branch information
clssn committed Mar 10, 2024
1 parent 7d85842 commit 59784df
Showing 1 changed file with 81 additions and 90 deletions.
171 changes: 81 additions & 90 deletions src/numato_gpio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Python API for Numato USB GPIO devices."""

import serial
import threading
from typing import Optional

import serial

# Edge detection
RISING = 1
Expand All @@ -19,8 +21,8 @@
devices = dict()


class NumatoGpioError(Exception):
pass
class NumatoGpioError(RuntimeError):
"""Generic error during GPIO processing."""


DISCOVER_LOCK = threading.RLock()
Expand Down Expand Up @@ -49,7 +51,7 @@ def discover(dev_files=DEFAULT_DEVICES):
with DISCOVER_LOCK:
# remove disconnected
for dev_id, dev in list(devices.items()):
if not (dev._ser and dev._ser.is_open):
if not dev.connected:
del devices[dev_id]

# discover newly connected
Expand Down Expand Up @@ -112,27 +114,32 @@ def __init__(self, device="/dev/ttyACM0"):
self._write(b"gpio notify off\r")
self._drain_ser_buffer()
self._poll_thread.start()
self._MASK_ALL_PORTS = 2 ** self.ports - 1
self._HEX_DIGITS = self.ports // 4
self._mask_all_ports = 2**self.ports - 1
self._hex_digits = self.ports // 4
self._callback = [0] * self.ports
self._edge = [None] * self.ports
self._ver = None
try:
_ = self.id
_ = self.ver
self.iodir = self._MASK_ALL_PORTS # resets iomask as well
self.iodir = self._mask_all_ports # resets iomask as well
self.notify = False
except NumatoGpioError as err:
raise NumatoGpioError(
f"Device {self.dev_file} doesn't answer like a numato device: {err}"
)
) from err

@property
def connected(self) -> bool:
"""Determine whether a serial connection to the device is established."""
return self._ser and self._ser.is_open

@property
def ver(self):
"""Return the device's version string."""
if self._ver is None:
with self._rw_lock:
self._ver = self._query_string(f"ver")
self._ver = self._query_string("ver")
return self._ver

@property
Expand All @@ -146,7 +153,8 @@ def id(self):
@id.setter
def id(self, new_id):
"""Re-program the device id to the value in new_id."""
self._query(f"id set {new_id:08x}", expected=">")
self._query(f"id set {new_id:08x}")
self._read_response("")
self._id = new_id

@property
Expand All @@ -169,7 +177,7 @@ def setup(self, port, direction):
"""Set up a single port as input or output port."""
self._check_port_range(port)
with self._rw_lock:
new_iodir = (self._iodir & ((1 << port) ^ self._MASK_ALL_PORTS)) | (
new_iodir = (self._iodir & ((1 << port) ^ self._mask_all_ports)) | (
(0 if not direction else 1) << port
)
self.iodir = new_iodir
Expand All @@ -181,8 +189,8 @@ def cleanup(self):
output port when re-connected to e.g. a grounded input signal.
"""
with self._rw_lock:
self.iomask = self._MASK_ALL_PORTS
self.iodir = self._MASK_ALL_PORTS
self.iomask = self._mask_all_ports
self.iodir = self._mask_all_ports
self.notify = False
self._ser.close()

Expand All @@ -195,7 +203,7 @@ def write(self, port, value):
with self._rw_lock:
if (self._iodir >> port) & 1:
raise NumatoGpioError("Can't write to input port")
self._state = (self._state & ((1 << port) ^ self._MASK_ALL_PORTS)) | (
self._state = (self._state & ((1 << port) ^ self._mask_all_ports)) | (
(0 if not value else 1) << port
)
self.writeall(self._state)
Expand All @@ -220,25 +228,24 @@ def adc_read(self, adc_port):
"that port does not provide an ADC."
)
with self._rw_lock:
"""On devices with more than 32 ports, adc read command **only**
accepts two-digit numbers with leading zero.
This is vaguely described at the end of "The Command Set"
in the documentation:
https://numato.com/docs/64-channel-usb-gpio-module-analog-inputs/
https://numato.com/docs/128-channel-usb-gpio-module-with-analog-inputs/
"""
# On devices with more than 32 ports, adc read command **only**
# accepts two-digit numbers with leading zero.
#
# This is vaguely described at the end of "The Command Set"
# in the documentation:
# https://numato.com/docs/64-channel-usb-gpio-module-analog-inputs/
# https://numato.com/docs/128-channel-usb-gpio-module-with-analog-inputs/
digits = 2 if self.ports > 32 else 1
query = f"adc read {adc_port:0{digits}}"
self._query(query)
resp = self._read_until(">")
try:
return int(resp[:-1])
except ValueError:
resp = self._read_response()
return int(resp)
except ValueError as err:
raise NumatoGpioError(
f"Query '{repr(query)}' returned unexpected result {repr(resp)}. "
"Expected 10 bit decimal integer."
)
) from err

@property
def can_notify(self):
Expand All @@ -254,13 +261,12 @@ def notify(self):

if not hasattr(self, "_notify"):
query = "gpio notify get"
expected = "gpio notify "
with self._rw_lock:
self._query(query, expected=expected)
response = self._read_until(">")
if response.startswith("enabled"):
self._query(query)
response = self._read_response()
if response == "gpio notify enabled":
self._notify = True
elif response.startswith("disabled"):
elif response == "gpio notify disabled":
self._notify = False
else:
raise NumatoGpioError(
Expand All @@ -282,12 +288,13 @@ def notify(self, enable):
return

query = f"gpio notify {'on' if enable else 'off'}"
expected_response = (
f"gpio notify {'enabled' if enable else 'disabled'}>"
)
expected_response = f"gpio notify {'enabled' if enable else 'disabled'}"

with self._rw_lock:
self._query(query, expected=expected_response)

self._query(query)
self._read_response(expected_response)

self._notify = enable

def add_event_detect(self, port, callback, edge=BOTH):
Expand Down Expand Up @@ -324,16 +331,15 @@ def iomask(self, mask):
each call to iodir.
"""
with self._rw_lock:
self._query(
"gpio iomask {:0{dgts}x}".format(mask, dgts=self._HEX_DIGITS),
expected=">",
)
self._query(f"gpio iomask {mask:0{self._hex_digits}x}")
self._read_response("")
self._iomask = mask

@property
def iodir(self):
"""Get the I/O direction of the device's ports."""
if not hasattr(self, "_iodir"):
self._iodir = self._MASK_ALL_PORTS
self._iodir = self._mask_all_ports
return self._iodir

@iodir.setter
Expand All @@ -345,12 +351,12 @@ def iodir(self, direction):
inputs from being written to.
"""
with self._rw_lock:
self.iomask = self._MASK_ALL_PORTS
self.iomask = self._mask_all_ports
self._query(
"gpio iodir {:0{dgts}x}".format(direction, dgts=self._HEX_DIGITS),
expected=">",
f"gpio iodir {direction:0{self._hex_digits}x}",
)
self.iomask = direction ^ self._MASK_ALL_PORTS
self._read_response("")
self.iomask = direction ^ self._mask_all_ports
self._iodir = direction

def readall(self):
Expand All @@ -373,41 +379,31 @@ def writeall(self, bits):
"""
with self._rw_lock:
self._state = bits & ~self._iodir
self._query(
"gpio writeall {:0{dgts}x}".format(self._state, dgts=self._HEX_DIGITS),
expected=">",
)
self._query(f"gpio writeall {self._state:0{self._hex_digits}x}")
self._read_response("")

EOL_BYTES = b"\r\n"

def _remove_eol(self, sequence: bytes) -> bytes:
return bytes(x for x in sequence if x not in self.EOL_BYTES)

def _query(self, query, expected=None):
def _query(self, query):
with self._rw_lock:
self._write(f"{query}\r".encode())
expected_echo = f"{query}"
try:
self._read_string(expected_echo)
self._read_string(query)
except NumatoGpioError as err:
raise NumatoGpioError(
f"Query {repr(query)} returned unexpected echo {repr(str(err))}"
)
if not expected:
return
try:
self._read_string(expected)
except NumatoGpioError as err:
raise NumatoGpioError(
f"Query {repr(query)} returned unexpected response {repr(str(err))}"
)
) from err

def _write(self, query):
try:
with self._rw_lock:
self._ser.write(query)
except serial.serialutil.SerialException:
except serial.serialutil.SerialException as err:
self._ser.close()
raise NumatoGpioError("Serial communication failure")
raise NumatoGpioError("Serial communication failure") from err

def _read_string(self, expected):
string = self._read(len(expected.encode()))
Expand All @@ -423,7 +419,7 @@ def _query_string(self, query: str) -> str:
"""
with self._rw_lock:
self._query(query)
response = self._read_until(">")[:-1]
response = self._read_response()
return response

def _read_int(self, query, bits):
Expand All @@ -433,23 +429,27 @@ def _read_int(self, query, bits):
try:
val = int(response, 16)
self._read_string(">")
except ValueError:
except ValueError as err:
raise NumatoGpioError(
f"Query '{repr(query)}' returned unexpected result "
f"{repr(response)}. Expected a {bits} bit integer in "
"hexadecimal notation."
)
) from err
except NumatoGpioError as err:
raise NumatoGpioError(
f"Unexpected string {repr(str(err))} after successful query "
f"{repr(query)}"
)
) from err
return val

def _read_until(self, end_str):
def _read_response(self, expected: Optional[str] = None):
response = ""
while not response.endswith(end_str):
response += self._read(1)
while (read_byte := self._read(1)) != ">":
response += read_byte
if expected and expected.lower() != response.lower():
raise NumatoGpioError(
f"Expected response {repr(expected)}, got {repr(response)}"
)
return response

def _read(self, num):
Expand All @@ -460,7 +460,7 @@ def _read(self, num):
self._can_read.release()
return response

def _ser_read(self, num_bytes: int) -> bytes:
def _serial_read(self, num_bytes: int) -> bytes:
response = self._ser.read(num_bytes)
return self._remove_eol(response)

Expand All @@ -474,12 +474,12 @@ def _read_notification(self):
^ ^ ^ ^
start previous value new value iodir mask
"""
self._ser_read(1)
current_value = int(self._ser_read(self.ports // 4), 16)
self._ser_read(1)
previous_value = int(self._ser_read(self.ports // 4), 16)
self._ser_read(1)
_ = int(self._ser_read(self.ports // 4), 16) # read and discard iodir
self._serial_read(1)
current_value = int(self._serial_read(self.ports // 4), 16)
self._serial_read(1)
previous_value = int(self._serial_read(self.ports // 4), 16)
self._serial_read(1)
_ = int(self._serial_read(self.ports // 4), 16) # read and discard iodir

assert current_value is not None and previous_value is not None
edges = current_value ^ previous_value
Expand Down Expand Up @@ -514,7 +514,7 @@ def _poll(self): # noqa: C901
try:
while self._ser and self._ser.is_open:

if not (b := self._ser_read(1).decode()):
if not (b := self._serial_read(1).decode()):
continue

if b != "#":
Expand All @@ -529,29 +529,20 @@ def _poll(self): # noqa: C901
except (TypeError, serial.serialutil.SerialException):
self._ser.close() # ends the polling loop and its thread


def _check_port_range(self, port):
if port not in range(self.ports):
raise NumatoGpioError(f"Port number {port} out of range.")

def _drain_ser_buffer(self):
while self._ser_read(DEVICE_BUFFER_SIZE):
while self._serial_read(DEVICE_BUFFER_SIZE):
pass

def __str__(self):
"""Return human readable string of the device's curent state."""
return (
"dev: {} | id: {} | ver: {} | ports: {} | iodir: 0x{:0{dgts}x} | "
"iomask: 0x{:0{dgts}x} | state: 0x{:0{dgts}x}".format(
self.dev_file,
self.id,
self.ver,
self.ports,
self.iodir,
self.iomask,
self._state,
dgts=self._HEX_DIGITS,
)
f"dev: {self.dev_file} | id: {self.id} | ver: {self.ver} | ports: {self.ports}"
f" | iodir: 0x{self.iodir:0{self._hex_digits}x} | "
f"iomask: 0x{self.iomask:0{self._hex_digits}x} | state: 0x{self._state:0{self._hex_digits}x}"
)

ADC_RESOLUTION = {
Expand Down

0 comments on commit 59784df

Please sign in to comment.