rio/acme.go

344 lines
8.5 KiB
Go

package main
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"io/ioutil"
"math/big"
"net/http"
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/lego"
log "github.com/sirupsen/logrus"
)
const (
AcmeChallengePathPrefix = "/.well-known/acme-challenge/"
)
var (
// Well-known -> Challenge solution
runningChallenges = make(map[string]string)
Certificates = CertificatesCache{
Certificates: make(map[string]CertificateWrapper),
}
// To access requestingDomains, first acquire the lock.
requestingLock = sync.Mutex{}
// Domain -> _. Check if domain is a key here to see if we're already requeting
// a certificate for it.
requestingDomains = make(map[string]bool)
)
func lockIfUnlockedDomain(domain string) bool {
requestingLock.Lock()
defer requestingLock.Unlock()
_, found := requestingDomains[domain]
if !found {
requestingDomains[domain] = true
}
return found
}
func unlockDomain(domain string) {
requestingLock.Lock()
defer requestingLock.Unlock()
delete(requestingDomains, domain)
}
type CertificateWrapper struct {
TlsCertificate *tls.Certificate `json:"-"`
Domain string `json:"domain"`
NotAfter time.Time `json:"not_after"`
PrivateKeyEncoded string `json:"private_key"`
Certificate []byte `json:"certificate"`
IssuerCertificate []byte `json:"issuer_certificate"`
CertificateUrl string `json:"certificate_url"`
}
func (c *CertificateWrapper) GetPrivateKey() *rsa.PrivateKey {
data, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded)
pk, _ := certcrypto.ParsePEMPrivateKey(data)
return pk.(*rsa.PrivateKey)
}
type CertificatesCache struct {
FallbackCertificate *CertificateWrapper
Certificates map[string]CertificateWrapper
}
type CertificatesStore struct {
FallbackCertificate CertificateWrapper `json:"fallback"`
Certificates []CertificateWrapper `json:"certificates"`
}
func (c *CertificatesCache) toStoreData() string {
certs := make([]CertificateWrapper, 0)
for _, cert := range c.Certificates {
certs = append(certs, cert)
}
result, err := json.Marshal(CertificatesStore{
FallbackCertificate: *c.FallbackCertificate,
Certificates: certs,
})
if err != nil {
log.Errorf("Failed to Marshal cache: %v", err)
}
return string(result)
}
func (c *CertificateWrapper) initTlsCertificate() {
pk, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded)
tlsCert, _ := tls.X509KeyPair(
c.Certificate,
pk,
)
c.TlsCertificate = &tlsCert
}
func CertificateFromStoreData(rawJson string) CertificatesCache {
var store CertificatesStore
_ = json.Unmarshal([]byte(rawJson), &store)
store.FallbackCertificate.initTlsCertificate()
cache := CertificatesCache{
FallbackCertificate: &store.FallbackCertificate,
}
certs := make(map[string]CertificateWrapper)
for _, cert := range store.Certificates {
cert.initTlsCertificate()
certs[cert.Domain] = cert
}
cache.Certificates = certs
return cache
}
func LoadCertificateStoreFromFile(path string) error {
content, err := ioutil.ReadFile(path)
if err != nil {
return err
}
Certificates = CertificateFromStoreData(string(content))
return nil
}
func FlushCertificateStoreToFile(path string) {
data := Certificates.toStoreData()
ioutil.WriteFile(path, []byte(data), 0600)
}
func InitialiseFallbackCert(pagesDomain string) error {
cert, err := fallbackCert(pagesDomain)
Certificates.FallbackCertificate = cert
return err
}
func fallbackCert(pagesDomain string) (*CertificateWrapper, error) {
key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
if err != nil {
return nil, err
}
notAfter := time.Now().Add(time.Hour * 24 * 7)
cert := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: pagesDomain,
Organization: []string{"Pages Server"},
},
NotAfter: notAfter,
NotBefore: time.Now(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
&cert,
&cert,
&key.(*rsa.PrivateKey).PublicKey,
key,
)
if err != nil {
return nil, err
}
out := &bytes.Buffer{}
err = pem.Encode(out, &pem.Block{
Bytes: certBytes,
Type: "CERTIFICATE",
})
if err != nil {
return nil, err
}
outBytes := out.Bytes()
res := &certificate.Resource{
PrivateKey: certcrypto.PEMEncode(key),
Certificate: outBytes,
IssuerCertificate: outBytes,
Domain: pagesDomain,
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
return nil, err
}
return &CertificateWrapper{
TlsCertificate: &tlsCertificate,
Domain: pagesDomain,
NotAfter: notAfter,
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)),
Certificate: outBytes,
IssuerCertificate: outBytes,
CertificateUrl: "localhost",
}, nil
}
func isCertStillValid(cert CertificateWrapper) bool {
return time.Now().Compare(cert.NotAfter) <= -1
}
func makeTlsConfig(pagesDomain, path string, acmeClient *lego.Client) *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Validate that we should even care about this domain
if !strings.HasSuffix(info.ServerName, pagesDomain) {
// Note: We do not check err here because err != nil
// always implies that cname == "", which does not have
// pagesDomain as a suffix.
cname, _ := lookupCNAME(info.ServerName)
if !strings.HasSuffix(cname, pagesDomain) {
log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName)
return Certificates.FallbackCertificate.TlsCertificate, nil
}
}
// If we want to access <user>.<pages domain>, then we can just
// use a wildcard certificate.
domain := info.ServerName
/*if strings.HasSuffix(info.ServerName, pagesDomain) {
domain = "*." + pagesDomain
}*/
cert, found := Certificates.Certificates[info.ServerName]
if found {
if isCertStillValid(cert) {
return cert.TlsCertificate, nil
} else {
// If we're already working on the domain,
// return the old certificate
if lockIfUnlockedDomain(domain) {
return cert.TlsCertificate, nil
}
defer unlockDomain(domain)
// TODO: Renew
log.Debugf("Certificate for %s expired, renewing", domain)
}
} else {
// Don't request if we're already requesting.
if lockIfUnlockedDomain(domain) {
return Certificates.FallbackCertificate.TlsCertificate, nil
}
defer unlockDomain(domain)
// Request new certificate
log.Debugf("Obtaining new certificate for %s...", domain)
err := ObtainNewCertificate(
[]string{domain},
path,
acmeClient,
)
if err != nil {
log.Errorf(
"Failed to get certificate for %s: %v",
domain,
err,
)
return Certificates.FallbackCertificate.TlsCertificate, nil
}
cert, _ = Certificates.Certificates[domain]
return cert.TlsCertificate, nil
}
log.Debugf("TLS ServerName: %s", info.ServerName)
return Certificates.FallbackCertificate.TlsCertificate, nil
},
NextProtos: []string{
"http/0.9",
"http/1.0",
"http/1.1",
"h2",
"h2c",
},
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
}
}
func addChallenge(id, token string) {
runningChallenges[id] = token
}
func removeChallenge(id string) {
delete(runningChallenges, id)
}
func getChallenge(id string) string {
if value, found := runningChallenges[id]; found {
return value
}
return ""
}
func handleLetsEncryptChallenge(w http.ResponseWriter, req *http.Request) bool {
if !strings.HasPrefix(req.URL.Path, AcmeChallengePathPrefix) {
return false
}
log.Debug("Handling ACME challenge path")
id := strings.TrimPrefix(req.URL.Path, AcmeChallengePathPrefix)
challenge := getChallenge(id)
if id == "" {
w.WriteHeader(404)
return true
}
w.WriteHeader(200)
w.Write([]byte(challenge))
removeChallenge(id)
return true
}