Make noisicaa.instrument_db pylint and mypy clean.

looper
Ben Niemann 5 years ago
parent fd9820d3e9
commit 794471af3c

@ -18,13 +18,14 @@
#
# @end:license
from .client import InstrumentDBClientMixin
from .client import InstrumentDBClient
from .instrument_description import (
InstrumentDescription,
Property,
parse_uri,
)
from .mutations import (
Mutation,
AddInstrumentDescription,
RemoveInstrumentDescription,
)

@ -20,71 +20,76 @@
#
# @end:license
# TODO: mypy-unclean
import asyncio
import logging
from typing import Dict, Set, List, Iterable # pylint: disable=unused-import
from noisicaa import core
from noisicaa.core import ipc
from . import mutations
from . import instrument_description
logger = logging.getLogger(__name__)
class InstrumentDBClientMixin(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._stub = None
self._session_id = None
self._instruments = {}
class InstrumentDBClient(object):
def __init__(self, event_loop: asyncio.AbstractEventLoop, server: ipc.Server) -> None:
self.event_loop = event_loop
self.server = server
self.listeners = core.CallbackRegistry()
self.__stub = None # type: ipc.Stub
self.__session_id = None # type: str
self.__instruments = {} # type: Dict[str, instrument_description.InstrumentDescription]
@property
def instruments(self):
def instruments(self) -> Iterable[instrument_description.InstrumentDescription]:
return sorted(
self._instruments.values(), key=lambda i: i.display_name.lower())
self.__instruments.values(), key=lambda i: i.display_name.lower())
def get_instrument_description(self, uri: str) -> instrument_description.InstrumentDescription:
return self.__instruments[uri]
def get_instrument_description(self, uri):
return self._instrument[uri]
async def setup(self) -> None:
self.server.add_command_handler('INSTRUMENTDB_MUTATIONS', self.__handle_mutation)
async def setup(self):
await super().setup()
self.server.add_command_handler(
'INSTRUMENTDB_MUTATIONS', self.handle_mutation)
async def cleanup(self) -> None:
self.server.remove_command_handler('INSTRUMENTDB_MUTATIONS')
async def connect(self, address, flags=None):
assert self._stub is None
self._stub = ipc.Stub(self.event_loop, address)
await self._stub.connect()
self._session_id = await self._stub.call(
async def connect(self, address: str, flags: Set[str] = None) -> None:
assert self.__stub is None
self.__stub = ipc.Stub(self.event_loop, address)
await self.__stub.connect()
self.__session_id = await self.__stub.call(
'START_SESSION', self.server.address, flags)
async def disconnect(self, shutdown=False):
if self._session_id is not None:
async def disconnect(self, shutdown: bool = False) -> None:
if self.__session_id is not None:
try:
await self._stub.call('END_SESSION', self._session_id)
await self.__stub.call('END_SESSION', self.__session_id)
except ipc.ConnectionClosed:
logger.info("Connection already closed.")
self._session_id = None
self.__session_id = None
if self._stub is not None:
if self.__stub is not None:
if shutdown:
await self.shutdown()
await self._stub.close()
self._stub = None
await self.__stub.close()
self.__stub = None
async def shutdown(self):
await self._stub.call('SHUTDOWN')
async def shutdown(self) -> None:
await self.__stub.call('SHUTDOWN')
async def start_scan(self):
return await self._stub.call('START_SCAN', self._session_id)
async def start_scan(self) -> None:
await self.__stub.call('START_SCAN', self.__session_id)
def handle_mutation(self, mutation_list):
def __handle_mutation(self, mutation_list: List[mutations.Mutation]) -> None:
for mutation in mutation_list:
logger.info("Mutation received: %s", mutation)
if isinstance(mutation, mutations.AddInstrumentDescription):
self._instruments[mutation.description.uri] = mutation.description
self.__instruments[mutation.description.uri] = mutation.description
else:
raise ValueError(mutation)

@ -30,29 +30,13 @@ from . import process
from . import client
class TestClientImpl(object):
def __init__(self, event_loop):
super().__init__()
self.event_loop = event_loop
self.server = ipc.Server(self.event_loop, 'client', socket_dir=TEST_OPTS.TMP_DIR)
async def setup(self):
await self.server.setup()
async def cleanup(self):
await self.server.cleanup()
class TestClient(client.InstrumentDBClientMixin, TestClientImpl):
pass
class InstrumentDBClientTest(unittest.AsyncTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.process = None
self.process_task = None
self.client_server = None
self.client = None
async def setup_testcase(self):
@ -61,7 +45,10 @@ class InstrumentDBClientTest(unittest.AsyncTestCase):
await self.process.setup()
self.process_task = self.loop.create_task(self.process.run())
self.client = TestClient(self.loop)
self.client_server = ipc.Server(self.loop, 'client', socket_dir=TEST_OPTS.TMP_DIR)
await self.client_server.setup()
self.client = client.InstrumentDBClient(self.loop, self.client_server)
await self.client.setup()
await self.client.connect(self.process.server.address)
@ -69,6 +56,8 @@ class InstrumentDBClientTest(unittest.AsyncTestCase):
if self.client is not None:
await self.client.disconnect(shutdown=True)
await self.client.cleanup()
if self.client_server is not None:
await self.client_server.cleanup()
if self.process is not None:
if self.process_task is not None:
await self.process.shutdown()

@ -20,11 +20,9 @@
#
# @end:license
# TODO: mypy-unclean
import enum
import urllib.parse
from typing import Callable
from typing import Callable, Dict, Any
from noisicaa import node_db
@ -49,24 +47,30 @@ class Property(enum.Enum):
class InstrumentDescription(object):
def __init__(self, uri, path, display_name, properties):
def __init__(
self,
uri: str,
path: str,
display_name: str,
properties: Dict[Property, Any],
) -> None:
self.uri = uri
self.path = path
self.display_name = display_name
self.properties = properties
@property
def format(self):
def format(self) -> str:
return urllib.parse.urlparse(self.uri).scheme
def parse_uri(
uri: str,
get_node_description: Callable[[str], node_db.NodeDescription]) -> node_db.NodeDescription:
fmt, _, path, _, args, _ = urllib.parse.urlparse(uri)
fmt, _, path, _, query, _ = urllib.parse.urlparse(uri)
path = urllib.parse.unquote(path)
if args:
args = dict(urllib.parse.parse_qsl(args, strict_parsing=True))
if query:
args = dict(urllib.parse.parse_qsl(query, strict_parsing=True))
else:
args = {}

@ -20,9 +20,7 @@
#
# @end:license
# TODO: mypy-unclean
# TODO: pylint-unclean
import asyncio
import os
import os.path
import logging
@ -31,6 +29,7 @@ import queue
import sys
import threading
import time
from typing import Any, Callable, Dict, List, Set, Iterable # pylint: disable=unused-import
from noisicaa import core
from noisicaa import instrument_db
@ -48,51 +47,56 @@ class ScanAborted(Exception):
class InstrumentDB(object):
VERSION = 2
def __init__(self, event_loop, cache_dir):
self.event_loop = event_loop
self.cache_dir = cache_dir
def __init__(self, event_loop: asyncio.AbstractEventLoop, cache_dir: str) -> None:
self.listeners = core.CallbackRegistry()
self.instruments = None
self.file_map = None
self.last_scan_time = None
self.scan_thread = None
self.scan_commands = queue.Queue()
self.stopping = threading.Event()
self.__event_loop = event_loop
self.__cache_dir = cache_dir
self.__instruments = None # type: Dict[str, instrument_db.InstrumentDescription]
self.__file_map = None # type: Dict[str, float]
self.__last_scan_time = None # type: float
self.__scan_thread = None # type: threading.Thread
self.__scan_commands = queue.Queue() # type: queue.Queue
self.__stopping = threading.Event() # type: threading.Event
@property
def last_scan_time(self) -> float:
return self.__last_scan_time
def setup(self):
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)
def setup(self) -> None:
if not os.path.isdir(self.__cache_dir):
os.makedirs(self.__cache_dir)
cache_data = self.load_cache(self.cache_path, None)
cache_data = self.__load_cache(self.__cache_path, None)
if cache_data is not None:
logger.info("Loaded cached instrument database.")
self.instruments = cache_data['instruments']
self.file_map = cache_data['file_map']
self.last_scan_time = cache_data['last_scan_time']
logger.info("%d instruments.", len(self.instruments))
logger.info("last scan: %s.", time.ctime(self.last_scan_time))
self.__instruments = cache_data['instruments']
self.__file_map = cache_data['file_map']
self.__last_scan_time = cache_data['last_scan_time']
logger.info("%d instruments.", len(self.__instruments))
logger.info("last scan: %s.", time.ctime(self.__last_scan_time))
else:
logger.info("Starting with empty instrument database.")
self.instruments = {}
self.file_map = {}
self.last_scan_time = 0
self.scan_thread = threading.Thread(target=self.scan_main)
self.scan_thread.start()
def cleanup(self):
if self.scan_thread is not None:
self.scan_commands.put(('STOP',))
self.stopping.set()
self.scan_thread.join()
self.scan_thread = None
def add_mutations_listener(self, callback):
self.__instruments = {}
self.__file_map = {}
self.__last_scan_time = 0
self.__scan_thread = threading.Thread(target=self.__scan_main)
self.__scan_thread.start()
def cleanup(self) -> None:
if self.__scan_thread is not None:
self.__scan_commands.put(('STOP',))
self.__stopping.set()
self.__scan_thread.join()
self.__scan_thread = None
def add_mutations_listener(
self, callback: Callable[[List[instrument_db.Mutation]], None]) -> None:
return self.listeners.add('db_mutations', callback)
def load_cache(self, path, default):
def __load_cache(self, path: str, default: Dict[str, Any]) -> Dict[str, Any]:
if not os.path.isfile(path):
return default
@ -107,7 +111,7 @@ class InstrumentDB(object):
return cached.get('data', default)
def store_cache(self, path, data):
def __store_cache(self, path: str, data: Dict[str, Any]) -> None:
cached = {
'version': self.VERSION,
'data': data,
@ -119,30 +123,30 @@ class InstrumentDB(object):
os.replace(path + '.new', path)
@property
def cache_path(self):
return os.path.join(self.cache_dir, 'instrument_db.cache')
def __cache_path(self) -> str:
return os.path.join(self.__cache_dir, 'instrument_db.cache')
def publish_scan_state(self, state, *args):
self.event_loop.call_soon_threadsafe(
def __publish_scan_state(self, state: str, *args) -> None:
self.__event_loop.call_soon_threadsafe(
self.listeners.call, 'scan-state', state, *args)
def add_instruments(self, descriptions):
def __add_instruments(self, descriptions: List[instrument_db.InstrumentDescription]) -> None:
for description in descriptions:
self.instruments[description.uri] = description
self.__instruments[description.uri] = description
self.listeners.call(
'db_mutations',
[instrument_db.AddInstrumentDescription(description)
for description in descriptions])
def scan_main(self):
def __scan_main(self) -> None:
try:
while not self.stopping.is_set():
cmd, *args = self.scan_commands.get()
while not self.__stopping.is_set():
cmd, *args = self.__scan_commands.get()
if cmd == 'STOP':
break
elif cmd == 'SCAN':
self.do_scan(*args)
self.__do_scan(*args)
else:
raise ValueError(cmd)
@ -150,37 +154,37 @@ class InstrumentDB(object):
sys.stdout.flush()
sys.excepthook(*sys.exc_info())
sys.stderr.flush()
os._exit(1)
os._exit(1) # pylint: disable=protected-access
def update_cache(self):
def __update_cache(self) -> None:
cache_data = {
'instruments': self.instruments,
'file_map': self.file_map,
'last_scan_time': self.last_scan_time,
'instruments': self.__instruments,
'file_map': self.__file_map,
'last_scan_time': self.__last_scan_time,
}
self.store_cache(self.cache_path, cache_data)
self.__store_cache(self.__cache_path, cache_data)
def do_scan(self, search_paths, incremental):
def __do_scan(self, search_paths: List[str], incremental: bool) -> None:
try:
file_list = self.collect_files(search_paths, incremental)
self.scan_files(file_list)
self.last_scan_time = time.time()
self.update_cache()
file_list = self.__collect_files(search_paths, incremental)
self.__scan_files(file_list)
self.__last_scan_time = time.time()
self.__update_cache()
except ScanAborted:
logger.warning("Scan was aborted.")
self.publish_scan_state('aborted')
self.__publish_scan_state('aborted')
def collect_files(self, search_paths, incremental):
def __collect_files(self, search_paths: List[str], incremental: bool) -> List[str]:
logger.info("Collecting files (incremental=%s)", incremental)
self.publish_scan_state('prepare')
self.__publish_scan_state('prepare')
seen_files = set()
seen_files = set() # type: Set[str]
file_list = []
for root_path in search_paths:
logger.info("Collecting files from %s", root_path)
for dname, dirs, files in os.walk(root_path):
if self.stopping.is_set():
for dname, _, files in os.walk(root_path):
if self.__stopping.is_set():
raise ScanAborted
for fname in sorted(files):
@ -189,7 +193,7 @@ class InstrumentDB(object):
if path in seen_files:
continue
if incremental and os.path.getmtime(path) == self.file_map.get(path, -1):
if incremental and os.path.getmtime(path) == self.__file_map.get(path, -1):
continue
seen_files.add(path)
@ -202,7 +206,7 @@ class InstrumentDB(object):
return file_list
def scan_files(self, file_list):
def __scan_files(self, file_list: List[str]) -> None:
scanners = [
sample_scanner.SampleScanner(),
soundfont_scanner.SoundFontScanner(),
@ -210,32 +214,32 @@ class InstrumentDB(object):
batch = []
for idx, path in enumerate(file_list):
if self.stopping.is_set():
if self.__stopping.is_set():
raise ScanAborted
logger.info("Scanning file %s...", path)
self.publish_scan_state('scan', idx, len(file_list))
self.__publish_scan_state('scan', idx, len(file_list))
for scanner in scanners:
for description in scanner.scan(path):
batch.append(description)
if len(batch) > 10:
self.event_loop.call_soon_threadsafe(
self.add_instruments, list(batch))
self.__event_loop.call_soon_threadsafe(
self.__add_instruments, list(batch))
batch.clear()
self.file_map[path] = os.path.getmtime(path)
self.__file_map[path] = os.path.getmtime(path)
if batch:
self.event_loop.call_soon_threadsafe(
self.add_instruments, list(batch))
self.__event_loop.call_soon_threadsafe(
self.__add_instruments, list(batch))
batch.clear()
self.publish_scan_state('complete')
self.__publish_scan_state('complete')
def initial_mutations(self):
for uri, description in sorted(self.instruments.items()):
def initial_mutations(self) -> Iterable[instrument_db.Mutation]:
for _, description in sorted(self.__instruments.items()):
yield instrument_db.AddInstrumentDescription(description)
def start_scan(self, search_paths, incremental):
self.scan_commands.put(('SCAN', list(search_paths), incremental))
def start_scan(self, search_paths: List[str], incremental: bool) -> None:
self.__scan_commands.put(('SCAN', list(search_paths), incremental))

@ -31,8 +31,8 @@ from . import db
logger = logging.getLogger(__name__)
class NodeDBTest(unittest.AsyncTestCase):
async def test_foo(self):
class InstrumentDBTest(unittest.AsyncTestCase):
async def test_scan(self):
complete = asyncio.Event(loop=self.loop)
def state_listener(state, *args):
if state == 'complete':

@ -20,13 +20,14 @@
#
# @end:license
from typing import Iterable
import urllib.parse
class Scanner(object):
def __init__(self):
pass
from noisicaa import instrument_db
def make_uri(self, fmt, path, **kwargs):
class Scanner(object):
def make_uri(self, fmt: str, path: str, **kwargs) -> str:
return urllib.parse.urlunparse((
fmt,
None,
@ -35,5 +36,5 @@ class Scanner(object):
urllib.parse.urlencode(sorted((k, str(v)) for k, v in kwargs.items()), True),
None))
def scan(self, path):
def scan(self, path: str) -> Iterable[instrument_db.InstrumentDescription]:
raise NotImplementedError

@ -20,42 +20,39 @@
#
# @end:license
# TODO: mypy-unclean
import logging
import time
from typing import cast, List, Set # pylint: disable=unused-import
from noisicaa import constants
from noisicaa import core
from .private import db
from . import process_base
from . import mutations as mutations_lib # pylint: disable=unused-import
logger = logging.getLogger(__name__)
class Session(core.CallbackSessionMixin, core.SessionBase):
def __init__(self, client_address, flags, **kwargs):
def __init__(self, client_address: str, flags: Set[str], **kwargs) -> None:
super().__init__(callback_address=client_address, **kwargs)
self.flags = flags or set()
self.pending_mutations = []
async def setup(self):
await super().setup()
self.__flags = flags or set()
self.__pending_mutations = [] # type: List[mutations_lib.Mutation]
def callback_connected(self):
def callback_connected(self) -> None:
logger.info(
"Client callback connection established, sending %d pending mutations.",
len(self.pending_mutations))
self.publish_mutations(self.pending_mutations)
self.pending_mutations.clear()
len(self.__pending_mutations))
self.publish_mutations(self.__pending_mutations)
self.__pending_mutations.clear()
def publish_mutations(self, mutations):
def publish_mutations(self, mutations: List[mutations_lib.Mutation]) -> None:
if not mutations:
return
if not self.callback_alive:
self.pending_mutations.extend(mutations)
self.__pending_mutations.extend(mutations)
return
self.async_callback('INSTRUMENTDB_MUTATIONS', list(mutations))
@ -64,15 +61,16 @@ class Session(core.CallbackSessionMixin, core.SessionBase):
class InstrumentDBProcess(process_base.InstrumentDBProcessBase):
session_cls = Session
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.db = None
self.db = None # type: db.InstrumentDB
self.search_paths = [
'/usr/share/sounds/sf2/',
'/data/instruments/',
]
async def setup(self):
async def setup(self) -> None:
await super().setup()
self.db = db.InstrumentDB(self.event_loop, constants.CACHE_DIR)
@ -81,25 +79,27 @@ class InstrumentDBProcess(process_base.InstrumentDBProcessBase):
if time.time() - self.db.last_scan_time > 3600:
self.db.start_scan(self.search_paths, True)
async def cleanup(self):
async def cleanup(self) -> None:
if self.db is not None:
self.db.cleanup()
self.db = None
await super().cleanup()
def publish_mutations(self, mutations):
def publish_mutations(self, mutations: List[mutations_lib.Mutation]) -> None:
for session in self.sessions:
session = cast(Session, session)
session.publish_mutations(mutations)
async def session_started(self, session):
async def session_started(self, session: core.SessionBase) -> None:
session = cast(Session, session)
# Send initial mutations to build up the current pipeline
# state.
session.publish_mutations(list(self.db.initial_mutations()))
async def handle_start_scan(self, session_id):
async def handle_start_scan(self, session_id: str) -> None:
self.get_session(session_id)
return self.db.start_scan(self.search_paths, True)
self.db.start_scan(self.search_paths, True)
class InstrumentDBSubprocess(core.SubprocessMixin, InstrumentDBProcess):

@ -87,22 +87,6 @@ class AudioProcClient(
self.__app.onPipelineStatus(status)
class InstrumentDBClientImpl(object):
def __init__(self, event_loop, server):
super().__init__()
self.event_loop = event_loop
self.server = server
async def setup(self):
pass
async def cleanup(self):
pass
class InstrumentDBClient(instrument_db.InstrumentDBClientMixin, InstrumentDBClientImpl):
pass
class BaseEditorApp(object):
def __init__(self, *, process, runtime_settings, settings=None):
self.__context = ui_base.CommonContext(app=self)
@ -207,7 +191,7 @@ class BaseEditorApp(object):
instrument_db_address = await self.process.manager.call(
'CREATE_INSTRUMENT_DB_PROCESS')
self.instrument_db = InstrumentDBClient(
self.instrument_db = instrument_db.InstrumentDBClient(
self.process.event_loop, self.process.server)
await self.instrument_db.setup()
await self.instrument_db.connect(instrument_db_address)

Loading…
Cancel
Save