diff --git a/mira/base.py b/mira/base.py
index 9e19117..0c638b1 100644
--- a/mira/base.py
+++ b/mira/base.py
@@ -14,13 +14,19 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
'''
+import sys
import importlib
import asyncio
+import logging
+from optparse import OptionParser
import aioxmpp
import toml
from mira.subscription import SubscriptionManager
+from mira.storage import StorageManager
+
+logger = logging.getLogger('mira.base')
def message_wrapper(to, body):
msg = aioxmpp.Message(
@@ -30,39 +36,38 @@ def message_wrapper(to, body):
return msg
class MiraBot:
- def __init__(self):
+ def __init__(self, config_path):
# Bot specific settings
- self._jid = ""
- self._password = ""
+ self._config = toml.load(config_path)
+
+ self._jid = aioxmpp.JID.fromstr(self._config['jid'])
+ self._password = self._config['password']
+ self._avatar = self._config.get('avatar', None)
self._client = None
- self._avatar = None
self._modules = {} # Module name -> module
- self._subscription_manager = SubscriptionManager()
+ self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', './storage.json'))
+ self._subscription_manager = SubscriptionManager.get_instance()
- def load_config(self, path):
- data = toml.load(path)
-
- self._jid = aioxmpp.JID.fromstr(data['jid'])
- self._password = data['password']
- self._avatar = data.get('avatar', None)
-
- for module in data['modules']:
+ def _initialise_modules(self):
+ for module in self._config['modules']:
+ logger.debug("Initialising module %s" % (module))
mod = importlib.import_module(module['name'])
- self._modules[mod.NAME] = mod.get_instance(self, module)
- self._modules[mod.NAME].set_name(mod.NAME)
+ self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
async def connect(self):
self._client = aioxmpp.PresenceManagedClient(
self._jid,
aioxmpp.make_security_layer(self._password))
async with self._client.connected():
+ logger.info('Client connected')
self._client.stream.register_message_callback(
aioxmpp.MessageType.CHAT,
None,
self._on_message)
if self._avatar:
+ logger.info('Publishing avatar')
with open(self._avatar, 'rb') as avatar_file:
data = avatar_file.read()
avatar_set = aioxmpp.avatar.AvatarSet()
@@ -70,10 +75,17 @@ class MiraBot:
avatar_set.add_avatar_image('image/png', image_bytes=data)
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
await avatar.publish_avatar_set(avatar_set)
+ logger.info('Avatar published')
+ logger.debug('Initialising modules')
+ self._initialise_modules()
+
while True:
await asyncio.sleep(1)
+ def _is_sender_local(self, from_):
+ return from_.domain == self._jid.domain
+
def _on_message(self, message):
# Automatically handles sending a message receipt and dealing
# with unwanted messages
@@ -87,9 +99,24 @@ class MiraBot:
self._client.enqueue(receipt)
if not cmd[0] in self._modules:
+ logger.debug('Received command for unknown module. Dropping')
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
return
+ # Just drop messages that are not local when the module should
+ # be local only
+ if self._modules[cmd[0]]._restricted:
+ if self._modules[cmd[0]]._config['restrict_local']:
+ if not self._is_sender_local(message.from_):
+ logger.warning('Received a command from a non-local user to a'
+ ' module that is restricted to local users only')
+ return
+ elif self._modules[cmd[0]]._config['allowed_domains']:
+ if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
+ logger.warning('Received a command from a non-whitelisted user to a'
+ ' module that is restricted to whitelisted users only')
+ return
+
self._modules[cmd[0]]._on_command(cmd[1:], message)
# Module Function: Send message
@@ -99,14 +126,17 @@ class MiraBot:
# Module Function: Send a message to @to with @body
def send_message_wrapper(self, to, body):
self.send_message(message_wrapper(to, body))
-
- # Module Function
- def get_subscription_manager(self):
- return self._subscription_manager
-
+
def main():
- bot = MiraBot()
- bot.load_config("./config.toml")
+ parser = OptionParser()
+ parser.add_option('-d', '--debug', dest='debug',
+ help='Enable debug logging', action='store_true')
+ (options, args) = parser.parse_args()
+
+ verbosity = logging.DEBUG if options.debug else logging.INFO
+ logging.basicConfig(stream=sys.stdout, level=verbosity)
+
+ bot = MiraBot("./config.toml")
loop = asyncio.get_event_loop()
loop.run_until_complete(bot.connect())
diff --git a/mira/module.py b/mira/module.py
index 174d480..1c00227 100644
--- a/mira/module.py
+++ b/mira/module.py
@@ -14,55 +14,70 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
'''
+import asyncio
+from functools import partial
+
+from mira.storage import StorageManager
+from mira.subscription import SubscriptionManager
+
+class ManagerWrapper:
+ '''
+ Wrapper class around {Storage, Subscription}Manager in order
+ to to expose those directly to the modules without allowing them
+ access to other modules.
+ '''
+ _name = ''
+ _manager = None
+
+ def __init__(self, name, manager):
+ self._name = name
+ self._manager = manager.get_instance()
+
+ def __getattr__(self, key):
+ if not key in dir(self._manager):
+ raise AttributeError("Attribute %s does not exist in wrapped"
+ " class %s" % (key, type(self._manager)))
+
+ return partial(getattr(self._manager, key), self._name)
class BaseModule:
- def __init__(self, base, config={}, subcommand_table={}):
- self._name = ''
+ def __init__(self, base, config={}, subcommand_table={}, name=''):
+ self._name = name
self._base = base
self._config = config
self._subcommand_table = subcommand_table
-
- def set_name(self, name):
- if self._name:
- raise Exception('Name change of module attempted!')
+ self._local_only = self.get_option('restrict_local', False)
+ self._restricted = ('restrict_local' in self._config or
+ 'allowed_domains' in self._config)
- self._name = name
+ self._stm = ManagerWrapper(self._name, StorageManager)
+ self._sum = ManagerWrapper(self._name, SubscriptionManager)
def get_option(self, key, default=None):
+ '''
+ Like dict.get(), but for the options from the bot configuration
+ file.
+
+ If key does not exist, then default will be returned.
+ '''
return self._config.get(key, default)
- # Used for access control
- # Returns True if @jid is allowed to access the command. False otherwise.
- # This is configured by either restrict_local or allowed_domains. restrict_local
- # takes precedence over allowed_domains
- def is_jid_allowed(self, jid):
- only_local = self.get_option('restrict_local')
- if only_local:
- return only_local
-
- domains = self.get_option('allowed_domains')
- if not domains:
- return True
-
- return jid.domain in domains
+ def send_message(self, to, body):
+ '''
+ A simple wrapper that sends a message with type='chat' to
+ @to with @body as the body
+ '''
+ self._base.send_message_wrapper(to, body)
- def get_subscriptions_for(self, jid):
- return self._base.get_subscription_manager.get_subscriptions_for(self._name, jid)
-
- def add_subscription_for(self, jid, keyword):
- return self._base.get_subscription_manager.add_subscription_for(self._name, jid, keyword)
-
- def remove_subscription_for(self, jid, keyword):
- return self._base.get_subscription_manager.remove_subscription_for(self._name, jid, keyword)
-
- def is_subscribed_to(self, jid, keyword):
- return self._base.get_subscription_manager.is_subscribed_to(self._name, jid, keyword)
-
def _on_command(self, cmd, msg):
+ def run(func):
+ loop = asyncio.get_event_loop()
+ loop.create_task(func(cmd, msg))
+
if not self._subcommand_table:
- self.on_command(cmd, msg)
+ run(self.on_command)
elif cmd and cmd[0] in self._subcommand_table:
- self._subcommand_table[cmd[0]](cmd[1:], msg)
+ run(self._subcommand_table[cmd[0]])
else:
if '*' in self._subcommand_table:
- self._subcommand_table['*'](cmd, msg)
+ run(self._subcommand_table['*'])
diff --git a/mira/modules/test.py b/mira/modules/test.py
index db43fac..88b9965 100644
--- a/mira/modules/test.py
+++ b/mira/modules/test.py
@@ -22,22 +22,22 @@ class TestModule(BaseModule):
__instance = None
@staticmethod
- def get_instance(base, config):
+ def get_instance(base, **kwargs):
if TestModule.__instance == None:
- TestModule(base, config)
+ TestModule(base, **kwargs)
return TestModule.__instance
- def __init__(self, base, config):
+ def __init__(self, base, **kwargs):
if TestModule.__instance != None:
raise Exception('Trying to init singleton twice')
- super().__init__(base, config)
+ super().__init__(base, **kwargs)
TestModule.__instance = self
- def on_command(self, cmd, msg):
+ async def on_command(self, cmd, msg):
greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare()))
self._base.send_message_wrapper(msg.from_, greeting)
-def get_instance(base, config={}):
- return TestModule.get_instance(base, config)
+def get_instance(base, **kwargs):
+ return TestModule.get_instance(base, **kwargs)
diff --git a/mira/storage.py b/mira/storage.py
new file mode 100644
index 0000000..f435778
--- /dev/null
+++ b/mira/storage.py
@@ -0,0 +1,75 @@
+'''
+Copyright (C) 2021 Alexander "PapaTutuWawa"
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see .
+'''
+import os
+import json
+import logging
+
+logger = logging.getLogger('mira.storage.StorageManager')
+
+class StorageManager:
+ __instance = None
+
+ @staticmethod
+ def get_instance(file_location='/etc/mira/storage.json'):
+ if not StorageManager.__instance:
+ StorageManager(file_location=file_location)
+ return StorageManager.__instance
+
+ def __init__(self, file_location):
+ if StorageManager.__instance:
+ raise Exception('Trying to instanciate StorageManger twice')
+
+ self._data = {} # Module -> Section -> Data
+ self._file_location = file_location
+
+ logger.debug('Loading data from %s' % (file_location))
+ if os.path.exists(file_location):
+ with open(file_location, 'r') as f:
+ self._data = json.loads(f.read())
+
+ StorageManager.__instance = self
+
+ def get_data(self, module, section):
+ '''
+ Get the data stored for module @module under the section
+ @section. Returns {} if there is no data stored for @module.
+ '''
+ if not module in self._data:
+ logging.debug('get_data: module unknown in self._data')
+ logging.debug('module: "%s"' % (module))
+ return {}
+ if not section in self._data[module]:
+ logging.debug('get_data: section unknown in self._data[module]')
+ logging.debug('module: "%s", section: "%s"' % (module, section))
+ return {}
+
+ return self._data[module][section]
+
+ def set_data(self, module, section, data):
+ '''
+ Stores the data @data for @module under section @section.
+ Flushes the data to storage afterwards.
+ '''
+ if not module in self._data:
+ self._data[module] = {}
+ self._data[module][section] = data
+ self.__flush()
+
+ def __flush(self):
+ logger.debug('Flushing to storage')
+ with open(self._file_location, 'w') as f:
+ f.write(json.dumps(self._data))
diff --git a/mira/subscription.py b/mira/subscription.py
index ae84763..80a14ef 100644
--- a/mira/subscription.py
+++ b/mira/subscription.py
@@ -1,5 +1,24 @@
-#from collections import namedtuple
-# TODO: Allow storing data along with the keyword
+'''
+Copyright (C) 2021 Alexander "PapaTutuWawa"
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see .
+'''
+# TODO: Replace most of these with a query API
+import os
+import json
+
+from mira.storage import StorageManager
def append_or_insert(dict_, key, value):
if key in dict_ or dict_[key]:
@@ -14,26 +33,150 @@ def remove_or_delete(dict_, key, value):
dict_[key].remove(value)
class SubscriptionManager:
+ '''
+ This class is tasked with providing functions that simplify dealing
+ with subscriptions.
+
+ NOTE: This class must only by instanciated *after* the StorageManager
+ has been instanciated at least once. For modules, this is no
+ issue as they're only created after the manager classes are
+ ready.
+ '''
+ __instance = None
+
+ @staticmethod
+ def get_instance():
+ if not SubscriptionManager.__instance:
+ SubscriptionManager()
+ return SubscriptionManager.__instance
+
def __init__(self):
- self._subscriptions = {} # Module -> JID -> Keywords
+ self._sm = StorageManager.get_instance()
+ # Module -> JID -> Keywords
+ self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions')
+
+ SubscriptionManager.__instance = self
def get_subscriptions_for(self, module, jid):
if not module in self._subscriptions:
- return None
+ return []
if not jid in self._subscriptions[module]:
- return None
+ return []
return self._subscriptions[module][jid]
+ def get_subscriptions_for_keyword(self, module, keyword):
+ '''
+ Returns an array of JIDs that are subscribed to the keyword of module
+ '''
+ if not module in self._subscriptions:
+ return []
+
+ tmp = []
+ for jid in self._subscriptions[module]:
+ if not keyword in self._subscriptions[module][jid]:
+ continue
+
+ data = self._subscriptions[module][jid][keyword]['data']
+ tmp.append((jid, data))
+ return tmp
+
+ def get_subscription_keywords(self, module):
+ '''Returns a list of subscribed keywords in module'''
+ if not module in self._subscriptions:
+ return []
+
+ tmp = []
+ for subscription in self._subscriptions[module].values():
+ tmp += list(subscription.keys())
+ return tmp
+
def is_subscribed_to(self, module, jid, keyword):
+ '''
+ Returns True if @jid is subscribed to @keyword within the context
+ of @module. False otherwise
+ '''
return keyword in self.get_subscriptions_for(module, jid)
- def __flush(self):
- # TODO
- pass
+ def is_subscribed_to_data(self, module, jid, keyword, item):
+ '''
+ Returns True if @jid is subscribed to the item @item inside
+ the keyword @keyword within the context of @module
+ '''
+ subscriptions = self.get_subscriptions_for(module, jid)
+ if not subscriptions:
+ return False
+ if not keyword in subscriptions:
+ return False
- def add_subscription_for(self, module, jid, keyword):
- append_or_insert(self._subscriptions[module], jid, keyword)
+ return item in subscriptions[keyword]
+
+ def add_subscription_for(self, module, jid, keyword, data={}):
+ '''
+ Adds a subscription to @keyword with data @data for @jid within
+ the context of @module.
+ '''
+ if not module in self._subscriptions:
+ self._subscriptions[module] = {}
+
+ append_or_insert(self._subscriptions[module], jid, {
+ 'keyword': keyword,
+ 'data': data
+ })
+ self.__flush()
+
+ def append_data_for_subscription(self, module, jid, keyword, item):
+ '''
+ Special helper function which appends item to the data field of
+ a subscription to @keyword from @jid within the context of @module.
+
+ If no subscription exists, then one will be created.
+
+ NOTE: This function expects data to be a dict with key 'data', which
+ must be an array, so it will fail if add_subscription_for has
+ been called beforehand with data equal to anything but an dict
+ with a 'data' key containing an array.
+ '''
+ if not module in self._subscriptions:
+ self._subscriptions[module] = {}
+ if not jid in self._subscriptions[module]:
+ self._subscriptions[module][jid] = {}
+ if not keyword in self._subscriptions[module][jid]:
+ self._subscriptions[module][jid][keyword] = {
+ 'data': [item]
+ }
+ self.__flush()
+ return
+
+ self._subscriptions[module][jid][keyword]['data'].append(item)
+ self.__flush()
def remove_subscription_for(self, module, jid, keyword):
+ '''
+ Removes a subscription to @keyword for @jid within the context
+ of @module
+ '''
remove_or_delete(self._subscriptions[module], jid, keyword)
+ self.__flush()
+
+ def remove_item_for_subscription(self, module, jid, keyword, item):
+ '''
+ The deletion counterpart of append_data_for_subscription.
+ '''
+ if not module in self._subscriptions:
+ return
+ if not jid in self._subscriptions[module]:
+ return
+ if not keyword in self._subscriptions[module][jid]:
+ return
+ if not item in self._subscriptions[module][jid][keyword]['data']:
+ return
+
+ self._subscriptions[module][jid][keyword]['data'].remove(item)
+ self.__flush()
+
+ def __flush(self):
+ '''
+ Write subscription data to disk. Just an interface to StorageManager
+ '''
+ self._sm.set_data('_StorageManager', 'subscriptions', self._subscriptions)