diff --git a/mira/base.py b/mira/base.py index 9e19117..0c638b1 100644 --- a/mira/base.py +++ b/mira/base.py @@ -14,13 +14,19 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . ''' +import sys import importlib import asyncio +import logging +from optparse import OptionParser import aioxmpp import toml from mira.subscription import SubscriptionManager +from mira.storage import StorageManager + +logger = logging.getLogger('mira.base') def message_wrapper(to, body): msg = aioxmpp.Message( @@ -30,39 +36,38 @@ def message_wrapper(to, body): return msg class MiraBot: - def __init__(self): + def __init__(self, config_path): # Bot specific settings - self._jid = "" - self._password = "" + self._config = toml.load(config_path) + + self._jid = aioxmpp.JID.fromstr(self._config['jid']) + self._password = self._config['password'] + self._avatar = self._config.get('avatar', None) self._client = None - self._avatar = None self._modules = {} # Module name -> module - self._subscription_manager = SubscriptionManager() + self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', './storage.json')) + self._subscription_manager = SubscriptionManager.get_instance() - def load_config(self, path): - data = toml.load(path) - - self._jid = aioxmpp.JID.fromstr(data['jid']) - self._password = data['password'] - self._avatar = data.get('avatar', None) - - for module in data['modules']: + def _initialise_modules(self): + for module in self._config['modules']: + logger.debug("Initialising module %s" % (module)) mod = importlib.import_module(module['name']) - self._modules[mod.NAME] = mod.get_instance(self, module) - self._modules[mod.NAME].set_name(mod.NAME) + self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME) async def connect(self): self._client = aioxmpp.PresenceManagedClient( self._jid, aioxmpp.make_security_layer(self._password)) async with self._client.connected(): + logger.info('Client connected') self._client.stream.register_message_callback( aioxmpp.MessageType.CHAT, None, self._on_message) if self._avatar: + logger.info('Publishing avatar') with open(self._avatar, 'rb') as avatar_file: data = avatar_file.read() avatar_set = aioxmpp.avatar.AvatarSet() @@ -70,10 +75,17 @@ class MiraBot: avatar_set.add_avatar_image('image/png', image_bytes=data) avatar = self._client.summon(aioxmpp.avatar.AvatarService) await avatar.publish_avatar_set(avatar_set) + logger.info('Avatar published') + logger.debug('Initialising modules') + self._initialise_modules() + while True: await asyncio.sleep(1) + def _is_sender_local(self, from_): + return from_.domain == self._jid.domain + def _on_message(self, message): # Automatically handles sending a message receipt and dealing # with unwanted messages @@ -87,9 +99,24 @@ class MiraBot: self._client.enqueue(receipt) if not cmd[0] in self._modules: + logger.debug('Received command for unknown module. Dropping') self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl")) return + # Just drop messages that are not local when the module should + # be local only + if self._modules[cmd[0]]._restricted: + if self._modules[cmd[0]]._config['restrict_local']: + if not self._is_sender_local(message.from_): + logger.warning('Received a command from a non-local user to a' + ' module that is restricted to local users only') + return + elif self._modules[cmd[0]]._config['allowed_domains']: + if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']: + logger.warning('Received a command from a non-whitelisted user to a' + ' module that is restricted to whitelisted users only') + return + self._modules[cmd[0]]._on_command(cmd[1:], message) # Module Function: Send message @@ -99,14 +126,17 @@ class MiraBot: # Module Function: Send a message to @to with @body def send_message_wrapper(self, to, body): self.send_message(message_wrapper(to, body)) - - # Module Function - def get_subscription_manager(self): - return self._subscription_manager - + def main(): - bot = MiraBot() - bot.load_config("./config.toml") + parser = OptionParser() + parser.add_option('-d', '--debug', dest='debug', + help='Enable debug logging', action='store_true') + (options, args) = parser.parse_args() + + verbosity = logging.DEBUG if options.debug else logging.INFO + logging.basicConfig(stream=sys.stdout, level=verbosity) + + bot = MiraBot("./config.toml") loop = asyncio.get_event_loop() loop.run_until_complete(bot.connect()) diff --git a/mira/module.py b/mira/module.py index 174d480..1c00227 100644 --- a/mira/module.py +++ b/mira/module.py @@ -14,55 +14,70 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . ''' +import asyncio +from functools import partial + +from mira.storage import StorageManager +from mira.subscription import SubscriptionManager + +class ManagerWrapper: + ''' + Wrapper class around {Storage, Subscription}Manager in order + to to expose those directly to the modules without allowing them + access to other modules. + ''' + _name = '' + _manager = None + + def __init__(self, name, manager): + self._name = name + self._manager = manager.get_instance() + + def __getattr__(self, key): + if not key in dir(self._manager): + raise AttributeError("Attribute %s does not exist in wrapped" + " class %s" % (key, type(self._manager))) + + return partial(getattr(self._manager, key), self._name) class BaseModule: - def __init__(self, base, config={}, subcommand_table={}): - self._name = '' + def __init__(self, base, config={}, subcommand_table={}, name=''): + self._name = name self._base = base self._config = config self._subcommand_table = subcommand_table - - def set_name(self, name): - if self._name: - raise Exception('Name change of module attempted!') + self._local_only = self.get_option('restrict_local', False) + self._restricted = ('restrict_local' in self._config or + 'allowed_domains' in self._config) - self._name = name + self._stm = ManagerWrapper(self._name, StorageManager) + self._sum = ManagerWrapper(self._name, SubscriptionManager) def get_option(self, key, default=None): + ''' + Like dict.get(), but for the options from the bot configuration + file. + + If key does not exist, then default will be returned. + ''' return self._config.get(key, default) - # Used for access control - # Returns True if @jid is allowed to access the command. False otherwise. - # This is configured by either restrict_local or allowed_domains. restrict_local - # takes precedence over allowed_domains - def is_jid_allowed(self, jid): - only_local = self.get_option('restrict_local') - if only_local: - return only_local - - domains = self.get_option('allowed_domains') - if not domains: - return True - - return jid.domain in domains + def send_message(self, to, body): + ''' + A simple wrapper that sends a message with type='chat' to + @to with @body as the body + ''' + self._base.send_message_wrapper(to, body) - def get_subscriptions_for(self, jid): - return self._base.get_subscription_manager.get_subscriptions_for(self._name, jid) - - def add_subscription_for(self, jid, keyword): - return self._base.get_subscription_manager.add_subscription_for(self._name, jid, keyword) - - def remove_subscription_for(self, jid, keyword): - return self._base.get_subscription_manager.remove_subscription_for(self._name, jid, keyword) - - def is_subscribed_to(self, jid, keyword): - return self._base.get_subscription_manager.is_subscribed_to(self._name, jid, keyword) - def _on_command(self, cmd, msg): + def run(func): + loop = asyncio.get_event_loop() + loop.create_task(func(cmd, msg)) + if not self._subcommand_table: - self.on_command(cmd, msg) + run(self.on_command) elif cmd and cmd[0] in self._subcommand_table: - self._subcommand_table[cmd[0]](cmd[1:], msg) + run(self._subcommand_table[cmd[0]]) else: if '*' in self._subcommand_table: - self._subcommand_table['*'](cmd, msg) + run(self._subcommand_table['*']) diff --git a/mira/modules/test.py b/mira/modules/test.py index db43fac..88b9965 100644 --- a/mira/modules/test.py +++ b/mira/modules/test.py @@ -22,22 +22,22 @@ class TestModule(BaseModule): __instance = None @staticmethod - def get_instance(base, config): + def get_instance(base, **kwargs): if TestModule.__instance == None: - TestModule(base, config) + TestModule(base, **kwargs) return TestModule.__instance - def __init__(self, base, config): + def __init__(self, base, **kwargs): if TestModule.__instance != None: raise Exception('Trying to init singleton twice') - super().__init__(base, config) + super().__init__(base, **kwargs) TestModule.__instance = self - def on_command(self, cmd, msg): + async def on_command(self, cmd, msg): greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare())) self._base.send_message_wrapper(msg.from_, greeting) -def get_instance(base, config={}): - return TestModule.get_instance(base, config) +def get_instance(base, **kwargs): + return TestModule.get_instance(base, **kwargs) diff --git a/mira/storage.py b/mira/storage.py new file mode 100644 index 0000000..f435778 --- /dev/null +++ b/mira/storage.py @@ -0,0 +1,75 @@ +''' +Copyright (C) 2021 Alexander "PapaTutuWawa" + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +''' +import os +import json +import logging + +logger = logging.getLogger('mira.storage.StorageManager') + +class StorageManager: + __instance = None + + @staticmethod + def get_instance(file_location='/etc/mira/storage.json'): + if not StorageManager.__instance: + StorageManager(file_location=file_location) + return StorageManager.__instance + + def __init__(self, file_location): + if StorageManager.__instance: + raise Exception('Trying to instanciate StorageManger twice') + + self._data = {} # Module -> Section -> Data + self._file_location = file_location + + logger.debug('Loading data from %s' % (file_location)) + if os.path.exists(file_location): + with open(file_location, 'r') as f: + self._data = json.loads(f.read()) + + StorageManager.__instance = self + + def get_data(self, module, section): + ''' + Get the data stored for module @module under the section + @section. Returns {} if there is no data stored for @module. + ''' + if not module in self._data: + logging.debug('get_data: module unknown in self._data') + logging.debug('module: "%s"' % (module)) + return {} + if not section in self._data[module]: + logging.debug('get_data: section unknown in self._data[module]') + logging.debug('module: "%s", section: "%s"' % (module, section)) + return {} + + return self._data[module][section] + + def set_data(self, module, section, data): + ''' + Stores the data @data for @module under section @section. + Flushes the data to storage afterwards. + ''' + if not module in self._data: + self._data[module] = {} + self._data[module][section] = data + self.__flush() + + def __flush(self): + logger.debug('Flushing to storage') + with open(self._file_location, 'w') as f: + f.write(json.dumps(self._data)) diff --git a/mira/subscription.py b/mira/subscription.py index ae84763..80a14ef 100644 --- a/mira/subscription.py +++ b/mira/subscription.py @@ -1,5 +1,24 @@ -#from collections import namedtuple -# TODO: Allow storing data along with the keyword +''' +Copyright (C) 2021 Alexander "PapaTutuWawa" + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +''' +# TODO: Replace most of these with a query API +import os +import json + +from mira.storage import StorageManager def append_or_insert(dict_, key, value): if key in dict_ or dict_[key]: @@ -14,26 +33,150 @@ def remove_or_delete(dict_, key, value): dict_[key].remove(value) class SubscriptionManager: + ''' + This class is tasked with providing functions that simplify dealing + with subscriptions. + + NOTE: This class must only by instanciated *after* the StorageManager + has been instanciated at least once. For modules, this is no + issue as they're only created after the manager classes are + ready. + ''' + __instance = None + + @staticmethod + def get_instance(): + if not SubscriptionManager.__instance: + SubscriptionManager() + return SubscriptionManager.__instance + def __init__(self): - self._subscriptions = {} # Module -> JID -> Keywords + self._sm = StorageManager.get_instance() + # Module -> JID -> Keywords + self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions') + + SubscriptionManager.__instance = self def get_subscriptions_for(self, module, jid): if not module in self._subscriptions: - return None + return [] if not jid in self._subscriptions[module]: - return None + return [] return self._subscriptions[module][jid] + def get_subscriptions_for_keyword(self, module, keyword): + ''' + Returns an array of JIDs that are subscribed to the keyword of module + ''' + if not module in self._subscriptions: + return [] + + tmp = [] + for jid in self._subscriptions[module]: + if not keyword in self._subscriptions[module][jid]: + continue + + data = self._subscriptions[module][jid][keyword]['data'] + tmp.append((jid, data)) + return tmp + + def get_subscription_keywords(self, module): + '''Returns a list of subscribed keywords in module''' + if not module in self._subscriptions: + return [] + + tmp = [] + for subscription in self._subscriptions[module].values(): + tmp += list(subscription.keys()) + return tmp + def is_subscribed_to(self, module, jid, keyword): + ''' + Returns True if @jid is subscribed to @keyword within the context + of @module. False otherwise + ''' return keyword in self.get_subscriptions_for(module, jid) - def __flush(self): - # TODO - pass + def is_subscribed_to_data(self, module, jid, keyword, item): + ''' + Returns True if @jid is subscribed to the item @item inside + the keyword @keyword within the context of @module + ''' + subscriptions = self.get_subscriptions_for(module, jid) + if not subscriptions: + return False + if not keyword in subscriptions: + return False - def add_subscription_for(self, module, jid, keyword): - append_or_insert(self._subscriptions[module], jid, keyword) + return item in subscriptions[keyword] + + def add_subscription_for(self, module, jid, keyword, data={}): + ''' + Adds a subscription to @keyword with data @data for @jid within + the context of @module. + ''' + if not module in self._subscriptions: + self._subscriptions[module] = {} + + append_or_insert(self._subscriptions[module], jid, { + 'keyword': keyword, + 'data': data + }) + self.__flush() + + def append_data_for_subscription(self, module, jid, keyword, item): + ''' + Special helper function which appends item to the data field of + a subscription to @keyword from @jid within the context of @module. + + If no subscription exists, then one will be created. + + NOTE: This function expects data to be a dict with key 'data', which + must be an array, so it will fail if add_subscription_for has + been called beforehand with data equal to anything but an dict + with a 'data' key containing an array. + ''' + if not module in self._subscriptions: + self._subscriptions[module] = {} + if not jid in self._subscriptions[module]: + self._subscriptions[module][jid] = {} + if not keyword in self._subscriptions[module][jid]: + self._subscriptions[module][jid][keyword] = { + 'data': [item] + } + self.__flush() + return + + self._subscriptions[module][jid][keyword]['data'].append(item) + self.__flush() def remove_subscription_for(self, module, jid, keyword): + ''' + Removes a subscription to @keyword for @jid within the context + of @module + ''' remove_or_delete(self._subscriptions[module], jid, keyword) + self.__flush() + + def remove_item_for_subscription(self, module, jid, keyword, item): + ''' + The deletion counterpart of append_data_for_subscription. + ''' + if not module in self._subscriptions: + return + if not jid in self._subscriptions[module]: + return + if not keyword in self._subscriptions[module][jid]: + return + if not item in self._subscriptions[module][jid][keyword]['data']: + return + + self._subscriptions[module][jid][keyword]['data'].remove(item) + self.__flush() + + def __flush(self): + ''' + Write subscription data to disk. Just an interface to StorageManager + ''' + self._sm.set_data('_StorageManager', 'subscriptions', self._subscriptions)