Allow pure proto requests.

No more pickle.
About 15% faster for small messages, but (surprisingly) slower for large messages...
looper
Ben Niemann 2019-02-01 00:59:38 +01:00
parent ce16b966f2
commit 22550078e6
3 changed files with 85 additions and 26 deletions

View File

@ -30,9 +30,11 @@ import pickle
import pprint
import time
import traceback
from typing import cast, Any, Optional, Dict, Callable
from typing import cast, Any, Optional, Dict, Tuple, Callable, Type
import uuid
from google.protobuf import message as protobuf
from . import stats
logger = logging.getLogger(__name__)
@ -155,7 +157,7 @@ class Server(object):
self.__next_connection_id = 0
self.__server = None # type: asyncio.AbstractServer
self.__command_handlers = {} # type: Dict[str, Callable]
self.__command_handlers = {} # type: Dict[str, Tuple[Callable, Type[protobuf.Message], Type[protobuf.Message]]
self.__command_log_levels = {} # type: Dict[str, int]
self.stat_bytes_sent = None # type: stats.Counter
@ -166,9 +168,16 @@ class Server(object):
return self.__server is None
def add_command_handler(
self, cmd: str, handler: Callable[..., Any], log_level: Optional[int] = -1) -> None:
self,
cmd: str,
handler: Callable[..., Any],
request_cls: Type[protobuf.Message] = None,
response_cls: Type[protobuf.Message] = None,
*,
log_level: Optional[int] = -1
) -> None:
assert cmd not in self.__command_handlers
self.__command_handlers[cmd] = handler
self.__command_handlers[cmd] = (handler, request_cls, response_cls)
if log_level is not None:
self.__command_log_levels[cmd] = log_level
@ -245,28 +254,43 @@ class Server(object):
async def handle_command(self, command: str, payload: bytes) -> bytes:
try:
handler = self.__command_handlers[command]
handler, request_cls, response_cls = self.__command_handlers[command]
args, kwargs = self.deserialize(payload)
if request_cls is not None:
request = request_cls()
request.ParseFromString(payload)
response = response_cls()
log_level = self.__command_log_levels.get(command, logging.INFO)
if log_level >= 0:
logger.log(
log_level,
"%s(%s%s)",
command,
', '.join(str(a) for a in args),
''.join(', %s=%r' % (k, v)
for k, v in sorted(kwargs.items())))
if asyncio.iscoroutinefunction(handler):
await handler(request, response)
else:
handler(request, response)
return b'OK:' + response.SerializeToString()
if asyncio.iscoroutinefunction(handler):
result = await handler(*args, **kwargs)
else:
result = handler(*args, **kwargs)
if result is not None:
return b'OK:' + self.serialize(result)
else:
return b'OK'
args, kwargs = self.deserialize(payload)
log_level = self.__command_log_levels.get(command, logging.INFO)
if log_level >= 0:
logger.log(
log_level,
"%s(%s%s)",
command,
', '.join(str(a) for a in args),
''.join(', %s=%r' % (k, v)
for k, v in sorted(kwargs.items())))
if asyncio.iscoroutinefunction(handler):
result = await handler(*args, **kwargs)
else:
result = handler(*args, **kwargs)
if result is not None:
return b'OK:' + self.serialize(result)
else:
return b'OK'
except Exception: # pylint: disable=broad-except
return b'EXC:' + str(traceback.format_exc()).encode('utf-8')
@ -474,6 +498,24 @@ class Stub(object):
else:
raise InvalidResponseError(response)
async def proto_call(
self, cmd: str, request: protobuf.Message, response: protobuf.Message
) -> None:
payload = request.SerializeToString()
response_container = ResponseContainer(self.__event_loop)
self.__command_queue.put_nowait((cmd.encode('ascii'), payload, response_container))
serialized_response = await response_container.wait()
if serialized_response is self.CLOSE_SENTINEL:
raise ConnectionClosed(self.id)
elif serialized_response.startswith(b'OK:'):
response.ParseFromString(serialized_response[3:])
elif serialized_response.startswith(b'EXC:'):
raise RemoteException(self.__server_address, serialized_response[4:].decode('utf-8'))
else:
raise InvalidResponseError(serialized_response)
def call_sync(self, cmd: str, payload: bytes = b'') -> Any:
return self.__event_loop.run_until_complete(self.call(cmd, payload))

View File

@ -27,6 +27,7 @@ package noisicaa.pb;
import "noisicaa/audioproc/public/musical_time.proto";
message TestRequest {
optional int32 num = 2;
repeated MusicalTime t = 1;
}

View File

@ -74,17 +74,32 @@ class IPCTest(unittest.AsyncTestCase):
async with ipc.Stub(self.loop, server.address) as stub:
self.assertEqual(await stub.call('foo', 3), 4)
async def test_proto_message(self):
async with ipc.Server(self.loop, name='test', socket_dir=TEST_OPTS.TMP_DIR) as server:
async def handler(request, response):
response.num = request.num + 1
server.add_command_handler(
'foo', handler, ipc_test_pb2.TestRequest, ipc_test_pb2.TestResponse)
async with ipc.Stub(self.loop, server.address) as stub:
request = ipc_test_pb2.TestRequest()
request.num = 3
response = ipc_test_pb2.TestResponse()
await stub.proto_call('foo', request, response)
self.assertEqual(response.num, 4)
class TestSubprocess(process_manager.SubprocessMixin, process_manager.ProcessBase):
async def run(self):
quit_event = asyncio.Event(loop=self.event_loop)
self.server.add_command_handler('foo', self.msg_handler)
self.server.add_command_handler(
'foo', self.msg_handler, ipc_test_pb2.TestRequest, ipc_test_pb2.TestResponse)
self.server.add_command_handler('quit', quit_event.set)
await quit_event.wait()
async def msg_handler(self, msg):
return ipc_test_pb2.TestResponse(num=2)
async def msg_handler(self, request, response):
response.num = 2
class IPCPerfTest(unittest.AsyncTestCase):
@ -126,7 +141,8 @@ class IPCPerfTest(unittest.AsyncTestCase):
wt0 = time.perf_counter()
ct0 = time.clock_gettime(time.CLOCK_PROCESS_CPUTIME_ID)
for _ in range(num_requests):
await self.stub.call('foo', request)
response = ipc_test_pb2.TestResponse()
await self.stub.proto_call('foo', request, response)
wt = time.perf_counter() - wt0
ct = time.clock_gettime(time.CLOCK_PROCESS_CPUTIME_ID) - ct0
passes.append((wt, ct))