refactor: Huge refactor
- Storage and Subscriptions are Singletons - SubscriptionManager stores data in StorageManager - StorageManager and SubscriptionManager are now persistent - The module template is now much simpler - Added logging (Debug logging via --debug)
This commit is contained in:
parent
4d8436df67
commit
e9900ee9b6
74
mira/base.py
74
mira/base.py
@ -14,13 +14,19 @@ GNU General Public License for more details.
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
import sys
|
||||
import importlib
|
||||
import asyncio
|
||||
import logging
|
||||
from optparse import OptionParser
|
||||
|
||||
import aioxmpp
|
||||
import toml
|
||||
|
||||
from mira.subscription import SubscriptionManager
|
||||
from mira.storage import StorageManager
|
||||
|
||||
logger = logging.getLogger('mira.base')
|
||||
|
||||
def message_wrapper(to, body):
|
||||
msg = aioxmpp.Message(
|
||||
@ -30,39 +36,38 @@ def message_wrapper(to, body):
|
||||
return msg
|
||||
|
||||
class MiraBot:
|
||||
def __init__(self):
|
||||
def __init__(self, config_path):
|
||||
# Bot specific settings
|
||||
self._jid = ""
|
||||
self._password = ""
|
||||
self._config = toml.load(config_path)
|
||||
|
||||
self._jid = aioxmpp.JID.fromstr(self._config['jid'])
|
||||
self._password = self._config['password']
|
||||
self._avatar = self._config.get('avatar', None)
|
||||
self._client = None
|
||||
self._avatar = None
|
||||
|
||||
self._modules = {} # Module name -> module
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', './storage.json'))
|
||||
self._subscription_manager = SubscriptionManager.get_instance()
|
||||
|
||||
def load_config(self, path):
|
||||
data = toml.load(path)
|
||||
|
||||
self._jid = aioxmpp.JID.fromstr(data['jid'])
|
||||
self._password = data['password']
|
||||
self._avatar = data.get('avatar', None)
|
||||
|
||||
for module in data['modules']:
|
||||
def _initialise_modules(self):
|
||||
for module in self._config['modules']:
|
||||
logger.debug("Initialising module %s" % (module))
|
||||
mod = importlib.import_module(module['name'])
|
||||
self._modules[mod.NAME] = mod.get_instance(self, module)
|
||||
self._modules[mod.NAME].set_name(mod.NAME)
|
||||
self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
|
||||
|
||||
async def connect(self):
|
||||
self._client = aioxmpp.PresenceManagedClient(
|
||||
self._jid,
|
||||
aioxmpp.make_security_layer(self._password))
|
||||
async with self._client.connected():
|
||||
logger.info('Client connected')
|
||||
self._client.stream.register_message_callback(
|
||||
aioxmpp.MessageType.CHAT,
|
||||
None,
|
||||
self._on_message)
|
||||
|
||||
if self._avatar:
|
||||
logger.info('Publishing avatar')
|
||||
with open(self._avatar, 'rb') as avatar_file:
|
||||
data = avatar_file.read()
|
||||
avatar_set = aioxmpp.avatar.AvatarSet()
|
||||
@ -70,10 +75,17 @@ class MiraBot:
|
||||
avatar_set.add_avatar_image('image/png', image_bytes=data)
|
||||
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
|
||||
await avatar.publish_avatar_set(avatar_set)
|
||||
logger.info('Avatar published')
|
||||
|
||||
logger.debug('Initialising modules')
|
||||
self._initialise_modules()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def _is_sender_local(self, from_):
|
||||
return from_.domain == self._jid.domain
|
||||
|
||||
def _on_message(self, message):
|
||||
# Automatically handles sending a message receipt and dealing
|
||||
# with unwanted messages
|
||||
@ -87,9 +99,24 @@ class MiraBot:
|
||||
self._client.enqueue(receipt)
|
||||
|
||||
if not cmd[0] in self._modules:
|
||||
logger.debug('Received command for unknown module. Dropping')
|
||||
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
|
||||
return
|
||||
|
||||
# Just drop messages that are not local when the module should
|
||||
# be local only
|
||||
if self._modules[cmd[0]]._restricted:
|
||||
if self._modules[cmd[0]]._config['restrict_local']:
|
||||
if not self._is_sender_local(message.from_):
|
||||
logger.warning('Received a command from a non-local user to a'
|
||||
' module that is restricted to local users only')
|
||||
return
|
||||
elif self._modules[cmd[0]]._config['allowed_domains']:
|
||||
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
|
||||
logger.warning('Received a command from a non-whitelisted user to a'
|
||||
' module that is restricted to whitelisted users only')
|
||||
return
|
||||
|
||||
self._modules[cmd[0]]._on_command(cmd[1:], message)
|
||||
|
||||
# Module Function: Send message
|
||||
@ -99,14 +126,17 @@ class MiraBot:
|
||||
# Module Function: Send a message to @to with @body
|
||||
def send_message_wrapper(self, to, body):
|
||||
self.send_message(message_wrapper(to, body))
|
||||
|
||||
# Module Function
|
||||
def get_subscription_manager(self):
|
||||
return self._subscription_manager
|
||||
|
||||
|
||||
def main():
|
||||
bot = MiraBot()
|
||||
bot.load_config("./config.toml")
|
||||
parser = OptionParser()
|
||||
parser.add_option('-d', '--debug', dest='debug',
|
||||
help='Enable debug logging', action='store_true')
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
verbosity = logging.DEBUG if options.debug else logging.INFO
|
||||
logging.basicConfig(stream=sys.stdout, level=verbosity)
|
||||
|
||||
bot = MiraBot("./config.toml")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(bot.connect())
|
||||
|
@ -14,55 +14,70 @@ GNU General Public License for more details.
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from mira.storage import StorageManager
|
||||
from mira.subscription import SubscriptionManager
|
||||
|
||||
class ManagerWrapper:
|
||||
'''
|
||||
Wrapper class around {Storage, Subscription}Manager in order
|
||||
to to expose those directly to the modules without allowing them
|
||||
access to other modules.
|
||||
'''
|
||||
_name = ''
|
||||
_manager = None
|
||||
|
||||
def __init__(self, name, manager):
|
||||
self._name = name
|
||||
self._manager = manager.get_instance()
|
||||
|
||||
def __getattr__(self, key):
|
||||
if not key in dir(self._manager):
|
||||
raise AttributeError("Attribute %s does not exist in wrapped"
|
||||
" class %s" % (key, type(self._manager)))
|
||||
|
||||
return partial(getattr(self._manager, key), self._name)
|
||||
|
||||
class BaseModule:
|
||||
def __init__(self, base, config={}, subcommand_table={}):
|
||||
self._name = ''
|
||||
def __init__(self, base, config={}, subcommand_table={}, name=''):
|
||||
self._name = name
|
||||
self._base = base
|
||||
self._config = config
|
||||
self._subcommand_table = subcommand_table
|
||||
|
||||
def set_name(self, name):
|
||||
if self._name:
|
||||
raise Exception('Name change of module attempted!')
|
||||
self._local_only = self.get_option('restrict_local', False)
|
||||
self._restricted = ('restrict_local' in self._config or
|
||||
'allowed_domains' in self._config)
|
||||
|
||||
self._name = name
|
||||
self._stm = ManagerWrapper(self._name, StorageManager)
|
||||
self._sum = ManagerWrapper(self._name, SubscriptionManager)
|
||||
|
||||
def get_option(self, key, default=None):
|
||||
'''
|
||||
Like dict.get(), but for the options from the bot configuration
|
||||
file.
|
||||
|
||||
If key does not exist, then default will be returned.
|
||||
'''
|
||||
return self._config.get(key, default)
|
||||
|
||||
# Used for access control
|
||||
# Returns True if @jid is allowed to access the command. False otherwise.
|
||||
# This is configured by either restrict_local or allowed_domains. restrict_local
|
||||
# takes precedence over allowed_domains
|
||||
def is_jid_allowed(self, jid):
|
||||
only_local = self.get_option('restrict_local')
|
||||
if only_local:
|
||||
return only_local
|
||||
|
||||
domains = self.get_option('allowed_domains')
|
||||
if not domains:
|
||||
return True
|
||||
|
||||
return jid.domain in domains
|
||||
def send_message(self, to, body):
|
||||
'''
|
||||
A simple wrapper that sends a message with type='chat' to
|
||||
@to with @body as the body
|
||||
'''
|
||||
self._base.send_message_wrapper(to, body)
|
||||
|
||||
def get_subscriptions_for(self, jid):
|
||||
return self._base.get_subscription_manager.get_subscriptions_for(self._name, jid)
|
||||
|
||||
def add_subscription_for(self, jid, keyword):
|
||||
return self._base.get_subscription_manager.add_subscription_for(self._name, jid, keyword)
|
||||
|
||||
def remove_subscription_for(self, jid, keyword):
|
||||
return self._base.get_subscription_manager.remove_subscription_for(self._name, jid, keyword)
|
||||
|
||||
def is_subscribed_to(self, jid, keyword):
|
||||
return self._base.get_subscription_manager.is_subscribed_to(self._name, jid, keyword)
|
||||
|
||||
def _on_command(self, cmd, msg):
|
||||
def run(func):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(func(cmd, msg))
|
||||
|
||||
if not self._subcommand_table:
|
||||
self.on_command(cmd, msg)
|
||||
run(self.on_command)
|
||||
elif cmd and cmd[0] in self._subcommand_table:
|
||||
self._subcommand_table[cmd[0]](cmd[1:], msg)
|
||||
run(self._subcommand_table[cmd[0]])
|
||||
else:
|
||||
if '*' in self._subcommand_table:
|
||||
self._subcommand_table['*'](cmd, msg)
|
||||
run(self._subcommand_table['*'])
|
||||
|
@ -22,22 +22,22 @@ class TestModule(BaseModule):
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance(base, config):
|
||||
def get_instance(base, **kwargs):
|
||||
if TestModule.__instance == None:
|
||||
TestModule(base, config)
|
||||
TestModule(base, **kwargs)
|
||||
|
||||
return TestModule.__instance
|
||||
|
||||
def __init__(self, base, config):
|
||||
def __init__(self, base, **kwargs):
|
||||
if TestModule.__instance != None:
|
||||
raise Exception('Trying to init singleton twice')
|
||||
|
||||
super().__init__(base, config)
|
||||
super().__init__(base, **kwargs)
|
||||
TestModule.__instance = self
|
||||
|
||||
def on_command(self, cmd, msg):
|
||||
async def on_command(self, cmd, msg):
|
||||
greeting = self.get_option('greeting', 'OwO, %%user%%!').replace('%%user%%', str(msg.from_.bare()))
|
||||
self._base.send_message_wrapper(msg.from_, greeting)
|
||||
|
||||
def get_instance(base, config={}):
|
||||
return TestModule.get_instance(base, config)
|
||||
def get_instance(base, **kwargs):
|
||||
return TestModule.get_instance(base, **kwargs)
|
||||
|
75
mira/storage.py
Normal file
75
mira/storage.py
Normal file
@ -0,0 +1,75 @@
|
||||
'''
|
||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('mira.storage.StorageManager')
|
||||
|
||||
class StorageManager:
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance(file_location='/etc/mira/storage.json'):
|
||||
if not StorageManager.__instance:
|
||||
StorageManager(file_location=file_location)
|
||||
return StorageManager.__instance
|
||||
|
||||
def __init__(self, file_location):
|
||||
if StorageManager.__instance:
|
||||
raise Exception('Trying to instanciate StorageManger twice')
|
||||
|
||||
self._data = {} # Module -> Section -> Data
|
||||
self._file_location = file_location
|
||||
|
||||
logger.debug('Loading data from %s' % (file_location))
|
||||
if os.path.exists(file_location):
|
||||
with open(file_location, 'r') as f:
|
||||
self._data = json.loads(f.read())
|
||||
|
||||
StorageManager.__instance = self
|
||||
|
||||
def get_data(self, module, section):
|
||||
'''
|
||||
Get the data stored for module @module under the section
|
||||
@section. Returns {} if there is no data stored for @module.
|
||||
'''
|
||||
if not module in self._data:
|
||||
logging.debug('get_data: module unknown in self._data')
|
||||
logging.debug('module: "%s"' % (module))
|
||||
return {}
|
||||
if not section in self._data[module]:
|
||||
logging.debug('get_data: section unknown in self._data[module]')
|
||||
logging.debug('module: "%s", section: "%s"' % (module, section))
|
||||
return {}
|
||||
|
||||
return self._data[module][section]
|
||||
|
||||
def set_data(self, module, section, data):
|
||||
'''
|
||||
Stores the data @data for @module under section @section.
|
||||
Flushes the data to storage afterwards.
|
||||
'''
|
||||
if not module in self._data:
|
||||
self._data[module] = {}
|
||||
self._data[module][section] = data
|
||||
self.__flush()
|
||||
|
||||
def __flush(self):
|
||||
logger.debug('Flushing to storage')
|
||||
with open(self._file_location, 'w') as f:
|
||||
f.write(json.dumps(self._data))
|
@ -1,5 +1,24 @@
|
||||
#from collections import namedtuple
|
||||
# TODO: Allow storing data along with the keyword
|
||||
'''
|
||||
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
# TODO: Replace most of these with a query API
|
||||
import os
|
||||
import json
|
||||
|
||||
from mira.storage import StorageManager
|
||||
|
||||
def append_or_insert(dict_, key, value):
|
||||
if key in dict_ or dict_[key]:
|
||||
@ -14,26 +33,150 @@ def remove_or_delete(dict_, key, value):
|
||||
dict_[key].remove(value)
|
||||
|
||||
class SubscriptionManager:
|
||||
'''
|
||||
This class is tasked with providing functions that simplify dealing
|
||||
with subscriptions.
|
||||
|
||||
NOTE: This class must only by instanciated *after* the StorageManager
|
||||
has been instanciated at least once. For modules, this is no
|
||||
issue as they're only created after the manager classes are
|
||||
ready.
|
||||
'''
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if not SubscriptionManager.__instance:
|
||||
SubscriptionManager()
|
||||
return SubscriptionManager.__instance
|
||||
|
||||
def __init__(self):
|
||||
self._subscriptions = {} # Module -> JID -> Keywords
|
||||
self._sm = StorageManager.get_instance()
|
||||
# Module -> JID -> Keywords
|
||||
self._subscriptions = self._sm.get_data('_SubscriptionManager', 'subscriptions')
|
||||
|
||||
SubscriptionManager.__instance = self
|
||||
|
||||
def get_subscriptions_for(self, module, jid):
|
||||
if not module in self._subscriptions:
|
||||
return None
|
||||
return []
|
||||
if not jid in self._subscriptions[module]:
|
||||
return None
|
||||
return []
|
||||
|
||||
return self._subscriptions[module][jid]
|
||||
|
||||
def get_subscriptions_for_keyword(self, module, keyword):
|
||||
'''
|
||||
Returns an array of JIDs that are subscribed to the keyword of module
|
||||
'''
|
||||
if not module in self._subscriptions:
|
||||
return []
|
||||
|
||||
tmp = []
|
||||
for jid in self._subscriptions[module]:
|
||||
if not keyword in self._subscriptions[module][jid]:
|
||||
continue
|
||||
|
||||
data = self._subscriptions[module][jid][keyword]['data']
|
||||
tmp.append((jid, data))
|
||||
return tmp
|
||||
|
||||
def get_subscription_keywords(self, module):
|
||||
'''Returns a list of subscribed keywords in module'''
|
||||
if not module in self._subscriptions:
|
||||
return []
|
||||
|
||||
tmp = []
|
||||
for subscription in self._subscriptions[module].values():
|
||||
tmp += list(subscription.keys())
|
||||
return tmp
|
||||
|
||||
def is_subscribed_to(self, module, jid, keyword):
|
||||
'''
|
||||
Returns True if @jid is subscribed to @keyword within the context
|
||||
of @module. False otherwise
|
||||
'''
|
||||
return keyword in self.get_subscriptions_for(module, jid)
|
||||
|
||||
def __flush(self):
|
||||
# TODO
|
||||
pass
|
||||
def is_subscribed_to_data(self, module, jid, keyword, item):
|
||||
'''
|
||||
Returns True if @jid is subscribed to the item @item inside
|
||||
the keyword @keyword within the context of @module
|
||||
'''
|
||||
subscriptions = self.get_subscriptions_for(module, jid)
|
||||
if not subscriptions:
|
||||
return False
|
||||
if not keyword in subscriptions:
|
||||
return False
|
||||
|
||||
def add_subscription_for(self, module, jid, keyword):
|
||||
append_or_insert(self._subscriptions[module], jid, keyword)
|
||||
return item in subscriptions[keyword]
|
||||
|
||||
def add_subscription_for(self, module, jid, keyword, data={}):
|
||||
'''
|
||||
Adds a subscription to @keyword with data @data for @jid within
|
||||
the context of @module.
|
||||
'''
|
||||
if not module in self._subscriptions:
|
||||
self._subscriptions[module] = {}
|
||||
|
||||
append_or_insert(self._subscriptions[module], jid, {
|
||||
'keyword': keyword,
|
||||
'data': data
|
||||
})
|
||||
self.__flush()
|
||||
|
||||
def append_data_for_subscription(self, module, jid, keyword, item):
|
||||
'''
|
||||
Special helper function which appends item to the data field of
|
||||
a subscription to @keyword from @jid within the context of @module.
|
||||
|
||||
If no subscription exists, then one will be created.
|
||||
|
||||
NOTE: This function expects data to be a dict with key 'data', which
|
||||
must be an array, so it will fail if add_subscription_for has
|
||||
been called beforehand with data equal to anything but an dict
|
||||
with a 'data' key containing an array.
|
||||
'''
|
||||
if not module in self._subscriptions:
|
||||
self._subscriptions[module] = {}
|
||||
if not jid in self._subscriptions[module]:
|
||||
self._subscriptions[module][jid] = {}
|
||||
if not keyword in self._subscriptions[module][jid]:
|
||||
self._subscriptions[module][jid][keyword] = {
|
||||
'data': [item]
|
||||
}
|
||||
self.__flush()
|
||||
return
|
||||
|
||||
self._subscriptions[module][jid][keyword]['data'].append(item)
|
||||
self.__flush()
|
||||
|
||||
def remove_subscription_for(self, module, jid, keyword):
|
||||
'''
|
||||
Removes a subscription to @keyword for @jid within the context
|
||||
of @module
|
||||
'''
|
||||
remove_or_delete(self._subscriptions[module], jid, keyword)
|
||||
self.__flush()
|
||||
|
||||
def remove_item_for_subscription(self, module, jid, keyword, item):
|
||||
'''
|
||||
The deletion counterpart of append_data_for_subscription.
|
||||
'''
|
||||
if not module in self._subscriptions:
|
||||
return
|
||||
if not jid in self._subscriptions[module]:
|
||||
return
|
||||
if not keyword in self._subscriptions[module][jid]:
|
||||
return
|
||||
if not item in self._subscriptions[module][jid][keyword]['data']:
|
||||
return
|
||||
|
||||
self._subscriptions[module][jid][keyword]['data'].remove(item)
|
||||
self.__flush()
|
||||
|
||||
def __flush(self):
|
||||
'''
|
||||
Write subscription data to disk. Just an interface to StorageManager
|
||||
'''
|
||||
self._sm.set_data('_StorageManager', 'subscriptions', self._subscriptions)
|
||||
|
Loading…
Reference in New Issue
Block a user