fix: Rework SubscriptionManager

- Give functions better names
- Change how these functions behave
- Add tests (!) for the SubscriptionManager
- Format using black
This commit is contained in:
PapaTutuWawa 2021-06-15 12:41:48 +02:00
parent fad4541132
commit 34d001b5bc
9 changed files with 386 additions and 183 deletions

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,4 +13,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,7 +13,7 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """
import sys import sys
import importlib import importlib
import asyncio import asyncio
@ -26,58 +26,61 @@ import toml
from mira.subscription import SubscriptionManager from mira.subscription import SubscriptionManager
from mira.storage import StorageManager from mira.storage import StorageManager
logger = logging.getLogger('mira.base') logger = logging.getLogger("mira.base")
def message_wrapper(to, body): def message_wrapper(to, body):
msg = aioxmpp.Message( msg = aioxmpp.Message(type_=aioxmpp.MessageType.CHAT, to=to)
type_=aioxmpp.MessageType.CHAT,
to=to)
msg.body[None] = body msg.body[None] = body
return msg return msg
class MiraBot: class MiraBot:
def __init__(self, config_path): def __init__(self, config_path):
# Bot specific settings # Bot specific settings
self._config = toml.load(config_path) self._config = toml.load(config_path)
self._jid = aioxmpp.JID.fromstr(self._config['jid']) self._jid = aioxmpp.JID.fromstr(self._config["jid"])
self._password = self._config['password'] self._password = self._config["password"]
self._avatar = self._config.get('avatar', None) self._avatar = self._config.get("avatar", None)
self._client = None self._client = None
self._modules = {} # Module name -> module self._modules = {} # Module name -> module
self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json')) self._storage_manager = StorageManager.get_instance(
self._config.get("storage_path", "/etc/mira/storage.json")
)
self._subscription_manager = SubscriptionManager.get_instance() self._subscription_manager = SubscriptionManager.get_instance()
def _initialise_modules(self): def _initialise_modules(self):
for module in self._config['modules']: for module in self._config["modules"]:
logger.debug("Initialising module %s" % (module['name'])) logger.debug("Initialising module %s" % (module["name"]))
mod = importlib.import_module(module['name']) mod = importlib.import_module(module["name"])
self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME) self._modules[mod.NAME] = mod.get_instance(
self, config=module, name=mod.NAME
)
async def connect(self): async def connect(self):
self._client = aioxmpp.PresenceManagedClient( self._client = aioxmpp.PresenceManagedClient(
self._jid, self._jid, aioxmpp.make_security_layer(self._password)
aioxmpp.make_security_layer(self._password)) )
async with self._client.connected(): async with self._client.connected():
logger.info('Client connected') logger.info("Client connected")
self._client.stream.register_message_callback( self._client.stream.register_message_callback(
aioxmpp.MessageType.CHAT, aioxmpp.MessageType.CHAT, None, self._on_message
None, )
self._on_message)
if self._avatar: if self._avatar:
logger.info('Publishing avatar') logger.info("Publishing avatar")
with open(self._avatar, 'rb') as avatar_file: with open(self._avatar, "rb") as avatar_file:
data = avatar_file.read() data = avatar_file.read()
avatar_set = aioxmpp.avatar.AvatarSet() avatar_set = aioxmpp.avatar.AvatarSet()
# TODO: Detect MIME type # TODO: Detect MIME type
avatar_set.add_avatar_image('image/png', image_bytes=data) avatar_set.add_avatar_image("image/png", image_bytes=data)
avatar = self._client.summon(aioxmpp.avatar.AvatarService) avatar = self._client.summon(aioxmpp.avatar.AvatarService)
await avatar.publish_avatar_set(avatar_set) await avatar.publish_avatar_set(avatar_set)
logger.info('Avatar published') logger.info("Avatar published")
logger.debug('Initialising modules') logger.debug("Initialising modules")
self._initialise_modules() self._initialise_modules()
while True: while True:
@ -89,32 +92,38 @@ class MiraBot:
def _on_message(self, message): def _on_message(self, message):
# Automatically handles sending a message receipt and dealing # Automatically handles sending a message receipt and dealing
# with unwanted messages # with unwanted messages
if (message.type_ != aioxmpp.MessageType.CHAT or if message.type_ != aioxmpp.MessageType.CHAT or not message.body:
not message.body):
return return
cmd = str(message.body.any()).split(' ') cmd = str(message.body.any()).split(" ")
receipt = aioxmpp.mdr.compose_receipt(message) receipt = aioxmpp.mdr.compose_receipt(message)
self._client.enqueue(receipt) self._client.enqueue(receipt)
if not cmd[0] in self._modules: if not cmd[0] in self._modules:
logger.debug('Received command for unknown module. Dropping') logger.debug("Received command for unknown module. Dropping")
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl")) self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
return return
# Just drop messages that are not local when the module should # Just drop messages that are not local when the module should
# be local only # be local only
if self._modules[cmd[0]]._restricted: if self._modules[cmd[0]]._restricted:
if self._modules[cmd[0]]._config['restrict_local']: if self._modules[cmd[0]]._config["restrict_local"]:
if not self._is_sender_local(message.from_): if not self._is_sender_local(message.from_):
logger.warning('Received a command from a non-local user to a' logger.warning(
' module that is restricted to local users only') "Received a command from a non-local user to a"
" module that is restricted to local users only"
)
return return
elif self._modules[cmd[0]]._config['allowed_domains']: elif self._modules[cmd[0]]._config["allowed_domains"]:
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']: if (
logger.warning('Received a command from a non-whitelisted user to a' not message.from_.domain
' module that is restricted to whitelisted users only') in self._modules[cmd[0]]._config["allowed_domains"]
):
logger.warning(
"Received a command from a non-whitelisted user to a"
" module that is restricted to whitelisted users only"
)
return return
self._modules[cmd[0]]._base_on_command(cmd[1:], message) self._modules[cmd[0]]._base_on_command(cmd[1:], message)
@ -127,18 +136,25 @@ class MiraBot:
def send_message_wrapper(self, to, body): def send_message_wrapper(self, to, body):
self.send_message(message_wrapper(to, body)) self.send_message(message_wrapper(to, body))
def main(): def main():
parser = OptionParser() parser = OptionParser()
parser.add_option('-d', '--debug', dest='debug', parser.add_option(
help='Enable debug logging', action='store_true') "-d", "--debug", dest="debug", help="Enable debug logging", action="store_true"
parser.add_option('-c', '--config', dest='config', help='Location of the config.toml', )
default='/etc/mira/config.toml') parser.add_option(
"-c",
"--config",
dest="config",
help="Location of the config.toml",
default="/etc/mira/config.toml",
)
(options, args) = parser.parse_args() (options, args) = parser.parse_args()
verbosity = logging.DEBUG if options.debug else logging.INFO verbosity = logging.DEBUG if options.debug else logging.INFO
logging.basicConfig(stream=sys.stdout, level=verbosity) logging.basicConfig(stream=sys.stdout, level=verbosity)
logging.info('Loading config from %s' % (options.config)) logging.info("Loading config from %s" % (options.config))
bot = MiraBot(options.config) bot = MiraBot(options.config)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,7 +13,7 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """
import asyncio import asyncio
from functools import partial from functools import partial
import logging import logging
@ -21,15 +21,17 @@ import logging
from mira.storage import StorageManager from mira.storage import StorageManager
from mira.subscription import SubscriptionManager from mira.subscription import SubscriptionManager
logger = logging.getLogger('mira.module') logger = logging.getLogger("mira.module")
class ManagerWrapper: class ManagerWrapper:
''' """
Wrapper class around {Storage, Subscription}Manager in order Wrapper class around {Storage, Subscription}Manager in order
to to expose those directly to the modules without allowing them to to expose those directly to the modules without allowing them
access to other modules. access to other modules.
''' """
_name = ''
_name = ""
_manager = None _manager = None
def __init__(self, name, manager): def __init__(self, name, manager):
@ -38,40 +40,44 @@ class ManagerWrapper:
def __getattr__(self, key): def __getattr__(self, key):
if not key in dir(self._manager): if not key in dir(self._manager):
raise AttributeError("Attribute %s does not exist in wrapped" raise AttributeError(
" class %s" % (key, type(self._manager))) "Attribute %s does not exist in wrapped"
" class %s" % (key, type(self._manager))
)
return partial(getattr(self._manager, key), self._name) return partial(getattr(self._manager, key), self._name)
class BaseModule: class BaseModule:
def __init__(self, base, config={}, subcommand_table={}, name=''): def __init__(self, base, config={}, subcommand_table={}, name=""):
self._name = name self._name = name
self._base = base self._base = base
self._config = config self._config = config
self._subcommand_table = subcommand_table self._subcommand_table = subcommand_table
self._local_only = self.get_option('restrict_local', False) self._local_only = self.get_option("restrict_local", False)
self._restricted = ('restrict_local' in self._config or self._restricted = (
'allowed_domains' in self._config) "restrict_local" in self._config or "allowed_domains" in self._config
)
self._stm = ManagerWrapper(self._name, StorageManager) self._stm = ManagerWrapper(self._name, StorageManager)
self._sum = ManagerWrapper(self._name, SubscriptionManager) self._sum = ManagerWrapper(self._name, SubscriptionManager)
logger.debug('Init of %s done' % (self._name)) logger.debug("Init of %s done" % (self._name))
def get_option(self, key, default=None): def get_option(self, key, default=None):
''' """
Like dict.get(), but for the options from the bot configuration Like dict.get(), but for the options from the bot configuration
file. file.
If key does not exist, then default will be returned. If key does not exist, then default will be returned.
''' """
return self._config.get(key, default) return self._config.get(key, default)
def send_message(self, to, body): def send_message(self, to, body):
''' """
A simple wrapper that sends a message with type='chat' to A simple wrapper that sends a message with type='chat' to
@to with @body as the body @to with @body as the body
''' """
self._base.send_message_wrapper(to, body) self._base.send_message_wrapper(to, body)
def _base_on_command(self, cmd, msg): def _base_on_command(self, cmd, msg):
@ -79,12 +85,12 @@ class BaseModule:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(func(cmd, msg)) loop.create_task(func(cmd, msg))
logger.debug('Received command: %s' % (str(cmd))) logger.debug("Received command: %s" % (str(cmd)))
if not self._subcommand_table: if not self._subcommand_table:
run(self.on_command) run(self.on_command)
elif cmd and cmd[0] in self._subcommand_table: elif cmd and cmd[0] in self._subcommand_table:
run(self._subcommand_table[cmd[0]]) run(self._subcommand_table[cmd[0]])
else: else:
if '*' in self._subcommand_table: if "*" in self._subcommand_table:
run(self._subcommand_table['*']) run(self._subcommand_table["*"])

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,4 +13,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,10 +13,11 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """
from mira.module import BaseModule from mira.module import BaseModule
NAME = 'test' NAME = "test"
class TestModule(BaseModule): class TestModule(BaseModule):
__instance = None __instance = None
@ -30,14 +31,17 @@ class TestModule(BaseModule):
def __init__(self, base, **kwargs): def __init__(self, base, **kwargs):
if TestModule.__instance != None: if TestModule.__instance != None:
raise Exception('Trying to init singleton twice') raise Exception("Trying to init singleton twice")
super().__init__(base, **kwargs) super().__init__(base, **kwargs)
TestModule.__instance = self TestModule.__instance = self
async def on_command(self, cmd, msg): async def on_command(self, cmd, msg):
greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare())) greeting = self.get_option("greeting", "OwO, %%user%%!").replace(
"%%user%%", str(msg.from_.bare())
)
self.send_message(msg.from_, greeting) self.send_message(msg.from_, greeting)
def get_instance(base, **kwargs): def get_instance(base, **kwargs):
return TestModule.get_instance(base, **kwargs) return TestModule.get_instance(base, **kwargs)

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,63 +13,64 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """
import os import os
import json import json
import logging import logging
logger = logging.getLogger('mira.storage.StorageManager') logger = logging.getLogger("mira.storage.StorageManager")
class StorageManager: class StorageManager:
__instance = None __instance = None
@staticmethod @staticmethod
def get_instance(file_location='/etc/mira/storage.json'): def get_instance(file_location="/etc/mira/storage.json"):
if not StorageManager.__instance: if not StorageManager.__instance:
StorageManager(file_location=file_location) StorageManager(file_location=file_location)
return StorageManager.__instance return StorageManager.__instance
def __init__(self, file_location): def __init__(self, file_location):
if StorageManager.__instance: if StorageManager.__instance:
raise Exception('Trying to instanciate StorageManger twice') raise Exception("Trying to instanciate StorageManger twice")
self._data = {} # Module -> Section -> Data self._data = {} # Module -> Section -> Data
self._file_location = file_location self._file_location = file_location
logger.debug('Loading data from %s' % (file_location)) logger.debug("Loading data from %s" % (file_location))
if os.path.exists(file_location): if os.path.exists(file_location):
with open(file_location, 'r') as f: with open(file_location, "r") as f:
self._data = json.loads(f.read()) self._data = json.loads(f.read())
StorageManager.__instance = self StorageManager.__instance = self
def get_data(self, module, section): def get_data(self, module, section):
''' """
Get the data stored for module @module under the section Get the data stored for module @module under the section
@section. Returns {} if there is no data stored for @module. @section. Returns {} if there is no data stored for @module.
''' """
if not module in self._data: if not module in self._data:
logging.debug('get_data: module unknown in self._data') logging.debug("get_data: module unknown in self._data")
logging.debug('module: "%s"' % (module)) logging.debug('module: "%s"' % (module))
return {} return {}
if not section in self._data[module]: if not section in self._data[module]:
logging.debug('get_data: section unknown in self._data[module]') logging.debug("get_data: section unknown in self._data[module]")
logging.debug('module: "%s", section: "%s"' % (module, section)) logging.debug('module: "%s", section: "%s"' % (module, section))
return {} return {}
return self._data[module][section] return self._data[module][section]
def set_data(self, module, section, data): def set_data(self, module, section, data):
''' """
Stores the data @data for @module under section @section. Stores the data @data for @module under section @section.
Flushes the data to storage afterwards. Flushes the data to storage afterwards.
''' """
if not module in self._data: if not module in self._data:
self._data[module] = {} self._data[module] = {}
self._data[module][section] = data self._data[module][section] = data
self.__flush() self.__flush()
def __flush(self): def __flush(self):
logger.debug('Flushing to storage') logger.debug("Flushing to storage")
with open(self._file_location, 'w') as f: with open(self._file_location, "w") as f:
f.write(json.dumps(self._data)) f.write(json.dumps(self._data))

View File

@ -1,4 +1,4 @@
''' """
Copyright (C) 2021 Alexander "PapaTutuWawa" Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -13,21 +13,23 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' """
# TODO: Replace most of these with a query API # TODO: Replace most of these with a query API
import os import os
import json import json
from mira.storage import StorageManager from mira.storage import StorageManager
def append_or_insert(dict_, key, value): def append_or_insert(dict_, key, value):
if key in dict_: if key in dict_:
dict_[key].append(value) dict_[key].append(value)
else: else:
dict_[key] = [value] dict_[key] = [value]
class SubscriptionManager: class SubscriptionManager:
''' """
This class is tasked with providing functions that simplify dealing This class is tasked with providing functions that simplify dealing
with subscriptions. with subscriptions.
@ -35,7 +37,8 @@ class SubscriptionManager:
has been instanciated at least once. For modules, this is no has been instanciated at least once. For modules, this is no
issue as they're only created after the manager classes are issue as they're only created after the manager classes are
ready. ready.
''' """
__instance = None __instance = None
@staticmethod @staticmethod
@ -44,18 +47,25 @@ class SubscriptionManager:
SubscriptionManager() SubscriptionManager()
return SubscriptionManager.__instance return SubscriptionManager.__instance
def __init__(self): def __init__(self, subscriptions={}, sm=None):
self._sm = StorageManager.get_instance()
# Module -> JID -> Keywords # Module -> JID -> Keywords
self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions') if subscriptions:
# NOTE: This is just for testing
self._subscriptions = subscriptions
self._sm = sm
else:
self._sm = StorageManager.get_instance()
self._subscriptions = self._sm.get_data(
"_SubscriptionManager", "subscriptions"
)
SubscriptionManager.__instance = self SubscriptionManager.__instance = self
def get_subscriptions_for(self, module, jid): def get_subscriptions_for_jid(self, module, jid):
''' """
Returns a dictionary keyword -> data which represents Returns a dictionary keyword -> data which represents
every subscription a jid has in the context of @module. every subscription a jid has in the context of @module.
''' """
if not module in self._subscriptions: if not module in self._subscriptions:
return [] return []
if not jid in self._subscriptions[module]: if not jid in self._subscriptions[module]:
@ -64,41 +74,47 @@ class SubscriptionManager:
return self._subscriptions[module][jid] return self._subscriptions[module][jid]
def get_subscriptions_for_keyword(self, module, keyword): def get_subscriptions_for_keyword(self, module, keyword):
''' """
Returns an array of JIDs that are subscribed to the keyword of module Returns a dictionary JID -> Data for JIDs that hava a subscription to
''' the keyword @keyword in the context of module.
"""
if not module in self._subscriptions: if not module in self._subscriptions:
return [] return {}
tmp = [] tmp = {}
for jid in self._subscriptions[module]: for jid in self._subscriptions[module]:
if not keyword in self._subscriptions[module][jid]: if not keyword in self._subscriptions[module][jid]:
continue continue
data = self._subscriptions[module][jid][keyword]['data'] data = self._subscriptions[module][jid][keyword]["data"]
tmp.append((jid, data)) tmp[jid] = data
return tmp return tmp
def get_subscriptions_for_keywords(self, module, keywords): def get_subscriptions_for_keywords(self, module, keywords):
''' """
Returns an array of JIDs that are subscribed to at least one of the keywords Returns a dictionary of form JID -> keyword -> data of JIDs that are
of module subscribed to at least one of the keywords in @keywords within the context
''' of @module.
"""
if not module in self._subscriptions: if not module in self._subscriptions:
return [] return {}
tmp = [] tmp = {}
keyword_set = set(keywords) keyword_set = set(keywords)
for jid in self._subscriptions[module]: for jid in self._subscriptions[module]:
if set(self._subscriptions[module][jid].keys()) & keyword_set: union = set(self._subscriptions[module][jid].keys()) & keyword_set
if not union:
continue continue
data = self._subscriptions[module][jid][keyword]['data'] if not jid in tmp:
tmp.append((jid, data)) tmp[jid] = {}
for keyword in union:
tmp[jid][keyword] = self._subscriptions[module][jid][keyword]["data"]
return tmp return tmp
def get_subscription_keywords(self, module): def get_subscription_keywords(self, module):
'''Returns a list of subscribed keywords in module''' """Returns a list of subscribed keywords in module"""
if not module in self._subscriptions: if not module in self._subscriptions:
return [] return []
@ -107,59 +123,57 @@ class SubscriptionManager:
tmp += list(subscription.keys()) tmp += list(subscription.keys())
return tmp return tmp
def is_subscribed_to(self, module, jid, keyword): def is_subscribed_to_keyword(self, module, jid, keyword):
''' """
Returns True if @jid is subscribed to @keyword within the context Returns True if @jid is subscribed to @keyword within the context
of @module. False otherwise of @module. False otherwise
''' """
return keyword in self.get_subscriptions_for(module, jid) return keyword in self.get_subscriptions_for_jid(module, jid)
def is_subscribed_to_data(self, module, jid, keyword, item): def is_subscribed_to_data(self, module, jid, keyword, item):
''' """
Returns True if @jid is subscribed to the item @item inside Returns True if @jid is subscribed to the item @item inside
the keyword @keyword within the context of @module the keyword @keyword within the context of @module
''' """
subscriptions = self.get_subscriptions_for(module, jid) subscriptions = self.get_subscriptions_for_jid(module, jid)
if not subscriptions: if not subscriptions:
return False return False
if not keyword in subscriptions: if not keyword in subscriptions:
return False return False
return item in subscriptions[keyword]['data'] return item in subscriptions[keyword]["data"]
def is_subscribed_to_data_one(self, module, jid, keyword, func): def is_subscribed_to_data_func(self, module, jid, keyword, func):
''' """
Like is_subscribed_to_data, but returns True if there is at Like is_subscribed_to_data, but returns True if there is at
least one item for which func returns True. least one item for which func returns True.
''' """
subscriptions = self.get_subscriptions_for(module, jid) subscriptions = self.get_subscriptions_for_jid(module, jid)
if not subscriptions: if not subscriptions:
return False return False
if not keyword in subscriptions: if not keyword in subscriptions:
return False return False
for item in subscriptions[keyword]['data']: for item in subscriptions[keyword]["data"]:
if func(item): if func(item):
return True return True
return False return False
def add_subscription_for(self, module, jid, keyword, data={}): def add_subscription(self, module, jid, keyword, data={}):
''' """
Adds a subscription to @keyword with data @data for @jid within Adds a subscription to @keyword with data @data for @jid within
the context of @module. the context of @module.
''' """
if not module in self._subscriptions: if not module in self._subscriptions:
self._subscriptions[module] = {} self._subscriptions[module] = {}
if not jid in self._subscriptions[module]: if not jid in self._subscriptions[module]:
self._subscriptions[module][jid] = {} self._subscriptions[module][jid] = {}
self._subscriptions[module][jid][keyword] = { self._subscriptions[module][jid][keyword] = {"data": data}
'data': data
}
self.__flush() self.__flush()
def append_data_for_subscription(self, module, jid, keyword, item): def append_subscription_data(self, module, jid, keyword, item):
''' """
Special helper function which appends item to the data field of Special helper function which appends item to the data field of
a subscription to @keyword from @jid within the context of @module. a subscription to @keyword from @jid within the context of @module.
@ -169,26 +183,24 @@ class SubscriptionManager:
must be an array, so it will fail if add_subscription_for has must be an array, so it will fail if add_subscription_for has
been called beforehand with data equal to anything but an dict been called beforehand with data equal to anything but an dict
with a 'data' key containing an array. with a 'data' key containing an array.
''' """
if not module in self._subscriptions: if not module in self._subscriptions:
self._subscriptions[module] = {} self._subscriptions[module] = {}
if not jid in self._subscriptions[module]: if not jid in self._subscriptions[module]:
self._subscriptions[module][jid] = {} self._subscriptions[module][jid] = {}
if not keyword in self._subscriptions[module][jid]: if not keyword in self._subscriptions[module][jid]:
self._subscriptions[module][jid][keyword] = { self._subscriptions[module][jid][keyword] = {"data": [item]}
'data': [item]
}
self.__flush() self.__flush()
return return
self._subscriptions[module][jid][keyword]['data'].append(item) self._subscriptions[module][jid][keyword]["data"].append(item)
self.__flush() self.__flush()
def remove_subscription_for(self, module, jid, keyword): def remove_subscription(self, module, jid, keyword):
''' """
Removes a subscription to @keyword for @jid within the context Removes a subscription to @keyword for @jid within the context
of @module of @module
''' """
del self._subscriptions[module][jid][keyword] del self._subscriptions[module][jid][keyword]
if not self._subscriptions[module][jid]: if not self._subscriptions[module][jid]:
@ -198,20 +210,18 @@ class SubscriptionManager:
self.__flush() self.__flush()
def remove_item_for_subscription(self, module, jid, keyword, item, flush=True): def remove_subscription_data_item(self, module, jid, keyword, item, flush=True):
''' """
The deletion counterpart of append_data_for_subscription. The deletion counterpart of append_data_for_subscription.
''' """
self.filter_items_for_subscription(module, self.filter_subscription_data_items(
jid, module, jid, keyword, func=lambda x: x != item, flush=flush
keyword, )
func=lambda x: x == item,
flush=flush)
def filter_items_for_subscription(self, module, jid, keyword, func, flush=True): def filter_subscription_data_items(self, module, jid, keyword, func, flush=True):
''' """
remove_item_for_subscription but for multiple items remove_item_for_subscription but for multiple items
''' """
if not module in self._subscriptions: if not module in self._subscriptions:
return return
if not jid in self._subscriptions[module]: if not jid in self._subscriptions[module]:
@ -219,10 +229,11 @@ class SubscriptionManager:
if not keyword in self._subscriptions[module][jid]: if not keyword in self._subscriptions[module][jid]:
return return
self._subscriptions[module][jid][keyword]['data'] = list(filter(func, self._subscriptions[module][jid][keyword]["data"] = list(
self._subscriptions[module][jid][keyword]['data'])) filter(func, self._subscriptions[module][jid][keyword]["data"])
)
if not self._subscriptions[module][jid][keyword]['data']: if not self._subscriptions[module][jid][keyword]["data"]:
del self._subscriptions[module][jid][keyword] del self._subscriptions[module][jid][keyword]
if not self._subscriptions[module][jid]: if not self._subscriptions[module][jid]:
del self._subscriptions[module][jid] del self._subscriptions[module][jid]
@ -233,7 +244,7 @@ class SubscriptionManager:
self.__flush() self.__flush()
def __flush(self): def __flush(self):
''' """
Write subscription data to disk. Just an interface to StorageManager Write subscription data to disk. Just an interface to StorageManager
''' """
self._sm.set_data('_SubscriptionManager', 'subscriptions', self._subscriptions) self._sm.set_data("_SubscriptionManager", "subscriptions", self._subscriptions)

View File

@ -1,22 +1,31 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup( setup(
name = 'mira', name = "mira",
version = '0.2.0', version = "0.3.0",
description = 'A command-base XMPP bot framework', description = "A command-base XMPP bot framework",
url = 'https://git.polynom.me/PapaTutuWawa/mira', url = "https://git.polynom.me/PapaTutuWawa/mira",
author = 'Alexander "PapaTutuWawa"', author = "Alexander \"PapaTutuWawa\"",
author_email = 'papatutuwawa <at> polynom.me', author_email = "papatutuwawa <at> polynom.me",
license = 'GPLv3', license = "GPLv3",
packages = find_packages(), packages = find_packages(),
install_requires = [ install_requires = [
'aioxmpp>=0.12.0', "aioxmpp>=0.12.0",
'toml>=0.10.2' "toml>=0.10.2"
],
extra_require = {
"dev": [
"pytest",
"black"
]
},
tests_require = [
"pytest"
], ],
zip_safe=True, zip_safe=True,
entry_points={ entry_points={
'console_scripts': [ "console_scripts": [
'mira = mira.base:main' "mira = mira.base:main"
] ]
} }
) )

156
tests/test_subscription.py Normal file
View File

@ -0,0 +1,156 @@
from mira.subscription import SubscriptionManager
class MockStorageManager:
'''
The SubscriptionManager requieres the StorageManager, but we don't
need it for the tests. So just stub it.
'''
def set_data(self, module, section, data):
pass
def get_sum():
return SubscriptionManager({
'test': {
'a@localhost': {
'thing1': {
'data': 42
},
'thing2': {
'data': 100
}
},
'b@localhost': {
'thing2': {
'data': 89
},
'thing3': {
'data': [1, 2, 4]
}
},
'd@localhost': {
'thing1': {
'data': {}
}
}
}
}, MockStorageManager())
def test_get_subscriptions_for_jid():
sum = get_sum()
assert sum.get_subscriptions_for_jid('prod', 'a@localhost') == []
assert sum.get_subscriptions_for_jid('test', 'z@localhost') == []
subs = sum.get_subscriptions_for_jid('test', 'a@localhost')
assert len(subs.keys()) == 2
assert 'thing1' in subs and subs['thing1']['data'] == 42
assert 'thing2' in subs and subs['thing2']['data'] == 100
def test_get_subscriptions_for_keyword():
sum = get_sum()
assert sum.get_subscriptions_for_keyword('prod', 'thing1') == {}
assert sum.get_subscriptions_for_keyword('test', 'thing4') == {}
subs = sum.get_subscriptions_for_keyword('test', 'thing2')
assert 'a@localhost' in subs and subs['a@localhost'] == 100
assert 'b@localhost' in subs and subs['b@localhost'] == 89
def test_get_subscriptions_for_keywords():
sum = get_sum()
assert sum.get_subscriptions_for_keywords('prod', 'thing1') == {}
assert sum.get_subscriptions_for_keywords('test', 'thing4') == {}
subs = sum.get_subscriptions_for_keywords('test', ['thing2', 'thing3'])
assert 'a@localhost' in subs and 'thing2'in subs['a@localhost'] and subs['a@localhost']['thing2'] == 100
assert not 'thing3' in subs['a@localhost']
assert 'b@localhost' in subs and 'thing2' in subs['b@localhost'] and subs['b@localhost']['thing2'] == 89
assert 'b@localhost' in subs and 'thing3' in subs['b@localhost'] and subs['b@localhost']['thing3'] == [1, 2, 4]
def test_get_subscription_keywords():
sum = get_sum()
assert sum.get_subscription_keywords('prod') == []
assert not set(sum.get_subscription_keywords('test')) - set(['thing1', 'thing2', 'thing3'])
def test_is_subscribed_to_keyword():
sum = get_sum()
assert not sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing1')
assert not sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing1')
def test_is_subscribed_to_data():
sum = get_sum()
assert not sum.is_subscribed_to_data('prod', 'b@localhost', 'thing1', 1)
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing4', 1)
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 10)
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
def test_is_subscribed_to_data_func():
sum = get_sum()
func1 = lambda x: x % 2 == 0
func2 = lambda x: x == 10
assert not sum.is_subscribed_to_data_func('prod', 'b@localhost', 'thing1', func1)
assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing4', func1)
assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func2)
assert sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func1)
def test_add_subscription():
sum = get_sum()
sum.add_subscription('test', 'c@localhost', 'thing1')
assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
sum.add_subscription('test', 'a@localhost', 'thing4')
assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
sum.add_subscription('prod', 'a@localhost', 'thing4')
assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing4')
sum.add_subscription('prod', 'a@localhost', 'thing5', 60)
subs = sum.get_subscriptions_for_jid('prod', 'a@localhost')
assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing5')
assert subs and 'thing5' in subs and subs['thing5']['data'] == 60
def test_append_subscription_data():
sum = get_sum()
sum.add_subscription('test', 'c@localhost', 'thing1', [])
sum.append_subscription_data('test', 'c@localhost', 'thing1', 1)
subs = sum.get_subscriptions_for_jid('test', 'c@localhost')
assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
assert subs and 'thing1' in subs and subs['thing1']['data'] == [1]
sum.append_subscription_data('test', 'c@localhost', 'thing1', 5)
assert subs['thing1']['data'] == [1, 5]
def test_remove_subscription():
sum = get_sum()
sum.remove_subscription('test', 'd@localhost', 'thing1')
assert not sum.is_subscribed_to_keyword('test', 'd@localhost', 'thing1')
assert not sum.get_subscriptions_for_jid('test', 'd@localhost')
def test_filter_subscription_data_items():
sum = get_sum()
func = lambda x: not x % 2 == 0
sum.filter_subscription_data_items('test', 'b@localhost', 'thing3', func)
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)
def test_remove_subscription_data_item():
sum = get_sum()
sum.remove_subscription_data_item('test', 'b@localhost', 'thing3', 4)
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)