fix: Rework SubscriptionManager
- Give functions better names - Change how these functions behave - Add tests (!) for the SubscriptionManager - Format using black
This commit is contained in:
parent
fad4541132
commit
34d001b5bc
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,4 +13,4 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
|
100
mira/base.py
100
mira/base.py
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,7 +13,7 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
import importlib
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -26,58 +26,61 @@ import toml
|
|||||||
from mira.subscription import SubscriptionManager
|
from mira.subscription import SubscriptionManager
|
||||||
from mira.storage import StorageManager
|
from mira.storage import StorageManager
|
||||||
|
|
||||||
logger = logging.getLogger('mira.base')
|
logger = logging.getLogger("mira.base")
|
||||||
|
|
||||||
|
|
||||||
def message_wrapper(to, body):
|
def message_wrapper(to, body):
|
||||||
msg = aioxmpp.Message(
|
msg = aioxmpp.Message(type_=aioxmpp.MessageType.CHAT, to=to)
|
||||||
type_=aioxmpp.MessageType.CHAT,
|
|
||||||
to=to)
|
|
||||||
msg.body[None] = body
|
msg.body[None] = body
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
class MiraBot:
|
class MiraBot:
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path):
|
||||||
# Bot specific settings
|
# Bot specific settings
|
||||||
self._config = toml.load(config_path)
|
self._config = toml.load(config_path)
|
||||||
|
|
||||||
self._jid = aioxmpp.JID.fromstr(self._config['jid'])
|
self._jid = aioxmpp.JID.fromstr(self._config["jid"])
|
||||||
self._password = self._config['password']
|
self._password = self._config["password"]
|
||||||
self._avatar = self._config.get('avatar', None)
|
self._avatar = self._config.get("avatar", None)
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
self._modules = {} # Module name -> module
|
self._modules = {} # Module name -> module
|
||||||
self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json'))
|
self._storage_manager = StorageManager.get_instance(
|
||||||
|
self._config.get("storage_path", "/etc/mira/storage.json")
|
||||||
|
)
|
||||||
self._subscription_manager = SubscriptionManager.get_instance()
|
self._subscription_manager = SubscriptionManager.get_instance()
|
||||||
|
|
||||||
def _initialise_modules(self):
|
def _initialise_modules(self):
|
||||||
for module in self._config['modules']:
|
for module in self._config["modules"]:
|
||||||
logger.debug("Initialising module %s" % (module['name']))
|
logger.debug("Initialising module %s" % (module["name"]))
|
||||||
mod = importlib.import_module(module['name'])
|
mod = importlib.import_module(module["name"])
|
||||||
self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
|
self._modules[mod.NAME] = mod.get_instance(
|
||||||
|
self, config=module, name=mod.NAME
|
||||||
|
)
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
self._client = aioxmpp.PresenceManagedClient(
|
self._client = aioxmpp.PresenceManagedClient(
|
||||||
self._jid,
|
self._jid, aioxmpp.make_security_layer(self._password)
|
||||||
aioxmpp.make_security_layer(self._password))
|
)
|
||||||
async with self._client.connected():
|
async with self._client.connected():
|
||||||
logger.info('Client connected')
|
logger.info("Client connected")
|
||||||
self._client.stream.register_message_callback(
|
self._client.stream.register_message_callback(
|
||||||
aioxmpp.MessageType.CHAT,
|
aioxmpp.MessageType.CHAT, None, self._on_message
|
||||||
None,
|
)
|
||||||
self._on_message)
|
|
||||||
|
|
||||||
if self._avatar:
|
if self._avatar:
|
||||||
logger.info('Publishing avatar')
|
logger.info("Publishing avatar")
|
||||||
with open(self._avatar, 'rb') as avatar_file:
|
with open(self._avatar, "rb") as avatar_file:
|
||||||
data = avatar_file.read()
|
data = avatar_file.read()
|
||||||
avatar_set = aioxmpp.avatar.AvatarSet()
|
avatar_set = aioxmpp.avatar.AvatarSet()
|
||||||
# TODO: Detect MIME type
|
# TODO: Detect MIME type
|
||||||
avatar_set.add_avatar_image('image/png', image_bytes=data)
|
avatar_set.add_avatar_image("image/png", image_bytes=data)
|
||||||
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
|
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
|
||||||
await avatar.publish_avatar_set(avatar_set)
|
await avatar.publish_avatar_set(avatar_set)
|
||||||
logger.info('Avatar published')
|
logger.info("Avatar published")
|
||||||
|
|
||||||
logger.debug('Initialising modules')
|
logger.debug("Initialising modules")
|
||||||
self._initialise_modules()
|
self._initialise_modules()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -89,32 +92,38 @@ class MiraBot:
|
|||||||
def _on_message(self, message):
|
def _on_message(self, message):
|
||||||
# Automatically handles sending a message receipt and dealing
|
# Automatically handles sending a message receipt and dealing
|
||||||
# with unwanted messages
|
# with unwanted messages
|
||||||
if (message.type_ != aioxmpp.MessageType.CHAT or
|
if message.type_ != aioxmpp.MessageType.CHAT or not message.body:
|
||||||
not message.body):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd = str(message.body.any()).split(' ')
|
cmd = str(message.body.any()).split(" ")
|
||||||
|
|
||||||
receipt = aioxmpp.mdr.compose_receipt(message)
|
receipt = aioxmpp.mdr.compose_receipt(message)
|
||||||
self._client.enqueue(receipt)
|
self._client.enqueue(receipt)
|
||||||
|
|
||||||
if not cmd[0] in self._modules:
|
if not cmd[0] in self._modules:
|
||||||
logger.debug('Received command for unknown module. Dropping')
|
logger.debug("Received command for unknown module. Dropping")
|
||||||
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
|
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Just drop messages that are not local when the module should
|
# Just drop messages that are not local when the module should
|
||||||
# be local only
|
# be local only
|
||||||
if self._modules[cmd[0]]._restricted:
|
if self._modules[cmd[0]]._restricted:
|
||||||
if self._modules[cmd[0]]._config['restrict_local']:
|
if self._modules[cmd[0]]._config["restrict_local"]:
|
||||||
if not self._is_sender_local(message.from_):
|
if not self._is_sender_local(message.from_):
|
||||||
logger.warning('Received a command from a non-local user to a'
|
logger.warning(
|
||||||
' module that is restricted to local users only')
|
"Received a command from a non-local user to a"
|
||||||
|
" module that is restricted to local users only"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
elif self._modules[cmd[0]]._config['allowed_domains']:
|
elif self._modules[cmd[0]]._config["allowed_domains"]:
|
||||||
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
|
if (
|
||||||
logger.warning('Received a command from a non-whitelisted user to a'
|
not message.from_.domain
|
||||||
' module that is restricted to whitelisted users only')
|
in self._modules[cmd[0]]._config["allowed_domains"]
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Received a command from a non-whitelisted user to a"
|
||||||
|
" module that is restricted to whitelisted users only"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._modules[cmd[0]]._base_on_command(cmd[1:], message)
|
self._modules[cmd[0]]._base_on_command(cmd[1:], message)
|
||||||
@ -127,18 +136,25 @@ class MiraBot:
|
|||||||
def send_message_wrapper(self, to, body):
|
def send_message_wrapper(self, to, body):
|
||||||
self.send_message(message_wrapper(to, body))
|
self.send_message(message_wrapper(to, body))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = OptionParser()
|
parser = OptionParser()
|
||||||
parser.add_option('-d', '--debug', dest='debug',
|
parser.add_option(
|
||||||
help='Enable debug logging', action='store_true')
|
"-d", "--debug", dest="debug", help="Enable debug logging", action="store_true"
|
||||||
parser.add_option('-c', '--config', dest='config', help='Location of the config.toml',
|
)
|
||||||
default='/etc/mira/config.toml')
|
parser.add_option(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
|
dest="config",
|
||||||
|
help="Location of the config.toml",
|
||||||
|
default="/etc/mira/config.toml",
|
||||||
|
)
|
||||||
(options, args) = parser.parse_args()
|
(options, args) = parser.parse_args()
|
||||||
|
|
||||||
verbosity = logging.DEBUG if options.debug else logging.INFO
|
verbosity = logging.DEBUG if options.debug else logging.INFO
|
||||||
logging.basicConfig(stream=sys.stdout, level=verbosity)
|
logging.basicConfig(stream=sys.stdout, level=verbosity)
|
||||||
|
|
||||||
logging.info('Loading config from %s' % (options.config))
|
logging.info("Loading config from %s" % (options.config))
|
||||||
bot = MiraBot(options.config)
|
bot = MiraBot(options.config)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,7 +13,7 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
@ -21,15 +21,17 @@ import logging
|
|||||||
from mira.storage import StorageManager
|
from mira.storage import StorageManager
|
||||||
from mira.subscription import SubscriptionManager
|
from mira.subscription import SubscriptionManager
|
||||||
|
|
||||||
logger = logging.getLogger('mira.module')
|
logger = logging.getLogger("mira.module")
|
||||||
|
|
||||||
|
|
||||||
class ManagerWrapper:
|
class ManagerWrapper:
|
||||||
'''
|
"""
|
||||||
Wrapper class around {Storage, Subscription}Manager in order
|
Wrapper class around {Storage, Subscription}Manager in order
|
||||||
to to expose those directly to the modules without allowing them
|
to to expose those directly to the modules without allowing them
|
||||||
access to other modules.
|
access to other modules.
|
||||||
'''
|
"""
|
||||||
_name = ''
|
|
||||||
|
_name = ""
|
||||||
_manager = None
|
_manager = None
|
||||||
|
|
||||||
def __init__(self, name, manager):
|
def __init__(self, name, manager):
|
||||||
@ -38,40 +40,44 @@ class ManagerWrapper:
|
|||||||
|
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
if not key in dir(self._manager):
|
if not key in dir(self._manager):
|
||||||
raise AttributeError("Attribute %s does not exist in wrapped"
|
raise AttributeError(
|
||||||
" class %s" % (key, type(self._manager)))
|
"Attribute %s does not exist in wrapped"
|
||||||
|
" class %s" % (key, type(self._manager))
|
||||||
|
)
|
||||||
|
|
||||||
return partial(getattr(self._manager, key), self._name)
|
return partial(getattr(self._manager, key), self._name)
|
||||||
|
|
||||||
|
|
||||||
class BaseModule:
|
class BaseModule:
|
||||||
def __init__(self, base, config={}, subcommand_table={}, name=''):
|
def __init__(self, base, config={}, subcommand_table={}, name=""):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._base = base
|
self._base = base
|
||||||
self._config = config
|
self._config = config
|
||||||
self._subcommand_table = subcommand_table
|
self._subcommand_table = subcommand_table
|
||||||
self._local_only = self.get_option('restrict_local', False)
|
self._local_only = self.get_option("restrict_local", False)
|
||||||
self._restricted = ('restrict_local' in self._config or
|
self._restricted = (
|
||||||
'allowed_domains' in self._config)
|
"restrict_local" in self._config or "allowed_domains" in self._config
|
||||||
|
)
|
||||||
|
|
||||||
self._stm = ManagerWrapper(self._name, StorageManager)
|
self._stm = ManagerWrapper(self._name, StorageManager)
|
||||||
self._sum = ManagerWrapper(self._name, SubscriptionManager)
|
self._sum = ManagerWrapper(self._name, SubscriptionManager)
|
||||||
|
|
||||||
logger.debug('Init of %s done' % (self._name))
|
logger.debug("Init of %s done" % (self._name))
|
||||||
|
|
||||||
def get_option(self, key, default=None):
|
def get_option(self, key, default=None):
|
||||||
'''
|
"""
|
||||||
Like dict.get(), but for the options from the bot configuration
|
Like dict.get(), but for the options from the bot configuration
|
||||||
file.
|
file.
|
||||||
|
|
||||||
If key does not exist, then default will be returned.
|
If key does not exist, then default will be returned.
|
||||||
'''
|
"""
|
||||||
return self._config.get(key, default)
|
return self._config.get(key, default)
|
||||||
|
|
||||||
def send_message(self, to, body):
|
def send_message(self, to, body):
|
||||||
'''
|
"""
|
||||||
A simple wrapper that sends a message with type='chat' to
|
A simple wrapper that sends a message with type='chat' to
|
||||||
@to with @body as the body
|
@to with @body as the body
|
||||||
'''
|
"""
|
||||||
self._base.send_message_wrapper(to, body)
|
self._base.send_message_wrapper(to, body)
|
||||||
|
|
||||||
def _base_on_command(self, cmd, msg):
|
def _base_on_command(self, cmd, msg):
|
||||||
@ -79,12 +85,12 @@ class BaseModule:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.create_task(func(cmd, msg))
|
loop.create_task(func(cmd, msg))
|
||||||
|
|
||||||
logger.debug('Received command: %s' % (str(cmd)))
|
logger.debug("Received command: %s" % (str(cmd)))
|
||||||
|
|
||||||
if not self._subcommand_table:
|
if not self._subcommand_table:
|
||||||
run(self.on_command)
|
run(self.on_command)
|
||||||
elif cmd and cmd[0] in self._subcommand_table:
|
elif cmd and cmd[0] in self._subcommand_table:
|
||||||
run(self._subcommand_table[cmd[0]])
|
run(self._subcommand_table[cmd[0]])
|
||||||
else:
|
else:
|
||||||
if '*' in self._subcommand_table:
|
if "*" in self._subcommand_table:
|
||||||
run(self._subcommand_table['*'])
|
run(self._subcommand_table["*"])
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,4 +13,4 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,10 +13,11 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
from mira.module import BaseModule
|
from mira.module import BaseModule
|
||||||
|
|
||||||
NAME = 'test'
|
NAME = "test"
|
||||||
|
|
||||||
|
|
||||||
class TestModule(BaseModule):
|
class TestModule(BaseModule):
|
||||||
__instance = None
|
__instance = None
|
||||||
@ -30,14 +31,17 @@ class TestModule(BaseModule):
|
|||||||
|
|
||||||
def __init__(self, base, **kwargs):
|
def __init__(self, base, **kwargs):
|
||||||
if TestModule.__instance != None:
|
if TestModule.__instance != None:
|
||||||
raise Exception('Trying to init singleton twice')
|
raise Exception("Trying to init singleton twice")
|
||||||
|
|
||||||
super().__init__(base, **kwargs)
|
super().__init__(base, **kwargs)
|
||||||
TestModule.__instance = self
|
TestModule.__instance = self
|
||||||
|
|
||||||
async def on_command(self, cmd, msg):
|
async def on_command(self, cmd, msg):
|
||||||
greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare()))
|
greeting = self.get_option("greeting", "OwO, %%user%%!").replace(
|
||||||
|
"%%user%%", str(msg.from_.bare())
|
||||||
|
)
|
||||||
self.send_message(msg.from_, greeting)
|
self.send_message(msg.from_, greeting)
|
||||||
|
|
||||||
|
|
||||||
def get_instance(base, **kwargs):
|
def get_instance(base, **kwargs):
|
||||||
return TestModule.get_instance(base, **kwargs)
|
return TestModule.get_instance(base, **kwargs)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,63 +13,64 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger('mira.storage.StorageManager')
|
logger = logging.getLogger("mira.storage.StorageManager")
|
||||||
|
|
||||||
|
|
||||||
class StorageManager:
|
class StorageManager:
|
||||||
__instance = None
|
__instance = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_instance(file_location='/etc/mira/storage.json'):
|
def get_instance(file_location="/etc/mira/storage.json"):
|
||||||
if not StorageManager.__instance:
|
if not StorageManager.__instance:
|
||||||
StorageManager(file_location=file_location)
|
StorageManager(file_location=file_location)
|
||||||
return StorageManager.__instance
|
return StorageManager.__instance
|
||||||
|
|
||||||
def __init__(self, file_location):
|
def __init__(self, file_location):
|
||||||
if StorageManager.__instance:
|
if StorageManager.__instance:
|
||||||
raise Exception('Trying to instanciate StorageManger twice')
|
raise Exception("Trying to instanciate StorageManger twice")
|
||||||
|
|
||||||
self._data = {} # Module -> Section -> Data
|
self._data = {} # Module -> Section -> Data
|
||||||
self._file_location = file_location
|
self._file_location = file_location
|
||||||
|
|
||||||
logger.debug('Loading data from %s' % (file_location))
|
logger.debug("Loading data from %s" % (file_location))
|
||||||
if os.path.exists(file_location):
|
if os.path.exists(file_location):
|
||||||
with open(file_location, 'r') as f:
|
with open(file_location, "r") as f:
|
||||||
self._data = json.loads(f.read())
|
self._data = json.loads(f.read())
|
||||||
|
|
||||||
StorageManager.__instance = self
|
StorageManager.__instance = self
|
||||||
|
|
||||||
def get_data(self, module, section):
|
def get_data(self, module, section):
|
||||||
'''
|
"""
|
||||||
Get the data stored for module @module under the section
|
Get the data stored for module @module under the section
|
||||||
@section. Returns {} if there is no data stored for @module.
|
@section. Returns {} if there is no data stored for @module.
|
||||||
'''
|
"""
|
||||||
if not module in self._data:
|
if not module in self._data:
|
||||||
logging.debug('get_data: module unknown in self._data')
|
logging.debug("get_data: module unknown in self._data")
|
||||||
logging.debug('module: "%s"' % (module))
|
logging.debug('module: "%s"' % (module))
|
||||||
return {}
|
return {}
|
||||||
if not section in self._data[module]:
|
if not section in self._data[module]:
|
||||||
logging.debug('get_data: section unknown in self._data[module]')
|
logging.debug("get_data: section unknown in self._data[module]")
|
||||||
logging.debug('module: "%s", section: "%s"' % (module, section))
|
logging.debug('module: "%s", section: "%s"' % (module, section))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return self._data[module][section]
|
return self._data[module][section]
|
||||||
|
|
||||||
def set_data(self, module, section, data):
|
def set_data(self, module, section, data):
|
||||||
'''
|
"""
|
||||||
Stores the data @data for @module under section @section.
|
Stores the data @data for @module under section @section.
|
||||||
Flushes the data to storage afterwards.
|
Flushes the data to storage afterwards.
|
||||||
'''
|
"""
|
||||||
if not module in self._data:
|
if not module in self._data:
|
||||||
self._data[module] = {}
|
self._data[module] = {}
|
||||||
self._data[module][section] = data
|
self._data[module][section] = data
|
||||||
self.__flush()
|
self.__flush()
|
||||||
|
|
||||||
def __flush(self):
|
def __flush(self):
|
||||||
logger.debug('Flushing to storage')
|
logger.debug("Flushing to storage")
|
||||||
with open(self._file_location, 'w') as f:
|
with open(self._file_location, "w") as f:
|
||||||
f.write(json.dumps(self._data))
|
f.write(json.dumps(self._data))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
@ -13,21 +13,23 @@ GNU General Public License for more details.
|
|||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
"""
|
||||||
# TODO: Replace most of these with a query API
|
# TODO: Replace most of these with a query API
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mira.storage import StorageManager
|
from mira.storage import StorageManager
|
||||||
|
|
||||||
|
|
||||||
def append_or_insert(dict_, key, value):
|
def append_or_insert(dict_, key, value):
|
||||||
if key in dict_:
|
if key in dict_:
|
||||||
dict_[key].append(value)
|
dict_[key].append(value)
|
||||||
else:
|
else:
|
||||||
dict_[key] = [value]
|
dict_[key] = [value]
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionManager:
|
class SubscriptionManager:
|
||||||
'''
|
"""
|
||||||
This class is tasked with providing functions that simplify dealing
|
This class is tasked with providing functions that simplify dealing
|
||||||
with subscriptions.
|
with subscriptions.
|
||||||
|
|
||||||
@ -35,7 +37,8 @@ class SubscriptionManager:
|
|||||||
has been instanciated at least once. For modules, this is no
|
has been instanciated at least once. For modules, this is no
|
||||||
issue as they're only created after the manager classes are
|
issue as they're only created after the manager classes are
|
||||||
ready.
|
ready.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
__instance = None
|
__instance = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -44,18 +47,25 @@ class SubscriptionManager:
|
|||||||
SubscriptionManager()
|
SubscriptionManager()
|
||||||
return SubscriptionManager.__instance
|
return SubscriptionManager.__instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, subscriptions={}, sm=None):
|
||||||
self._sm = StorageManager.get_instance()
|
|
||||||
# Module -> JID -> Keywords
|
# Module -> JID -> Keywords
|
||||||
self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions')
|
if subscriptions:
|
||||||
|
# NOTE: This is just for testing
|
||||||
|
self._subscriptions = subscriptions
|
||||||
|
self._sm = sm
|
||||||
|
else:
|
||||||
|
self._sm = StorageManager.get_instance()
|
||||||
|
self._subscriptions = self._sm.get_data(
|
||||||
|
"_SubscriptionManager", "subscriptions"
|
||||||
|
)
|
||||||
|
|
||||||
SubscriptionManager.__instance = self
|
SubscriptionManager.__instance = self
|
||||||
|
|
||||||
def get_subscriptions_for(self, module, jid):
|
def get_subscriptions_for_jid(self, module, jid):
|
||||||
'''
|
"""
|
||||||
Returns a dictionary keyword -> data which represents
|
Returns a dictionary keyword -> data which represents
|
||||||
every subscription a jid has in the context of @module.
|
every subscription a jid has in the context of @module.
|
||||||
'''
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
return []
|
return []
|
||||||
if not jid in self._subscriptions[module]:
|
if not jid in self._subscriptions[module]:
|
||||||
@ -64,41 +74,47 @@ class SubscriptionManager:
|
|||||||
return self._subscriptions[module][jid]
|
return self._subscriptions[module][jid]
|
||||||
|
|
||||||
def get_subscriptions_for_keyword(self, module, keyword):
|
def get_subscriptions_for_keyword(self, module, keyword):
|
||||||
'''
|
"""
|
||||||
Returns an array of JIDs that are subscribed to the keyword of module
|
Returns a dictionary JID -> Data for JIDs that hava a subscription to
|
||||||
'''
|
the keyword @keyword in the context of module.
|
||||||
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
tmp = []
|
tmp = {}
|
||||||
for jid in self._subscriptions[module]:
|
for jid in self._subscriptions[module]:
|
||||||
if not keyword in self._subscriptions[module][jid]:
|
if not keyword in self._subscriptions[module][jid]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = self._subscriptions[module][jid][keyword]['data']
|
data = self._subscriptions[module][jid][keyword]["data"]
|
||||||
tmp.append((jid, data))
|
tmp[jid] = data
|
||||||
|
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def get_subscriptions_for_keywords(self, module, keywords):
|
def get_subscriptions_for_keywords(self, module, keywords):
|
||||||
'''
|
"""
|
||||||
Returns an array of JIDs that are subscribed to at least one of the keywords
|
Returns a dictionary of form JID -> keyword -> data of JIDs that are
|
||||||
of module
|
subscribed to at least one of the keywords in @keywords within the context
|
||||||
'''
|
of @module.
|
||||||
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
tmp = []
|
tmp = {}
|
||||||
keyword_set = set(keywords)
|
keyword_set = set(keywords)
|
||||||
for jid in self._subscriptions[module]:
|
for jid in self._subscriptions[module]:
|
||||||
if set(self._subscriptions[module][jid].keys()) & keyword_set:
|
union = set(self._subscriptions[module][jid].keys()) & keyword_set
|
||||||
|
if not union:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = self._subscriptions[module][jid][keyword]['data']
|
if not jid in tmp:
|
||||||
tmp.append((jid, data))
|
tmp[jid] = {}
|
||||||
|
for keyword in union:
|
||||||
|
tmp[jid][keyword] = self._subscriptions[module][jid][keyword]["data"]
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def get_subscription_keywords(self, module):
|
def get_subscription_keywords(self, module):
|
||||||
'''Returns a list of subscribed keywords in module'''
|
"""Returns a list of subscribed keywords in module"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -107,59 +123,57 @@ class SubscriptionManager:
|
|||||||
tmp += list(subscription.keys())
|
tmp += list(subscription.keys())
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def is_subscribed_to(self, module, jid, keyword):
|
def is_subscribed_to_keyword(self, module, jid, keyword):
|
||||||
'''
|
"""
|
||||||
Returns True if @jid is subscribed to @keyword within the context
|
Returns True if @jid is subscribed to @keyword within the context
|
||||||
of @module. False otherwise
|
of @module. False otherwise
|
||||||
'''
|
"""
|
||||||
return keyword in self.get_subscriptions_for(module, jid)
|
return keyword in self.get_subscriptions_for_jid(module, jid)
|
||||||
|
|
||||||
def is_subscribed_to_data(self, module, jid, keyword, item):
|
def is_subscribed_to_data(self, module, jid, keyword, item):
|
||||||
'''
|
"""
|
||||||
Returns True if @jid is subscribed to the item @item inside
|
Returns True if @jid is subscribed to the item @item inside
|
||||||
the keyword @keyword within the context of @module
|
the keyword @keyword within the context of @module
|
||||||
'''
|
"""
|
||||||
subscriptions = self.get_subscriptions_for(module, jid)
|
subscriptions = self.get_subscriptions_for_jid(module, jid)
|
||||||
if not subscriptions:
|
if not subscriptions:
|
||||||
return False
|
return False
|
||||||
if not keyword in subscriptions:
|
if not keyword in subscriptions:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return item in subscriptions[keyword]['data']
|
return item in subscriptions[keyword]["data"]
|
||||||
|
|
||||||
def is_subscribed_to_data_one(self, module, jid, keyword, func):
|
def is_subscribed_to_data_func(self, module, jid, keyword, func):
|
||||||
'''
|
"""
|
||||||
Like is_subscribed_to_data, but returns True if there is at
|
Like is_subscribed_to_data, but returns True if there is at
|
||||||
least one item for which func returns True.
|
least one item for which func returns True.
|
||||||
'''
|
"""
|
||||||
subscriptions = self.get_subscriptions_for(module, jid)
|
subscriptions = self.get_subscriptions_for_jid(module, jid)
|
||||||
if not subscriptions:
|
if not subscriptions:
|
||||||
return False
|
return False
|
||||||
if not keyword in subscriptions:
|
if not keyword in subscriptions:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for item in subscriptions[keyword]['data']:
|
for item in subscriptions[keyword]["data"]:
|
||||||
if func(item):
|
if func(item):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_subscription_for(self, module, jid, keyword, data={}):
|
def add_subscription(self, module, jid, keyword, data={}):
|
||||||
'''
|
"""
|
||||||
Adds a subscription to @keyword with data @data for @jid within
|
Adds a subscription to @keyword with data @data for @jid within
|
||||||
the context of @module.
|
the context of @module.
|
||||||
'''
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
self._subscriptions[module] = {}
|
self._subscriptions[module] = {}
|
||||||
if not jid in self._subscriptions[module]:
|
if not jid in self._subscriptions[module]:
|
||||||
self._subscriptions[module][jid] = {}
|
self._subscriptions[module][jid] = {}
|
||||||
|
|
||||||
self._subscriptions[module][jid][keyword] = {
|
self._subscriptions[module][jid][keyword] = {"data": data}
|
||||||
'data': data
|
|
||||||
}
|
|
||||||
self.__flush()
|
self.__flush()
|
||||||
|
|
||||||
def append_data_for_subscription(self, module, jid, keyword, item):
|
def append_subscription_data(self, module, jid, keyword, item):
|
||||||
'''
|
"""
|
||||||
Special helper function which appends item to the data field of
|
Special helper function which appends item to the data field of
|
||||||
a subscription to @keyword from @jid within the context of @module.
|
a subscription to @keyword from @jid within the context of @module.
|
||||||
|
|
||||||
@ -169,26 +183,24 @@ class SubscriptionManager:
|
|||||||
must be an array, so it will fail if add_subscription_for has
|
must be an array, so it will fail if add_subscription_for has
|
||||||
been called beforehand with data equal to anything but an dict
|
been called beforehand with data equal to anything but an dict
|
||||||
with a 'data' key containing an array.
|
with a 'data' key containing an array.
|
||||||
'''
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
self._subscriptions[module] = {}
|
self._subscriptions[module] = {}
|
||||||
if not jid in self._subscriptions[module]:
|
if not jid in self._subscriptions[module]:
|
||||||
self._subscriptions[module][jid] = {}
|
self._subscriptions[module][jid] = {}
|
||||||
if not keyword in self._subscriptions[module][jid]:
|
if not keyword in self._subscriptions[module][jid]:
|
||||||
self._subscriptions[module][jid][keyword] = {
|
self._subscriptions[module][jid][keyword] = {"data": [item]}
|
||||||
'data': [item]
|
|
||||||
}
|
|
||||||
self.__flush()
|
self.__flush()
|
||||||
return
|
return
|
||||||
|
|
||||||
self._subscriptions[module][jid][keyword]['data'].append(item)
|
self._subscriptions[module][jid][keyword]["data"].append(item)
|
||||||
self.__flush()
|
self.__flush()
|
||||||
|
|
||||||
def remove_subscription_for(self, module, jid, keyword):
|
def remove_subscription(self, module, jid, keyword):
|
||||||
'''
|
"""
|
||||||
Removes a subscription to @keyword for @jid within the context
|
Removes a subscription to @keyword for @jid within the context
|
||||||
of @module
|
of @module
|
||||||
'''
|
"""
|
||||||
del self._subscriptions[module][jid][keyword]
|
del self._subscriptions[module][jid][keyword]
|
||||||
|
|
||||||
if not self._subscriptions[module][jid]:
|
if not self._subscriptions[module][jid]:
|
||||||
@ -198,20 +210,18 @@ class SubscriptionManager:
|
|||||||
|
|
||||||
self.__flush()
|
self.__flush()
|
||||||
|
|
||||||
def remove_item_for_subscription(self, module, jid, keyword, item, flush=True):
|
def remove_subscription_data_item(self, module, jid, keyword, item, flush=True):
|
||||||
'''
|
"""
|
||||||
The deletion counterpart of append_data_for_subscription.
|
The deletion counterpart of append_data_for_subscription.
|
||||||
'''
|
"""
|
||||||
self.filter_items_for_subscription(module,
|
self.filter_subscription_data_items(
|
||||||
jid,
|
module, jid, keyword, func=lambda x: x != item, flush=flush
|
||||||
keyword,
|
)
|
||||||
func=lambda x: x == item,
|
|
||||||
flush=flush)
|
|
||||||
|
|
||||||
def filter_items_for_subscription(self, module, jid, keyword, func, flush=True):
|
def filter_subscription_data_items(self, module, jid, keyword, func, flush=True):
|
||||||
'''
|
"""
|
||||||
remove_item_for_subscription but for multiple items
|
remove_item_for_subscription but for multiple items
|
||||||
'''
|
"""
|
||||||
if not module in self._subscriptions:
|
if not module in self._subscriptions:
|
||||||
return
|
return
|
||||||
if not jid in self._subscriptions[module]:
|
if not jid in self._subscriptions[module]:
|
||||||
@ -219,10 +229,11 @@ class SubscriptionManager:
|
|||||||
if not keyword in self._subscriptions[module][jid]:
|
if not keyword in self._subscriptions[module][jid]:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._subscriptions[module][jid][keyword]['data'] = list(filter(func,
|
self._subscriptions[module][jid][keyword]["data"] = list(
|
||||||
self._subscriptions[module][jid][keyword]['data']))
|
filter(func, self._subscriptions[module][jid][keyword]["data"])
|
||||||
|
)
|
||||||
|
|
||||||
if not self._subscriptions[module][jid][keyword]['data']:
|
if not self._subscriptions[module][jid][keyword]["data"]:
|
||||||
del self._subscriptions[module][jid][keyword]
|
del self._subscriptions[module][jid][keyword]
|
||||||
if not self._subscriptions[module][jid]:
|
if not self._subscriptions[module][jid]:
|
||||||
del self._subscriptions[module][jid]
|
del self._subscriptions[module][jid]
|
||||||
@ -233,7 +244,7 @@ class SubscriptionManager:
|
|||||||
self.__flush()
|
self.__flush()
|
||||||
|
|
||||||
def __flush(self):
|
def __flush(self):
|
||||||
'''
|
"""
|
||||||
Write subscription data to disk. Just an interface to StorageManager
|
Write subscription data to disk. Just an interface to StorageManager
|
||||||
'''
|
"""
|
||||||
self._sm.set_data('_SubscriptionManager', 'subscriptions', self._subscriptions)
|
self._sm.set_data("_SubscriptionManager", "subscriptions", self._subscriptions)
|
||||||
|
31
setup.py
31
setup.py
@ -1,22 +1,31 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = 'mira',
|
name = "mira",
|
||||||
version = '0.2.0',
|
version = "0.3.0",
|
||||||
description = 'A command-base XMPP bot framework',
|
description = "A command-base XMPP bot framework",
|
||||||
url = 'https://git.polynom.me/PapaTutuWawa/mira',
|
url = "https://git.polynom.me/PapaTutuWawa/mira",
|
||||||
author = 'Alexander "PapaTutuWawa"',
|
author = "Alexander \"PapaTutuWawa\"",
|
||||||
author_email = 'papatutuwawa <at> polynom.me',
|
author_email = "papatutuwawa <at> polynom.me",
|
||||||
license = 'GPLv3',
|
license = "GPLv3",
|
||||||
packages = find_packages(),
|
packages = find_packages(),
|
||||||
install_requires = [
|
install_requires = [
|
||||||
'aioxmpp>=0.12.0',
|
"aioxmpp>=0.12.0",
|
||||||
'toml>=0.10.2'
|
"toml>=0.10.2"
|
||||||
|
],
|
||||||
|
extra_require = {
|
||||||
|
"dev": [
|
||||||
|
"pytest",
|
||||||
|
"black"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
tests_require = [
|
||||||
|
"pytest"
|
||||||
],
|
],
|
||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
"console_scripts": [
|
||||||
'mira = mira.base:main'
|
"mira = mira.base:main"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
156
tests/test_subscription.py
Normal file
156
tests/test_subscription.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from mira.subscription import SubscriptionManager
|
||||||
|
|
||||||
|
class MockStorageManager:
|
||||||
|
'''
|
||||||
|
The SubscriptionManager requieres the StorageManager, but we don't
|
||||||
|
need it for the tests. So just stub it.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def set_data(self, module, section, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_sum():
|
||||||
|
return SubscriptionManager({
|
||||||
|
'test': {
|
||||||
|
'a@localhost': {
|
||||||
|
'thing1': {
|
||||||
|
'data': 42
|
||||||
|
},
|
||||||
|
'thing2': {
|
||||||
|
'data': 100
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'b@localhost': {
|
||||||
|
'thing2': {
|
||||||
|
'data': 89
|
||||||
|
},
|
||||||
|
'thing3': {
|
||||||
|
'data': [1, 2, 4]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'd@localhost': {
|
||||||
|
'thing1': {
|
||||||
|
'data': {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, MockStorageManager())
|
||||||
|
|
||||||
|
def test_get_subscriptions_for_jid():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert sum.get_subscriptions_for_jid('prod', 'a@localhost') == []
|
||||||
|
assert sum.get_subscriptions_for_jid('test', 'z@localhost') == []
|
||||||
|
|
||||||
|
subs = sum.get_subscriptions_for_jid('test', 'a@localhost')
|
||||||
|
assert len(subs.keys()) == 2
|
||||||
|
assert 'thing1' in subs and subs['thing1']['data'] == 42
|
||||||
|
assert 'thing2' in subs and subs['thing2']['data'] == 100
|
||||||
|
|
||||||
|
def test_get_subscriptions_for_keyword():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert sum.get_subscriptions_for_keyword('prod', 'thing1') == {}
|
||||||
|
assert sum.get_subscriptions_for_keyword('test', 'thing4') == {}
|
||||||
|
|
||||||
|
subs = sum.get_subscriptions_for_keyword('test', 'thing2')
|
||||||
|
assert 'a@localhost' in subs and subs['a@localhost'] == 100
|
||||||
|
assert 'b@localhost' in subs and subs['b@localhost'] == 89
|
||||||
|
|
||||||
|
def test_get_subscriptions_for_keywords():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert sum.get_subscriptions_for_keywords('prod', 'thing1') == {}
|
||||||
|
assert sum.get_subscriptions_for_keywords('test', 'thing4') == {}
|
||||||
|
|
||||||
|
subs = sum.get_subscriptions_for_keywords('test', ['thing2', 'thing3'])
|
||||||
|
assert 'a@localhost' in subs and 'thing2'in subs['a@localhost'] and subs['a@localhost']['thing2'] == 100
|
||||||
|
assert not 'thing3' in subs['a@localhost']
|
||||||
|
assert 'b@localhost' in subs and 'thing2' in subs['b@localhost'] and subs['b@localhost']['thing2'] == 89
|
||||||
|
assert 'b@localhost' in subs and 'thing3' in subs['b@localhost'] and subs['b@localhost']['thing3'] == [1, 2, 4]
|
||||||
|
|
||||||
|
def test_get_subscription_keywords():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert sum.get_subscription_keywords('prod') == []
|
||||||
|
assert not set(sum.get_subscription_keywords('test')) - set(['thing1', 'thing2', 'thing3'])
|
||||||
|
|
||||||
|
def test_is_subscribed_to_keyword():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert not sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing1')
|
||||||
|
assert not sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
|
||||||
|
assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing1')
|
||||||
|
|
||||||
|
def test_is_subscribed_to_data():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
assert not sum.is_subscribed_to_data('prod', 'b@localhost', 'thing1', 1)
|
||||||
|
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing4', 1)
|
||||||
|
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 10)
|
||||||
|
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
|
||||||
|
|
||||||
|
def test_is_subscribed_to_data_func():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
func1 = lambda x: x % 2 == 0
|
||||||
|
func2 = lambda x: x == 10
|
||||||
|
|
||||||
|
assert not sum.is_subscribed_to_data_func('prod', 'b@localhost', 'thing1', func1)
|
||||||
|
assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing4', func1)
|
||||||
|
assert not sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func2)
|
||||||
|
assert sum.is_subscribed_to_data_func('test', 'b@localhost', 'thing3', func1)
|
||||||
|
|
||||||
|
def test_add_subscription():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
sum.add_subscription('test', 'c@localhost', 'thing1')
|
||||||
|
assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
|
||||||
|
|
||||||
|
sum.add_subscription('test', 'a@localhost', 'thing4')
|
||||||
|
assert sum.is_subscribed_to_keyword('test', 'a@localhost', 'thing4')
|
||||||
|
|
||||||
|
sum.add_subscription('prod', 'a@localhost', 'thing4')
|
||||||
|
assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing4')
|
||||||
|
|
||||||
|
sum.add_subscription('prod', 'a@localhost', 'thing5', 60)
|
||||||
|
subs = sum.get_subscriptions_for_jid('prod', 'a@localhost')
|
||||||
|
assert sum.is_subscribed_to_keyword('prod', 'a@localhost', 'thing5')
|
||||||
|
assert subs and 'thing5' in subs and subs['thing5']['data'] == 60
|
||||||
|
|
||||||
|
def test_append_subscription_data():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
sum.add_subscription('test', 'c@localhost', 'thing1', [])
|
||||||
|
sum.append_subscription_data('test', 'c@localhost', 'thing1', 1)
|
||||||
|
subs = sum.get_subscriptions_for_jid('test', 'c@localhost')
|
||||||
|
assert sum.is_subscribed_to_keyword('test', 'c@localhost', 'thing1')
|
||||||
|
assert subs and 'thing1' in subs and subs['thing1']['data'] == [1]
|
||||||
|
|
||||||
|
sum.append_subscription_data('test', 'c@localhost', 'thing1', 5)
|
||||||
|
assert subs['thing1']['data'] == [1, 5]
|
||||||
|
|
||||||
|
def test_remove_subscription():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
sum.remove_subscription('test', 'd@localhost', 'thing1')
|
||||||
|
assert not sum.is_subscribed_to_keyword('test', 'd@localhost', 'thing1')
|
||||||
|
assert not sum.get_subscriptions_for_jid('test', 'd@localhost')
|
||||||
|
|
||||||
|
def test_filter_subscription_data_items():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
func = lambda x: not x % 2 == 0
|
||||||
|
|
||||||
|
sum.filter_subscription_data_items('test', 'b@localhost', 'thing3', func)
|
||||||
|
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
|
||||||
|
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
|
||||||
|
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)
|
||||||
|
|
||||||
|
def test_remove_subscription_data_item():
|
||||||
|
sum = get_sum()
|
||||||
|
|
||||||
|
sum.remove_subscription_data_item('test', 'b@localhost', 'thing3', 4)
|
||||||
|
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 1)
|
||||||
|
assert sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 2)
|
||||||
|
assert not sum.is_subscribed_to_data('test', 'b@localhost', 'thing3', 4)
|
Loading…
Reference in New Issue
Block a user