Skip to content

Commit

Permalink
Convert to a class
Browse files Browse the repository at this point in the history
  • Loading branch information
lpsinger committed Jan 6, 2023
1 parent 1f6d8a7 commit 2c13f05
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 47 deletions.
66 changes: 30 additions & 36 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type {
} from '@aws-sdk/lib-dynamodb'
import type { NativeAttributeValue } from '@aws-sdk/util-dynamodb'

export interface dynamoDBAutoIncrementProps {
export interface DynamoDBAutoIncrementProps {
/** a DynamoDB document client instance */
doc: DynamoDBDocument

Expand Down Expand Up @@ -59,78 +59,76 @@ export interface dynamoDBAutoIncrementProps {
* })
* ```
*/
export function dynamoDBAutoIncrement({
doc,
counterTableName,
counterTableKey,
counterTableAttributeName,
tableName,
tableAttributeName,
initialValue,
dangerously,
}: dynamoDBAutoIncrementProps) {
async function getLast(): Promise<number | undefined> {
export class DynamoDBAutoIncrement {
constructor(readonly props: DynamoDBAutoIncrementProps) {}

async getLast(): Promise<number | undefined> {
return (
(
await doc.get({
AttributesToGet: [counterTableAttributeName],
Key: counterTableKey,
TableName: counterTableName,
await this.props.doc.get({
AttributesToGet: [this.props.counterTableAttributeName],
Key: this.props.counterTableKey,
TableName: this.props.counterTableName,
})
).Item?.[counterTableAttributeName] ?? undefined
).Item?.[this.props.counterTableAttributeName] ?? undefined
)
}

async function put(item: Record<string, NativeAttributeValue>) {
async put(item: Record<string, NativeAttributeValue>) {
for (;;) {
const counter = await getLast()
const counter = await this.getLast()

let nextCounter: number
let Update: UpdateCommandInput & { UpdateExpression: string }

if (counter === undefined) {
nextCounter = initialValue
nextCounter = this.props.initialValue
Update = {
ConditionExpression: 'attribute_not_exists(#counter)',
ExpressionAttributeNames: {
'#counter': counterTableAttributeName,
'#counter': this.props.counterTableAttributeName,
},
ExpressionAttributeValues: {
':nextCounter': nextCounter,
},
Key: counterTableKey,
TableName: counterTableName,
Key: this.props.counterTableKey,
TableName: this.props.counterTableName,
UpdateExpression: 'SET #counter = :nextCounter',
}
} else {
nextCounter = counter + 1
Update = {
ConditionExpression: '#counter = :counter',
ExpressionAttributeNames: {
'#counter': counterTableAttributeName,
'#counter': this.props.counterTableAttributeName,
},
ExpressionAttributeValues: {
':counter': counter,
':nextCounter': nextCounter,
},
Key: counterTableKey,
TableName: counterTableName,
Key: this.props.counterTableKey,
TableName: this.props.counterTableName,
UpdateExpression: 'SET #counter = :nextCounter',
}
}

const Put: PutCommandInput = {
ConditionExpression: 'attribute_not_exists(#counter)',
ExpressionAttributeNames: { '#counter': tableAttributeName },
Item: { [tableAttributeName]: nextCounter, ...item },
TableName: tableName,
ExpressionAttributeNames: { '#counter': this.props.tableAttributeName },
Item: { [this.props.tableAttributeName]: nextCounter, ...item },
TableName: this.props.tableName,
}

if (dangerously) {
await Promise.all([doc.update(Update), doc.put(Put)])
if (this.props.dangerously) {
await Promise.all([
this.props.doc.update(Update),
this.props.doc.put(Put),
])
} else {
try {
await doc.transactWrite({ TransactItems: [{ Update }, { Put }] })
await this.props.doc.transactWrite({
TransactItems: [{ Update }, { Put }],
})
} catch (e) {
if (e instanceof TransactionCanceledException) {
continue
Expand All @@ -143,8 +141,4 @@ export function dynamoDBAutoIncrement({
return nextCounter
}
}

put.getLast = getLast

return put
}
23 changes: 12 additions & 11 deletions src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ import {
DynamoDB,
} from '@aws-sdk/client-dynamodb'
import { DynamoDBDocument } from '@aws-sdk/lib-dynamodb'
import type { dynamoDBAutoIncrementProps } from '.'
import { dynamoDBAutoIncrement } from '.'
import type { DynamoDBAutoIncrementProps } from '.'
import { DynamoDBAutoIncrement } from '.'

let doc: DynamoDBDocument
let autoincrement: ReturnType<typeof dynamoDBAutoIncrement>
let autoincrementDangerously: ReturnType<typeof dynamoDBAutoIncrement>
let autoincrement: DynamoDBAutoIncrement
let autoincrementDangerously: DynamoDBAutoIncrement
const N = 20

beforeAll(async () => {
Expand All @@ -19,7 +19,7 @@ beforeAll(async () => {
region: '-',
})
)
const options: dynamoDBAutoIncrementProps = {
const options: DynamoDBAutoIncrementProps = {
doc,
counterTableName: 'autoincrement',
counterTableKey: { tableName: 'widgets' },
Expand All @@ -28,8 +28,8 @@ beforeAll(async () => {
tableAttributeName: 'widgetID',
initialValue: 1,
}
autoincrement = dynamoDBAutoIncrement(options)
autoincrementDangerously = dynamoDBAutoIncrement({
autoincrement = new DynamoDBAutoIncrement(options)
autoincrementDangerously = new DynamoDBAutoIncrement({
...options,
dangerously: true,
})
Expand Down Expand Up @@ -71,7 +71,7 @@ describe('dynamoDBAutoIncrement', () => {
nextID = lastID + 1
}

const result = await autoincrement({ widgetName: 'runcible spoon' })
const result = await autoincrement.put({ widgetName: 'runcible spoon' })
expect(result).toEqual(nextID)

expect(await autoincrement.getLast()).toEqual(nextID)
Expand All @@ -96,7 +96,7 @@ describe('dynamoDBAutoIncrement', () => {

test('correctly handles a large number of parallel puts', async () => {
const ids = Array.from(Array(N).keys()).map((i) => i + 1)
const result = await Promise.all(ids.map(() => autoincrement({})))
const result = await Promise.all(ids.map(() => autoincrement.put({})))
expect(result.sort()).toEqual(ids.sort())
})
})
Expand All @@ -106,15 +106,16 @@ describe('dynamoDBAutoIncrement dangerously', () => {
const ids = Array.from(Array(N).keys()).map((i) => i + 1)
const result: number[] = []
for (const item of ids) {
result.push(await autoincrementDangerously({ widgetName: item }))
result.push(await autoincrementDangerously.put({ widgetName: item }))
}
expect(result.sort()).toEqual(ids.sort())
})

test('fails on a large number of parallel puts', async () => {
const ids = Array.from(Array(N).keys()).map((i) => i + 1)
await expect(
async () => await Promise.all(ids.map(() => autoincrementDangerously({})))
async () =>
await Promise.all(ids.map(() => autoincrementDangerously.put({})))
).rejects.toThrow(ConditionalCheckFailedException)
})
})

0 comments on commit 2c13f05

Please sign in to comment.