diff --git a/mira/__init__.py b/mira/__init__.py
index 16b99cb..80676f3 100644
--- a/mira/__init__.py
+++ b/mira/__init__.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,4 +13,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
diff --git a/mira/base.py b/mira/base.py
index 2c3c32f..477d939 100644
--- a/mira/base.py
+++ b/mira/base.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,7 +13,7 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
import sys
import importlib
import asyncio
@@ -26,95 +26,104 @@ import toml
from mira.subscription import SubscriptionManager
from mira.storage import StorageManager
-logger = logging.getLogger('mira.base')
+logger = logging.getLogger("mira.base")
+
def message_wrapper(to, body):
- msg = aioxmpp.Message(
- type_=aioxmpp.MessageType.CHAT,
- to=to)
+ msg = aioxmpp.Message(type_=aioxmpp.MessageType.CHAT, to=to)
msg.body[None] = body
return msg
+
class MiraBot:
def __init__(self, config_path):
# Bot specific settings
self._config = toml.load(config_path)
- self._jid = aioxmpp.JID.fromstr(self._config['jid'])
- self._password = self._config['password']
- self._avatar = self._config.get('avatar', None)
+ self._jid = aioxmpp.JID.fromstr(self._config["jid"])
+ self._password = self._config["password"]
+ self._avatar = self._config.get("avatar", None)
self._client = None
- self._modules = {} # Module name -> module
- self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json'))
+ self._modules = {} # Module name -> module
+ self._storage_manager = StorageManager.get_instance(
+ self._config.get("storage_path", "/etc/mira/storage.json")
+ )
self._subscription_manager = SubscriptionManager.get_instance()
def _initialise_modules(self):
- for module in self._config['modules']:
- logger.debug("Initialising module %s" % (module['name']))
- mod = importlib.import_module(module['name'])
- self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
+ for module in self._config["modules"]:
+ logger.debug("Initialising module %s" % (module["name"]))
+ mod = importlib.import_module(module["name"])
+ self._modules[mod.NAME] = mod.get_instance(
+ self, config=module, name=mod.NAME
+ )
async def connect(self):
self._client = aioxmpp.PresenceManagedClient(
- self._jid,
- aioxmpp.make_security_layer(self._password))
+ self._jid, aioxmpp.make_security_layer(self._password)
+ )
async with self._client.connected():
- logger.info('Client connected')
+ logger.info("Client connected")
self._client.stream.register_message_callback(
- aioxmpp.MessageType.CHAT,
- None,
- self._on_message)
+ aioxmpp.MessageType.CHAT, None, self._on_message
+ )
if self._avatar:
- logger.info('Publishing avatar')
- with open(self._avatar, 'rb') as avatar_file:
+ logger.info("Publishing avatar")
+ with open(self._avatar, "rb") as avatar_file:
data = avatar_file.read()
avatar_set = aioxmpp.avatar.AvatarSet()
# TODO: Detect MIME type
- avatar_set.add_avatar_image('image/png', image_bytes=data)
+ avatar_set.add_avatar_image("image/png", image_bytes=data)
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
await avatar.publish_avatar_set(avatar_set)
- logger.info('Avatar published')
+ logger.info("Avatar published")
- logger.debug('Initialising modules')
+ logger.debug("Initialising modules")
self._initialise_modules()
-
+
while True:
await asyncio.sleep(1)
def _is_sender_local(self, from_):
return from_.domain == self._jid.domain
-
+
def _on_message(self, message):
# Automatically handles sending a message receipt and dealing
# with unwanted messages
- if (message.type_ != aioxmpp.MessageType.CHAT or
- not message.body):
+ if message.type_ != aioxmpp.MessageType.CHAT or not message.body:
return
- cmd = str(message.body.any()).split(' ')
+ cmd = str(message.body.any()).split(" ")
receipt = aioxmpp.mdr.compose_receipt(message)
self._client.enqueue(receipt)
if not cmd[0] in self._modules:
- logger.debug('Received command for unknown module. Dropping')
+ logger.debug("Received command for unknown module. Dropping")
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
return
# Just drop messages that are not local when the module should
# be local only
if self._modules[cmd[0]]._restricted:
- if self._modules[cmd[0]]._config['restrict_local']:
+ if self._modules[cmd[0]]._config["restrict_local"]:
if not self._is_sender_local(message.from_):
- logger.warning('Received a command from a non-local user to a'
- ' module that is restricted to local users only')
+ logger.warning(
+ "Received a command from a non-local user to a"
+ " module that is restricted to local users only"
+ )
return
- elif self._modules[cmd[0]]._config['allowed_domains']:
- if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
- logger.warning('Received a command from a non-whitelisted user to a'
- ' module that is restricted to whitelisted users only')
+ elif self._modules[cmd[0]]._config["allowed_domains"]:
+ if (
+ not message.from_.domain
+ in self._modules[cmd[0]]._config["allowed_domains"]
+ ):
+ logger.warning(
+ "Received a command from a non-whitelisted user to a"
+ " module that is restricted to whitelisted users only"
+ )
return
self._modules[cmd[0]]._base_on_command(cmd[1:], message)
@@ -126,19 +135,26 @@ class MiraBot:
# Module Function: Send a message to @to with @body
def send_message_wrapper(self, to, body):
self.send_message(message_wrapper(to, body))
-
+
+
def main():
parser = OptionParser()
- parser.add_option('-d', '--debug', dest='debug',
- help='Enable debug logging', action='store_true')
- parser.add_option('-c', '--config', dest='config', help='Location of the config.toml',
- default='/etc/mira/config.toml')
+ parser.add_option(
+ "-d", "--debug", dest="debug", help="Enable debug logging", action="store_true"
+ )
+ parser.add_option(
+ "-c",
+ "--config",
+ dest="config",
+ help="Location of the config.toml",
+ default="/etc/mira/config.toml",
+ )
(options, args) = parser.parse_args()
verbosity = logging.DEBUG if options.debug else logging.INFO
logging.basicConfig(stream=sys.stdout, level=verbosity)
- logging.info('Loading config from %s' % (options.config))
+ logging.info("Loading config from %s" % (options.config))
bot = MiraBot(options.config)
loop = asyncio.get_event_loop()
diff --git a/mira/module.py b/mira/module.py
index 9c3ae59..545a95e 100644
--- a/mira/module.py
+++ b/mira/module.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,7 +13,7 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
import asyncio
from functools import partial
import logging
@@ -21,15 +21,17 @@ import logging
from mira.storage import StorageManager
from mira.subscription import SubscriptionManager
-logger = logging.getLogger('mira.module')
+logger = logging.getLogger("mira.module")
+
class ManagerWrapper:
- '''
+ """
Wrapper class around {Storage, Subscription}Manager in order
to to expose those directly to the modules without allowing them
access to other modules.
- '''
- _name = ''
+ """
+
+ _name = ""
_manager = None
def __init__(self, name, manager):
@@ -38,53 +40,57 @@ class ManagerWrapper:
def __getattr__(self, key):
if not key in dir(self._manager):
- raise AttributeError("Attribute %s does not exist in wrapped"
- " class %s" % (key, type(self._manager)))
+ raise AttributeError(
+ "Attribute %s does not exist in wrapped"
+ " class %s" % (key, type(self._manager))
+ )
return partial(getattr(self._manager, key), self._name)
+
class BaseModule:
- def __init__(self, base, config={}, subcommand_table={}, name=''):
+ def __init__(self, base, config={}, subcommand_table={}, name=""):
self._name = name
self._base = base
self._config = config
self._subcommand_table = subcommand_table
- self._local_only = self.get_option('restrict_local', False)
- self._restricted = ('restrict_local' in self._config or
- 'allowed_domains' in self._config)
+ self._local_only = self.get_option("restrict_local", False)
+ self._restricted = (
+ "restrict_local" in self._config or "allowed_domains" in self._config
+ )
self._stm = ManagerWrapper(self._name, StorageManager)
self._sum = ManagerWrapper(self._name, SubscriptionManager)
- logger.debug('Init of %s done' % (self._name))
-
+ logger.debug("Init of %s done" % (self._name))
+
def get_option(self, key, default=None):
- '''
+ """
Like dict.get(), but for the options from the bot configuration
file.
If key does not exist, then default will be returned.
- '''
+ """
return self._config.get(key, default)
def send_message(self, to, body):
- '''
+ """
A simple wrapper that sends a message with type='chat' to
@to with @body as the body
- '''
+ """
self._base.send_message_wrapper(to, body)
-
+
def _base_on_command(self, cmd, msg):
def run(func):
loop = asyncio.get_event_loop()
loop.create_task(func(cmd, msg))
- logger.debug('Received command: %s' % (str(cmd)))
-
+ logger.debug("Received command: %s" % (str(cmd)))
+
if not self._subcommand_table:
run(self.on_command)
elif cmd and cmd[0] in self._subcommand_table:
run(self._subcommand_table[cmd[0]])
else:
- if '*' in self._subcommand_table:
- run(self._subcommand_table['*'])
+ if "*" in self._subcommand_table:
+ run(self._subcommand_table["*"])
diff --git a/mira/modules/__init__.py b/mira/modules/__init__.py
index 16b99cb..80676f3 100644
--- a/mira/modules/__init__.py
+++ b/mira/modules/__init__.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,4 +13,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
diff --git a/mira/modules/test.py b/mira/modules/test.py
index fa92b84..f0dcb2c 100644
--- a/mira/modules/test.py
+++ b/mira/modules/test.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,10 +13,11 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
from mira.module import BaseModule
-NAME = 'test'
+NAME = "test"
+
class TestModule(BaseModule):
__instance = None
@@ -27,17 +28,20 @@ class TestModule(BaseModule):
TestModule(base, **kwargs)
return TestModule.__instance
-
+
def __init__(self, base, **kwargs):
if TestModule.__instance != None:
- raise Exception('Trying to init singleton twice')
+ raise Exception("Trying to init singleton twice")
super().__init__(base, **kwargs)
TestModule.__instance = self
async def on_command(self, cmd, msg):
- greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare()))
+ greeting = self.get_option("greeting", "OwO, %%user%%!").replace(
+ "%%user%%", str(msg.from_.bare())
+ )
self.send_message(msg.from_, greeting)
+
def get_instance(base, **kwargs):
return TestModule.get_instance(base, **kwargs)
diff --git a/mira/storage.py b/mira/storage.py
index f435778..65c3941 100644
--- a/mira/storage.py
+++ b/mira/storage.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,63 +13,64 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
import os
import json
import logging
-logger = logging.getLogger('mira.storage.StorageManager')
+logger = logging.getLogger("mira.storage.StorageManager")
+
class StorageManager:
__instance = None
@staticmethod
- def get_instance(file_location='/etc/mira/storage.json'):
+ def get_instance(file_location="/etc/mira/storage.json"):
if not StorageManager.__instance:
StorageManager(file_location=file_location)
return StorageManager.__instance
def __init__(self, file_location):
if StorageManager.__instance:
- raise Exception('Trying to instanciate StorageManger twice')
+ raise Exception("Trying to instanciate StorageManger twice")
- self._data = {} # Module -> Section -> Data
+ self._data = {} # Module -> Section -> Data
self._file_location = file_location
- logger.debug('Loading data from %s' % (file_location))
+ logger.debug("Loading data from %s" % (file_location))
if os.path.exists(file_location):
- with open(file_location, 'r') as f:
+ with open(file_location, "r") as f:
self._data = json.loads(f.read())
StorageManager.__instance = self
def get_data(self, module, section):
- '''
+ """
Get the data stored for module @module under the section
@section. Returns {} if there is no data stored for @module.
- '''
+ """
if not module in self._data:
- logging.debug('get_data: module unknown in self._data')
+ logging.debug("get_data: module unknown in self._data")
logging.debug('module: "%s"' % (module))
return {}
if not section in self._data[module]:
- logging.debug('get_data: section unknown in self._data[module]')
+ logging.debug("get_data: section unknown in self._data[module]")
logging.debug('module: "%s", section: "%s"' % (module, section))
return {}
return self._data[module][section]
def set_data(self, module, section, data):
- '''
+ """
Stores the data @data for @module under section @section.
Flushes the data to storage afterwards.
- '''
+ """
if not module in self._data:
self._data[module] = {}
self._data[module][section] = data
self.__flush()
def __flush(self):
- logger.debug('Flushing to storage')
- with open(self._file_location, 'w') as f:
+ logger.debug("Flushing to storage")
+ with open(self._file_location, "w") as f:
f.write(json.dumps(self._data))
diff --git a/mira/subscription.py b/mira/subscription.py
index 838255c..4e53ed9 100644
--- a/mira/subscription.py
+++ b/mira/subscription.py
@@ -1,4 +1,4 @@
-'''
+"""
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
@@ -13,21 +13,23 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
-'''
+"""
# TODO: Replace most of these with a query API
import os
import json
from mira.storage import StorageManager
+
def append_or_insert(dict_, key, value):
if key in dict_:
dict_[key].append(value)
else:
dict_[key] = [value]
+
class SubscriptionManager:
- '''
+ """
This class is tasked with providing functions that simplify dealing
with subscriptions.
@@ -35,7 +37,8 @@ class SubscriptionManager:
has been instanciated at least once. For modules, this is no
issue as they're only created after the manager classes are
ready.
- '''
+ """
+
__instance = None
@staticmethod
@@ -43,19 +46,26 @@ class SubscriptionManager:
if not SubscriptionManager.__instance:
SubscriptionManager()
return SubscriptionManager.__instance
-
- def __init__(self):
- self._sm = StorageManager.get_instance()
+
+ def __init__(self, subscriptions={}, sm=None):
# Module -> JID -> Keywords
- self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions')
+ if subscriptions:
+ # NOTE: This is just for testing
+ self._subscriptions = subscriptions
+ self._sm = sm
+ else:
+ self._sm = StorageManager.get_instance()
+ self._subscriptions = self._sm.get_data(
+ "_SubscriptionManager", "subscriptions"
+ )
SubscriptionManager.__instance = self
- def get_subscriptions_for(self, module, jid):
- '''
+ def get_subscriptions_for_jid(self, module, jid):
+ """
Returns a dictionary keyword -> data which represents
every subscription a jid has in the context of @module.
- '''
+ """
if not module in self._subscriptions:
return []
if not jid in self._subscriptions[module]:
@@ -64,41 +74,47 @@ class SubscriptionManager:
return self._subscriptions[module][jid]
def get_subscriptions_for_keyword(self, module, keyword):
- '''
- Returns an array of JIDs that are subscribed to the keyword of module
- '''
+ """
+ Returns a dictionary JID -> Data for JIDs that hava a subscription to
+ the keyword @keyword in the context of module.
+ """
if not module in self._subscriptions:
- return []
+ return {}
- tmp = []
+ tmp = {}
for jid in self._subscriptions[module]:
if not keyword in self._subscriptions[module][jid]:
continue
- data = self._subscriptions[module][jid][keyword]['data']
- tmp.append((jid, data))
+ data = self._subscriptions[module][jid][keyword]["data"]
+ tmp[jid] = data
+
return tmp
def get_subscriptions_for_keywords(self, module, keywords):
- '''
- Returns an array of JIDs that are subscribed to at least one of the keywords
- of module
- '''
+ """
+ Returns a dictionary of form JID -> keyword -> data of JIDs that are
+ subscribed to at least one of the keywords in @keywords within the context
+ of @module.
+ """
if not module in self._subscriptions:
- return []
+ return {}
- tmp = []
+ tmp = {}
keyword_set = set(keywords)
for jid in self._subscriptions[module]:
- if set(self._subscriptions[module][jid].keys()) & keyword_set:
+ union = set(self._subscriptions[module][jid].keys()) & keyword_set
+ if not union:
continue
- data = self._subscriptions[module][jid][keyword]['data']
- tmp.append((jid, data))
+ if not jid in tmp:
+ tmp[jid] = {}
+ for keyword in union:
+ tmp[jid][keyword] = self._subscriptions[module][jid][keyword]["data"]
return tmp
-
+
def get_subscription_keywords(self, module):
- '''Returns a list of subscribed keywords in module'''
+ """Returns a list of subscribed keywords in module"""
if not module in self._subscriptions:
return []
@@ -106,60 +122,58 @@ class SubscriptionManager:
for subscription in self._subscriptions[module].values():
tmp += list(subscription.keys())
return tmp
-
- def is_subscribed_to(self, module, jid, keyword):
- '''
+
+ def is_subscribed_to_keyword(self, module, jid, keyword):
+ """
Returns True if @jid is subscribed to @keyword within the context
of @module. False otherwise
- '''
- return keyword in self.get_subscriptions_for(module, jid)
+ """
+ return keyword in self.get_subscriptions_for_jid(module, jid)
def is_subscribed_to_data(self, module, jid, keyword, item):
- '''
+ """
Returns True if @jid is subscribed to the item @item inside
the keyword @keyword within the context of @module
- '''
- subscriptions = self.get_subscriptions_for(module, jid)
+ """
+ subscriptions = self.get_subscriptions_for_jid(module, jid)
if not subscriptions:
return False
if not keyword in subscriptions:
return False
- return item in subscriptions[keyword]['data']
+ return item in subscriptions[keyword]["data"]
- def is_subscribed_to_data_one(self, module, jid, keyword, func):
- '''
+ def is_subscribed_to_data_func(self, module, jid, keyword, func):
+ """
Like is_subscribed_to_data, but returns True if there is at
least one item for which func returns True.
- '''
- subscriptions = self.get_subscriptions_for(module, jid)
+ """
+ subscriptions = self.get_subscriptions_for_jid(module, jid)
if not subscriptions:
return False
if not keyword in subscriptions:
return False
- for item in subscriptions[keyword]['data']:
+ for item in subscriptions[keyword]["data"]:
if func(item):
return True
return False
-
- def add_subscription_for(self, module, jid, keyword, data={}):
- '''
+
+ def add_subscription(self, module, jid, keyword, data={}):
+ """
Adds a subscription to @keyword with data @data for @jid within
the context of @module.
- '''
+ """
if not module in self._subscriptions:
self._subscriptions[module] = {}
if not jid in self._subscriptions[module]:
self._subscriptions[module][jid] = {}
- self._subscriptions[module][jid][keyword] = {
- 'data': data
- }
+ self._subscriptions[module][jid][keyword] = {"data": data}
self.__flush()
- def append_data_for_subscription(self, module, jid, keyword, item):
- '''
+ def append_subscription_data(self, module, jid, keyword, item):
+ """
Special helper function which appends item to the data field of
a subscription to @keyword from @jid within the context of @module.
@@ -169,26 +183,24 @@ class SubscriptionManager:
must be an array, so it will fail if add_subscription_for has
been called beforehand with data equal to anything but an dict
with a 'data' key containing an array.
- '''
+ """
if not module in self._subscriptions:
self._subscriptions[module] = {}
if not jid in self._subscriptions[module]:
self._subscriptions[module][jid] = {}
if not keyword in self._subscriptions[module][jid]:
- self._subscriptions[module][jid][keyword] = {
- 'data': [item]
- }
+ self._subscriptions[module][jid][keyword] = {"data": [item]}
self.__flush()
return
- self._subscriptions[module][jid][keyword]['data'].append(item)
+ self._subscriptions[module][jid][keyword]["data"].append(item)
self.__flush()
- def remove_subscription_for(self, module, jid, keyword):
- '''
+ def remove_subscription(self, module, jid, keyword):
+ """
Removes a subscription to @keyword for @jid within the context
of @module
- '''
+ """
del self._subscriptions[module][jid][keyword]
if not self._subscriptions[module][jid]:
@@ -198,20 +210,18 @@ class SubscriptionManager:
self.__flush()
- def remove_item_for_subscription(self, module, jid, keyword, item, flush=True):
- '''
+ def remove_subscription_data_item(self, module, jid, keyword, item, flush=True):
+ """
The deletion counterpart of append_data_for_subscription.
- '''
- self.filter_items_for_subscription(module,
- jid,
- keyword,
- func=lambda x: x == item,
- flush=flush)
+ """
+ self.filter_subscription_data_items(
+ module, jid, keyword, func=lambda x: x != item, flush=flush
+ )
- def filter_items_for_subscription(self, module, jid, keyword, func, flush=True):
- '''
+ def filter_subscription_data_items(self, module, jid, keyword, func, flush=True):
+ """
remove_item_for_subscription but for multiple items
- '''
+ """
if not module in self._subscriptions:
return
if not jid in self._subscriptions[module]:
@@ -219,10 +229,11 @@ class SubscriptionManager:
if not keyword in self._subscriptions[module][jid]:
return
- self._subscriptions[module][jid][keyword]['data'] = list(filter(func,
- self._subscriptions[module][jid][keyword]['data']))
+ self._subscriptions[module][jid][keyword]["data"] = list(
+ filter(func, self._subscriptions[module][jid][keyword]["data"])
+ )
- if not self._subscriptions[module][jid][keyword]['data']:
+ if not self._subscriptions[module][jid][keyword]["data"]:
del self._subscriptions[module][jid][keyword]
if not self._subscriptions[module][jid]:
del self._subscriptions[module][jid]
@@ -231,9 +242,9 @@ class SubscriptionManager:
if flush:
self.__flush()
-
+
def __flush(self):
- '''
+ """
Write subscription data to disk. Just an interface to StorageManager
- '''
- self._sm.set_data('_SubscriptionManager', 'subscriptions', self._subscriptions)
+ """
+ self._sm.set_data("_SubscriptionManager", "subscriptions", self._subscriptions)
diff --git a/setup.py b/setup.py
index dd16a4c..c8011b5 100644
--- a/setup.py
+++ b/setup.py
@@ -1,22 +1,31 @@
from setuptools import setup, find_packages
setup(
- name = 'mira',
- version = '0.2.0',
- description = 'A command-base XMPP bot framework',
- url = 'https://git.polynom.me/PapaTutuWawa/mira',
- author = 'Alexander "PapaTutuWawa"',
- author_email = 'papatutuwawa polynom.me',
- license = 'GPLv3',
+ name = "mira",
+ version = "0.3.0",
+ description = "A command-base XMPP bot framework",
+ url = "https://git.polynom.me/PapaTutuWawa/mira",
+ author = "Alexander \"PapaTutuWawa\"",
+ author_email = "papatutuwawa polynom.me",
+ license = "GPLv3",
packages = find_packages(),
install_requires = [
- 'aioxmpp>=0.12.0',
- 'toml>=0.10.2'
+ "aioxmpp>=0.12.0",
+ "toml>=0.10.2"
+ ],
+ extra_require = {
+ "dev": [
+ "pytest",
+ "black"
+ ]
+ },
+ tests_require = [
+ "pytest"
],
zip_safe=True,
entry_points={
- 'console_scripts': [
- 'mira = mira.base:main'
+ "console_scripts": [
+ "mira = mira.base:main"
]
}
)
diff --git a/tests/test_subscription.py b/tests/test_subscription.py
new file mode 100644
index 0000000..2b60f51
--- /dev/null
+++ b/tests/test_subscription.py
@@ -0,0 +1,156 @@
+from mira.subscription import SubscriptionManager
+
+class MockStorageManager:
+ '''
+ The SubscriptionManager requieres the StorageManager, but we don't
+ need it for the tests. So just stub it.
+ '''
+
+ def set_data(self, module, section, data):
+ pass
+
+def get_sum():
+ return SubscriptionManager({
+ 'test': {
+ 'a@localhost': {
+ 'thing1': {
+ 'data': 42
+ },
+ 'thing2': {
+ 'data': 100
+ }
+ },
+ 'b@localhost': {
+ 'thing2': {
+ 'data': 89
+ },
+ 'thing3': {
+ 'data': [1, 2, 4]
+ }
+ },
+ 'd@localhost': {
+ 'thing1': {
+ 'data': {}
+ }
+ }
+ }
+ }, MockStorageManager())
+
+def test_get_subscriptions_for_jid():
+ sum = get_sum()
+
+ assert sum.get_subscriptions_for_jid('prod', 'a@localhost') == []
+ assert sum.get_subscriptions_for_jid('test', 'z@localhost') == []
+
+ subs = sum.get_subscriptions_for_jid('test', 'a@localhost')
+ assert len(subs.keys()) == 2
+ assert 'thing1' in subs and subs['thing1']['data'] == 42
+ assert 'thing2' in subs and subs['thing2']['data'] == 100
+
+def test_get_subscriptions_for_keyword():
+ sum = get_sum()
+
+ assert sum.get_subscriptions_for_keyword('prod', 'thing1') == {}
+ assert sum.get_subscriptions_for_keyword('test', 'thing4') == {}
+
+ subs = sum.get_subscriptions_for_keyword('test', 'thing2')
+ assert 'a@localhost' in subs and subs['a@localhost'] == 100
+ assert 'b@localhost' in subs and subs['b@localhost'] == 89
+
+def test_get_subscriptions_for_keywords():
+ sum = get_sum()
+
+ assert sum.get_subscriptions_for_keywords('prod', 'thing1') == {}
+ assert sum.get_subscriptions_for_keywords('test', 'thing4') == {}
+
+ subs = sum.get_subscriptions_for_keywords('test', ['thing2', 'thing3'])
+ assert 'a@localhost' in subs and 'thing2'in subs['a@localhost'] and subs['a@localhost']['thing2'] == 100
+ assert not 'thing3' in subs['a@localhost']
+ assert 'b@localhost' in subs and 'thing2' in subs['b@localhost'] and subs['b@localhost']['thing2'] == 89
+ assert 'b@localhost' in subs and 'thing3' in subs['b@localhost'] and subs['b@localhost']['thing3'] == [1, 2, 4]
+
+def test_get_subscription_keywords():
+ sum = get_sum()
+
+ assert sum.get_subscription_keywords('prod') == []
+ assert not set(sum.get_subscription_keywords('test')) - set(['thing1', 'thing2', 'thing3'])
+
+def test_is_subscribed_to_keyword():
+ sum = get_sum()
+
+ assert not sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing1')
+ assert not sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
+ assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing1')
+
+def test_is_subscribed_to_data():
+ sum = get_sum()
+
+ assert not sum.is_subscribed_to_data('prod', 'b@localhost', 'thing1', 1)
+ assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing4', 1)
+ assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 10)
+ assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
+
+def test_is_subscribed_to_data_func():
+ sum = get_sum()
+
+ func1 = lambda x: x % 2 == 0
+ func2 = lambda x: x == 10
+
+ assert not sum.is_subscribed_to_data_func('prod', 'b@localhost', 'thing1', func1)
+ assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing4', func1)
+ assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func2)
+ assert sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func1)
+
+def test_add_subscription():
+ sum = get_sum()
+
+ sum.add_subscription('test', 'c@localhost', 'thing1')
+ assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
+
+ sum.add_subscription('test', 'a@localhost', 'thing4')
+ assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
+
+ sum.add_subscription('prod', 'a@localhost', 'thing4')
+ assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing4')
+
+ sum.add_subscription('prod', 'a@localhost', 'thing5', 60)
+ subs = sum.get_subscriptions_for_jid('prod', 'a@localhost')
+ assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing5')
+ assert subs and 'thing5' in subs and subs['thing5']['data'] == 60
+
+def test_append_subscription_data():
+ sum = get_sum()
+
+ sum.add_subscription('test', 'c@localhost', 'thing1', [])
+ sum.append_subscription_data('test', 'c@localhost', 'thing1', 1)
+ subs = sum.get_subscriptions_for_jid('test', 'c@localhost')
+ assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
+ assert subs and 'thing1' in subs and subs['thing1']['data'] == [1]
+
+ sum.append_subscription_data('test', 'c@localhost', 'thing1', 5)
+ assert subs['thing1']['data'] == [1, 5]
+
+def test_remove_subscription():
+ sum = get_sum()
+
+ sum.remove_subscription('test', 'd@localhost', 'thing1')
+ assert not sum.is_subscribed_to_keyword('test', 'd@localhost', 'thing1')
+ assert not sum.get_subscriptions_for_jid('test', 'd@localhost')
+
+def test_filter_subscription_data_items():
+ sum = get_sum()
+
+ func = lambda x: not x % 2 == 0
+
+ sum.filter_subscription_data_items('test', 'b@localhost', 'thing3', func)
+ assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
+ assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
+ assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)
+
+def test_remove_subscription_data_item():
+ sum = get_sum()
+
+ sum.remove_subscription_data_item('test', 'b@localhost', 'thing3', 4)
+ assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
+ assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
+ assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)