Skip to content

Commit

Permalink
Add x-remote header auth support
Browse files Browse the repository at this point in the history
  • Loading branch information
sd109 committed Aug 30, 2024
1 parent 336c41b commit 0777c6b
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 29 deletions.
12 changes: 6 additions & 6 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ async def create(
) -> models.UP:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
# if hasattr(user_create, "role"):
# user_count = await get_user_count()
# if user_count == 0 or user_create.email in get_default_admin_user_emails():
# user_create.role = UserRole.ADMIN
# else:
# user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore

async def oauth_callback(
Expand Down
10 changes: 5 additions & 5 deletions backend/danswer/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ async def get_user_count() -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC
# user_count = await get_user_count()
# if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
# create_dict["role"] = UserRole.ADMIN
# else:
# create_dict["role"] = UserRole.BASIC
return await super().create(create_dict)


Expand Down
69 changes: 69 additions & 0 deletions web/src/app/auth/login/HeaderLogin.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"use client";

import { usePopup } from "@/components/admin/connectors/Popup";
import { basicLogin, basicSignup } from "@/lib/user";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import { Spinner } from "@/components/Spinner";

export function HeaderLoginLoading({
user, groups
}: {
user: string;
groups: string[];
}) {
console.log(user, groups);

const router = useRouter();
const { popup, setPopup } = usePopup();
const email = `${user}@default.com`;
const password = `not-used-${window.crypto.randomUUID()}`
const role = groups.includes("/admins") ? "admin" : "basic"

async function tryLogin() {
// TODO: Update user role here if groups have changed?

// TODO: Use other API endpoints here to update user roles
// and check for existence instead of attempting sign up
// Endpoints:
// - /api/manage/users
// - /api/manage/promote-user-to-admin (auth required)
// - /api/manage/demote-admin-to-user (auth required)

// signup every time.
// Ensure user exists
const response = await basicSignup(email, password, role);
if (!response.ok) {
const errorDetail = (await response.json()).detail;

if (errorDetail !== "REGISTER_USER_ALREADY_EXISTS") {
setPopup({
type: "error",
message: `Failed to sign up - ${errorDetail}`,
});
}
}
// Login as user
const loginResponse = await basicLogin(email, password);
if (loginResponse.ok) {
router.push("/");
} else {
const errorDetail = (await loginResponse.json()).detail;
setPopup({
type: "error",
message: `Failed to login - ${errorDetail}`,
});
}
}

useEffect(() => {
tryLogin()
}, []);

return (
<>
{popup}
<Spinner />
</>
);
}
41 changes: 24 additions & 17 deletions web/src/app/auth/login/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import Link from "next/link";
import { Logo } from "@/components/Logo";
import { LoginText } from "./LoginText";
import { getSecondsUntilExpiration } from "@/lib/time";
import { headers } from 'next/headers';
import { HeaderLoginLoading } from "./HeaderLogin";

const Page = async ({
searchParams,
Expand Down Expand Up @@ -69,6 +71,9 @@ const Page = async ({
return redirect(authUrl);
}

const userHeader = headers().get('x-remote-user');
const groupsHeader = headers().get('x-remote-group');

return (
<main>
<div className="absolute top-10x w-full">
Expand All @@ -90,23 +95,25 @@ const Page = async ({
</>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
<LoginText />
</Title>
</div>
<EmailPasswordForm />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</Card>
)}
(userHeader && groupsHeader) ?
<HeaderLoginLoading user={userHeader} groups={groupsHeader.split(',')} /> : (
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
<LoginText />
</Title>
</div>
<EmailPasswordForm />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</Card>
))}
</div>
</div>
</main>
Expand Down
3 changes: 2 additions & 1 deletion web/src/lib/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export const basicLogin = async (
return response;
};

export const basicSignup = async (email: string, password: string) => {
export const basicSignup = async (email: string, password: string, role = "basic") => {
const response = await fetch("/api/auth/register", {
method: "POST",
credentials: "include",
Expand All @@ -53,6 +53,7 @@ export const basicSignup = async (email: string, password: string) => {
email,
username: email,
password,
role,
}),
});
return response;
Expand Down

0 comments on commit 0777c6b

Please sign in to comment.