344 lines
8.3 KiB
Go
344 lines
8.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"io/ioutil"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-acme/lego/v4/certcrypto"
|
|
"github.com/go-acme/lego/v4/certificate"
|
|
"github.com/go-acme/lego/v4/lego"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
AcmeChallengePathPrefix = "/.well-known/acme-challenge/"
|
|
)
|
|
|
|
var (
|
|
// Well-known -> Challenge solution
|
|
runningChallenges = make(map[string]string)
|
|
Certificates = CertificatesCache{
|
|
Certificates: make(map[string]CertificateWrapper),
|
|
}
|
|
|
|
// To access requestingDomains, first acquire the lock.
|
|
requestingLock = sync.Mutex{}
|
|
|
|
// Domain -> _. Check if domain is a key here to see if we're already requeting
|
|
// a certificate for it.
|
|
requestingDomains = make(map[string]bool)
|
|
)
|
|
|
|
func lockIfUnlockedDomain(domain string) bool {
|
|
requestingLock.Lock()
|
|
defer requestingLock.Unlock()
|
|
|
|
_, found := requestingDomains[domain]
|
|
if !found {
|
|
requestingDomains[domain] = true
|
|
}
|
|
|
|
return found
|
|
}
|
|
|
|
func unlockDomain(domain string) {
|
|
requestingLock.Lock()
|
|
defer requestingLock.Unlock()
|
|
|
|
delete(requestingDomains, domain)
|
|
}
|
|
|
|
type CertificateWrapper struct {
|
|
TlsCertificate *tls.Certificate `json:"-"`
|
|
Domain string `json:"domain"`
|
|
NotAfter time.Time `json:"not_after"`
|
|
PrivateKeyEncoded string `json:"private_key"`
|
|
Certificate []byte `json:"certificate"`
|
|
IssuerCertificate []byte `json:"issuer_certificate"`
|
|
CertificateUrl string `json:"certificate_url"`
|
|
}
|
|
|
|
func (c *CertificateWrapper) GetPrivateKey() *rsa.PrivateKey {
|
|
data, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded)
|
|
pk, _ := certcrypto.ParsePEMPrivateKey(data);
|
|
|
|
return pk.(*rsa.PrivateKey)
|
|
}
|
|
|
|
type CertificatesCache struct {
|
|
FallbackCertificate *CertificateWrapper
|
|
Certificates map[string]CertificateWrapper
|
|
}
|
|
|
|
type CertificatesStore struct {
|
|
FallbackCertificate CertificateWrapper `json:"fallback"`
|
|
Certificates []CertificateWrapper `json:"certificates"`
|
|
}
|
|
|
|
func (c *CertificatesCache) toStoreData() string {
|
|
certs := make([]CertificateWrapper, 0)
|
|
for _, cert := range c.Certificates {
|
|
certs = append(certs, cert)
|
|
}
|
|
|
|
result, err := json.Marshal(CertificatesStore{
|
|
FallbackCertificate: *c.FallbackCertificate,
|
|
Certificates: certs,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("Failed to Marshal cache: %v", err)
|
|
}
|
|
return string(result)
|
|
}
|
|
|
|
func (c *CertificateWrapper) initTlsCertificate() {
|
|
pk, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded)
|
|
tlsCert, _ := tls.X509KeyPair(
|
|
c.Certificate,
|
|
pk,
|
|
)
|
|
c.TlsCertificate = &tlsCert
|
|
}
|
|
|
|
func CertificateFromStoreData(rawJson string) CertificatesCache {
|
|
var store CertificatesStore
|
|
_ = json.Unmarshal([]byte(rawJson), &store)
|
|
|
|
store.FallbackCertificate.initTlsCertificate()
|
|
cache := CertificatesCache{
|
|
FallbackCertificate: &store.FallbackCertificate,
|
|
}
|
|
|
|
certs := make(map[string]CertificateWrapper)
|
|
for _, cert := range store.Certificates {
|
|
cert.initTlsCertificate()
|
|
certs[cert.Domain] = cert
|
|
}
|
|
cache.Certificates = certs
|
|
|
|
return cache
|
|
}
|
|
|
|
func LoadCertificateStoreFromFile(path string) error {
|
|
content, err := ioutil.ReadFile(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
Certificates = CertificateFromStoreData(string(content))
|
|
return nil
|
|
}
|
|
|
|
func FlushCertificateStoreToFile(path string) {
|
|
data := Certificates.toStoreData()
|
|
ioutil.WriteFile(path, []byte(data), 0600)
|
|
}
|
|
|
|
func InitialiseFallbackCert(pagesDomain string) error {
|
|
cert, err := fallbackCert(pagesDomain)
|
|
Certificates.FallbackCertificate = cert
|
|
return err
|
|
}
|
|
|
|
func fallbackCert(pagesDomain string) (*CertificateWrapper, error) {
|
|
key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
notAfter := time.Now().Add(time.Hour*24*7);
|
|
cert := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
CommonName: pagesDomain,
|
|
Organization: []string{"Pages Server"},
|
|
},
|
|
NotAfter: notAfter,
|
|
NotBefore: time.Now(),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certBytes, err := x509.CreateCertificate(
|
|
rand.Reader,
|
|
&cert,
|
|
&cert,
|
|
&key.(*rsa.PrivateKey).PublicKey,
|
|
key,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
out := &bytes.Buffer{}
|
|
err = pem.Encode(out, &pem.Block{
|
|
Bytes: certBytes,
|
|
Type: "CERTIFICATE",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
outBytes := out.Bytes()
|
|
res := &certificate.Resource{
|
|
PrivateKey: certcrypto.PEMEncode(key),
|
|
Certificate: outBytes,
|
|
IssuerCertificate: outBytes,
|
|
Domain: pagesDomain,
|
|
}
|
|
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &CertificateWrapper{
|
|
TlsCertificate: &tlsCertificate,
|
|
Domain: pagesDomain,
|
|
NotAfter: notAfter,
|
|
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)),
|
|
Certificate: outBytes,
|
|
IssuerCertificate: outBytes,
|
|
CertificateUrl: "localhost",
|
|
}, nil
|
|
}
|
|
|
|
func isCertStillValid(cert CertificateWrapper) bool {
|
|
return time.Now().Compare(cert.NotAfter) <= -1
|
|
}
|
|
|
|
func makeTlsConfig(pagesDomain, path string, acmeClient *lego.Client) *tls.Config {
|
|
return &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
GetCertificate: func (info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
// Validate that we should even care about this domain
|
|
if !strings.HasSuffix(info.ServerName, pagesDomain) {
|
|
// Note: We do not check err here because err != nil
|
|
// always implies that cname == "", which does not have
|
|
// pagesDomain as a suffix.
|
|
cname, _ := lookupCNAME(info.ServerName)
|
|
if !strings.HasSuffix(cname, pagesDomain) {
|
|
log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName)
|
|
return Certificates.FallbackCertificate.TlsCertificate, nil
|
|
}
|
|
}
|
|
|
|
// If we want to access <user>.<pages domain>, then we can just
|
|
// use a wildcard certificate.
|
|
domain := info.ServerName
|
|
/*if strings.HasSuffix(info.ServerName, pagesDomain) {
|
|
domain = "*." + pagesDomain
|
|
}*/
|
|
|
|
cert, found := Certificates.Certificates[info.ServerName]
|
|
if found {
|
|
if isCertStillValid(cert) {
|
|
return cert.TlsCertificate, nil
|
|
} else {
|
|
// If we're already working on the domain,
|
|
// return the old certificate
|
|
if lockIfUnlockedDomain(domain) {
|
|
return cert.TlsCertificate, nil
|
|
}
|
|
defer unlockDomain(domain)
|
|
|
|
// TODO: Renew
|
|
log.Debugf("Certificate for %s expired, renewing", domain)
|
|
}
|
|
} else {
|
|
// Don't request if we're already requesting.
|
|
if lockIfUnlockedDomain(domain) {
|
|
return Certificates.FallbackCertificate.TlsCertificate, nil
|
|
}
|
|
defer unlockDomain(domain)
|
|
|
|
// Request new certificate
|
|
log.Debugf("Obtaining new certificate for %s...", domain)
|
|
err := ObtainNewCertificate(
|
|
[]string{domain},
|
|
path,
|
|
acmeClient,
|
|
)
|
|
if err != nil {
|
|
log.Errorf(
|
|
"Failed to get certificate for %s: %v",
|
|
domain,
|
|
err,
|
|
)
|
|
return Certificates.FallbackCertificate.TlsCertificate, nil
|
|
}
|
|
|
|
cert, _ = Certificates.Certificates[domain]
|
|
return cert.TlsCertificate, nil
|
|
}
|
|
|
|
log.Debugf("TLS ServerName: %s", info.ServerName)
|
|
return Certificates.FallbackCertificate.TlsCertificate, nil
|
|
},
|
|
NextProtos: []string{
|
|
"http/0.9",
|
|
"http/1.0",
|
|
"http/1.1",
|
|
"h2",
|
|
"h2c",
|
|
},
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
},
|
|
}
|
|
}
|
|
|
|
func addChallenge(id, token string) {
|
|
runningChallenges[id] = token
|
|
}
|
|
|
|
func removeChallenge(id string) {
|
|
delete(runningChallenges, id)
|
|
}
|
|
|
|
func getChallenge(id string) string {
|
|
if value, found := runningChallenges[id]; found {
|
|
return value
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func handleLetsEncryptChallenge(w http.ResponseWriter, req *http.Request) bool {
|
|
if !strings.HasPrefix(req.URL.Path, AcmeChallengePathPrefix) {
|
|
return false
|
|
}
|
|
|
|
log.Debug("Handling ACME challenge path")
|
|
id := strings.TrimPrefix(req.URL.Path, AcmeChallengePathPrefix)
|
|
challenge := getChallenge(id)
|
|
if id == "" {
|
|
w.WriteHeader(404)
|
|
return true
|
|
}
|
|
|
|
w.WriteHeader(200)
|
|
w.Write([]byte(challenge))
|
|
|
|
removeChallenge(id)
|
|
return true
|
|
}
|