package controllers import ( "bytes" "encoding/base32" "encoding/base64" "fmt" "image/png" "git.readonly.ch/bouzoure/popvaud-people/helpers" "git.readonly.ch/bouzoure/popvaud-people/models" "github.com/gofiber/fiber/v2" "github.com/pquerna/otp/totp" ) func TotpEnrollPage(c *fiber.Ctx) error { db, err := helpers.GetDatabase() if err != nil { return err } userid, err := helpers.GetSessionUserId(c) if err != nil { return err } var user models.User result := db.First(&user, "id = ?", userid) if result.Error != nil { return result.Error } if user.TotpSecret.Valid { return fiber.NewError(fiber.StatusForbidden, "Forbidden") } sess, err := helpers.GetSessionStore(c) if err != nil { return err } options := totp.GenerateOpts{ Issuer: "POP Vaud", AccountName: user.Email, } existingSecret := sess.Get("totp-enroll-secret") if existingSecret != nil { var b32NoPadding = base32.StdEncoding.WithPadding(base32.NoPadding) options.Secret, err = b32NoPadding.DecodeString(existingSecret.(string)) if err != nil { return err } } key, err := totp.Generate(options) if err != nil { return err } img, err := key.Image(200, 200) if err != nil { return err } var buf bytes.Buffer err = png.Encode(&buf, img) if err != nil { return err } imgBase64 := fmt.Sprintf( "data:image/png;base64,%s", base64.StdEncoding.EncodeToString(buf.Bytes()), ) var mfaError string if c.Method() == "POST" { otp := c.FormValue("otp") if totp.Validate(otp, key.Secret()) { err = user.TotpSecret.Scan(key.Secret()) if err != nil { return err } result = db.Save(&user) if result.Error != nil { return err } sess.Set("totp-verified", "yes") redirectId := c.Query("redirect") redirectUrl := "/" if len(redirectId) > 0 { redirectKey := fmt.Sprintf("redirect-%s", redirectId) redirectVal := sess.Get(redirectKey) if redirectVal != nil { redirectUrl = redirectVal.(string) } } err = sess.Save() if err != nil { return err } return c.Redirect(redirectUrl) } else { mfaError = "Code temporaire invalide" } } sess.Set("totp-enroll-secret", key.Secret()) err = sess.Save() if err != nil { return err } return c.Render("totp_enroll", fiber.Map{ "PageTitle": "Enregistrement multifacteur (TOTP)", "QrCode": imgBase64, "Secret": key.Secret(), "MfaError": mfaError, }) } func TotpVerifyPage(c *fiber.Ctx) error { db, err := helpers.GetDatabase() if err != nil { return err } sess, err := helpers.GetSessionStore(c) if err != nil { return err } totpVerified := sess.Get("totp-verified") if totpVerified != nil { return fiber.NewError(fiber.StatusForbidden, "Forbidden") } userid, err := helpers.GetSessionUserId(c) if err != nil { return err } var user models.User result := db.First(&user, "id = ?", userid) if result.Error != nil { return result.Error } if !user.TotpSecret.Valid { return fiber.NewError(fiber.StatusForbidden, "Forbidden") } var mfaError string if c.Method() == "POST" { otp := c.FormValue("otp") if totp.Validate(otp, user.TotpSecret.String) { redirectId := c.Query("redirect") redirectUrl := "/" if len(redirectId) > 0 { redirectKey := fmt.Sprintf("redirect-%s", redirectId) redirectVal := sess.Get(redirectKey) if redirectVal != nil { redirectUrl = redirectVal.(string) } } sess.Set("totp-verified", "yes") sess.Save() return c.Redirect(redirectUrl) } else { mfaError = "Code temporaire invalide" } } return c.Render("totp_verify", fiber.Map{ "PageTitle": "Vérification multifacteur (TOTP)", "MfaError": mfaError, }) }