diff --git a/internal/jimm/model.go b/internal/jimm/model.go index 0b60539b3..2027bbbea 100644 --- a/internal/jimm/model.go +++ b/internal/jimm/model.go @@ -4,6 +4,7 @@ package jimm import ( "context" + "database/sql" "fmt" "math/rand" "sort" @@ -686,6 +687,20 @@ func (j *JIMM) AddModel(ctx context.Context, user *openfga.User, args *ModelCrea return mi, nil } +func (j *JIMM) GetModel(ctx context.Context, uuid string) (dbmodel.Model, error) { + model := dbmodel.Model{ + UUID: sql.NullString{ + String: uuid, + Valid: uuid != "", + }, + } + if err := j.Database.GetModel(context.Background(), &model); err != nil { + zapctx.Error(ctx, "failed to find model", zap.String("uuid", uuid), zap.Error(err)) + return dbmodel.Model{}, fmt.Errorf("failed to get model: %s", err.Error()) + } + return model, nil +} + // ModelInfo returns the model info for the model with the given ModelTag. // The returned ModelInfo will be appropriate for the given user's // access-level on the model. If the model does not exist then the returned diff --git a/internal/jimm/model_test.go b/internal/jimm/model_test.go index f7e256803..76d0515e4 100644 --- a/internal/jimm/model_test.go +++ b/internal/jimm/model_test.go @@ -1182,6 +1182,72 @@ func assertConfig(config map[string]interface{}, fnc func(context.Context, *juju } +const getModelTestEnv = `clouds: +- name: test-cloud + type: test-provider + regions: + - name: test-cloud-region +cloud-credentials: +- owner: alice@canonical.com + name: cred-1 + cloud: test-cloud +controllers: +- name: controller-1 + uuid: 00000001-0000-0000-0000-000000000001 + cloud: test-cloud + region: test-cloud-region +models: +- name: model-1 + type: iaas + uuid: 00000002-0000-0000-0000-000000000001 + controller: controller-1 + default-series: warty + cloud: test-cloud + region: test-cloud-region + cloud-credential: cred-1 + owner: alice@canonical.com + life: alive + status: + status: available + info: "OK!" + since: 2020-02-20T20:02:20Z + sla: + level: unsupported + agent-version: 1.2.3 +` + +func TestGetModel(t *testing.T) { + ctx := context.Background() + c := qt.New(t) + + client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name()) + c.Assert(err, qt.IsNil) + + j := &jimm.JIMM{ + UUID: uuid.NewString(), + OpenFGAClient: client, + Database: db.Database{ + DB: jimmtest.PostgresDB(c, nil), + }, + } + err = j.Database.Migrate(ctx, false) + c.Assert(err, qt.IsNil) + + env := jimmtest.ParseEnvironment(c, getModelTestEnv) + env.PopulateDBAndPermissions(c, j.ResourceTag(), j.Database, client) + + // Get model + model, err := j.GetModel(ctx, env.Models[0].UUID) + c.Assert(err, qt.IsNil) + c.Assert(model.UUID.String, qt.Equals, env.Models[0].UUID) + c.Assert(model.Name, qt.Equals, env.Models[0].Name) + c.Assert(model.ControllerID, qt.Equals, env.Models[0].DBObject(c, j.Database).ControllerID) + + // Get model that doesn't exist + _, err = j.GetModel(ctx, "fake-uuid") + c.Assert(err, qt.ErrorMatches, "failed to get model: model not found") +} + // Note that this env does not give the everyone user access to the model. const modelInfoTestEnv = `clouds: - name: test-cloud diff --git a/internal/jujuapi/streamproxy.go b/internal/jujuapi/streamproxy.go index c54788a38..2ef9f270a 100644 --- a/internal/jujuapi/streamproxy.go +++ b/internal/jujuapi/streamproxy.go @@ -3,7 +3,6 @@ package jujuapi import ( "context" - "database/sql" "fmt" "net/http" @@ -14,7 +13,6 @@ import ( "go.uber.org/zap" "github.com/canonical/jimm/v3/internal/auth" - "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimmhttp" "github.com/canonical/jimm/v3/internal/openfga" @@ -58,23 +56,27 @@ func (s streamProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) zapctx.Error(ctx, "failed to write error message to client", zap.Error(err), zap.Any("client message", errResult)) } } + user, err := s.jimm.UserLogin(ctx, auth.SessionIdentityFromContext(ctx)) if err != nil { zapctx.Error(ctx, "user login error", zap.Error(err)) writeError(err.Error(), errors.CodeUnauthorized) return } + uuid, finalPath, err := modelInfoFromPath(jimmhttp.PathElementFromContext(ctx, "path")) if err != nil { zapctx.Error(ctx, "error parsing path", zap.Error(err)) writeError(fmt.Sprintf("error parsing path: %s", err.Error()), errors.CodeBadRequest) return } - model, err := s.getModel(ctx, uuid) + + model, err := s.jimm.GetModel(ctx, uuid) if err != nil { writeError(err.Error(), errors.CodeModelNotFound) return } + if ok, err := checkPermission(ctx, finalPath, user, model.ResourceTag()); err != nil { writeError(err.Error(), errors.CodeUnauthorized) return @@ -82,6 +84,7 @@ func (s streamProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) writeError(fmt.Sprintf("unauthorized access to endpoint: %s", finalPath), errors.CodeUnauthorized) return } + api, err := s.jimm.Dialer.Dial(ctx, &model.Controller, model.ResourceTag(), nil) if err != nil { zapctx.Error(ctx, "failed to dial controller", zap.Error(err)) @@ -89,27 +92,15 @@ func (s streamProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) return } defer api.Close() + controllerStream, err := api.ConnectStream(finalPath, nil) if err != nil { zapctx.Error(ctx, "failed to connect stream", zap.Error(err)) writeError(fmt.Sprintf("failed to connect stream: %s", err.Error()), errors.CodeConnectionFailed) return } - jimmRPC.ProxyStreams(ctx, clientConn, controllerStream) -} -func (s streamProxier) getModel(ctx context.Context, modelUUID string) (dbmodel.Model, error) { - model := dbmodel.Model{ - UUID: sql.NullString{ - String: modelUUID, - Valid: modelUUID != "", - }, - } - if err := s.jimm.Database.GetModel(context.Background(), &model); err != nil { - zapctx.Error(ctx, "failed to find model", zap.String("uuid", modelUUID), zap.Error(err)) - return dbmodel.Model{}, fmt.Errorf("failed to find model: %s", err.Error()) - } - return model, nil + jimmRPC.ProxyStreams(ctx, clientConn, controllerStream) } func checkPermission(ctx context.Context, path string, u *openfga.User, mt names.ModelTag) (bool, error) { diff --git a/internal/jujuclient/dial.go b/internal/jujuclient/dial.go index 821de3a2c..4ea3d8698 100644 --- a/internal/jujuclient/dial.go +++ b/internal/jujuclient/dial.go @@ -341,16 +341,17 @@ func (c *Connection) Context() context.Context { // when making the initial HTTP request. func (c *Connection) ConnectStream(path string, attrs url.Values) (base.Stream, error) { const op = errors.Op("jujuclient.ConnectStream") + modelTag, ok := c.ModelTag() + if !ok { + return nil, errors.E(op, "no model found") + } + user, pass, err := c.dialer.ControllerCredentialsStore.GetControllerCredentials(c.ctx, c.ctl.Name) if err != nil { return nil, errors.E(op, err) } requestHeader := jujuhttp.BasicAuthHeader(names.NewUserTag(user).String(), pass) - modelTag, ok := c.ModelTag() - if !ok { - return nil, errors.E(op, "no model found") - } conn, err := rpc.Dial(c.ctx, c.ctl, modelTag, path, requestHeader) if err != nil { return nil, errors.E(op, err)