Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
vici: add asynchronous session in python module
Add a new Python 3.6 package allowing vici to send asynchronous requests. This feature is only available in Python 3 so we add a condition in the setup.py in order to not break the Python 2 package.
- Loading branch information
1 parent
9189626
commit d5167d8
Showing
7 changed files
with
338 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .session import AsyncSession |
119 changes: 119 additions & 0 deletions
119
src/libcharon/plugins/vici/python/asyncvici/command_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
class AsyncCommandWrappers: | ||
async def version(self): | ||
return await self.request("version") | ||
|
||
async def stats(self): | ||
return await self.request("stats") | ||
|
||
async def reload_settings(self): | ||
await self.request("reload-settings") | ||
|
||
async def initiate(self, sa): | ||
async for x in self.streamed_request("initiate", "control-log", sa): | ||
yield x | ||
|
||
async def terminate(self, sa): | ||
async for x in self.streamed_request("terminate", "control-log", sa): | ||
yield x | ||
|
||
async def rekey(self, sa): | ||
return await self.request("rekey", sa) | ||
|
||
async def redirect(self, sa): | ||
return await self.request("redirect", sa) | ||
|
||
async def install(self, policy): | ||
await self.request("install", policy) | ||
|
||
async def uninstall(self, policy): | ||
await self.request("uninstall", policy) | ||
|
||
async def list_sas(self, filters=None): | ||
async for x in self.streamed_request("list-sas", "list-sa", filters): | ||
yield x | ||
|
||
async def list_policies(self, filters=None): | ||
async for x in self.streamed_request( | ||
"list-policies", "list-policy", filters): | ||
yield x | ||
|
||
async def list_conns(self, filters=None): | ||
async for x in self.streamed_request( | ||
"list-conns", "list-conn", filters): | ||
yield x | ||
|
||
async def get_conns(self): | ||
return await self.request("get-conns") | ||
|
||
async def list_certs(self, filters=None): | ||
async for x in self.streamed_request( | ||
"list-certs", "list-cert", filters): | ||
yield | ||
|
||
async def list_authorities(self, filters=None): | ||
async for x in self.streamed_request( | ||
"list-authorities", "list-authority", filters): | ||
yield | ||
|
||
async def get_authorities(self): | ||
return await self.request("get-authorities") | ||
|
||
async def load_conn(self, connection): | ||
await self.request("load-conn", connection) | ||
|
||
async def unload_conn(self, name): | ||
await self.request("unload-conn", name) | ||
|
||
async def load_cert(self, certificate): | ||
await self.request("load-cert", certificate) | ||
|
||
async def load_key(self, private_key): | ||
return await self.request("load-key", private_key) | ||
|
||
async def unload_key(self, key_id): | ||
await self.request("unload-key", key_id) | ||
|
||
async def get_keys(self): | ||
return await self.request("get-keys") | ||
|
||
async def load_token(self, token): | ||
return await self.request("load-token", token) | ||
|
||
async def load_shared(self, secret): | ||
await self.request("load-shared", secret) | ||
|
||
async def unload_shared(self, identifier): | ||
await self.request("unload-shared", identifier) | ||
|
||
async def get_shared(self): | ||
return await self.request("get-shared") | ||
|
||
async def flush_certs(self, filter=None): | ||
await self.request("flush-certs", filter) | ||
|
||
async def clear_creds(self): | ||
await self.request("clear-creds") | ||
|
||
async def load_authority(self, ca): | ||
await self.request("load-authority", ca) | ||
|
||
async def unload_authority(self, ca): | ||
await self.request("unload-authority", ca) | ||
|
||
async def load_pool(self, pool): | ||
return await self.request("load-pool", pool) | ||
|
||
async def unload_pool(self, pool_name): | ||
await self.request("unload-pool", pool_name) | ||
|
||
async def get_pools(self, options=None): | ||
return await self.request("get-pools", options) | ||
|
||
async def get_algorithms(self): | ||
return await self.request("get-algorithms") | ||
|
||
async def get_counters(self, options=None): | ||
return await self.request("get-counters", options) | ||
|
||
async def reset_counters(self, options=None): | ||
await self.request("reset-counters", options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import asyncio | ||
import struct | ||
import socket | ||
|
||
from vici.protocol import Transport | ||
|
||
|
||
class AsyncTransport(Transport): | ||
def __init__(self, sock, path): | ||
super().__init__(sock) | ||
self.path = path | ||
|
||
async def connect(self): | ||
await asyncio.get_event_loop().sock_connect( | ||
self.socket, self.path) | ||
|
||
async def send(self, packet): | ||
await asyncio.get_event_loop().sock_sendall( | ||
self.socket, | ||
struct.pack("!I", len(packet)) + packet | ||
) | ||
|
||
async def receive(self): | ||
raw_length = await self._recvall(self.HEADER_LENGTH) | ||
length, = struct.unpack("!I", raw_length) | ||
payload = await self._recvall(length) | ||
return payload | ||
|
||
async def _recvall(self, count): | ||
data = b"" | ||
while len(data) < count: | ||
buf = await asyncio.get_event_loop().sock_recv( | ||
self.socket, count - len(data)) | ||
if not buf: | ||
raise socket.error('Connection closed') | ||
data += buf | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import asyncio | ||
import socket | ||
|
||
from vici.exception import SessionException | ||
from vici.exception import CommandException | ||
from vici.exception import EventUnknownException | ||
from vici.protocol import Packet, Message | ||
|
||
from .command_wrappers import AsyncCommandWrappers | ||
from .protocol import AsyncTransport | ||
|
||
|
||
class AsyncSession(AsyncCommandWrappers): | ||
def __init__(self, sock=None, path="/var/run/charon.vici"): | ||
if sock is None: | ||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | ||
sock.setblocking(False) | ||
self.transport = AsyncTransport(sock, path) | ||
|
||
async def connect(self): | ||
await self.transport.connect() | ||
|
||
def close(self): | ||
self.transport.close() | ||
|
||
async def _communicate(self, packet): | ||
await self.transport.send(packet) | ||
return Packet.parse(await self.transport.receive()) | ||
|
||
async def _register_unregister(self, event_type, register): | ||
if register: | ||
packet = Packet.register_event(event_type) | ||
else: | ||
packet = Packet.unregister_event(event_type) | ||
response = await self._communicate(packet) | ||
if response.response_type == Packet.EVENT_UNKNOWN: | ||
raise EventUnknownException( | ||
"Unknown event type '{event}'".format(event=event_type) | ||
) | ||
elif response.response_type != Packet.EVENT_CONFIRM: | ||
raise SessionException( | ||
"Unexpected response type {type}, " | ||
"expected '{confirm}' (EVENT_CONFIRM)".format( | ||
type=response.response_type, | ||
confirm=Packet.EVENT_CONFIRM, | ||
) | ||
) | ||
|
||
async def request(self, command, message=None): | ||
if message is not None: | ||
message = Message.serialize(message) | ||
packet = Packet.request(command, message) | ||
response = await self._communicate(packet) | ||
|
||
if response.response_type != Packet.CMD_RESPONSE: | ||
raise SessionException( | ||
"Unexpected response type {type}, " | ||
"expected '{response}' (CMD_RESPONSE)".format( | ||
type=response.response_type, | ||
response=Packet.CMD_RESPONSE | ||
) | ||
) | ||
|
||
command_response = Message.deserialize(response.payload) | ||
if "success" in command_response: | ||
if command_response["success"] != b"yes": | ||
raise CommandException( | ||
"Command failed: {errmsg}".format( | ||
errmsg=command_response["errmsg"].decode("UTF-8") | ||
) | ||
) | ||
|
||
return command_response | ||
|
||
async def streamed_request(self, command, event_stream_type, message=None): | ||
if message is not None: | ||
message = Message.serialize(message) | ||
|
||
await self._register_unregister(event_stream_type, True) | ||
|
||
try: | ||
packet = Packet.request(command, message) | ||
await self.transport.send(packet) | ||
exited = False | ||
while True: | ||
response = Packet.parse(await self.transport.receive()) | ||
if response.response_type == Packet.EVENT: | ||
if not exited: | ||
try: | ||
yield Message.deserialize(response.payload) | ||
except GeneratorExit: | ||
exited = True | ||
else: | ||
break | ||
|
||
if response.response_type == Packet.CMD_RESPONSE: | ||
command_response = Message.deserialize(response.payload) | ||
else: | ||
raise SessionException( | ||
"Unexpected response type {type}, " | ||
"expected '{response}' (CMD_RESPONSE)".format( | ||
type=response.response_type, | ||
response=Packet.CMD_RESPONSE | ||
) | ||
) | ||
|
||
finally: | ||
await self._register_unregister(event_stream_type, False) | ||
|
||
# evaluate command result, if any | ||
if "success" in command_response: | ||
if command_response["success"] != b"yes": | ||
raise CommandException( | ||
"Command failed: {errmsg}".format( | ||
errmsg=command_response["errmsg"].decode("UTF-8") | ||
) | ||
) | ||
|
||
async def listen(self, event_types): | ||
for event_type in event_types: | ||
await self._register_unregister(event_type, True) | ||
|
||
try: | ||
while True: | ||
response = Packet.parse(await self.transport.receive()) | ||
if response.response_type == Packet.EVENT: | ||
try: | ||
msg = Message.deserialize(response.payload) | ||
yield response.event_type, msg | ||
except GeneratorExit: | ||
break | ||
|
||
finally: | ||
for event_type in event_types: | ||
await self._register_unregister(event_type, False) | ||
|
||
async def __aenter__(self): | ||
try: | ||
await self.connect() | ||
except Exception: | ||
self.close() | ||
raise | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
self.close() | ||
|
||
def __del__(self): | ||
self.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters