diff --git a/mira/__init__.py b/mira/__init__.py index 16b99cb..80676f3 100644 --- a/mira/__init__.py +++ b/mira/__init__.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,4 +13,4 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" diff --git a/mira/base.py b/mira/base.py index 2c3c32f..477d939 100644 --- a/mira/base.py +++ b/mira/base.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,7 +13,7 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" import sys import importlib import asyncio @@ -26,95 +26,104 @@ import toml from mira.subscription import SubscriptionManager from mira.storage import StorageManager -logger = logging.getLogger('mira.base') +logger = logging.getLogger("mira.base") + def message_wrapper(to, body): - msg = aioxmpp.Message( - type_=aioxmpp.MessageType.CHAT, - to=to) + msg = aioxmpp.Message(type_=aioxmpp.MessageType.CHAT, to=to) msg.body[None] = body return msg + class MiraBot: def __init__(self, config_path): # Bot specific settings self._config = toml.load(config_path) - self._jid = aioxmpp.JID.fromstr(self._config['jid']) - self._password = self._config['password'] - self._avatar = self._config.get('avatar', None) + self._jid = aioxmpp.JID.fromstr(self._config["jid"]) + self._password = self._config["password"] + self._avatar = self._config.get("avatar", None) self._client = None - self._modules = {} # Module name -> module - self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json')) + self._modules = {} # Module name -> module + self._storage_manager = StorageManager.get_instance( + self._config.get("storage_path", "/etc/mira/storage.json") + ) self._subscription_manager = SubscriptionManager.get_instance() def _initialise_modules(self): - for module in self._config['modules']: - logger.debug("Initialising module %s" % (module['name'])) - mod = importlib.import_module(module['name']) - self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME) + for module in self._config["modules"]: + logger.debug("Initialising module %s" % (module["name"])) + mod = importlib.import_module(module["name"]) + self._modules[mod.NAME] = mod.get_instance( + self, config=module, name=mod.NAME + ) async def connect(self): self._client = aioxmpp.PresenceManagedClient( - self._jid, - aioxmpp.make_security_layer(self._password)) + self._jid, aioxmpp.make_security_layer(self._password) + ) async with self._client.connected(): - logger.info('Client connected') + logger.info("Client connected") self._client.stream.register_message_callback( - aioxmpp.MessageType.CHAT, - None, - self._on_message) + aioxmpp.MessageType.CHAT, None, self._on_message + ) if self._avatar: - logger.info('Publishing avatar') - with open(self._avatar, 'rb') as avatar_file: + logger.info("Publishing avatar") + with open(self._avatar, "rb") as avatar_file: data = avatar_file.read() avatar_set = aioxmpp.avatar.AvatarSet() # TODO: Detect MIME type - avatar_set.add_avatar_image('image/png', image_bytes=data) + avatar_set.add_avatar_image("image/png", image_bytes=data) avatar = self._client.summon(aioxmpp.avatar.AvatarService) await avatar.publish_avatar_set(avatar_set) - logger.info('Avatar published') + logger.info("Avatar published") - logger.debug('Initialising modules') + logger.debug("Initialising modules") self._initialise_modules() - + while True: await asyncio.sleep(1) def _is_sender_local(self, from_): return from_.domain == self._jid.domain - + def _on_message(self, message): # Automatically handles sending a message receipt and dealing # with unwanted messages - if (message.type_ != aioxmpp.MessageType.CHAT or - not message.body): + if message.type_ != aioxmpp.MessageType.CHAT or not message.body: return - cmd = str(message.body.any()).split(' ') + cmd = str(message.body.any()).split(" ") receipt = aioxmpp.mdr.compose_receipt(message) self._client.enqueue(receipt) if not cmd[0] in self._modules: - logger.debug('Received command for unknown module. Dropping') + logger.debug("Received command for unknown module. Dropping") self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl")) return # Just drop messages that are not local when the module should # be local only if self._modules[cmd[0]]._restricted: - if self._modules[cmd[0]]._config['restrict_local']: + if self._modules[cmd[0]]._config["restrict_local"]: if not self._is_sender_local(message.from_): - logger.warning('Received a command from a non-local user to a' - ' module that is restricted to local users only') + logger.warning( + "Received a command from a non-local user to a" + " module that is restricted to local users only" + ) return - elif self._modules[cmd[0]]._config['allowed_domains']: - if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']: - logger.warning('Received a command from a non-whitelisted user to a' - ' module that is restricted to whitelisted users only') + elif self._modules[cmd[0]]._config["allowed_domains"]: + if ( + not message.from_.domain + in self._modules[cmd[0]]._config["allowed_domains"] + ): + logger.warning( + "Received a command from a non-whitelisted user to a" + " module that is restricted to whitelisted users only" + ) return self._modules[cmd[0]]._base_on_command(cmd[1:], message) @@ -126,19 +135,26 @@ class MiraBot: # Module Function: Send a message to @to with @body def send_message_wrapper(self, to, body): self.send_message(message_wrapper(to, body)) - + + def main(): parser = OptionParser() - parser.add_option('-d', '--debug', dest='debug', - help='Enable debug logging', action='store_true') - parser.add_option('-c', '--config', dest='config', help='Location of the config.toml', - default='/etc/mira/config.toml') + parser.add_option( + "-d", "--debug", dest="debug", help="Enable debug logging", action="store_true" + ) + parser.add_option( + "-c", + "--config", + dest="config", + help="Location of the config.toml", + default="/etc/mira/config.toml", + ) (options, args) = parser.parse_args() verbosity = logging.DEBUG if options.debug else logging.INFO logging.basicConfig(stream=sys.stdout, level=verbosity) - logging.info('Loading config from %s' % (options.config)) + logging.info("Loading config from %s" % (options.config)) bot = MiraBot(options.config) loop = asyncio.get_event_loop() diff --git a/mira/module.py b/mira/module.py index 9c3ae59..545a95e 100644 --- a/mira/module.py +++ b/mira/module.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,7 +13,7 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" import asyncio from functools import partial import logging @@ -21,15 +21,17 @@ import logging from mira.storage import StorageManager from mira.subscription import SubscriptionManager -logger = logging.getLogger('mira.module') +logger = logging.getLogger("mira.module") + class ManagerWrapper: - ''' + """ Wrapper class around {Storage, Subscription}Manager in order to to expose those directly to the modules without allowing them access to other modules. - ''' - _name = '' + """ + + _name = "" _manager = None def __init__(self, name, manager): @@ -38,53 +40,57 @@ class ManagerWrapper: def __getattr__(self, key): if not key in dir(self._manager): - raise AttributeError("Attribute %s does not exist in wrapped" - " class %s" % (key, type(self._manager))) + raise AttributeError( + "Attribute %s does not exist in wrapped" + " class %s" % (key, type(self._manager)) + ) return partial(getattr(self._manager, key), self._name) + class BaseModule: - def __init__(self, base, config={}, subcommand_table={}, name=''): + def __init__(self, base, config={}, subcommand_table={}, name=""): self._name = name self._base = base self._config = config self._subcommand_table = subcommand_table - self._local_only = self.get_option('restrict_local', False) - self._restricted = ('restrict_local' in self._config or - 'allowed_domains' in self._config) + self._local_only = self.get_option("restrict_local", False) + self._restricted = ( + "restrict_local" in self._config or "allowed_domains" in self._config + ) self._stm = ManagerWrapper(self._name, StorageManager) self._sum = ManagerWrapper(self._name, SubscriptionManager) - logger.debug('Init of %s done' % (self._name)) - + logger.debug("Init of %s done" % (self._name)) + def get_option(self, key, default=None): - ''' + """ Like dict.get(), but for the options from the bot configuration file. If key does not exist, then default will be returned. - ''' + """ return self._config.get(key, default) def send_message(self, to, body): - ''' + """ A simple wrapper that sends a message with type='chat' to @to with @body as the body - ''' + """ self._base.send_message_wrapper(to, body) - + def _base_on_command(self, cmd, msg): def run(func): loop = asyncio.get_event_loop() loop.create_task(func(cmd, msg)) - logger.debug('Received command: %s' % (str(cmd))) - + logger.debug("Received command: %s" % (str(cmd))) + if not self._subcommand_table: run(self.on_command) elif cmd and cmd[0] in self._subcommand_table: run(self._subcommand_table[cmd[0]]) else: - if '*' in self._subcommand_table: - run(self._subcommand_table['*']) + if "*" in self._subcommand_table: + run(self._subcommand_table["*"]) diff --git a/mira/modules/__init__.py b/mira/modules/__init__.py index 16b99cb..80676f3 100644 --- a/mira/modules/__init__.py +++ b/mira/modules/__init__.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,4 +13,4 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" diff --git a/mira/modules/test.py b/mira/modules/test.py index fa92b84..f0dcb2c 100644 --- a/mira/modules/test.py +++ b/mira/modules/test.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,10 +13,11 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" from mira.module import BaseModule -NAME = 'test' +NAME = "test" + class TestModule(BaseModule): __instance = None @@ -27,17 +28,20 @@ class TestModule(BaseModule): TestModule(base, **kwargs) return TestModule.__instance - + def __init__(self, base, **kwargs): if TestModule.__instance != None: - raise Exception('Trying to init singleton twice') + raise Exception("Trying to init singleton twice") super().__init__(base, **kwargs) TestModule.__instance = self async def on_command(self, cmd, msg): - greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare())) + greeting = self.get_option("greeting", "OwO, %%user%%!").replace( + "%%user%%", str(msg.from_.bare()) + ) self.send_message(msg.from_, greeting) + def get_instance(base, **kwargs): return TestModule.get_instance(base, **kwargs) diff --git a/mira/storage.py b/mira/storage.py index f435778..65c3941 100644 --- a/mira/storage.py +++ b/mira/storage.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,63 +13,64 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" import os import json import logging -logger = logging.getLogger('mira.storage.StorageManager') +logger = logging.getLogger("mira.storage.StorageManager") + class StorageManager: __instance = None @staticmethod - def get_instance(file_location='/etc/mira/storage.json'): + def get_instance(file_location="/etc/mira/storage.json"): if not StorageManager.__instance: StorageManager(file_location=file_location) return StorageManager.__instance def __init__(self, file_location): if StorageManager.__instance: - raise Exception('Trying to instanciate StorageManger twice') + raise Exception("Trying to instanciate StorageManger twice") - self._data = {} # Module -> Section -> Data + self._data = {} # Module -> Section -> Data self._file_location = file_location - logger.debug('Loading data from %s' % (file_location)) + logger.debug("Loading data from %s" % (file_location)) if os.path.exists(file_location): - with open(file_location, 'r') as f: + with open(file_location, "r") as f: self._data = json.loads(f.read()) StorageManager.__instance = self def get_data(self, module, section): - ''' + """ Get the data stored for module @module under the section @section. Returns {} if there is no data stored for @module. - ''' + """ if not module in self._data: - logging.debug('get_data: module unknown in self._data') + logging.debug("get_data: module unknown in self._data") logging.debug('module: "%s"' % (module)) return {} if not section in self._data[module]: - logging.debug('get_data: section unknown in self._data[module]') + logging.debug("get_data: section unknown in self._data[module]") logging.debug('module: "%s", section: "%s"' % (module, section)) return {} return self._data[module][section] def set_data(self, module, section, data): - ''' + """ Stores the data @data for @module under section @section. Flushes the data to storage afterwards. - ''' + """ if not module in self._data: self._data[module] = {} self._data[module][section] = data self.__flush() def __flush(self): - logger.debug('Flushing to storage') - with open(self._file_location, 'w') as f: + logger.debug("Flushing to storage") + with open(self._file_location, "w") as f: f.write(json.dumps(self._data)) diff --git a/mira/subscription.py b/mira/subscription.py index 838255c..4e53ed9 100644 --- a/mira/subscription.py +++ b/mira/subscription.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2021 Alexander "PapaTutuWawa" This program is free software: you can redistribute it and/or modify @@ -13,21 +13,23 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -''' +""" # TODO: Replace most of these with a query API import os import json from mira.storage import StorageManager + def append_or_insert(dict_, key, value): if key in dict_: dict_[key].append(value) else: dict_[key] = [value] + class SubscriptionManager: - ''' + """ This class is tasked with providing functions that simplify dealing with subscriptions. @@ -35,7 +37,8 @@ class SubscriptionManager: has been instanciated at least once. For modules, this is no issue as they're only created after the manager classes are ready. - ''' + """ + __instance = None @staticmethod @@ -43,19 +46,26 @@ class SubscriptionManager: if not SubscriptionManager.__instance: SubscriptionManager() return SubscriptionManager.__instance - - def __init__(self): - self._sm = StorageManager.get_instance() + + def __init__(self, subscriptions={}, sm=None): # Module -> JID -> Keywords - self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions') + if subscriptions: + # NOTE: This is just for testing + self._subscriptions = subscriptions + self._sm = sm + else: + self._sm = StorageManager.get_instance() + self._subscriptions = self._sm.get_data( + "_SubscriptionManager", "subscriptions" + ) SubscriptionManager.__instance = self - def get_subscriptions_for(self, module, jid): - ''' + def get_subscriptions_for_jid(self, module, jid): + """ Returns a dictionary keyword -> data which represents every subscription a jid has in the context of @module. - ''' + """ if not module in self._subscriptions: return [] if not jid in self._subscriptions[module]: @@ -64,41 +74,47 @@ class SubscriptionManager: return self._subscriptions[module][jid] def get_subscriptions_for_keyword(self, module, keyword): - ''' - Returns an array of JIDs that are subscribed to the keyword of module - ''' + """ + Returns a dictionary JID -> Data for JIDs that hava a subscription to + the keyword @keyword in the context of module. + """ if not module in self._subscriptions: - return [] + return {} - tmp = [] + tmp = {} for jid in self._subscriptions[module]: if not keyword in self._subscriptions[module][jid]: continue - data = self._subscriptions[module][jid][keyword]['data'] - tmp.append((jid, data)) + data = self._subscriptions[module][jid][keyword]["data"] + tmp[jid] = data + return tmp def get_subscriptions_for_keywords(self, module, keywords): - ''' - Returns an array of JIDs that are subscribed to at least one of the keywords - of module - ''' + """ + Returns a dictionary of form JID -> keyword -> data of JIDs that are + subscribed to at least one of the keywords in @keywords within the context + of @module. + """ if not module in self._subscriptions: - return [] + return {} - tmp = [] + tmp = {} keyword_set = set(keywords) for jid in self._subscriptions[module]: - if set(self._subscriptions[module][jid].keys()) & keyword_set: + union = set(self._subscriptions[module][jid].keys()) & keyword_set + if not union: continue - data = self._subscriptions[module][jid][keyword]['data'] - tmp.append((jid, data)) + if not jid in tmp: + tmp[jid] = {} + for keyword in union: + tmp[jid][keyword] = self._subscriptions[module][jid][keyword]["data"] return tmp - + def get_subscription_keywords(self, module): - '''Returns a list of subscribed keywords in module''' + """Returns a list of subscribed keywords in module""" if not module in self._subscriptions: return [] @@ -106,60 +122,58 @@ class SubscriptionManager: for subscription in self._subscriptions[module].values(): tmp += list(subscription.keys()) return tmp - - def is_subscribed_to(self, module, jid, keyword): - ''' + + def is_subscribed_to_keyword(self, module, jid, keyword): + """ Returns True if @jid is subscribed to @keyword within the context of @module. False otherwise - ''' - return keyword in self.get_subscriptions_for(module, jid) + """ + return keyword in self.get_subscriptions_for_jid(module, jid) def is_subscribed_to_data(self, module, jid, keyword, item): - ''' + """ Returns True if @jid is subscribed to the item @item inside the keyword @keyword within the context of @module - ''' - subscriptions = self.get_subscriptions_for(module, jid) + """ + subscriptions = self.get_subscriptions_for_jid(module, jid) if not subscriptions: return False if not keyword in subscriptions: return False - return item in subscriptions[keyword]['data'] + return item in subscriptions[keyword]["data"] - def is_subscribed_to_data_one(self, module, jid, keyword, func): - ''' + def is_subscribed_to_data_func(self, module, jid, keyword, func): + """ Like is_subscribed_to_data, but returns True if there is at least one item for which func returns True. - ''' - subscriptions = self.get_subscriptions_for(module, jid) + """ + subscriptions = self.get_subscriptions_for_jid(module, jid) if not subscriptions: return False if not keyword in subscriptions: return False - for item in subscriptions[keyword]['data']: + for item in subscriptions[keyword]["data"]: if func(item): return True return False - - def add_subscription_for(self, module, jid, keyword, data={}): - ''' + + def add_subscription(self, module, jid, keyword, data={}): + """ Adds a subscription to @keyword with data @data for @jid within the context of @module. - ''' + """ if not module in self._subscriptions: self._subscriptions[module] = {} if not jid in self._subscriptions[module]: self._subscriptions[module][jid] = {} - self._subscriptions[module][jid][keyword] = { - 'data': data - } + self._subscriptions[module][jid][keyword] = {"data": data} self.__flush() - def append_data_for_subscription(self, module, jid, keyword, item): - ''' + def append_subscription_data(self, module, jid, keyword, item): + """ Special helper function which appends item to the data field of a subscription to @keyword from @jid within the context of @module. @@ -169,26 +183,24 @@ class SubscriptionManager: must be an array, so it will fail if add_subscription_for has been called beforehand with data equal to anything but an dict with a 'data' key containing an array. - ''' + """ if not module in self._subscriptions: self._subscriptions[module] = {} if not jid in self._subscriptions[module]: self._subscriptions[module][jid] = {} if not keyword in self._subscriptions[module][jid]: - self._subscriptions[module][jid][keyword] = { - 'data': [item] - } + self._subscriptions[module][jid][keyword] = {"data": [item]} self.__flush() return - self._subscriptions[module][jid][keyword]['data'].append(item) + self._subscriptions[module][jid][keyword]["data"].append(item) self.__flush() - def remove_subscription_for(self, module, jid, keyword): - ''' + def remove_subscription(self, module, jid, keyword): + """ Removes a subscription to @keyword for @jid within the context of @module - ''' + """ del self._subscriptions[module][jid][keyword] if not self._subscriptions[module][jid]: @@ -198,20 +210,18 @@ class SubscriptionManager: self.__flush() - def remove_item_for_subscription(self, module, jid, keyword, item, flush=True): - ''' + def remove_subscription_data_item(self, module, jid, keyword, item, flush=True): + """ The deletion counterpart of append_data_for_subscription. - ''' - self.filter_items_for_subscription(module, - jid, - keyword, - func=lambda x: x == item, - flush=flush) + """ + self.filter_subscription_data_items( + module, jid, keyword, func=lambda x: x != item, flush=flush + ) - def filter_items_for_subscription(self, module, jid, keyword, func, flush=True): - ''' + def filter_subscription_data_items(self, module, jid, keyword, func, flush=True): + """ remove_item_for_subscription but for multiple items - ''' + """ if not module in self._subscriptions: return if not jid in self._subscriptions[module]: @@ -219,10 +229,11 @@ class SubscriptionManager: if not keyword in self._subscriptions[module][jid]: return - self._subscriptions[module][jid][keyword]['data'] = list(filter(func, - self._subscriptions[module][jid][keyword]['data'])) + self._subscriptions[module][jid][keyword]["data"] = list( + filter(func, self._subscriptions[module][jid][keyword]["data"]) + ) - if not self._subscriptions[module][jid][keyword]['data']: + if not self._subscriptions[module][jid][keyword]["data"]: del self._subscriptions[module][jid][keyword] if not self._subscriptions[module][jid]: del self._subscriptions[module][jid] @@ -231,9 +242,9 @@ class SubscriptionManager: if flush: self.__flush() - + def __flush(self): - ''' + """ Write subscription data to disk. Just an interface to StorageManager - ''' - self._sm.set_data('_SubscriptionManager', 'subscriptions', self._subscriptions) + """ + self._sm.set_data("_SubscriptionManager", "subscriptions", self._subscriptions) diff --git a/setup.py b/setup.py index dd16a4c..c8011b5 100644 --- a/setup.py +++ b/setup.py @@ -1,22 +1,31 @@ from setuptools import setup, find_packages setup( - name = 'mira', - version = '0.2.0', - description = 'A command-base XMPP bot framework', - url = 'https://git.polynom.me/PapaTutuWawa/mira', - author = 'Alexander "PapaTutuWawa"', - author_email = 'papatutuwawa polynom.me', - license = 'GPLv3', + name = "mira", + version = "0.3.0", + description = "A command-base XMPP bot framework", + url = "https://git.polynom.me/PapaTutuWawa/mira", + author = "Alexander \"PapaTutuWawa\"", + author_email = "papatutuwawa polynom.me", + license = "GPLv3", packages = find_packages(), install_requires = [ - 'aioxmpp>=0.12.0', - 'toml>=0.10.2' + "aioxmpp>=0.12.0", + "toml>=0.10.2" + ], + extra_require = { + "dev": [ + "pytest", + "black" + ] + }, + tests_require = [ + "pytest" ], zip_safe=True, entry_points={ - 'console_scripts': [ - 'mira = mira.base:main' + "console_scripts": [ + "mira = mira.base:main" ] } ) diff --git a/tests/test_subscription.py b/tests/test_subscription.py new file mode 100644 index 0000000..2b60f51 --- /dev/null +++ b/tests/test_subscription.py @@ -0,0 +1,156 @@ +from mira.subscription import SubscriptionManager + +class MockStorageManager: + ''' + The SubscriptionManager requieres the StorageManager, but we don't + need it for the tests. So just stub it. + ''' + + def set_data(self, module, section, data): + pass + +def get_sum(): + return SubscriptionManager({ + 'test': { + 'a@localhost': { + 'thing1': { + 'data': 42 + }, + 'thing2': { + 'data': 100 + } + }, + 'b@localhost': { + 'thing2': { + 'data': 89 + }, + 'thing3': { + 'data': [1, 2, 4] + } + }, + 'd@localhost': { + 'thing1': { + 'data': {} + } + } + } + }, MockStorageManager()) + +def test_get_subscriptions_for_jid(): + sum = get_sum() + + assert sum.get_subscriptions_for_jid('prod', 'a@localhost') == [] + assert sum.get_subscriptions_for_jid('test', 'z@localhost') == [] + + subs = sum.get_subscriptions_for_jid('test', 'a@localhost') + assert len(subs.keys()) == 2 + assert 'thing1' in subs and subs['thing1']['data'] == 42 + assert 'thing2' in subs and subs['thing2']['data'] == 100 + +def test_get_subscriptions_for_keyword(): + sum = get_sum() + + assert sum.get_subscriptions_for_keyword('prod', 'thing1') == {} + assert sum.get_subscriptions_for_keyword('test', 'thing4') == {} + + subs = sum.get_subscriptions_for_keyword('test', 'thing2') + assert 'a@localhost' in subs and subs['a@localhost'] == 100 + assert 'b@localhost' in subs and subs['b@localhost'] == 89 + +def test_get_subscriptions_for_keywords(): + sum = get_sum() + + assert sum.get_subscriptions_for_keywords('prod', 'thing1') == {} + assert sum.get_subscriptions_for_keywords('test', 'thing4') == {} + + subs = sum.get_subscriptions_for_keywords('test', ['thing2', 'thing3']) + assert 'a@localhost' in subs and 'thing2'in subs['a@localhost'] and subs['a@localhost']['thing2'] == 100 + assert not 'thing3' in subs['a@localhost'] + assert 'b@localhost' in subs and 'thing2' in subs['b@localhost'] and subs['b@localhost']['thing2'] == 89 + assert 'b@localhost' in subs and 'thing3' in subs['b@localhost'] and subs['b@localhost']['thing3'] == [1, 2, 4] + +def test_get_subscription_keywords(): + sum = get_sum() + + assert sum.get_subscription_keywords('prod') == [] + assert not set(sum.get_subscription_keywords('test')) - set(['thing1', 'thing2', 'thing3']) + +def test_is_subscribed_to_keyword(): + sum = get_sum() + + assert not sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing1') + assert not sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4') + assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing1') + +def test_is_subscribed_to_data(): + sum = get_sum() + + assert not sum.is_subscribed_to_data('prod', 'b@localhost', 'thing1', 1) + assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing4', 1) + assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 10) + assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1) + +def test_is_subscribed_to_data_func(): + sum = get_sum() + + func1 = lambda x: x % 2 == 0 + func2 = lambda x: x == 10 + + assert not sum.is_subscribed_to_data_func('prod', 'b@localhost', 'thing1', func1) + assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing4', func1) + assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func2) + assert sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func1) + +def test_add_subscription(): + sum = get_sum() + + sum.add_subscription('test', 'c@localhost', 'thing1') + assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1') + + sum.add_subscription('test', 'a@localhost', 'thing4') + assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4') + + sum.add_subscription('prod', 'a@localhost', 'thing4') + assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing4') + + sum.add_subscription('prod', 'a@localhost', 'thing5', 60) + subs = sum.get_subscriptions_for_jid('prod', 'a@localhost') + assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing5') + assert subs and 'thing5' in subs and subs['thing5']['data'] == 60 + +def test_append_subscription_data(): + sum = get_sum() + + sum.add_subscription('test', 'c@localhost', 'thing1', []) + sum.append_subscription_data('test', 'c@localhost', 'thing1', 1) + subs = sum.get_subscriptions_for_jid('test', 'c@localhost') + assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1') + assert subs and 'thing1' in subs and subs['thing1']['data'] == [1] + + sum.append_subscription_data('test', 'c@localhost', 'thing1', 5) + assert subs['thing1']['data'] == [1, 5] + +def test_remove_subscription(): + sum = get_sum() + + sum.remove_subscription('test', 'd@localhost', 'thing1') + assert not sum.is_subscribed_to_keyword('test', 'd@localhost', 'thing1') + assert not sum.get_subscriptions_for_jid('test', 'd@localhost') + +def test_filter_subscription_data_items(): + sum = get_sum() + + func = lambda x: not x % 2 == 0 + + sum.filter_subscription_data_items('test', 'b@localhost', 'thing3', func) + assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1) + assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2) + assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4) + +def test_remove_subscription_data_item(): + sum = get_sum() + + sum.remove_subscription_data_item('test', 'b@localhost', 'thing3', 4) + assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1) + assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2) + assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)