147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
'''
|
|
Copyright (C) 2021 Alexander "PapaTutuWawa"
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
'''
|
|
import sys
|
|
import importlib
|
|
import asyncio
|
|
import logging
|
|
from optparse import OptionParser
|
|
|
|
import aioxmpp
|
|
import toml
|
|
|
|
from mira.subscription import SubscriptionManager
|
|
from mira.storage import StorageManager
|
|
|
|
logger = logging.getLogger('mira.base')
|
|
|
|
def message_wrapper(to, body):
|
|
msg = aioxmpp.Message(
|
|
type_=aioxmpp.MessageType.CHAT,
|
|
to=to)
|
|
msg.body[None] = body
|
|
return msg
|
|
|
|
class MiraBot:
|
|
def __init__(self, config_path):
|
|
# Bot specific settings
|
|
self._config = toml.load(config_path)
|
|
|
|
self._jid = aioxmpp.JID.fromstr(self._config['jid'])
|
|
self._password = self._config['password']
|
|
self._avatar = self._config.get('avatar', None)
|
|
self._client = None
|
|
|
|
self._modules = {} # Module name -> module
|
|
self._storage_manager = StorageManager.get_instance(self._config.get('storage_path', '/etc/mira/storage.json'))
|
|
self._subscription_manager = SubscriptionManager.get_instance()
|
|
|
|
def _initialise_modules(self):
|
|
for module in self._config['modules']:
|
|
logger.debug("Initialising module %s" % (module))
|
|
mod = importlib.import_module(module['name'])
|
|
self._modules[mod.NAME] = mod.get_instance(self, config=module, name=mod.NAME)
|
|
|
|
async def connect(self):
|
|
self._client = aioxmpp.PresenceManagedClient(
|
|
self._jid,
|
|
aioxmpp.make_security_layer(self._password))
|
|
async with self._client.connected():
|
|
logger.info('Client connected')
|
|
self._client.stream.register_message_callback(
|
|
aioxmpp.MessageType.CHAT,
|
|
None,
|
|
self._on_message)
|
|
|
|
if self._avatar:
|
|
logger.info('Publishing avatar')
|
|
with open(self._avatar, 'rb') as avatar_file:
|
|
data = avatar_file.read()
|
|
avatar_set = aioxmpp.avatar.AvatarSet()
|
|
# TODO: Detect MIME type
|
|
avatar_set.add_avatar_image('image/png', image_bytes=data)
|
|
avatar = self._client.summon(aioxmpp.avatar.AvatarService)
|
|
await avatar.publish_avatar_set(avatar_set)
|
|
logger.info('Avatar published')
|
|
|
|
logger.debug('Initialising modules')
|
|
self._initialise_modules()
|
|
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
|
|
def _is_sender_local(self, from_):
|
|
return from_.domain == self._jid.domain
|
|
|
|
def _on_message(self, message):
|
|
# Automatically handles sending a message receipt and dealing
|
|
# with unwanted messages
|
|
if (message.type_ != aioxmpp.MessageType.CHAT or
|
|
not message.body):
|
|
return
|
|
|
|
cmd = str(message.body.any()).split(' ')
|
|
|
|
receipt = aioxmpp.mdr.compose_receipt(message)
|
|
self._client.enqueue(receipt)
|
|
|
|
if not cmd[0] in self._modules:
|
|
logger.debug('Received command for unknown module. Dropping')
|
|
self._client.enqueue(message_wrapper(message.from_, "Unbekannter Befehl"))
|
|
return
|
|
|
|
# Just drop messages that are not local when the module should
|
|
# be local only
|
|
if self._modules[cmd[0]]._restricted:
|
|
if self._modules[cmd[0]]._config['restrict_local']:
|
|
if not self._is_sender_local(message.from_):
|
|
logger.warning('Received a command from a non-local user to a'
|
|
' module that is restricted to local users only')
|
|
return
|
|
elif self._modules[cmd[0]]._config['allowed_domains']:
|
|
if not message.from_.domain in self._modules[cmd[0]]._config['allowed_domains']:
|
|
logger.warning('Received a command from a non-whitelisted user to a'
|
|
' module that is restricted to whitelisted users only')
|
|
return
|
|
|
|
self._modules[cmd[0]]._on_command(cmd[1:], message)
|
|
|
|
# Module Function: Send message
|
|
def send_message(self, message):
|
|
self._client.enqueue(message)
|
|
|
|
# Module Function: Send a message to @to with @body
|
|
def send_message_wrapper(self, to, body):
|
|
self.send_message(message_wrapper(to, body))
|
|
|
|
def main():
|
|
parser = OptionParser()
|
|
parser.add_option('-d', '--debug', dest='debug',
|
|
help='Enable debug logging', action='store_true')
|
|
parser.add_option('-c', '--config', dest='config', help='Location of the config.toml',
|
|
default='/etc/mira/storage.json')
|
|
(options, args) = parser.parse_args()
|
|
|
|
verbosity = logging.DEBUG if options.debug else logging.INFO
|
|
logging.basicConfig(stream=sys.stdout, level=verbosity)
|
|
|
|
logging.info('Loading config from %s' % (options.config))
|
|
bot = MiraBot(options.config)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(bot.connect())
|
|
loop.close()
|