86 lines
2.1 KiB
Python
86 lines
2.1 KiB
Python
import uuid
|
|
import os
|
|
import urllib.parse
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
import requests
|
|
|
|
|
|
class SessionRequest(BaseModel):
|
|
# Internal session ID
|
|
session_id: str
|
|
|
|
# URL to show the user to perform authentication
|
|
auth_url: str
|
|
|
|
class TokenResponse(BaseModel):
|
|
token: str
|
|
|
|
OIDC_BASE_URI: str = os.environ["OIDC_BASE_URI"]
|
|
REDIRECT_URI: str = os.environ["REDIRECT_URI"]
|
|
CLIENT_ID: str = os.environ["CLIENT_ID"]
|
|
CLIENT_SECRET: str = os.environ["CLIENT_SECRET"]
|
|
app = FastAPI()
|
|
|
|
sessions: dict[str, str | None] = {}
|
|
|
|
@app.get("/.well-known/xmpp/oidc")
|
|
def request_session(provider: str) -> SessionRequest:
|
|
"""
|
|
Build the correct "session" and GET url to authenticate to the OIDC
|
|
provider.
|
|
"""
|
|
# TODO: Actually use provider
|
|
|
|
sid = uuid.uuid4().hex
|
|
sessions[sid] = None
|
|
|
|
params: dict[str, str] = {
|
|
"response_type": "code",
|
|
"scope": "openid",
|
|
"client_id": CLIENT_ID,
|
|
"state": sid,
|
|
"redirect_uri": REDIRECT_URI,
|
|
}
|
|
return SessionRequest(
|
|
session_id=sid,
|
|
auth_url=f"{OIDC_BASE_URI}/authorize/" + "?" + urllib.parse.urlencode(params),
|
|
)
|
|
|
|
@app.get("/.well-known/xmpp/token")
|
|
def request_token(sid: str) -> TokenResponse:
|
|
"""
|
|
Acquire the token that the server got from the OIDC provider.
|
|
"""
|
|
|
|
if sessions.get(sid) is None:
|
|
raise HTTPException(404)
|
|
|
|
return TokenResponse(
|
|
token=sessions.pop(sid),
|
|
)
|
|
|
|
@app.get("/.well-known/xmpp/callback")
|
|
def oidc_callback(code: str, state: str) -> None:
|
|
"""
|
|
Callback for the OIDC redirect
|
|
"""
|
|
|
|
token_req = requests.post(
|
|
f"{OIDC_BASE_URI}/token/",
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": REDIRECT_URI,
|
|
"client_id": CLIENT_ID,
|
|
"client_secret": CLIENT_SECRET,
|
|
},
|
|
)
|
|
if not token_req.ok:
|
|
print(f"Status Code: {token_req.status_code}")
|
|
print(f"Body: {token_req.text}")
|
|
raise HTTPException(500)
|
|
|
|
sessions[state] = token_req.json()["access_token"]
|