mira/mira/base.py

147 lines
5.5 KiB
Python

'''
Copyright (C) 2021 Alexander "PapaTutuWawa"
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
import sys
import importlib
import asyncio
import logging
from optparse import OptionParser
import aioxmpp
import toml
from mira.subscription import SubscriptionManager
from mira.storage import StorageManager
logger = logging.getLogger('mira.base')
def message_wrapper(to, body):
msg = aioxmpp.Message(
type_=aioxmpp.MessageType.CHAT,
to=to)
msg.body[None] = body
return msg
class MiraBot:
def __init__(self, config_path):
# Bot specific settings
self._config = toml.load(config_path)
self._jid = aioxmpp.JID.fromstr(self._config['jid'])
self._password = self._config['password']
self._avatar = self._config.get('avatar', None)
self._client = None
self._modules = {} # Module name -> module
self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json'))
self._subscription_manager = SubscriptionManager.get_instance()
def _initialise_modules(self):
for module in self._config['modules']:
logger.debug("Initialising module %s" % (module))
mod = importlib.import_module(module['name'])
self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
async def connect(self):
self._client = aioxmpp.PresenceManagedClient(
self._jid,
aioxmpp.make_security_layer(self._password))
async with self._client.connected():
logger.info('Client connected')
self._client.stream.register_message_callback(
aioxmpp.MessageType.CHAT,
None,
self._on_message)
if self._avatar:
logger.info('Publishing avatar')
with open(self._avatar, 'rb') as avatar_file:
data = avatar_file.read()
avatar_set = aioxmpp.avatar.AvatarSet()
# TODO: Detect MIME type
avatar_set.add_avatar_image('image/png', image_bytes=data)
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
await avatar.publish_avatar_set(avatar_set)
logger.info('Avatar published')
logger.debug('Initialising modules')
self._initialise_modules()
while True:
await asyncio.sleep(1)
def _is_sender_local(self, from_):
return from_.domain == self._jid.domain
def _on_message(self, message):
# Automatically handles sending a message receipt and dealing
# with unwanted messages
if (message.type_ != aioxmpp.MessageType.CHAT or
not message.body):
return
cmd = str(message.body.any()).split(' ')
receipt = aioxmpp.mdr.compose_receipt(message)
self._client.enqueue(receipt)
if not cmd[0] in self._modules:
logger.debug('Received command for unknown module. Dropping')
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
return
# Just drop messages that are not local when the module should
# be local only
if self._modules[cmd[0]]._restricted:
if self._modules[cmd[0]]._config['restrict_local']:
if not self._is_sender_local(message.from_):
logger.warning('Received a command from a non-local user to a'
' module that is restricted to local users only')
return
elif self._modules[cmd[0]]._config['allowed_domains']:
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
logger.warning('Received a command from a non-whitelisted user to a'
' module that is restricted to whitelisted users only')
return
self._modules[cmd[0]]._on_command(cmd[1:], message)
# Module Function: Send message
def send_message(self, message):
self._client.enqueue(message)
# Module Function: Send a message to @to with @body
def send_message_wrapper(self, to, body):
self.send_message(message_wrapper(to, body))
def main():
parser = OptionParser()
parser.add_option('-d', '--debug', dest='debug',
help='Enable debug logging', action='store_true')
parser.add_option('-c', '--config', dest='config', help='Location of the config.toml',
default='/etc/mira/storage.json')
(options, args) = parser.parse_args()
verbosity = logging.DEBUG if options.debug else logging.INFO
logging.basicConfig(stream=sys.stdout, level=verbosity)
logging.info('Loading config from %s' % (options.config))
bot = MiraBot(options.config)
loop = asyncio.get_event_loop()
loop.run_until_complete(bot.connect())
loop.close()