refactor: Huge refactor

- Storage and Subscriptions are Singletons
- SubscriptionManager stores data in StorageManager
- StorageManager and SubscriptionManager are now persistent
- The module template is now much simpler
- Added logging (Debug logging via --debug)
This commit is contained in:
PapaTutuWawa 2021-06-12 20:51:45 +02:00
parent 4d8436df67
commit e9900ee9b6
5 changed files with 338 additions and 75 deletions

View File

@ -14,13 +14,19 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
import sys
import importlib import importlib
import asyncio import asyncio
import logging
from optparse import OptionParser
import aioxmpp import aioxmpp
import toml import toml
from mira.subscription import SubscriptionManager from mira.subscription import SubscriptionManager
from mira.storage import StorageManager
logger = logging.getLogger('mira.base')
def message_wrapper(to, body): def message_wrapper(to, body):
msg = aioxmpp.Message( msg = aioxmpp.Message(
@ -30,39 +36,38 @@ def message_wrapper(to, body):
return msg return msg
class MiraBot: class MiraBot:
def __init__(self): def __init__(self, config_path):
# Bot specific settings # Bot specific settings
self._jid = "" self._config = toml.load(config_path)
self._password = ""
self._jid = aioxmpp.JID.fromstr(self._config['jid'])
self._password = self._config['password']
self._avatar = self._config.get('avatar', None)
self._client = None self._client = None
self._avatar = None
self._modules = {} # Module name -> module self._modules = {} # Module name -> module
self._subscription_manager = SubscriptionManager() self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', './storage.json'))
self._subscription_manager = SubscriptionManager.get_instance()
def load_config(self, path): def _initialise_modules(self):
data = toml.load(path) for module in self._config['modules']:
logger.debug("Initialising module %s" % (module))
self._jid = aioxmpp.JID.fromstr(data['jid'])
self._password = data['password']
self._avatar = data.get('avatar', None)
for module in data['modules']:
mod = importlib.import_module(module['name']) mod = importlib.import_module(module['name'])
self._modules[mod.NAME] = mod.get_instance(self, module) self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
self._modules[mod.NAME].set_name(mod.NAME)
async def connect(self): async def connect(self):
self._client = aioxmpp.PresenceManagedClient( self._client = aioxmpp.PresenceManagedClient(
self._jid, self._jid,
aioxmpp.make_security_layer(self._password)) aioxmpp.make_security_layer(self._password))
async with self._client.connected(): async with self._client.connected():
logger.info('Client connected')
self._client.stream.register_message_callback( self._client.stream.register_message_callback(
aioxmpp.MessageType.CHAT, aioxmpp.MessageType.CHAT,
None, None,
self._on_message) self._on_message)
if self._avatar: if self._avatar:
logger.info('Publishing avatar')
with open(self._avatar, 'rb') as avatar_file: with open(self._avatar, 'rb') as avatar_file:
data = avatar_file.read() data = avatar_file.read()
avatar_set = aioxmpp.avatar.AvatarSet() avatar_set = aioxmpp.avatar.AvatarSet()
@ -70,10 +75,17 @@ class MiraBot:
avatar_set.add_avatar_image('image/png', image_bytes=data) avatar_set.add_avatar_image('image/png', image_bytes=data)
avatar = self._client.summon(aioxmpp.avatar.AvatarService) avatar = self._client.summon(aioxmpp.avatar.AvatarService)
await avatar.publish_avatar_set(avatar_set) await avatar.publish_avatar_set(avatar_set)
logger.info('Avatar published')
logger.debug('Initialising modules')
self._initialise_modules()
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
def _is_sender_local(self, from_):
return from_.domain == self._jid.domain
def _on_message(self, message): def _on_message(self, message):
# Automatically handles sending a message receipt and dealing # Automatically handles sending a message receipt and dealing
# with unwanted messages # with unwanted messages
@ -87,9 +99,24 @@ class MiraBot:
self._client.enqueue(receipt) self._client.enqueue(receipt)
if not cmd[0] in self._modules: if not cmd[0] in self._modules:
logger.debug('Received command for unknown module. Dropping')
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl")) self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
return return
# Just drop messages that are not local when the module should
# be local only
if self._modules[cmd[0]]._restricted:
if self._modules[cmd[0]]._config['restrict_local']:
if not self._is_sender_local(message.from_):
logger.warning('Received a command from a non-local user to a'
' module that is restricted to local users only')
return
elif self._modules[cmd[0]]._config['allowed_domains']:
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
logger.warning('Received a command from a non-whitelisted user to a'
' module that is restricted to whitelisted users only')
return
self._modules[cmd[0]]._on_command(cmd[1:], message) self._modules[cmd[0]]._on_command(cmd[1:], message)
# Module Function: Send message # Module Function: Send message
@ -100,13 +127,16 @@ class MiraBot:
def send_message_wrapper(self, to, body): def send_message_wrapper(self, to, body):
self.send_message(message_wrapper(to, body)) self.send_message(message_wrapper(to, body))
# Module Function
def get_subscription_manager(self):
return self._subscription_manager
def main(): def main():
bot = MiraBot() parser = OptionParser()
bot.load_config("./config.toml") parser.add_option('-d', '--debug', dest='debug',
help='Enable debug logging', action='store_true')
(options, args) = parser.parse_args()
verbosity = logging.DEBUG if options.debug else logging.INFO
logging.basicConfig(stream=sys.stdout, level=verbosity)
bot = MiraBot("./config.toml")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(bot.connect()) loop.run_until_complete(bot.connect())

View File

@ -14,55 +14,70 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
import asyncio
from functools import partial
from mira.storage import StorageManager
from mira.subscription import SubscriptionManager
class ManagerWrapper:
'''
Wrapper class around {Storage, Subscription}Manager in order
to to expose those directly to the modules without allowing them
access to other modules.
'''
_name = ''
_manager = None
def __init__(self, name, manager):
self._name = name
self._manager = manager.get_instance()
def __getattr__(self, key):
if not key in dir(self._manager):
raise AttributeError("Attribute %s does not exist in wrapped"
" class %s" % (key, type(self._manager)))
return partial(getattr(self._manager, key), self._name)
class BaseModule: class BaseModule:
def __init__(self, base, config={}, subcommand_table={}): def __init__(self, base, config={}, subcommand_table={}, name=''):
self._name = '' self._name = name
self._base = base self._base = base
self._config = config self._config = config
self._subcommand_table = subcommand_table self._subcommand_table = subcommand_table
self._local_only = self.get_option('restrict_local', False)
self._restricted = ('restrict_local' in self._config or
'allowed_domains' in self._config)
def set_name(self, name): self._stm = ManagerWrapper(self._name, StorageManager)
if self._name: self._sum = ManagerWrapper(self._name, SubscriptionManager)
raise Exception('Name change of module attempted!')
self._name = name
def get_option(self, key, default=None): def get_option(self, key, default=None):
'''
Like dict.get(), but for the options from the bot configuration
file.
If key does not exist, then default will be returned.
'''
return self._config.get(key, default) return self._config.get(key, default)
# Used for access control def send_message(self, to, body):
# Returns True if @jid is allowed to access the command. False otherwise. '''
# This is configured by either restrict_local or allowed_domains. restrict_local A simple wrapper that sends a message with type='chat' to
# takes precedence over allowed_domains @to with @body as the body
def is_jid_allowed(self, jid): '''
only_local = self.get_option('restrict_local') self._base.send_message_wrapper(to, body)
if only_local:
return only_local
domains = self.get_option('allowed_domains')
if not domains:
return True
return jid.domain in domains
def get_subscriptions_for(self, jid):
return self._base.get_subscription_manager.get_subscriptions_for(self._name, jid)
def add_subscription_for(self, jid, keyword):
return self._base.get_subscription_manager.add_subscription_for(self._name, jid, keyword)
def remove_subscription_for(self, jid, keyword):
return self._base.get_subscription_manager.remove_subscription_for(self._name, jid, keyword)
def is_subscribed_to(self, jid, keyword):
return self._base.get_subscription_manager.is_subscribed_to(self._name, jid, keyword)
def _on_command(self, cmd, msg): def _on_command(self, cmd, msg):
def run(func):
loop = asyncio.get_event_loop()
loop.create_task(func(cmd, msg))
if not self._subcommand_table: if not self._subcommand_table:
self.on_command(cmd, msg) run(self.on_command)
elif cmd and cmd[0] in self._subcommand_table: elif cmd and cmd[0] in self._subcommand_table:
self._subcommand_table[cmd[0]](cmd[1:], msg) run(self._subcommand_table[cmd[0]])
else: else:
if '*' in self._subcommand_table: if '*' in self._subcommand_table:
self._subcommand_table['*'](cmd, msg) run(self._subcommand_table['*'])

View File

@ -22,22 +22,22 @@ class TestModule(BaseModule):
__instance = None __instance = None
@staticmethod @staticmethod
def get_instance(base, config): def get_instance(base, **kwargs):
if TestModule.__instance == None: if TestModule.__instance == None:
TestModule(base, config) TestModule(base, **kwargs)
return TestModule.__instance return TestModule.__instance
def __init__(self, base, config): def __init__(self, base, **kwargs):
if TestModule.__instance != None: if TestModule.__instance != None:
raise Exception('Trying to init singleton twice') raise Exception('Trying to init singleton twice')
super().__init__(base, config) super().__init__(base, **kwargs)
TestModule.__instance = self TestModule.__instance = self
def on_command(self, cmd, msg): async def on_command(self, cmd, msg):
greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare())) greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare()))
self._base.send_message_wrapper(msg.from_, greeting) self._base.send_message_wrapper(msg.from_, greeting)
def get_instance(base, config={}): def get_instance(base, **kwargs):
return TestModule.get_instance(base, config) return TestModule.get_instance(base, **kwargs)

75
mira/storage.py Normal file
View File

@ -0,0 +1,75 @@
'''
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
import os
import json
import logging
logger = logging.getLogger('mira.storage.StorageManager')
class StorageManager:
__instance = None
@staticmethod
def get_instance(file_location='/etc/mira/storage.json'):
if not StorageManager.__instance:
StorageManager(file_location=file_location)
return StorageManager.__instance
def __init__(self, file_location):
if StorageManager.__instance:
raise Exception('Trying to instanciate StorageManger twice')
self._data = {} # Module -> Section -> Data
self._file_location = file_location
logger.debug('Loading data from %s' % (file_location))
if os.path.exists(file_location):
with open(file_location, 'r') as f:
self._data = json.loads(f.read())
StorageManager.__instance = self
def get_data(self, module, section):
'''
Get the data stored for module @module under the section
@section. Returns {} if there is no data stored for @module.
'''
if not module in self._data:
logging.debug('get_data: module unknown in self._data')
logging.debug('module: "%s"' % (module))
return {}
if not section in self._data[module]:
logging.debug('get_data: section unknown in self._data[module]')
logging.debug('module: "%s", section: "%s"' % (module, section))
return {}
return self._data[module][section]
def set_data(self, module, section, data):
'''
Stores the data @data for @module under section @section.
Flushes the data to storage afterwards.
'''
if not module in self._data:
self._data[module] = {}
self._data[module][section] = data
self.__flush()
def __flush(self):
logger.debug('Flushing to storage')
with open(self._file_location, 'w') as f:
f.write(json.dumps(self._data))

View File

@ -1,5 +1,24 @@
#from collections import namedtuple '''
# TODO: Allow storing data along with the keyword Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
# TODO: Replace most of these with a query API
import os
import json
from mira.storage import StorageManager
def append_or_insert(dict_, key, value): def append_or_insert(dict_, key, value):
if key in dict_ or dict_[key]: if key in dict_ or dict_[key]:
@ -14,26 +33,150 @@ def remove_or_delete(dict_, key, value):
dict_[key].remove(value) dict_[key].remove(value)
class SubscriptionManager: class SubscriptionManager:
'''
This class is tasked with providing functions that simplify dealing
with subscriptions.
NOTE: This class must only by instanciated *after* the StorageManager
has been instanciated at least once. For modules, this is no
issue as they're only created after the manager classes are
ready.
'''
__instance = None
@staticmethod
def get_instance():
if not SubscriptionManager.__instance:
SubscriptionManager()
return SubscriptionManager.__instance
def __init__(self): def __init__(self):
self._subscriptions = {} # Module -> JID -> Keywords self._sm = StorageManager.get_instance()
# Module -> JID -> Keywords
self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions')
SubscriptionManager.__instance = self
def get_subscriptions_for(self, module, jid): def get_subscriptions_for(self, module, jid):
if not module in self._subscriptions: if not module in self._subscriptions:
return None return []
if not jid in self._subscriptions[module]: if not jid in self._subscriptions[module]:
return None return []
return self._subscriptions[module][jid] return self._subscriptions[module][jid]
def get_subscriptions_for_keyword(self, module, keyword):
'''
Returns an array of JIDs that are subscribed to the keyword of module
'''
if not module in self._subscriptions:
return []
tmp = []
for jid in self._subscriptions[module]:
if not keyword in self._subscriptions[module][jid]:
continue
data = self._subscriptions[module][jid][keyword]['data']
tmp.append((jid, data))
return tmp
def get_subscription_keywords(self, module):
'''Returns a list of subscribed keywords in module'''
if not module in self._subscriptions:
return []
tmp = []
for subscription in self._subscriptions[module].values():
tmp += list(subscription.keys())
return tmp
def is_subscribed_to(self, module, jid, keyword): def is_subscribed_to(self, module, jid, keyword):
'''
Returns True if @jid is subscribed to @keyword within the context
of @module. False otherwise
'''
return keyword in self.get_subscriptions_for(module, jid) return keyword in self.get_subscriptions_for(module, jid)
def __flush(self): def is_subscribed_to_data(self, module, jid, keyword, item):
# TODO '''
pass Returns True if @jid is subscribed to the item @item inside
the keyword @keyword within the context of @module
'''
subscriptions = self.get_subscriptions_for(module, jid)
if not subscriptions:
return False
if not keyword in subscriptions:
return False
def add_subscription_for(self, module, jid, keyword): return item in subscriptions[keyword]
append_or_insert(self._subscriptions[module], jid, keyword)
def add_subscription_for(self, module, jid, keyword, data={}):
'''
Adds a subscription to @keyword with data @data for @jid within
the context of @module.
'''
if not module in self._subscriptions:
self._subscriptions[module] = {}
append_or_insert(self._subscriptions[module], jid, {
'keyword': keyword,
'data': data
})
self.__flush()
def append_data_for_subscription(self, module, jid, keyword, item):
'''
Special helper function which appends item to the data field of
a subscription to @keyword from @jid within the context of @module.
If no subscription exists, then one will be created.
NOTE: This function expects data to be a dict with key 'data', which
must be an array, so it will fail if add_subscription_for has
been called beforehand with data equal to anything but an dict
with a 'data' key containing an array.
'''
if not module in self._subscriptions:
self._subscriptions[module] = {}
if not jid in self._subscriptions[module]:
self._subscriptions[module][jid] = {}
if not keyword in self._subscriptions[module][jid]:
self._subscriptions[module][jid][keyword] = {
'data': [item]
}
self.__flush()
return
self._subscriptions[module][jid][keyword]['data'].append(item)
self.__flush()
def remove_subscription_for(self, module, jid, keyword): def remove_subscription_for(self, module, jid, keyword):
'''
Removes a subscription to @keyword for @jid within the context
of @module
'''
remove_or_delete(self._subscriptions[module], jid, keyword) remove_or_delete(self._subscriptions[module], jid, keyword)
self.__flush()
def remove_item_for_subscription(self, module, jid, keyword, item):
'''
The deletion counterpart of append_data_for_subscription.
'''
if not module in self._subscriptions:
return
if not jid in self._subscriptions[module]:
return
if not keyword in self._subscriptions[module][jid]:
return
if not item in self._subscriptions[module][jid][keyword]['data']:
return
self._subscriptions[module][jid][keyword]['data'].remove(item)
self.__flush()
def __flush(self):
'''
Write subscription data to disk. Just an interface to StorageManager
'''
self._sm.set_data('_StorageManager', 'subscriptions', self._subscriptions)