Skip to content

Commit

Permalink
♻️ refactor: refactor the jwt code (#4322)
Browse files Browse the repository at this point in the history
* ♻️ refactor: refactor the jwt code

* clean

* fix lint
  • Loading branch information
arvinxx authored Oct 11, 2024
1 parent 16bce6d commit b7258b9
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 80 deletions.
8 changes: 6 additions & 2 deletions src/app/(backend)/api/chat/[provider]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import { getAuth } from '@clerk/nextjs/server';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { checkAuthMethod, getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
import { checkAuthMethod } from '@/app/(backend)/middleware/auth/utils';
import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
import { AgentRuntime, LobeRuntimeAI } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';
import { getJWTPayload } from '@/utils/server/jwt';

import { POST } from './route';

Expand All @@ -14,10 +15,13 @@ vi.mock('@clerk/nextjs/server', () => ({
}));

vi.mock('@/app/(backend)/middleware/auth/utils', () => ({
getJWTPayload: vi.fn(),
checkAuthMethod: vi.fn(),
}));

vi.mock('@/utils/server/jwt', () => ({
getJWTPayload: vi.fn(),
}));

// 定义一个变量来存储 enableAuth 的值
let enableClerk = false;

Expand Down
7 changes: 5 additions & 2 deletions src/app/(backend)/middleware/auth/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { getAuth } from '@clerk/nextjs/server';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { AgentRuntimeError } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';
import { createErrorResponse } from '@/utils/errorResponse';
import { getJWTPayload } from '@/utils/server/jwt';

import { RequestHandler, checkAuth } from './index';
import { checkAuthMethod, getJWTPayload } from './utils';
import { checkAuthMethod } from './utils';

vi.mock('@clerk/nextjs/server', () => ({
getAuth: vi.fn(),
Expand All @@ -18,6 +18,9 @@ vi.mock('@/utils/errorResponse', () => ({

vi.mock('./utils', () => ({
checkAuthMethod: vi.fn(),
}));

vi.mock('@/utils/server/jwt', () => ({
getJWTPayload: vi.fn(),
}));

Expand Down
3 changes: 2 additions & 1 deletion src/app/(backend)/middleware/auth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import { JWTPayload, LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED, enableClerk } from
import { AgentRuntime, AgentRuntimeError, ChatCompletionErrorPayload } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';
import { createErrorResponse } from '@/utils/errorResponse';
import { getJWTPayload } from '@/utils/server/jwt';

import { checkAuthMethod, getJWTPayload } from './utils';
import { checkAuthMethod } from './utils';

type CreateRuntime = (jwtPayload: JWTPayload) => AgentRuntime;
type RequestOptions = { createRuntime?: CreateRuntime; params: { provider: string } };
Expand Down
39 changes: 1 addition & 38 deletions src/app/(backend)/middleware/auth/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ import { type AuthObject } from '@clerk/backend';
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { getAppConfig } from '@/config/app';
import { NON_HTTP_PREFIX } from '@/const/auth';

import { checkAuthMethod, getJWTPayload } from './utils';
import { checkAuthMethod } from './utils';

let enableClerkMock = false;
let enableNextAuthMock = false;
Expand All @@ -27,42 +26,6 @@ vi.mock('@/config/app', () => ({
getAppConfig: vi.fn(),
}));

describe('getJWTPayload', () => {
it('should parse JWT payload for non-HTTPS token', async () => {
const token = `${NON_HTTP_PREFIX}.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ`;
const payload = await getJWTPayload(token);
expect(payload).toEqual({
sub: '1234567890',
name: 'John Doe',
iat: 1516239022,
});
});

it('should verify and parse JWT payload for HTTPS token', async () => {
const token =
'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiMDAxMzYyYzMtNDhjNS00NjM1LWJkM2ItODM3YmZmZjU4ZmMwIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY4MDIyMjUsImV4cCI6MTAwMDAwMDAwMDE3MTY4MDIwMDB9.FF0FxsE8Cajs-_hv5GD0TNUDwvekAkI9l_LL_IOPdGQ';
const payload = await getJWTPayload(token);
expect(payload).toEqual({
accessCode: '',
apiKey: 'abc',
endpoint: 'abc',
exp: 10000000001716802000,
iat: 1716802225,
userId: '001362c3-48c5-4635-bd3b-837bfff58fc0',
});
});

it('should not verify success and parse JWT payload for dated token', async () => {
const token =
'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiYWY3M2JhODktZjFhMy00YjliLWEwM2QtZGViZmZlMzE4NmQxIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY3OTk5ODAsImV4cCI6MTcxNjgwMDA4MH0.8AGFsLcwyrQG82kVUYOGFXHIwihm2n16ctyArKW9100';
try {
await getJWTPayload(token);
} catch (e) {
expect(e).toEqual(new TypeError('"exp" claim timestamp check failed'));
}
});
});

describe('checkAuthMethod', () => {
beforeEach(() => {
vi.mocked(getAppConfig).mockReturnValue({
Expand Down
34 changes: 1 addition & 33 deletions src/app/(backend)/middleware/auth/utils.ts
Original file line number Diff line number Diff line change
@@ -1,42 +1,10 @@
import { type AuthObject } from '@clerk/backend';
import { importJWK, jwtVerify } from 'jose';

import { getAppConfig } from '@/config/app';
import {
JWTPayload,
JWT_SECRET_KEY,
NON_HTTP_PREFIX,
enableClerk,
enableNextAuth,
} from '@/const/auth';
import { enableClerk, enableNextAuth } from '@/const/auth';
import { AgentRuntimeError } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';

export const getJWTPayload = async (token: string): Promise<JWTPayload> => {
//如果是 HTTP 协议发起的请求,直接解析 token
// 这是一个非常 hack 的解决方案,未来要找更好的解决方案来处理这个问题
// refs: https:/lobehub/lobe-chat/pull/1238
if (token.startsWith(NON_HTTP_PREFIX)) {
const jwtParts = token.split('.');

const payload = jwtParts[1];

return JSON.parse(atob(payload));
}

const encoder = new TextEncoder();
const secretKey = await crypto.subtle.digest('SHA-256', encoder.encode(JWT_SECRET_KEY));

const jwkSecretKey = await importJWK(
{ k: Buffer.from(secretKey).toString('base64'), kty: 'oct' },
'HS256',
);

const { payload } = await jwtVerify(token, jwkSecretKey);

return payload as JWTPayload;
};

interface CheckAuthParams {
accessCode?: string;
apiKey?: string;
Expand Down
2 changes: 1 addition & 1 deletion src/app/(backend)/webapi/plugin/gateway/route.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk';
import { createGatewayOnEdgeRuntime } from '@lobehub/chat-plugins-gateway';

import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
import { getAppConfig } from '@/config/app';
import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED, enableNextAuth } from '@/const/auth';
import { LOBE_CHAT_TRACE_ID, TraceNameMap } from '@/const/trace';
import { AgentRuntimeError } from '@/libs/agent-runtime';
import { TraceClient } from '@/libs/traces';
import { ChatErrorType, ErrorType } from '@/types/fetch';
import { createErrorResponse } from '@/utils/errorResponse';
import { getJWTPayload } from '@/utils/server/jwt';
import { getTracePayload } from '@/utils/trace';

import { parserPluginSettings } from './settings';
Expand Down
2 changes: 1 addition & 1 deletion src/libs/trpc/middleware/jwtPayload.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import { TRPCError } from '@trpc/server';
import { beforeEach, describe, expect, it, vi } from 'vitest';

import * as utils from '@/app/(backend)/middleware/auth/utils';
import { createCallerFactory } from '@/libs/trpc';
import { trpc } from '@/libs/trpc/init';
import { AuthContext, createContextInner } from '@/server/context';
import * as utils from '@/utils/server/jwt';

import { jwtPayloadChecker } from './jwtPayload';

Expand Down
2 changes: 1 addition & 1 deletion src/libs/trpc/middleware/jwtPayload.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { TRPCError } from '@trpc/server';

import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
import { trpc } from '@/libs/trpc/init';
import { getJWTPayload } from '@/utils/server/jwt';

export const jwtPayloadChecker = trpc.middleware(async (opts) => {
const { ctx } = opts;
Expand Down
2 changes: 1 addition & 1 deletion src/libs/trpc/middleware/keyVaults.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { TRPCError } from '@trpc/server';

import { getJWTPayload } from '@/app/(backend)/middleware/auth/utils';
import { trpc } from '@/libs/trpc/init';
import { getJWTPayload } from '@/utils/server/jwt';

export const keyVaults = trpc.middleware(async (opts) => {
const { ctx } = opts;
Expand Down
62 changes: 62 additions & 0 deletions src/utils/server/jwt.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { describe, expect, it, vi } from 'vitest';

import { NON_HTTP_PREFIX } from '@/const/auth';

import { getJWTPayload } from './jwt';

let enableClerkMock = false;
let enableNextAuthMock = false;

vi.mock('@/const/auth', async (importOriginal) => {
const data = await importOriginal();

return {
...(data as any),
get enableClerk() {
return enableClerkMock;
},
get enableNextAuth() {
return enableNextAuthMock;
},
};
});

vi.mock('@/config/app', () => ({
getAppConfig: vi.fn(),
}));

describe('getJWTPayload', () => {
it('should parse JWT payload for non-HTTPS token', async () => {
const token = `${NON_HTTP_PREFIX}.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ`;
const payload = await getJWTPayload(token);
expect(payload).toEqual({
sub: '1234567890',
name: 'John Doe',
iat: 1516239022,
});
});

it('should verify and parse JWT payload for HTTPS token', async () => {
const token =
'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiMDAxMzYyYzMtNDhjNS00NjM1LWJkM2ItODM3YmZmZjU4ZmMwIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY4MDIyMjUsImV4cCI6MTAwMDAwMDAwMDE3MTY4MDIwMDB9.FF0FxsE8Cajs-_hv5GD0TNUDwvekAkI9l_LL_IOPdGQ';
const payload = await getJWTPayload(token);
expect(payload).toEqual({
accessCode: '',
apiKey: 'abc',
endpoint: 'abc',
exp: 10000000001716802000,
iat: 1716802225,
userId: '001362c3-48c5-4635-bd3b-837bfff58fc0',
});
});

it('should not verify success and parse JWT payload for dated token', async () => {
const token =
'eyJhbGciOiJIUzI1NiJ9.eyJhY2Nlc3NDb2RlIjoiIiwidXNlcklkIjoiYWY3M2JhODktZjFhMy00YjliLWEwM2QtZGViZmZlMzE4NmQxIiwiYXBpS2V5IjoiYWJjIiwiZW5kcG9pbnQiOiJhYmMiLCJpYXQiOjE3MTY3OTk5ODAsImV4cCI6MTcxNjgwMDA4MH0.8AGFsLcwyrQG82kVUYOGFXHIwihm2n16ctyArKW9100';
try {
await getJWTPayload(token);
} catch (e) {
expect(e).toEqual(new TypeError('"exp" claim timestamp check failed'));
}
});
});
32 changes: 32 additions & 0 deletions src/utils/server/jwt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { importJWK, jwtVerify } from 'jose';

import {
JWTPayload,
JWT_SECRET_KEY,
NON_HTTP_PREFIX,
} from '@/const/auth';

export const getJWTPayload = async (token: string): Promise<JWTPayload> => {
//如果是 HTTP 协议发起的请求,直接解析 token
// 这是一个非常 hack 的解决方案,未来要找更好的解决方案来处理这个问题
// refs: https:/lobehub/lobe-chat/pull/1238
if (token.startsWith(NON_HTTP_PREFIX)) {
const jwtParts = token.split('.');

const payload = jwtParts[1];

return JSON.parse(atob(payload));
}

const encoder = new TextEncoder();
const secretKey = await crypto.subtle.digest('SHA-256', encoder.encode(JWT_SECRET_KEY));

const jwkSecretKey = await importJWK(
{ k: Buffer.from(secretKey).toString('base64'), kty: 'oct' },
'HS256',
);

const { payload } = await jwtVerify(token, jwkSecretKey);

return payload as JWTPayload;
};

0 comments on commit b7258b9

Please sign in to comment.