Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Implement token-authenticated registration (MSC3231) #3391

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clientapi/auth/authtypes/logintypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ const (
LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token"
LoginTypeRegistrationToken = "m.login.registration_token"
)
68 changes: 65 additions & 3 deletions clientapi/routing/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ type authDict struct {

// Recaptcha
Response string `json:"response"`

// Registration token
Token string `json:"token"`
// TODO: Lots of custom keys depending on the type
}

Expand Down Expand Up @@ -272,9 +275,12 @@ type recaptchaResponse struct {
}

var (
ErrInvalidCaptcha = errors.New("invalid captcha response")
ErrMissingResponse = errors.New("captcha response is required")
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
ErrInvalidCaptcha = errors.New("invalid captcha response")
ErrMissingResponse = errors.New("captcha response is required")
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
ErrRegistrationTokenDisabled = errors.New("token registration is disabled")
ErrMissingToken = errors.New("registration token is required")
ErrInvalidToken = errors.New("invalid registration token")
)

// validateRecaptcha returns an error response if the captcha response is invalid
Expand Down Expand Up @@ -326,6 +332,43 @@ func validateRecaptcha(
return nil
}

// authenticateToken returns an error response if the token is invalid
func authenticateToken(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
token string,
) error {
if !cfg.RegistrationRequiresToken {
return ErrRegistrationTokenDisabled
}

if token == "" {
return ErrMissingToken
}

registrationToken, err := userAPI.ValidateRegistrationToken(req.Context(), token)

if err != nil {
return err
}

if registrationToken == nil {
return ErrInvalidToken
}

// Decrease available uses
newAttributes := make(map[string]interface{})
newAttributes["usesAllowed"] = *registrationToken.UsesAllowed - 1
_, updateErr := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), token, newAttributes)

if updateErr != nil {
return updateErr
}

return nil
}

// UserIDIsWithinApplicationServiceNamespace checks to see if a given userID
// falls within any of the namespaces of a given Application Service. If no
// Application Service is given, it will check to see if it matches any
Expand Down Expand Up @@ -733,6 +776,25 @@ func handleRegistrationFlow(
// Add Recaptcha to the list of completed registration stages
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)

case authtypes.LoginTypeRegistrationToken:
// Check given token response
err := authenticateToken(req, userAPI, cfg, r.Auth.Token)
switch err {
case ErrRegistrationTokenDisabled:
return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())}
case ErrMissingToken:
return util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(err.Error())}
case ErrInvalidToken:
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: spec.BadJSON(err.Error())}
case nil:
default:
util.GetLogger(req.Context()).WithError(err).Error("failed to validate token")
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}}
}

// Add RegistrationToken to the list of completed registration stages
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRegistrationToken)

case authtypes.LoginTypeDummy:
// there is nothing to do
// Add Dummy to the list of completed registration stages
Expand Down
4 changes: 4 additions & 0 deletions setup/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ func (config *Dendrite) Derive() error {
config.Derived.Registration.Flows = []authtypes.Flow{
{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}},
}
} else if config.ClientAPI.RegistrationRequiresToken {
config.Derived.Registration.Flows = []authtypes.Flow{
{Stages: []authtypes.LoginType{authtypes.LoginTypeRegistrationToken}},
}
} else {
config.Derived.Registration.Flows = []authtypes.Flow{
{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}},
Expand Down
1 change: 1 addition & 0 deletions userapi/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
ValidateRegistrationToken(ctx context.Context, registrationToken string) (*clientapi.RegistrationToken, error)
}

type KeyBackupAPI interface {
Expand Down
10 changes: 10 additions & 0 deletions userapi/internal/user_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,13 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
}

const pushRulesAccountDataType = "m.push_rules"

func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, token string) (*clientapi.RegistrationToken, error) {
registrationToken, _ := a.DB.GetRegistrationToken(ctx, token)

if registrationToken == nil || *registrationToken.UsesAllowed == 0 || *registrationToken.ExpiryTime < int64(spec.AsTimestamp(time.Now())) {
return nil, nil
}

return registrationToken, nil
}