diff --git a/integration-tests/api/__tests__/admin/product.js b/integration-tests/api/__tests__/admin/product.js index 6b7bdea95b584..4761699e1d6cc 100644 --- a/integration-tests/api/__tests__/admin/product.js +++ b/integration-tests/api/__tests__/admin/product.js @@ -1866,25 +1866,28 @@ describe("/admin/products", () => { expect(res.status).toEqual(200) - expect(insertedVariant.prices).toEqual([ - expect.objectContaining({ - currency_code: "usd", - amount: 100, - min_quantity: null, - max_quantity: null, - variant_id: insertedVariant.id, - region_id: null, - }), - expect.objectContaining({ - currency_code: "usd", - amount: 200, - min_quantity: null, - max_quantity: null, - price_list_id: null, - variant_id: insertedVariant.id, - region_id: "test-region", - }), - ]) + expect(insertedVariant.prices).toHaveLength(2) + expect(insertedVariant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + currency_code: "usd", + amount: 100, + min_quantity: null, + max_quantity: null, + variant_id: insertedVariant.id, + region_id: null, + }), + expect.objectContaining({ + currency_code: "usd", + amount: 200, + min_quantity: null, + max_quantity: null, + price_list_id: null, + variant_id: insertedVariant.id, + region_id: "test-region", + }), + ]) + ) }) }) diff --git a/integration-tests/api/__tests__/store/cart/cart.js b/integration-tests/api/__tests__/store/cart/cart.js index 09696de85f390..061befb0efdea 100644 --- a/integration-tests/api/__tests__/store/cart/cart.js +++ b/integration-tests/api/__tests__/store/cart/cart.js @@ -66,7 +66,9 @@ describe("/store/carts", () => { tax_rate: 0, }) await manager.query( - `UPDATE "country" SET region_id='region' WHERE iso_2 = 'us'` + `UPDATE "country" + SET region_id='region' + WHERE iso_2 = 'us'` ) }) @@ -88,9 +90,12 @@ describe("/store/carts", () => { const api = useApi() await dbConnection.manager.query( - `UPDATE "country" SET region_id=null WHERE iso_2 = 'us'` + `UPDATE "country" + SET region_id=null + WHERE iso_2 = 'us'` ) - await dbConnection.manager.query(`DELETE from region`) + await dbConnection.manager.query(`DELETE + from region`) try { await api.post("/store/carts") @@ -1679,7 +1684,9 @@ describe("/store/carts", () => { const manager = dbConnection.manager const api = useApi() await manager.query( - `UPDATE "cart" SET completed_at=current_timestamp WHERE id = 'test-cart-2'` + `UPDATE "cart" + SET completed_at=current_timestamp + WHERE id = 'test-cart-2'` ) try { await api.post(`/store/carts/test-cart-2/complete-cart`) @@ -1982,7 +1989,8 @@ describe("/store/carts", () => { try { await cartSeeder(dbConnection) await dbConnection.manager.query( - `INSERT INTO "cart_discounts" (cart_id, discount_id) VALUES ('test-cart', 'free-shipping')` + `INSERT INTO "cart_discounts" (cart_id, discount_id) + VALUES ('test-cart', 'free-shipping')` ) } catch (err) { console.log(err) diff --git a/integration-tests/plugins/__tests__/medusa-plugin-sendgrid/__snapshots__/index.js.snap b/integration-tests/plugins/__tests__/medusa-plugin-sendgrid/__snapshots__/index.js.snap index cd593db1b8bbb..a34bbefd16f0e 100644 --- a/integration-tests/plugins/__tests__/medusa-plugin-sendgrid/__snapshots__/index.js.snap +++ b/integration-tests/plugins/__tests__/medusa-plugin-sendgrid/__snapshots__/index.js.snap @@ -432,7 +432,6 @@ Object { "claim_order_id": null, "created_at": Any, "description": "", - "discount_total": 0, "fulfilled_quantity": 2, "has_shipping": null, "id": Any, @@ -440,13 +439,10 @@ Object { "is_return": false, "metadata": null, "order_id": Any, - "original_tax_total": 400, - "original_total": 2400, "quantity": 2, "returned_quantity": 1, "shipped_quantity": 2, "should_merge": true, - "subtotal": 2000, "swap_id": null, "tax_lines": Array [ Object { @@ -460,10 +456,8 @@ Object { "updated_at": Any, }, ], - "tax_total": 400, "thumbnail": "", "title": "Intelligent Plastic Chips", - "total": 2400, "unit_price": 1000, "updated_at": Any, "variant": Object { @@ -774,7 +768,6 @@ Object { "claim_order_id": null, "created_at": Any, "description": "", - "discount_total": 0, "fulfilled_quantity": null, "has_shipping": null, "id": "test-item", @@ -782,14 +775,11 @@ Object { "is_return": false, "metadata": null, "order_id": Any, - "original_tax_total": 400, - "original_total": 2400, "price": "10.00 USD", "quantity": 2, "returned_quantity": null, "shipped_quantity": null, "should_merge": true, - "subtotal": 2000, "swap_id": null, "tax_lines": Array [ Object { @@ -803,10 +793,8 @@ Object { "updated_at": Any, }, ], - "tax_total": 400, "thumbnail": null, "title": "Intelligent Plastic Chips", - "total": 2400, "unit_price": 1000, "updated_at": Any, "variant": Object { @@ -1003,7 +991,6 @@ Object { "claim_order_id": null, "created_at": Any, "description": "", - "discount_total": 0, "discounted_price": "12.00 USD", "fulfilled_quantity": 2, "has_shipping": null, @@ -1012,14 +999,11 @@ Object { "is_return": false, "metadata": null, "order_id": Any, - "original_tax_total": 400, - "original_total": 2400, "price": "12.00 USD", "quantity": 2, "returned_quantity": null, "shipped_quantity": 2, "should_merge": true, - "subtotal": 2000, "swap_id": null, "tax_lines": Array [ Object { @@ -1033,10 +1017,8 @@ Object { "updated_at": Any, }, ], - "tax_total": 400, "thumbnail": null, "title": "Intelligent Plastic Chips", - "total": 2400, "totals": Object { "discount_total": 0, "original_tax_total": 400, @@ -1276,7 +1258,6 @@ Object { "claim_order_id": null, "created_at": Any, "description": "", - "discount_total": 0, "fulfilled_quantity": 2, "has_shipping": null, "id": "test-item", @@ -1284,13 +1265,10 @@ Object { "is_return": false, "metadata": null, "order_id": Any, - "original_tax_total": 400, - "original_total": 2400, "quantity": 2, "returned_quantity": null, "shipped_quantity": 2, "should_merge": true, - "subtotal": 2000, "swap_id": null, "tax_lines": Array [ Object { @@ -1304,10 +1282,8 @@ Object { "updated_at": Any, }, ], - "tax_total": 400, "thumbnail": "", "title": "Intelligent Plastic Chips", - "total": 2400, "unit_price": 1000, "updated_at": Any, "variant": Object { @@ -1603,7 +1579,6 @@ Object { "claim_order_id": null, "created_at": Any, "description": "", - "discount_total": 0, "fulfilled_quantity": 2, "has_shipping": null, "id": Any, @@ -1611,13 +1586,10 @@ Object { "is_return": false, "metadata": null, "order_id": Any, - "original_tax_total": 400, - "original_total": 2400, "quantity": 2, "returned_quantity": null, "shipped_quantity": 2, "should_merge": true, - "subtotal": 2000, "swap_id": null, "tax_lines": Array [ Object { @@ -1631,10 +1603,8 @@ Object { "updated_at": Any, }, ], - "tax_total": 400, "thumbnail": "", "title": "Intelligent Plastic Chips", - "total": 2400, "unit_price": 1000, "updated_at": Any, "variant": Object { diff --git a/packages/medusa/src/api/routes/admin/draft-orders/get-draft-order.ts b/packages/medusa/src/api/routes/admin/draft-orders/get-draft-order.ts index c6e8d8bd001bd..27f8a3e5835a4 100644 --- a/packages/medusa/src/api/routes/admin/draft-orders/get-draft-order.ts +++ b/packages/medusa/src/api/routes/admin/draft-orders/get-draft-order.ts @@ -71,9 +71,15 @@ export default async (req, res) => { relations: defaultAdminDraftOrdersRelations, }) - draftOrder.cart = await cartService.retrieveWithTotals(draftOrder.cart_id, { - relations: defaultAdminDraftOrdersCartRelations, - }) + draftOrder.cart = await cartService.retrieveWithTotals( + draftOrder.cart_id, + { + relations: defaultAdminDraftOrdersCartRelations, + }, + { + force_taxes: true, + } + ) res.json({ draft_order: draftOrder }) } diff --git a/packages/medusa/src/api/routes/admin/draft-orders/register-payment.ts b/packages/medusa/src/api/routes/admin/draft-orders/register-payment.ts index 53fa81d9ec6a1..a8a33fd5285e6 100644 --- a/packages/medusa/src/api/routes/admin/draft-orders/register-payment.ts +++ b/packages/medusa/src/api/routes/admin/draft-orders/register-payment.ts @@ -83,16 +83,7 @@ export default async (req, res) => { const cart = await cartService .withTransaction(manager) - .retrieve(draftOrder.cart_id, { - select: ["total"], - relations: [ - "discounts", - "discounts.rule", - "shipping_methods", - "region", - "items", - ], - }) + .retrieveWithTotals(draftOrder.cart_id) await paymentProviderService .withTransaction(manager) diff --git a/packages/medusa/src/api/routes/admin/orders/__tests__/get-order.js b/packages/medusa/src/api/routes/admin/orders/__tests__/get-order.js index 76b93b80a37f9..9bf02e2413204 100644 --- a/packages/medusa/src/api/routes/admin/orders/__tests__/get-order.js +++ b/packages/medusa/src/api/routes/admin/orders/__tests__/get-order.js @@ -26,11 +26,23 @@ describe("GET /admin/orders", () => { }) it("calls orderService retrieve", () => { - expect(OrderServiceMock.retrieve).toHaveBeenCalledTimes(1) - expect(OrderServiceMock.retrieve).toHaveBeenCalledWith( + expect(OrderServiceMock.retrieveWithTotals).toHaveBeenCalledTimes(1) + expect(OrderServiceMock.retrieveWithTotals).toHaveBeenCalledWith( IdMap.getId("test-order"), { - select: defaultAdminOrdersFields, + select: defaultAdminOrdersFields.filter((field) => { + return ![ + "shipping_total", + "discount_total", + "tax_total", + "refunded_total", + "total", + "subtotal", + "refundable_amount", + "gift_card_total", + "gift_card_tax_total", + ].includes(field) + }), relations: defaultAdminOrdersRelations, } ) diff --git a/packages/medusa/src/api/routes/admin/orders/get-order.ts b/packages/medusa/src/api/routes/admin/orders/get-order.ts index cb51c74847e81..86c8133311e34 100644 --- a/packages/medusa/src/api/routes/admin/orders/get-order.ts +++ b/packages/medusa/src/api/routes/admin/orders/get-order.ts @@ -56,7 +56,7 @@ export default async (req, res) => { const orderService: OrderService = req.scope.resolve("orderService") - const order = await orderService.retrieve(id, req.retrieveConfig) + const order = await orderService.retrieveWithTotals(id, req.retrieveConfig) res.json({ order }) } diff --git a/packages/medusa/src/api/routes/admin/orders/index.ts b/packages/medusa/src/api/routes/admin/orders/index.ts index 7805e0fbd5409..a91c1df6d41ad 100644 --- a/packages/medusa/src/api/routes/admin/orders/index.ts +++ b/packages/medusa/src/api/routes/admin/orders/index.ts @@ -43,7 +43,19 @@ export default (app, featureFlagRouter: FlagRouter) => { "/:id", transformQuery(FindParams, { defaultRelations: relations, - defaultFields: defaultAdminOrdersFields, + defaultFields: defaultAdminOrdersFields.filter((field) => { + return ![ + "shipping_total", + "discount_total", + "tax_total", + "refunded_total", + "total", + "subtotal", + "refundable_amount", + "gift_card_total", + "gift_card_tax_total", + ].includes(field) + }), allowedFields: allowedAdminOrdersFields, allowedRelations: allowedAdminOrdersRelations, isList: false, diff --git a/packages/medusa/src/services/__fixtures__/new-totals.ts b/packages/medusa/src/services/__fixtures__/new-totals.ts new file mode 100644 index 0000000000000..0784e218c88c5 --- /dev/null +++ b/packages/medusa/src/services/__fixtures__/new-totals.ts @@ -0,0 +1,72 @@ +import { asClass, asValue, createContainer } from "awilix" +import { IdMap, MockManager } from "medusa-test-utils" +import { taxProviderServiceMock } from "../__mocks__/tax-provider" +import { FlagRouter } from "../../utils/flag-router" +import { + GiftCard, + LineItem, + LineItemTaxLine, + ShippingMethod, + ShippingMethodTaxLine, +} from "../../models" +import TaxCalculationStrategy from "../../strategies/tax-calculation" + +export const defaultContainerMock = createContainer() +defaultContainerMock.register("manager", asValue(MockManager)) +defaultContainerMock.register( + "taxProviderService", + asValue(taxProviderServiceMock) +) +defaultContainerMock.register("featureFlagRouter", asValue(new FlagRouter({}))) +defaultContainerMock.register( + "taxCalculationStrategy", + asClass(TaxCalculationStrategy) +) + +export const lineItems = [ + { + id: IdMap.getId("item_1_with_tax_lines"), + cart_id: "", + order_id: "", + swap_id: "", + claim_order_id: "", + title: "title", + description: "description", + unit_price: 1000, + quantity: 1, + tax_lines: [ + { + id: IdMap.getId("item_1_with_tax_lines_tax_line_1"), + item_id: IdMap.getId("item_1_with_tax_lines"), + rate: 20, + name: "default", + code: "default", + }, + ] as LineItemTaxLine[], + }, +] as LineItem[] + +export const shippingMethods = [ + { + id: IdMap.getId("sm_1_with_tax_lines"), + price: 1000, + tax_lines: [ + { + id: IdMap.getId("sm_1_with_tax_lines_tax_line_1"), + shipping_method_id: IdMap.getId("sm_1_with_tax_lines"), + rate: 20, + name: "default", + code: "default", + }, + ] as ShippingMethodTaxLine[], + }, +] as ShippingMethod[] + +export const giftCards = [ + { + id: IdMap.getId("gift_card_1"), + code: "CODE", + value: 10000, + balance: 10000, + }, +] as GiftCard[] diff --git a/packages/medusa/src/services/__mocks__/new-totals.js b/packages/medusa/src/services/__mocks__/new-totals.js new file mode 100644 index 0000000000000..bae82eaf1cb12 --- /dev/null +++ b/packages/medusa/src/services/__mocks__/new-totals.js @@ -0,0 +1,25 @@ +export const newTotalsServiceMock = { + withTransaction: function () { + return this + }, + getLineItemTotals: jest.fn().mockImplementation(() => { + return Promise.resolve({}) + }), + getGiftCardTotals: jest.fn().mockImplementation((order, lineItems) => { + return Promise.resolve({}) + }), + getGiftCardTransactionsTotals: jest + .fn() + .mockImplementation((order, lineItems) => { + return Promise.resolve({}) + }), + getShippingMethodTotals: jest.fn().mockImplementation((order, lineItems) => { + return Promise.resolve({}) + }), +} + +const mock = jest.fn().mockImplementation(() => { + return newTotalsServiceMock +}) + +export default mock diff --git a/packages/medusa/src/services/__mocks__/order.js b/packages/medusa/src/services/__mocks__/order.js index 804c6864fa05c..3cb70ecd85b60 100644 --- a/packages/medusa/src/services/__mocks__/order.js +++ b/packages/medusa/src/services/__mocks__/order.js @@ -170,6 +170,15 @@ export const OrderServiceMock = { } return Promise.resolve(undefined) }), + retrieveWithTotals: jest.fn().mockImplementation((orderId) => { + if (orderId === IdMap.getId("test-order")) { + return Promise.resolve(orders.testOrder) + } + if (orderId === IdMap.getId("processed-order")) { + return Promise.resolve(orders.processedOrder) + } + return Promise.resolve(undefined) + }), retrieveByCartId: jest.fn().mockImplementation((cartId) => { return Promise.resolve({ id: IdMap.getId("test-order") }) }), diff --git a/packages/medusa/src/services/__mocks__/tax-provider.js b/packages/medusa/src/services/__mocks__/tax-provider.js index 2d46ba22e3fb9..c9581ff63c58c 100644 --- a/packages/medusa/src/services/__mocks__/tax-provider.js +++ b/packages/medusa/src/services/__mocks__/tax-provider.js @@ -8,6 +8,15 @@ export const taxProviderServiceMock = { clearLineItemsTaxLines: jest.fn().mockImplementation((_) => { return Promise.resolve() }), + getTaxLines: jest.fn().mockImplementation((_) => { + return Promise.resolve([]) + }), + getTaxLinesMap: jest.fn().mockImplementation((_) => { + return Promise.resolve({ + lineItemsTaxLines: {}, + shippingMethodsTaxLines: {}, + }) + }), } const mock = jest.fn().mockImplementation(() => { diff --git a/packages/medusa/src/services/__tests__/cart.js b/packages/medusa/src/services/__tests__/cart.js index fa00b3e7309b2..c84e76d391cff 100644 --- a/packages/medusa/src/services/__tests__/cart.js +++ b/packages/medusa/src/services/__tests__/cart.js @@ -6,6 +6,7 @@ import { InventoryServiceMock } from "../__mocks__/inventory" import { LineItemAdjustmentServiceMock } from "../__mocks__/line-item-adjustment" import { FlagRouter } from "../../utils/flag-router" import { taxProviderServiceMock } from "../__mocks__/tax-provider" +import { newTotalsServiceMock } from "../__mocks__/new-totals" const eventBusService = { emit: jest.fn(), @@ -59,6 +60,7 @@ describe("CartService", () => { totalsService, cartRepository, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) result = await cartService.retrieve(IdMap.getId("emptyCart")) @@ -93,6 +95,7 @@ describe("CartService", () => { cartRepository, eventBusService, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -180,6 +183,7 @@ describe("CartService", () => { totalsService, cartRepository, customerService, + newTotalsService: newTotalsServiceMock, regionService, eventBusService, taxProviderService: taxProviderServiceMock, @@ -351,6 +355,7 @@ describe("CartService", () => { cartRepository, lineItemService, lineItemRepository: MockRepository(), + newTotalsService: newTotalsServiceMock, eventBusService, shippingOptionService, inventoryService, @@ -585,6 +590,7 @@ describe("CartService", () => { cartRepository, lineItemService, lineItemRepository: MockRepository(), + newTotalsService: newTotalsServiceMock, eventBusService, shippingOptionService, inventoryService, @@ -683,6 +689,7 @@ describe("CartService", () => { cartRepository, lineItemService, lineItemRepository: MockRepository(), + newTotalsService: newTotalsServiceMock, shippingOptionService, eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, @@ -794,6 +801,7 @@ describe("CartService", () => { totalsService, eventBusService, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -880,6 +888,7 @@ describe("CartService", () => { cartRepository, lineItemService, eventBusService, + newTotalsService: newTotalsServiceMock, inventoryService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, taxProviderService: taxProviderServiceMock, @@ -965,6 +974,7 @@ describe("CartService", () => { cartRepository, eventBusService, customerService, + newTotalsService: newTotalsServiceMock, taxProviderService: taxProviderServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -1041,6 +1051,7 @@ describe("CartService", () => { cartRepository, addressRepository, eventBusService, + newTotalsService: newTotalsServiceMock, taxProviderService: taxProviderServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -1101,6 +1112,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + newTotalsService: newTotalsServiceMock, taxProviderService: taxProviderServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -1247,6 +1259,7 @@ describe("CartService", () => { addressRepository, totalsService, cartRepository, + newTotalsService: newTotalsServiceMock, regionService, lineItemService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, @@ -1343,6 +1356,7 @@ describe("CartService", () => { cartRepository, eventBusService, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -1467,6 +1481,7 @@ describe("CartService", () => { paymentProviderService, eventBusService, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -1658,6 +1673,7 @@ describe("CartService", () => { lineItemService, eventBusService, customShippingOptionService, + newTotalsService: newTotalsServiceMock, taxProviderService: taxProviderServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -2015,6 +2031,7 @@ describe("CartService", () => { eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) @@ -2289,6 +2306,7 @@ describe("CartService", () => { cartRepository, eventBusService, taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, featureFlagRouter: new FlagRouter({}), }) diff --git a/packages/medusa/src/services/__tests__/new-totals.ts b/packages/medusa/src/services/__tests__/new-totals.ts new file mode 100644 index 0000000000000..7826708b34039 --- /dev/null +++ b/packages/medusa/src/services/__tests__/new-totals.ts @@ -0,0 +1,1093 @@ +import { asClass, asValue, createContainer } from "awilix" +import { + defaultContainerMock, + giftCards, + lineItems, + shippingMethods, +} from "../__fixtures__/new-totals" +import { NewTotalsService } from "../index" +import { TaxCalculationContext } from "../../interfaces" +import { taxProviderServiceMock } from "../__mocks__/tax-provider" +import { + Discount, + DiscountRuleType, + LineItem, + Region, + ShippingMethod, +} from "../../models" +import { FlagRouter } from "../../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing" + +describe("New totals service", () => { + describe("Without [MEDUSA_FF_TAX_INCLUSIVE_PRICING]", () => { + describe("getLineItemTotals", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register( + "taxProviderService", + asValue({ + ...taxProviderServiceMock, + getTaxLinesMap: jest + .fn() + .mockImplementation(async (items: LineItem[]) => { + const result = { + lineItemsTaxLines: {}, + } + + for (const item of items) { + result.lineItemsTaxLines[item.id] = [ + { + item_id: item.id, + name: "default", + code: "default", + rate: 30, + }, + ] + } + + return result + }), + }) + ) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should use the items tax lines to compute the totals", async () => { + const testItem = lineItems[0] + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: true, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000, taxes 20% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 1000, + total: 1200, + original_total: 1200, + discount_total: 0, + original_tax_total: 200, + tax_total: 200, + tax_lines: expect.arrayContaining(testItem.tax_lines), + }) + ) + }) + + it("should fetch the items tax lines to compute the totals", async () => { + const testItem = { ...lineItems[0] } as LineItem + testItem.tax_lines = [] + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: true, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledTimes(1) + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledWith( + [testItem], + calculationContext + ) + + // unit_price: 1000, taxes 30% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 1000, + total: 1300, + original_total: 1300, + discount_total: 0, + original_tax_total: 300, + tax_total: 300, + tax_lines: expect.arrayContaining([ + expect.objectContaining({ + name: "default", + code: "default", + rate: 30, + }), + ]), + }) + ) + }) + + it("should not use tax lines when includeTax is not true to compute the totals", async () => { + const testItem = { ...lineItems[0] } as LineItem + testItem.tax_lines = [] + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: false, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000 + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 1000, + total: 1000, + original_total: 1000, + discount_total: 0, + original_tax_total: 0, + tax_total: 0, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + + it("should use the provided tax rate to compute the totals", async () => { + const testItem = lineItems[0] + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + taxRate: 20, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000, taxes 20% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 1000, + total: 1200, + original_total: 1200, + discount_total: 0, + original_tax_total: 200, + tax_total: 200, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + }) + + describe("getShippingMethodTotals", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register( + "taxProviderService", + asValue({ + ...taxProviderServiceMock, + getTaxLinesMap: jest + .fn() + .mockImplementation( + async ( + items: LineItem[], + calculationContext: TaxCalculationContext + ) => { + const result = { + shippingMethodsTaxLines: {}, + } + + for (const method of calculationContext.shipping_methods) { + result.shippingMethodsTaxLines[method.id] = [ + { + shipping_method_id: method.id, + name: "default", + code: "default", + rate: 30, + }, + ] + } + + return result + } + ), + }) + ) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should use the shipping method tax lines to compute the totals", async () => { + const testShippingMethod = shippingMethods[0] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // price: 1000, taxes: 20% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 1000, + total: 1200, + original_total: 1200, + original_tax_total: 200, + tax_total: 200, + tax_lines: expect.arrayContaining(testShippingMethod.tax_lines), + }) + ) + }) + + it("should fetch the shipping method tax lines to compute the totals", async () => { + const testShippingMethod = { ...shippingMethods[0] } as ShippingMethod + testShippingMethod.tax_lines = [] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledTimes(1) + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledWith( + [], + calculationContext + ) + + // price: 1000, taxes 30% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 1000, + total: 1300, + original_total: 1300, + original_tax_total: 300, + tax_total: 300, + tax_lines: expect.arrayContaining([ + expect.objectContaining({ + name: "default", + code: "default", + rate: 30, + }), + ]), + }) + ) + }) + + it("should not use tax lines when includeTax is not true to compute the totals", async () => { + const testShippingMethod = { ...shippingMethods[0] } as ShippingMethod + testShippingMethod.tax_lines = [] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: false, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // price: 1000 + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 1000, + total: 1000, + original_total: 1000, + original_tax_total: 0, + tax_total: 0, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + + it("should use the provided tax rate to compute the totals", async () => { + const testShippingMethod = shippingMethods[0] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + taxRate: 20, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000, taxes 20% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 1000, + total: 1000, // Legacy does not include the taxes + original_total: 1000, // Legacy does not include the taxes + original_tax_total: 200, + tax_total: 200, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + + it("should compute a total to 0 if a free shipping discount is present", async () => { + const testShippingMethod = shippingMethods[0] + + const discounts = [ + { + rule: { + type: DiscountRuleType.FREE_SHIPPING, + }, + }, + ] as Discount[] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + discounts, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000, taxes 20% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 0, + total: 0, + original_total: 1200, + original_tax_total: 200, + tax_total: 0, + tax_lines: expect.arrayContaining(testShippingMethod.tax_lines), + }) + ) + }) + }) + + describe("getLineItemRefund", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should compute the line item refundable amount", () => { + const testItem = lineItems[0] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const refundAmount = newTotalsService.getLineItemRefund(testItem, { + calculationContext, + }) + + // unit_price: 1000, taxes: 20% + expect(refundAmount).toEqual(1200) + }) + + it("should compute the line item refundable amount using the taxRate", () => { + const testItem = lineItems[0] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const refundAmount = newTotalsService.getLineItemRefund(testItem, { + taxRate: 30, + calculationContext, + }) + + // unit_price: 1000, taxes: 30% + expect(refundAmount).toEqual(1300) + }) + }) + + describe("getGiftCardTotals", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should compute the gift cards totals amount in non taxable region", async () => { + const maxAmount = 1000 + + const testGiftCard = giftCards[0] + + const region = { + gift_cards_taxable: false, + } as Region + + const gitCardTotals = await newTotalsService.getGiftCardTotals( + maxAmount, + { + giftCards: [testGiftCard], + region, + } + ) + + expect(gitCardTotals).toEqual( + expect.objectContaining({ + total: 1000, + tax_total: 0, + }) + ) + }) + + it("should compute the gift cards totals amount in a taxable region", async () => { + const maxAmount = 1000 + + const testGiftCard = giftCards[0] + + const region = { + gift_cards_taxable: true, + tax_rate: 20, + } as Region + + const gitCardTotals = await newTotalsService.getGiftCardTotals( + maxAmount, + { + giftCards: [testGiftCard], + region, + } + ) + + expect(gitCardTotals).toEqual( + expect.objectContaining({ + total: 1000, + tax_total: 200, + }) + ) + }) + + it("should compute the gift cards totals amount in non taxable region using gift card transactions", async () => { + const maxAmount = 1000 + + const giftCardTransactions = [ + { + tax_rate: 20, + is_taxable: false, + amount: 1000, + }, + ] + + const region = { + gift_cards_taxable: false, + } as Region + + const gitCardTotals = await newTotalsService.getGiftCardTotals( + maxAmount, + { + giftCardTransactions: giftCardTransactions, + region, + } + ) + + expect(gitCardTotals).toEqual( + expect.objectContaining({ + total: 1000, + tax_total: 200, + }) + ) + }) + + it("should compute the gift cards totals amount in a taxable region using gift card transactions", async () => { + const maxAmount = 1000 + + const giftCardTransactions = [ + { + tax_rate: 20, + is_taxable: null, + amount: 1000, + }, + ] + + const region = { + gift_cards_taxable: true, + tax_rate: 30, + } as Region + + const gitCardTotals = await newTotalsService.getGiftCardTotals( + maxAmount, + { + giftCardTransactions: giftCardTransactions, + region, + } + ) + + expect(gitCardTotals).toEqual( + expect.objectContaining({ + total: 1000, + tax_total: 300, + }) + ) + }) + }) + }) + + describe("With [MEDUSA_FF_TAX_INCLUSIVE_PRICING]", () => { + describe("getLineItemTotals", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register( + "featureFlagRouter", + asValue( + new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }) + ) + ) + container.register( + "taxProviderService", + asValue({ + ...taxProviderServiceMock, + getTaxLinesMap: jest + .fn() + .mockImplementation(async (items: LineItem[]) => { + const result = { + lineItemsTaxLines: {}, + } + + for (const item of items) { + result.lineItemsTaxLines[item.id] = [ + { + item_id: item.id, + name: "default", + code: "default", + rate: 30, + }, + ] + } + + return result + }), + }) + ) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should use the items tax lines to compute the totals", async () => { + const testItem = { ...lineItems[0] } as LineItem + testItem.includes_tax = true + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: true, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000 including taxes, taxes 20% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 833, + total: 1000, + original_total: 1000, + discount_total: 0, + original_tax_total: 167, + tax_total: 167, + tax_lines: expect.arrayContaining(testItem.tax_lines), + }) + ) + }) + + it("should fetch the tax lines to compute the totals", async () => { + const testItem = { ...lineItems[0] } as LineItem + testItem.tax_lines = [] + testItem.includes_tax = true + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: true, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledTimes(1) + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledWith( + [testItem], + calculationContext + ) + + // unit_price: 1000 including taxes, taxes 30% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 769, + total: 1000, + original_total: 1000, + discount_total: 0, + original_tax_total: 231, + tax_total: 231, + tax_lines: expect.arrayContaining([ + expect.objectContaining({ + name: "default", + code: "default", + rate: 30, + }), + ]), + }) + ) + }) + + it("should not use tax lines when includeTax is not true to compute the totals", async () => { + const testItem = { ...lineItems[0] } as LineItem + testItem.includes_tax = true + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + includeTax: false, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000 including taxes + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 833, + total: 1000, + original_total: 1000, + discount_total: 0, + original_tax_total: 167, + tax_total: 167, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + + it("should use the provided tax rate to compute the totals", async () => { + const testItem = lineItems[0] + testItem.includes_tax = true + + const calculationContext = { + allocation_map: { + [testItem.id]: {}, + }, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const itemsTotalsMap = await newTotalsService.getLineItemTotals( + [testItem], + { + taxRate: 20, + calculationContext, + } + ) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000 including taxes, taxes 20% + expect(itemsTotalsMap[testItem.id]).toEqual( + expect.objectContaining({ + unit_price: 1000, + subtotal: 833, + total: 1000, + original_total: 1000, + discount_total: 0, + original_tax_total: 167, + tax_total: 167, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + }) + + describe("getShippingMethodTotals", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register( + "featureFlagRouter", + asValue( + new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }) + ) + ) + container.register( + "taxProviderService", + asValue({ + ...taxProviderServiceMock, + getTaxLinesMap: jest + .fn() + .mockImplementation( + async ( + items: LineItem[], + calculationContext: TaxCalculationContext + ) => { + const result = { + shippingMethodsTaxLines: {}, + } + + for (const method of calculationContext.shipping_methods) { + result.shippingMethodsTaxLines[method.id] = [ + { + shipping_method_id: method.id, + name: "default", + code: "default", + rate: 30, + }, + ] + } + + return result + } + ), + }) + ) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should use the shipping method tax lines to compute the totals", async () => { + const testShippingMethod = shippingMethods[0] + testShippingMethod.includes_tax = true + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // price: 1000 including taxes, taxes: 20% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 833, + total: 1000, + original_total: 1000, + original_tax_total: 167, + tax_total: 167, + tax_lines: expect.arrayContaining(testShippingMethod.tax_lines), + }) + ) + }) + + it("should fetch shipping method tax lines to compute the totals", async () => { + const testShippingMethod = { ...shippingMethods[0] } as ShippingMethod + testShippingMethod.tax_lines = [] + testShippingMethod.includes_tax = true + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledTimes(1) + expect(taxProviderService.getTaxLinesMap).toHaveBeenCalledWith( + [], + calculationContext + ) + + // price: 1000 including taxes, taxes 30% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 769, + total: 1000, + original_total: 1000, + original_tax_total: 231, + tax_total: 231, + tax_lines: expect.arrayContaining([ + expect.objectContaining({ + name: "default", + code: "default", + rate: 30, + }), + ]), + }) + ) + }) + + it("should not use tax lines when includeTax is not true to compute the totals", async () => { + const testShippingMethod = { ...shippingMethods[0] } as ShippingMethod + testShippingMethod.tax_lines = [] + testShippingMethod.includes_tax = true + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: false, + calculationContext, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // price: 1000 including taxes + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 1000, + total: 1000, + original_total: 1000, + original_tax_total: 0, + tax_total: 0, + tax_lines: expect.arrayContaining([]), + }) + ) + }) + + // Not applicable to legacy shipping method totals calculation + /*it("should use the provided tax rate to compute the totals", async () => {})*/ + + it("should compute a total to 0 if a free shipping discount is present", async () => { + const testShippingMethod = shippingMethods[0] + testShippingMethod.includes_tax = true + + const discounts = [ + { + rule: { + type: DiscountRuleType.FREE_SHIPPING, + }, + }, + ] as Discount[] + + const calculationContext = { + allocation_map: {}, + shipping_methods: [testShippingMethod], + } as unknown as TaxCalculationContext + + const shippingMethodTotalsMap = + await newTotalsService.getShippingMethodTotals([testShippingMethod], { + includeTax: true, + calculationContext, + discounts, + }) + + const taxProviderService = container.resolve("taxProviderService") + expect(taxProviderService.getTaxLinesMap).not.toHaveBeenCalled() + + // unit_price: 1000 including taxes, taxes 20% + expect(shippingMethodTotalsMap[testShippingMethod.id]).toEqual( + expect.objectContaining({ + price: 1000, + subtotal: 0, + total: 0, + original_total: 1000, + original_tax_total: 167, + tax_total: 0, + tax_lines: expect.arrayContaining(testShippingMethod.tax_lines), + }) + ) + }) + }) + + describe("getLineItemRefund", () => { + let container + let newTotalsService: NewTotalsService + + beforeEach(() => { + container = createContainer({}, defaultContainerMock) + container.register( + "featureFlagRouter", + asValue( + new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }) + ) + ) + container.register("newTotalsService", asClass(NewTotalsService)) + newTotalsService = container.resolve("newTotalsService") + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should compute the line item refundable amount", () => { + const testItem = lineItems[0] + testItem.includes_tax = true + + const calculationContext = { + allocation_map: {}, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const refundAmount = newTotalsService.getLineItemRefund(testItem, { + calculationContext, + }) + + // unit_price: 1000 including taxes, taxes: 20% + expect(refundAmount).toEqual(1000) + }) + + it("should compute the line item refundable amount using the taxRate", () => { + const testItem = lineItems[0] + testItem.includes_tax = true + + const calculationContext = { + allocation_map: {}, + shipping_methods: [], + } as unknown as TaxCalculationContext + + const refundAmount = newTotalsService.getLineItemRefund(testItem, { + taxRate: 30, + calculationContext, + }) + + // unit_price: 1000 including taxes, taxes: 30% + expect(refundAmount).toEqual(1000) + }) + }) + }) +}) diff --git a/packages/medusa/src/services/__tests__/order.js b/packages/medusa/src/services/__tests__/order.js index cbd5c8915f3e1..f3aa48f1db077 100644 --- a/packages/medusa/src/services/__tests__/order.js +++ b/packages/medusa/src/services/__tests__/order.js @@ -2,6 +2,8 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" import OrderService from "../order" import { InventoryServiceMock } from "../__mocks__/inventory" import { LineItemServiceMock } from "../__mocks__/line-item" +import { newTotalsServiceMock } from "../__mocks__/new-totals" +import { taxProviderServiceMock } from "../__mocks__/tax-provider" describe("OrderService", () => { const totalsService = { @@ -141,6 +143,7 @@ describe("OrderService", () => { paymentProviderService, shippingOptionService, totalsService, + newTotalsService: newTotalsServiceMock, discountService, regionService, eventBusService, @@ -184,6 +187,8 @@ describe("OrderService", () => { { id: "item_2", variant_id: "variant-2", quantity: 1 }, ], total: 100, + subtotal: 100, + discount_total: 0, } orderService.cartService_.retrieveWithTotals = jest.fn(() => @@ -209,17 +214,7 @@ describe("OrderService", () => { expect(cartService.retrieveWithTotals).toHaveBeenCalledTimes(1) expect(cartService.retrieveWithTotals).toHaveBeenCalledWith("cart_id", { - relations: [ - "region", - "payment", - "items", - "discounts", - "discounts.rule", - "gift_cards", - "shipping_methods", - "items", - "items.adjustments", - ], + relations: ["region", "payment"], }) expect(paymentProviderService.updatePayment).toHaveBeenCalledTimes(1) @@ -288,6 +283,7 @@ describe("OrderService", () => { ], subtotal: 100, total: 100, + discount_total: 0, } orderService.cartService_.retrieveWithTotals = () => { @@ -380,9 +376,11 @@ describe("OrderService", () => { { id: "item_2", variant_id: "variant-2", quantity: 1 }, ], total: 0, + subtotal: 0, + discount_total: 0, } orderService.cartService_.retrieveWithTotals = () => Promise.resolve(cart) - await orderService.createFromCart(cart) + await orderService.createFromCart("cart_id") const order = { payment_status: "awaiting", email: cart.email, @@ -462,6 +460,7 @@ describe("OrderService", () => { manager: MockManager, orderRepository: orderRepo, totalsService, + newTotalsService: newTotalsServiceMock, }) beforeAll(async () => { @@ -485,6 +484,7 @@ describe("OrderService", () => { }) const orderService = new OrderService({ totalsService, + newTotalsService: newTotalsServiceMock, manager: MockManager, orderRepository: orderRepo, }) @@ -527,6 +527,7 @@ describe("OrderService", () => { }) const orderService = new OrderService({ totalsService, + newTotalsService: newTotalsServiceMock, manager: MockManager, orderRepository: orderRepo, eventBusService, @@ -638,6 +639,7 @@ describe("OrderService", () => { const orderService = new OrderService({ totalsService, + newTotalsService: newTotalsServiceMock, manager: MockManager, orderRepository: orderRepo, paymentProviderService, @@ -738,6 +740,7 @@ describe("OrderService", () => { orderRepository: orderRepo, paymentProviderService, totalsService, + newTotalsService: newTotalsServiceMock, eventBusService, }) @@ -857,6 +860,7 @@ describe("OrderService", () => { fulfillmentService, lineItemService, totalsService, + newTotalsService: newTotalsServiceMock, eventBusService, }) @@ -1092,6 +1096,7 @@ describe("OrderService", () => { orderRepository: orderRepo, paymentProviderService, totalsService, + newTotalsService: newTotalsServiceMock, eventBusService, }) @@ -1234,6 +1239,8 @@ describe("OrderService", () => { eventBusService: eventBusService, shippingOptionService: optionService, totalsService, + taxProviderService: taxProviderServiceMock, + newTotalsService: newTotalsServiceMock, }) beforeEach(async () => { @@ -1254,8 +1261,14 @@ describe("OrderService", () => { { some: "data" }, { order: { + discount_total: 0, + gift_card_tax_total: 0, + gift_card_total: 0, id: IdMap.getId("order"), items: [], + paid_total: 0, + refundable_amount: 0, + refunded_total: 0, shipping_methods: [ { shipping_option: { @@ -1263,7 +1276,10 @@ describe("OrderService", () => { }, }, ], + shipping_total: 0, subtotal: 0, + tax_total: 0, + total: 0, }, } ) @@ -1284,8 +1300,14 @@ describe("OrderService", () => { { some: "data" }, { order: { + discount_total: 0, + gift_card_tax_total: 0, + gift_card_total: 0, id: IdMap.getId("order"), items: [], + paid_total: 0, + refundable_amount: 0, + refunded_total: 0, shipping_methods: [ { shipping_option: { @@ -1293,7 +1315,10 @@ describe("OrderService", () => { }, }, ], + shipping_total: 0, subtotal: 0, + tax_total: 0, + total: 0, }, } ) @@ -1391,6 +1416,7 @@ describe("OrderService", () => { manager: MockManager, orderRepository: orderRepo, totalsService, + newTotalsService: newTotalsServiceMock, fulfillmentService, lineItemService, eventBusService, @@ -1511,6 +1537,7 @@ describe("OrderService", () => { orderRepository: orderRepo, paymentProviderService, totalsService, + newTotalsService: newTotalsServiceMock, eventBusService, }) diff --git a/packages/medusa/src/services/__tests__/swap.ts b/packages/medusa/src/services/__tests__/swap.ts index 356899395b2e6..eeaabc4d1444a 100644 --- a/packages/medusa/src/services/__tests__/swap.ts +++ b/packages/medusa/src/services/__tests__/swap.ts @@ -1,4 +1,4 @@ -import { IdMap, MockRepository, MockManager } from "medusa-test-utils" +import { IdMap, MockManager, MockRepository } from "medusa-test-utils" import SwapService from "../swap" import { InventoryServiceMock } from "../__mocks__/inventory" @@ -16,7 +16,7 @@ import { TotalsService, } from "../index" import CartService from "../cart" -import { Order, ReturnItem, Swap } from "../../models" +import { Order, Swap } from "../../models" import { SwapRepository } from "../../repositories/swap" import LineItemAdjustmentService from "../line-item-adjustment" @@ -49,6 +49,9 @@ const cartService = { withTransaction: function () { return this }, + retrieveWithTotals: jest + .fn() + .mockReturnValue(Promise.resolve({ id: "cart" })), } as unknown as CartService const customShippingOptionService = { @@ -826,6 +829,12 @@ describe("SwapService", () => { withTransaction: function () { return this }, + retrieveWithTotals: jest.fn().mockReturnValue( + Promise.resolve({ + id: "cart", + items: [{ id: "test-item", variant_id: "variant" }], + }) + ), } as unknown as CartService const paymentProviderService = { @@ -864,7 +873,8 @@ describe("SwapService", () => { other: "data", } - cartService.retrieve = (() => cart) as unknown as CartService["retrieve"] + cartService.retrieveWithTotals = (() => + cart) as unknown as CartService["retrieveWithTotals"] const swapRepo = MockRepository({ findOneWithRelations: () => Promise.resolve(existing), diff --git a/packages/medusa/src/services/cart.ts b/packages/medusa/src/services/cart.ts index fc82417f0f2d6..3538da63ac24b 100644 --- a/packages/medusa/src/services/cart.ts +++ b/packages/medusa/src/services/cart.ts @@ -24,6 +24,7 @@ import { CartCreateProps, CartUpdateProps, FilterableCartProps, + isCart, LineItemUpdate, } from "../types/cart" import { AddressPayload, FindConfig, TotalField } from "../types/common" @@ -35,7 +36,7 @@ import CustomerService from "./customer" import DiscountService from "./discount" import EventBusService from "./event-bus" import GiftCardService from "./gift-card" -import { SalesChannelService } from "./index" +import { NewTotalsService, SalesChannelService } from "./index" import InventoryService from "./inventory" import LineItemService from "./line-item" import LineItemAdjustmentService from "./line-item-adjustment" @@ -70,6 +71,7 @@ type InjectedDependencies = { discountService: DiscountService giftCardService: GiftCardService totalsService: TotalsService + newTotalsService: NewTotalsService inventoryService: InventoryService customShippingOptionService: CustomShippingOptionService lineItemAdjustmentService: LineItemAdjustmentService @@ -112,6 +114,7 @@ class CartService extends TransactionBaseService { protected readonly giftCardService_: GiftCardService protected readonly taxProviderService_: TaxProviderService protected readonly totalsService_: TotalsService + protected readonly newTotalsService_: NewTotalsService protected readonly inventoryService_: InventoryService protected readonly customShippingOptionService_: CustomShippingOptionService protected readonly priceSelectionStrategy_: IPriceSelectionStrategy @@ -135,6 +138,7 @@ class CartService extends TransactionBaseService { discountService, giftCardService, totalsService, + newTotalsService, addressRepository, paymentSessionRepository, inventoryService, @@ -163,6 +167,7 @@ class CartService extends TransactionBaseService { this.discountService_ = discountService this.giftCardService_ = giftCardService this.totalsService_ = totalsService + this.newTotalsService_ = newTotalsService this.addressRepository_ = addressRepository this.paymentSessionRepository_ = paymentSessionRepository this.inventoryService_ = inventoryService @@ -175,126 +180,6 @@ class CartService extends TransactionBaseService { this.storeService_ = storeService } - private getTotalsRelations(config: FindConfig): string[] { - const relationSet = new Set(config.relations) - - relationSet.add("items") - relationSet.add("items.tax_lines") - relationSet.add("items.adjustments") - relationSet.add("gift_cards") - relationSet.add("discounts") - relationSet.add("discounts.rule") - relationSet.add("shipping_methods") - relationSet.add("shipping_methods.tax_lines") - relationSet.add("shipping_address") - relationSet.add("region") - relationSet.add("region.tax_rates") - - return Array.from(relationSet.values()) - } - - protected transformQueryForTotals_( - config: FindConfig - ): FindConfig & { totalsToSelect: TotalField[] } { - let { select, relations } = config - - if (!select) { - return { - select, - relations, - totalsToSelect: [], - } - } - - const totalFields = [ - "subtotal", - "tax_total", - "shipping_total", - "discount_total", - "gift_card_total", - "total", - ] - - const totalsToSelect = select.filter((v) => - totalFields.includes(v) - ) as TotalField[] - if (totalsToSelect.length > 0) { - const relationSet = new Set(relations) - relationSet.add("items") - relationSet.add("items.tax_lines") - relationSet.add("gift_cards") - relationSet.add("discounts") - relationSet.add("discounts.rule") - // relationSet.add("discounts.parent_discount") - // relationSet.add("discounts.parent_discount.rule") - // relationSet.add("discounts.parent_discount.regions") - relationSet.add("shipping_methods") - relationSet.add("shipping_address") - relationSet.add("region") - relationSet.add("region.tax_rates") - relations = Array.from(relationSet.values()) - - select = select.filter((v) => !totalFields.includes(v)) - } - - return { - relations, - select, - totalsToSelect, - } - } - - protected async decorateTotals_( - cart: Cart, - totalsToSelect: TotalField[], - options: TotalsConfig = { force_taxes: false } - ): Promise { - const totals: { [K in TotalField]?: number | null } = {} - - for (const key of totalsToSelect) { - switch (key) { - case "total": { - totals.total = await this.totalsService_.getTotal(cart, { - force_taxes: options.force_taxes, - }) - break - } - case "shipping_total": { - totals.shipping_total = await this.totalsService_.getShippingTotal( - cart - ) - break - } - case "discount_total": - totals.discount_total = await this.totalsService_.getDiscountTotal( - cart - ) - break - case "tax_total": - totals.tax_total = await this.totalsService_.getTaxTotal( - cart, - options.force_taxes - ) - break - case "gift_card_total": { - const giftCardBreakdown = await this.totalsService_.getGiftCardTotal( - cart - ) - totals.gift_card_total = giftCardBreakdown.total - totals.gift_card_tax_total = giftCardBreakdown.tax_total - break - } - case "subtotal": - totals.subtotal = await this.totalsService_.getSubtotal(cart) - break - default: - break - } - } - - return Object.assign(cart, totals) - } - /** * @param selector - the query object for find * @param config - config object @@ -315,6 +200,7 @@ class CartService extends TransactionBaseService { * Gets a cart by id. * @param cartId - the id of the cart to get. * @param options - the options to get a cart + * @param totalsConfig * @return the cart document. */ async retrieve( @@ -322,23 +208,26 @@ class CartService extends TransactionBaseService { options: FindConfig = {}, totalsConfig: TotalsConfig = {} ): Promise { + const { totalsToSelect } = this.transformQueryForTotals_(options) + + if (totalsToSelect.length) { + return await this.retrieveLegacy(cartId, options, totalsConfig) + } + const manager = this.manager_ const cartRepo = manager.getCustomRepository(this.cartRepository_) - const { select, relations, totalsToSelect } = - this.transformQueryForTotals_(options) - - const query = buildQuery({ id: cartId }, { ...options, select, relations }) + const query = buildQuery({ id: cartId }, options) - if (relations && relations.length > 0) { - query.relations = relations + if ((options.select || []).length === 0) { + query.select = undefined } - query.select = select?.length ? select : undefined - const queryRelations = query.relations query.relations = undefined + const raw = await cartRepo.findOneWithRelations(queryRelations, query) + if (!raw) { throw new MedusaError( MedusaError.Types.NOT_FOUND, @@ -346,27 +235,38 @@ class CartService extends TransactionBaseService { ) } - return await this.decorateTotals_(raw, totalsToSelect, totalsConfig) + return raw } - private async retrieveNew( + /** + * @deprecated + * @param cartId + * @param options + * @param totalsConfig + * @protected + */ + protected async retrieveLegacy( cartId: string, - options: FindConfig = {} + options: FindConfig = {}, + totalsConfig: TotalsConfig = {} ): Promise { const manager = this.manager_ const cartRepo = manager.getCustomRepository(this.cartRepository_) - const query = buildQuery({ id: cartId }, options) + const { select, relations, totalsToSelect } = + this.transformQueryForTotals_(options) - if ((options.select || []).length <= 0) { - query.select = undefined + const query = buildQuery({ id: cartId }, { ...options, select, relations }) + + if (relations && relations.length > 0) { + query.relations = relations } + query.select = select?.length ? select : undefined + const queryRelations = query.relations query.relations = undefined - const raw = await cartRepo.findOneWithRelations(queryRelations, query) - if (!raw) { throw new MedusaError( MedusaError.Types.NOT_FOUND, @@ -374,7 +274,7 @@ class CartService extends TransactionBaseService { ) } - return raw + return await this.decorateTotals_(raw, totalsToSelect, totalsConfig) } async retrieveWithTotals( @@ -384,7 +284,7 @@ class CartService extends TransactionBaseService { ): Promise { const relations = this.getTotalsRelations(options) - const cart = await this.retrieveNew(cartId, { + const cart = await this.retrieve(cartId, { ...options, relations, }) @@ -1413,7 +1313,9 @@ class CartService extends TransactionBaseService { */ async authorizePayment( cartId: string, - context: Record = {} + context: Record & { + cart_id: string + } = { cart_id: "" } ): Promise { return await this.atomicPhase_( async (transactionManager: EntityManager) => { @@ -1421,27 +1323,23 @@ class CartService extends TransactionBaseService { this.cartRepository_ ) - const cart = await this.retrieve(cartId, { - select: ["total"], - relations: [ - "items", - "items.adjustments", - "region", - "payment_sessions", - ], + const cart = await this.retrieveWithTotals(cartId, { + relations: ["payment_sessions"], }) - if (typeof cart.total === "undefined") { - throw new MedusaError( - MedusaError.Types.UNEXPECTED_STATE, - "cart.total should be defined" - ) - } - // If cart total is 0, we don't perform anything payment related - if (cart.total <= 0) { + if (cart.total! <= 0) { cart.payment_authorized_at = new Date() - return await cartRepository.save(cart) + await cartRepository.save({ + id: cart.id, + payment_authorized_at: cart.payment_authorized_at, + }) + + await this.eventBus_ + .withTransaction(transactionManager) + .emit(CartService.Events.UPDATED, cart) + + return cart } if (!cart.payment_session) { @@ -1456,21 +1354,27 @@ class CartService extends TransactionBaseService { .authorizePayment(cart.payment_session, context)) as PaymentSession const freshCart = (await this.retrieve(cart.id, { - select: ["total"], - relations: ["payment_sessions", "items", "items.adjustments"], + relations: ["payment_sessions"], })) as Cart & { payment_session: PaymentSession } if (session.status === "authorized") { freshCart.payment = await this.paymentProviderService_ .withTransaction(transactionManager) - .createPayment(freshCart) + .createPayment({ + cart_id: cart.id, + currency_code: cart.region.currency_code, + amount: cart.total!, + payment_session: freshCart.payment_session, + }) freshCart.payment_authorized_at = new Date() } const updatedCart = await cartRepository.save(freshCart) + await this.eventBus_ .withTransaction(transactionManager) .emit(CartService.Events.UPDATED, updatedCart) + return updatedCart } ) @@ -2147,23 +2051,25 @@ class CartService extends TransactionBaseService { ) } - async createTaxLines(id: string): Promise { + async createTaxLines(cartOrId: string | Cart): Promise { return await this.atomicPhase_( async (transactionManager: EntityManager) => { - const cart = await this.retrieve(id, { - relations: [ - "customer", - "discounts", - "discounts.rule", - "gift_cards", - "items", - "items.adjustments", - "region", - "region.tax_rates", - "shipping_address", - "shipping_methods", - ], - }) + const cart = isCart(cartOrId) + ? cartOrId + : await this.retrieve(cartOrId, { + relations: [ + "customer", + "discounts", + "discounts.rule", + "gift_cards", + "items", + "items.adjustments", + "region", + "region.tax_rates", + "shipping_address", + "shipping_methods", + ], + }) const calculationContext = await this.totalsService_ .withTransaction(transactionManager) @@ -2172,8 +2078,6 @@ class CartService extends TransactionBaseService { await this.taxProviderService_ .withTransaction(transactionManager) .createTaxLines(cart, calculationContext) - - return cart } ) } @@ -2197,63 +2101,86 @@ class CartService extends TransactionBaseService { ) } - async decorateTotals(cart: Cart, totalsConfig?: TotalsConfig): Promise { - const totalsService = this.totalsService_ + async decorateTotals( + cart: Cart, + totalsConfig: TotalsConfig = {} + ): Promise { + const manager = this.transactionManager_ ?? this.manager_ + const newTotalsServiceTx = this.newTotalsService_.withTransaction(manager) - const calculationContext = await totalsService.getCalculationContext(cart, { - exclude_shipping: true, - }) + const calculationContext = await this.totalsService_.getCalculationContext( + cart + ) + const includeTax = totalsConfig?.force_taxes || cart.region?.automatic_taxes + const cartItems = [...(cart.items ?? [])] + const cartShippingMethods = [...(cart.shipping_methods ?? [])] - cart.items = await Promise.all( - (cart.items || []).map(async (item) => { - const itemTotals = await totalsService.getLineItemTotals(item, cart, { - include_tax: totalsConfig?.force_taxes || cart.region.automatic_taxes, - calculation_context: calculationContext, - }) + if (includeTax) { + const taxLinesMaps = await this.taxProviderService_ + .withTransaction(manager) + .getTaxLinesMap(cartItems, calculationContext) - return Object.assign(item, itemTotals) + cartItems.forEach((item) => { + if (item.is_return) { + return + } + item.tax_lines = taxLinesMaps.lineItemsTaxLines[item.id] ?? [] + }) + cartShippingMethods.forEach((method) => { + method.tax_lines = taxLinesMaps.shippingMethodsTaxLines[method.id] ?? [] }) + } + + const itemsTotals = await newTotalsServiceTx.getLineItemTotals(cartItems, { + includeTax, + calculationContext, + }) + const shippingTotals = await newTotalsServiceTx.getShippingMethodTotals( + cartShippingMethods, + { + discounts: cart.discounts, + includeTax, + calculationContext, + } ) - cart.shipping_methods = await Promise.all( - (cart.shipping_methods || []).map(async (shippingMethod) => { - const shippingTotals = await totalsService.getShippingMethodTotals( - shippingMethod, - cart, - { - include_tax: - totalsConfig?.force_taxes || cart.region.automatic_taxes, - calculation_context: calculationContext, - } - ) + cart.subtotal = 0 + cart.discount_total = 0 + cart.item_tax_total = 0 + cart.shipping_total = 0 + cart.shipping_tax_total = 0 - return Object.assign(shippingMethod, shippingTotals) - }) - ) + cart.items = (cart.items || []).map((item) => { + const itemWithTotals = Object.assign(item, itemsTotals[item.id] ?? {}) - cart.shipping_total = cart.shipping_methods.reduce((acc, method) => { - return acc + (method.subtotal ?? 0) - }, 0) + cart.subtotal! += itemWithTotals.subtotal ?? 0 + cart.discount_total! += itemWithTotals.discount_total ?? 0 + cart.item_tax_total! += itemWithTotals.tax_total ?? 0 - cart.subtotal = cart.items.reduce((acc, item) => { - return acc + (item.subtotal ?? 0) - }, 0) + return itemWithTotals + }) - cart.discount_total = cart.items.reduce((acc, item) => { - return acc + (item.discount_total ?? 0) - }, 0) + cart.shipping_methods = (cart.shipping_methods || []).map( + (shippingMethod) => { + const methodWithTotals = Object.assign( + shippingMethod, + shippingTotals[shippingMethod.id] ?? {} + ) - cart.item_tax_total = cart.items.reduce((acc, item) => { - return acc + (item.tax_total ?? 0) - }, 0) + cart.shipping_total! += methodWithTotals.subtotal ?? 0 + cart.shipping_tax_total! += methodWithTotals.tax_total ?? 0 - cart.shipping_tax_total = cart.shipping_methods.reduce((acc, method) => { - return acc + (method.tax_total ?? 0) - }, 0) + return methodWithTotals + } + ) - const giftCardTotal = await totalsService.getGiftCardTotal(cart, { - gift_cardable: cart.subtotal - cart.discount_total, - }) + const giftCardTotal = await this.newTotalsService_.getGiftCardTotals( + cart.subtotal - cart.discount_total, + { + region: cart.region, + giftCards: cart.gift_cards, + } + ) cart.gift_card_total = giftCardTotal.total || 0 cart.gift_card_tax_total = giftCardTotal.tax_total || 0 @@ -2287,6 +2214,133 @@ class CartService extends TransactionBaseService { .withTransaction(transactionManager) .createAdjustments(cart) } + + protected transformQueryForTotals_( + config: FindConfig + ): FindConfig & { totalsToSelect: TotalField[] } { + let { select, relations } = config + + if (!select) { + return { + select, + relations, + totalsToSelect: [], + } + } + + const totalFields = [ + "subtotal", + "tax_total", + "shipping_total", + "discount_total", + "gift_card_total", + "total", + ] + + const totalsToSelect = select.filter((v) => + totalFields.includes(v) + ) as TotalField[] + if (totalsToSelect.length > 0) { + const relationSet = new Set(relations) + relationSet.add("items") + relationSet.add("items.tax_lines") + relationSet.add("gift_cards") + relationSet.add("discounts") + relationSet.add("discounts.rule") + // relationSet.add("discounts.parent_discount") + // relationSet.add("discounts.parent_discount.rule") + // relationSet.add("discounts.parent_discount.regions") + relationSet.add("shipping_methods") + relationSet.add("shipping_address") + relationSet.add("region") + relationSet.add("region.tax_rates") + relations = Array.from(relationSet.values()) + + select = select.filter((v) => !totalFields.includes(v)) + } + + return { + relations, + select, + totalsToSelect, + } + } + + /** + * @deprecated Use decorateTotals instead + * @param cart + * @param totalsToSelect + * @param options + * @protected + */ + protected async decorateTotals_( + cart: Cart, + totalsToSelect: TotalField[], + options: TotalsConfig = { force_taxes: false } + ): Promise { + const totals: { [K in TotalField]?: number | null } = {} + + for (const key of totalsToSelect) { + switch (key) { + case "total": { + totals.total = await this.totalsService_.getTotal(cart, { + force_taxes: options.force_taxes, + }) + break + } + case "shipping_total": { + totals.shipping_total = await this.totalsService_.getShippingTotal( + cart + ) + break + } + case "discount_total": + totals.discount_total = await this.totalsService_.getDiscountTotal( + cart + ) + break + case "tax_total": + totals.tax_total = await this.totalsService_.getTaxTotal( + cart, + options.force_taxes + ) + break + case "gift_card_total": { + const giftCardBreakdown = await this.totalsService_.getGiftCardTotal( + cart + ) + totals.gift_card_total = giftCardBreakdown.total + totals.gift_card_tax_total = giftCardBreakdown.tax_total + break + } + case "subtotal": + totals.subtotal = await this.totalsService_.getSubtotal(cart) + break + default: + break + } + } + + return Object.assign(cart, totals) + } + + private getTotalsRelations(config: FindConfig): string[] { + const relationSet = new Set(config.relations) + + relationSet.add("items") + relationSet.add("items.tax_lines") + relationSet.add("items.adjustments") + relationSet.add("gift_cards") + relationSet.add("discounts") + relationSet.add("discounts.rule") + relationSet.add("shipping_methods") + relationSet.add("shipping_methods.tax_lines") + relationSet.add("shipping_address") + relationSet.add("region") + relationSet.add("region.tax_rates") + + return Array.from(relationSet.values()) + } } export default CartService diff --git a/packages/medusa/src/services/index.ts b/packages/medusa/src/services/index.ts index 2d4fbf21376a0..0ec15ed12a8dd 100644 --- a/packages/medusa/src/services/index.ts +++ b/packages/medusa/src/services/index.ts @@ -49,4 +49,5 @@ export { default as SystemPaymentProviderService } from "./system-payment-provid export { default as TaxProviderService } from "./tax-provider" export { default as TaxRateService } from "./tax-rate" export { default as TotalsService } from "./totals" +export { default as NewTotalsService } from "./new-totals" export { default as UserService } from "./user" diff --git a/packages/medusa/src/services/inventory.ts b/packages/medusa/src/services/inventory.ts index 8c58c6b2e61b8..67dea75637e3e 100644 --- a/packages/medusa/src/services/inventory.ts +++ b/packages/medusa/src/services/inventory.ts @@ -71,24 +71,22 @@ class InventoryService extends TransactionBaseService { return true } - return await this.atomicPhase_(async (manager) => { - const variant = await this.productVariantService_ - .withTransaction(manager) - .retrieve(variantId) - const { inventory_quantity, allow_backorder, manage_inventory } = variant - const isCovered = - !manage_inventory || allow_backorder || inventory_quantity >= quantity + const variant = await this.productVariantService_ + .withTransaction(this.manager_) + .retrieve(variantId) + const { inventory_quantity, allow_backorder, manage_inventory } = variant + const isCovered = + !manage_inventory || allow_backorder || inventory_quantity >= quantity - if (!isCovered) { - throw new MedusaError( - MedusaError.Types.NOT_ALLOWED, - `Variant with id: ${variant.id} does not have the required inventory`, - MedusaError.Codes.INSUFFICIENT_INVENTORY - ) - } + if (!isCovered) { + throw new MedusaError( + MedusaError.Types.NOT_ALLOWED, + `Variant with id: ${variant.id} does not have the required inventory`, + MedusaError.Codes.INSUFFICIENT_INVENTORY + ) + } - return isCovered - }) + return isCovered } } diff --git a/packages/medusa/src/services/new-totals.ts b/packages/medusa/src/services/new-totals.ts new file mode 100644 index 0000000000000..845be26908656 --- /dev/null +++ b/packages/medusa/src/services/new-totals.ts @@ -0,0 +1,737 @@ +import { + ITaxCalculationStrategy, + TaxCalculationContext, + TransactionBaseService, +} from "../interfaces" +import { EntityManager } from "typeorm" +import { + Discount, + DiscountRuleType, + GiftCard, + LineItem, + LineItemTaxLine, + Region, + ShippingMethod, + ShippingMethodTaxLine, +} from "../models" +import { TaxProviderService } from "./index" +import { LineAllocationsMap } from "../types/totals" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { FlagRouter } from "../utils/flag-router" +import { calculatePriceTaxAmount, isDefined } from "../utils" +import { MedusaError } from "medusa-core-utils" + +type LineItemTotals = { + unit_price: number + quantity: number + subtotal: number + tax_total: number + total: number + original_total: number + original_tax_total: number + tax_lines: LineItemTaxLine[] + discount_total: number +} + +type ShippingMethodTotals = { + price: number + tax_total: number + total: number + subtotal: number + original_total: number + original_tax_total: number + tax_lines: ShippingMethodTaxLine[] +} + +type InjectedDependencies = { + manager: EntityManager + taxProviderService: TaxProviderService + taxCalculationStrategy: ITaxCalculationStrategy + featureFlagRouter: FlagRouter +} + +export default class NewTotalsService extends TransactionBaseService { + protected readonly manager_: EntityManager + protected readonly transactionManager_: EntityManager | undefined + + protected readonly taxProviderService_: TaxProviderService + protected readonly featureFlagRouter_: FlagRouter + protected readonly taxCalculationStrategy_: ITaxCalculationStrategy + + constructor({ + manager, + taxProviderService, + featureFlagRouter, + taxCalculationStrategy, + }: InjectedDependencies) { + super(arguments[0]) + + this.manager_ = manager + this.taxProviderService_ = taxProviderService + this.featureFlagRouter_ = featureFlagRouter + this.taxCalculationStrategy_ = taxCalculationStrategy + } + + /** + * Calculate and return the items totals for either the legacy calculation or the new calculation + * @param items + * @param includeTax + * @param calculationContext + * @param taxRate + */ + async getLineItemTotals( + items: LineItem | LineItem[], + { + includeTax, + calculationContext, + taxRate, + }: { + includeTax?: boolean + calculationContext: TaxCalculationContext + taxRate?: number | null + } + ): Promise<{ [lineItemId: string]: LineItemTotals }> { + items = Array.isArray(items) ? items : [items] + + const manager = this.transactionManager_ ?? this.manager_ + let lineItemsTaxLinesMap: { [lineItemId: string]: LineItemTaxLine[] } = {} + + if (!taxRate && includeTax) { + // Use existing tax lines if they are present + const itemContainsTaxLines = items.some((item) => item.tax_lines?.length) + if (itemContainsTaxLines) { + items.forEach((item) => { + lineItemsTaxLinesMap[item.id] = item.tax_lines ?? [] + }) + } else { + const { lineItemsTaxLines } = await this.taxProviderService_ + .withTransaction(manager) + .getTaxLinesMap(items, calculationContext) + lineItemsTaxLinesMap = lineItemsTaxLines + } + } + + const calculationMethod = taxRate + ? this.getLineItemTotalsLegacy.bind(this) + : this.getLineItemTotals_.bind(this) + + const itemsTotals: { [lineItemId: string]: LineItemTotals } = {} + for (const item of items) { + const lineItemAllocation = + calculationContext.allocation_map[item.id] || {} + + itemsTotals[item.id] = await calculationMethod(item, { + taxRate, + includeTax, + lineItemAllocation, + taxLines: lineItemsTaxLinesMap[item.id], + calculationContext, + }) + } + + return itemsTotals + } + + /** + * Calculate and return the totals for an item + * @param item + * @param includeTax + * @param lineItemAllocation + * @param taxLines Only needed to force the usage of the specified tax lines, often in the case where the item does not hold the tax lines + * @param calculationContext + */ + protected async getLineItemTotals_( + item: LineItem, + { + includeTax, + lineItemAllocation, + taxLines, + calculationContext, + }: { + includeTax?: boolean + lineItemAllocation: LineAllocationsMap[number] + taxLines?: LineItemTaxLine[] + calculationContext: TaxCalculationContext + } + ): Promise { + let subtotal = item.unit_price * item.quantity + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + item.includes_tax + ) { + subtotal = 0 // in that case we need to know the tax rate to compute it later + } + + const discount_total = + (lineItemAllocation.discount?.unit_amount || 0) * item.quantity + + const totals: LineItemTotals = { + unit_price: item.unit_price, + quantity: item.quantity, + subtotal, + discount_total, + total: subtotal - discount_total, + original_total: subtotal, + original_tax_total: 0, + tax_total: 0, + tax_lines: item.tax_lines ?? [], + } + + if (includeTax) { + totals.tax_lines = totals.tax_lines.length + ? totals.tax_lines + : (taxLines as LineItemTaxLine[]) + + if (!totals.tax_lines) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Tax Lines must be joined to calculate taxes" + ) + } + } + + if (item.is_return) { + if (!isDefined(item.tax_lines)) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Return Line Items must join tax lines" + ) + } + } + + if (totals.tax_lines.length > 0) { + totals.tax_total = await this.taxCalculationStrategy_.calculate( + [item], + totals.tax_lines, + calculationContext + ) + const noDiscountContext = { + ...calculationContext, + allocation_map: {}, // Don't account for discounts + } + + totals.original_tax_total = await this.taxCalculationStrategy_.calculate( + [item], + totals.tax_lines, + noDiscountContext + ) + + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + item.includes_tax + ) { + totals.subtotal += + totals.unit_price * totals.quantity - totals.original_tax_total + totals.total += totals.subtotal + totals.original_total += totals.subtotal + } + + totals.total += totals.tax_total + totals.original_total += totals.original_tax_total + } + + return totals + } + + /** + * Calculate and return the legacy calculated totals using the tax rate + * @param item + * @param taxRate + * @param lineItemAllocation + * @param calculationContext + */ + protected async getLineItemTotalsLegacy( + item: LineItem, + { + taxRate, + lineItemAllocation, + calculationContext, + }: { + lineItemAllocation: LineAllocationsMap[number] + calculationContext: TaxCalculationContext + taxRate: number + } + ): Promise { + let subtotal = item.unit_price * item.quantity + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + item.includes_tax + ) { + subtotal = 0 // in that case we need to know the tax rate to compute it later + } + + const discount_total = + (lineItemAllocation.discount?.unit_amount || 0) * item.quantity + + const totals: LineItemTotals = { + unit_price: item.unit_price, + quantity: item.quantity, + subtotal, + discount_total, + total: subtotal - discount_total, + original_total: subtotal, + original_tax_total: 0, + tax_total: 0, + tax_lines: [], + } + + taxRate = taxRate / 100 + + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && item.includes_tax + const taxIncludedInPrice = !item.includes_tax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: item.unit_price, + taxRate: taxRate, + includesTax, + }) + ) + totals.subtotal = Math.round( + (item.unit_price - taxIncludedInPrice) * item.quantity + ) + totals.total = totals.subtotal + + totals.original_tax_total = Math.round(totals.subtotal * taxRate) + totals.tax_total = Math.round((totals.subtotal - discount_total) * taxRate) + + totals.total += totals.tax_total + + if (includesTax) { + totals.original_total += totals.subtotal + } + + totals.original_total += totals.original_tax_total + + return totals + } + + /** + * Return the amount that can be refund on a line item + * @param lineItem + * @param calculationContext + * @param taxRate + */ + getLineItemRefund( + lineItem: { + id: string + unit_price: number + includes_tax: boolean + quantity: number + tax_lines: LineItemTaxLine[] + }, + { + calculationContext, + taxRate, + }: { calculationContext: TaxCalculationContext; taxRate?: number | null } + ): number { + /* + * Used for backcompat with old tax system + */ + if (taxRate != null) { + return this.getLineItemRefundLegacy(lineItem, { + calculationContext, + taxRate, + }) + } + + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && lineItem.includes_tax + + const discountAmount = + (calculationContext.allocation_map[lineItem.id]?.discount?.unit_amount || + 0) * lineItem.quantity + + if (!isDefined(lineItem.tax_lines)) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Cannot compute line item refund amount, tax lines are missing from the line item" + ) + } + + const totalTaxRate = lineItem.tax_lines.reduce((acc, next) => { + return acc + next.rate / 100 + }, 0) + + const taxAmountIncludedInPrice = !includesTax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate: totalTaxRate, + includesTax, + }) + ) + + const lineSubtotal = + (lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity - + discountAmount + + const taxTotal = lineItem.tax_lines.reduce((acc, next) => { + return acc + Math.round(lineSubtotal * (next.rate / 100)) + }, 0) + + return lineSubtotal + taxTotal + } + + /** + * @param lineItem + * @param calculationContext + * @param taxRate + * @protected + */ + protected getLineItemRefundLegacy( + lineItem: { + id: string + unit_price: number + includes_tax: boolean + quantity: number + }, + { + calculationContext, + taxRate, + }: { calculationContext: TaxCalculationContext; taxRate: number } + ): number { + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && lineItem.includes_tax + + const taxAmountIncludedInPrice = !includesTax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate: taxRate / 100, + includesTax, + }) + ) + + const discountAmount = + (calculationContext.allocation_map[lineItem.id]?.discount?.unit_amount || + 0) * lineItem.quantity + + const lineSubtotal = + (lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity - + discountAmount + + return Math.round(lineSubtotal * (1 + taxRate / 100)) + } + + /** + * Calculate and return the gift cards totals + * @param giftCardableAmount + * @param giftCardTransactions + * @param region + * @param giftCards + */ + async getGiftCardTotals( + giftCardableAmount: number, + { + giftCardTransactions, + region, + giftCards, + }: { + region: Region + giftCardTransactions?: { + tax_rate: number | null + is_taxable: boolean | null + amount: number + }[] + giftCards?: GiftCard[] + } + ): Promise<{ + total: number + tax_total: number + }> { + if (!giftCards && !giftCardTransactions) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Cannot calculate the gift cart totals. Neither the gift cards or gift card transactions have been provided" + ) + } + + if (giftCardTransactions) { + return this.getGiftCardTransactionsTotals({ + giftCardTransactions, + region, + }) + } + + const result = { + total: 0, + tax_total: 0, + } + + if (!giftCards?.length) { + return result + } + + const giftAmount = giftCards.reduce((acc, next) => acc + next.balance, 0) + result.total = Math.min(giftCardableAmount, giftAmount) + + if (region?.gift_cards_taxable) { + result.tax_total = Math.round(result.total * (region.tax_rate / 100)) + return result + } + + return result + } + + /** + * Calculate and return the gift cards totals based on their transactions + * @param gift_card_transactions + * @param region + */ + getGiftCardTransactionsTotals({ + giftCardTransactions, + region, + }: { + giftCardTransactions: { + tax_rate: number | null + is_taxable: boolean | null + amount: number + }[] + region: { gift_cards_taxable: boolean; tax_rate: number } + }): { total: number; tax_total: number } { + return giftCardTransactions.reduce( + (acc, next) => { + let taxMultiplier = (next.tax_rate || 0) / 100 + + // Previously we did not record whether a gift card was taxable or not. + // All gift cards where is_taxable === null are from the old system, + // where we defaulted to taxable gift cards. + // + // This is a backwards compatability fix for orders that were created + // before we added the gift card tax rate. + if (next.is_taxable === null && region?.gift_cards_taxable) { + taxMultiplier = region.tax_rate / 100 + } + + return { + total: acc.total + next.amount, + tax_total: acc.tax_total + next.amount * taxMultiplier, + } + }, + { + total: 0, + tax_total: 0, + } + ) + } + + /** + * Calculate and return the shipping methods totals for either the legacy calculation or the new calculation + * @param shippingMethods + * @param includeTax + * @param discounts + * @param taxRate + * @param calculationContext + */ + async getShippingMethodTotals( + shippingMethods: ShippingMethod | ShippingMethod[], + { + includeTax, + discounts, + taxRate, + calculationContext, + }: { + includeTax?: boolean + calculationContext: TaxCalculationContext + discounts?: Discount[] + taxRate?: number | null + } + ): Promise<{ [shippingMethodId: string]: ShippingMethodTotals }> { + shippingMethods = Array.isArray(shippingMethods) + ? shippingMethods + : [shippingMethods] + + const manager = this.transactionManager_ ?? this.manager_ + let shippingMethodsTaxLinesMap: { + [shippingMethodId: string]: ShippingMethodTaxLine[] + } = {} + + if (!taxRate && includeTax) { + // Use existing tax lines if they are present + const shippingMethodContainsTaxLines = shippingMethods.some( + (method) => method.tax_lines?.length + ) + if (shippingMethodContainsTaxLines) { + shippingMethods.forEach((sm) => { + shippingMethodsTaxLinesMap[sm.id] = sm.tax_lines ?? [] + }) + } else { + const calculationContextWithGivenMethod = { + ...calculationContext, + shipping_methods: shippingMethods, + } + const { shippingMethodsTaxLines } = await this.taxProviderService_ + .withTransaction(manager) + .getTaxLinesMap([], calculationContextWithGivenMethod) + shippingMethodsTaxLinesMap = shippingMethodsTaxLines + } + } + + const calculationMethod = taxRate + ? this.getShippingMethodTotalsLegacy.bind(this) + : this.getShippingMethodTotals_.bind(this) + + const shippingMethodsTotals: { + [lineItemId: string]: ShippingMethodTotals + } = {} + for (const shippingMethod of shippingMethods) { + shippingMethodsTotals[shippingMethod.id] = await calculationMethod( + shippingMethod, + { + includeTax, + calculationContext, + taxLines: shippingMethodsTaxLinesMap[shippingMethod.id], + discounts, + taxRate, + } + ) + } + + return shippingMethodsTotals + } + + /** + * Calculate and return the shipping method totals + * @param shippingMethod + * @param includeTax + * @param calculationContext + * @param taxLines + * @param discounts + */ + protected async getShippingMethodTotals_( + shippingMethod: ShippingMethod, + { + includeTax, + calculationContext, + taxLines, + discounts, + }: { + includeTax?: boolean + calculationContext: TaxCalculationContext + taxLines?: ShippingMethodTaxLine[] + discounts?: Discount[] + } + ) { + const totals: ShippingMethodTotals = { + price: shippingMethod.price, + original_total: shippingMethod.price, + total: shippingMethod.price, + subtotal: shippingMethod.price, + original_tax_total: 0, + tax_total: 0, + tax_lines: shippingMethod.tax_lines ?? [], + } + + if (includeTax) { + totals.tax_lines = totals.tax_lines.length + ? totals.tax_lines + : (taxLines as ShippingMethodTaxLine[]) + + if (!totals.tax_lines) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Tax Lines must be joined to calculate taxes" + ) + } + } + + const calculationContext_: TaxCalculationContext = { + ...calculationContext, + shipping_methods: [shippingMethod], + } + + if (totals.tax_lines.length) { + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && shippingMethod.includes_tax + + totals.original_tax_total = await this.taxCalculationStrategy_.calculate( + [], + totals.tax_lines, + calculationContext_ + ) + totals.tax_total = totals.original_tax_total + + if (includesTax) { + totals.subtotal -= totals.tax_total + } else { + totals.original_total += totals.original_tax_total + totals.total += totals.tax_total + } + } + + const hasFreeShipping = discounts?.some( + (d) => d.rule.type === DiscountRuleType.FREE_SHIPPING + ) + + if (hasFreeShipping) { + totals.total = 0 + totals.subtotal = 0 + totals.tax_total = 0 + } + + return totals + } + + /** + * Calculate and return the shipping method totals legacy using teh tax rate + * @param shippingMethod + * @param calculationContext + * @param taxRate + * @param discounts + */ + protected async getShippingMethodTotalsLegacy( + shippingMethod: ShippingMethod, + { + calculationContext, + discounts, + taxRate, + }: { + calculationContext: TaxCalculationContext + discounts?: Discount[] + taxRate: number + } + ): Promise { + const totals: ShippingMethodTotals = { + price: shippingMethod.price, + original_total: shippingMethod.price, + total: shippingMethod.price, + subtotal: shippingMethod.price, + original_tax_total: 0, + tax_total: 0, + tax_lines: [], + } + + totals.original_tax_total = Math.round(totals.price * (taxRate / 100)) + totals.tax_total = Math.round(totals.price * (taxRate / 100)) + + const hasFreeShipping = discounts?.some( + (d) => d.rule.type === DiscountRuleType.FREE_SHIPPING + ) + + if (hasFreeShipping) { + totals.total = 0 + totals.subtotal = 0 + totals.tax_total = 0 + } + + return totals + } +} diff --git a/packages/medusa/src/services/order.ts b/packages/medusa/src/services/order.ts index b1cafd7ca304d..1b5c5960cd33f 100644 --- a/packages/medusa/src/services/order.ts +++ b/packages/medusa/src/services/order.ts @@ -4,6 +4,7 @@ import { TransactionBaseService } from "../interfaces" import SalesChannelFeatureFlag from "../loaders/feature-flags/sales-channels" import { Address, + Cart, ClaimOrder, Fulfillment, FulfillmentItem, @@ -26,7 +27,7 @@ import { } from "../types/fulfillment" import { UpdateOrderInput } from "../types/orders" import { CreateShippingMethodDto } from "../types/shipping-options" -import { buildQuery, setMetadata } from "../utils" +import { buildQuery, isDefined, isString, setMetadata } from "../utils" import { FlagRouter } from "../utils/flag-router" import CartService from "./cart" import CustomerService from "./customer" @@ -43,6 +44,9 @@ import RegionService from "./region" import ShippingOptionService from "./shipping-option" import ShippingProfileService from "./shipping-profile" import TotalsService from "./totals" +import { NewTotalsService, TaxProviderService } from "./index" + +export const ORDER_CART_ALREADY_EXISTS_ERROR = "Order from cart already exists" type InjectedDependencies = { manager: EntityManager @@ -56,6 +60,8 @@ type InjectedDependencies = { fulfillmentService: FulfillmentService lineItemService: LineItemService totalsService: TotalsService + newTotalsService: NewTotalsService + taxProviderService: TaxProviderService regionService: RegionService cartService: CartService addressRepository: typeof AddressRepository @@ -66,6 +72,10 @@ type InjectedDependencies = { featureFlagRouter: FlagRouter } +type TotalsConfig = { + force_taxes?: boolean +} + class OrderService extends TransactionBaseService { static readonly Events = { GIFT_CARD_CREATED: "order.gift_card_created", @@ -99,6 +109,8 @@ class OrderService extends TransactionBaseService { protected readonly fulfillmentService_: FulfillmentService protected readonly lineItemService_: LineItemService protected readonly totalsService_: TotalsService + protected readonly newTotalsService_: NewTotalsService + protected readonly taxProviderService_: TaxProviderService protected readonly regionService_: RegionService protected readonly cartService_: CartService protected readonly addressRepository_: typeof AddressRepository @@ -120,6 +132,8 @@ class OrderService extends TransactionBaseService { fulfillmentService, lineItemService, totalsService, + newTotalsService, + taxProviderService, regionService, cartService, addressRepository, @@ -140,6 +154,8 @@ class OrderService extends TransactionBaseService { this.fulfillmentProviderService_ = fulfillmentProviderService this.lineItemService_ = lineItemService this.totalsService_ = totalsService + this.newTotalsService_ = newTotalsService + this.taxProviderService_ = taxProviderService this.regionService_ = regionService this.fulfillmentService_ = fulfillmentService this.discountService_ = discountService @@ -225,7 +241,7 @@ class OrderService extends TransactionBaseService { this.transformQueryForTotals(config) query.select = select - const rels = relations + const rels = this.getTotalsRelations({ relations }) delete query.relations @@ -309,22 +325,57 @@ class OrderService extends TransactionBaseService { /** * Gets an order by id. - * @param orderId - id of order to retrieve + * @param orderId - id or selector of order to retrieve * @param config - config of order to retrieve * @return the order document */ async retrieve( orderId: string, config: FindConfig = {} + ): Promise { + const { totalsToSelect } = this.transformQueryForTotals(config) + + if (totalsToSelect?.length) { + return await this.retrieveLegacy(orderId, config) + } + + const manager = this.manager_ + const orderRepo = manager.getCustomRepository(this.orderRepository_) + + const query = buildQuery({ id: orderId }, config) + + if (!(config.select || []).length) { + query.select = undefined + } + + const queryRelations = query.relations + query.relations = undefined + + const raw = await orderRepo.findOneWithRelations(queryRelations, query) + + if (!raw) { + throw new MedusaError( + MedusaError.Types.NOT_FOUND, + `Order with id ${orderId} was not found` + ) + } + + return raw + } + + protected async retrieveLegacy( + orderIdOrSelector: string | Selector, + config: FindConfig = {} ): Promise { const orderRepo = this.manager_.getCustomRepository(this.orderRepository_) const { select, relations, totalsToSelect } = this.transformQueryForTotals(config) - const query = { - where: { id: orderId }, - } as FindConfig + const selector = isString(orderIdOrSelector) + ? { id: orderIdOrSelector } + : orderIdOrSelector + const query = buildQuery(selector, config) if (relations && relations.length > 0) { query.relations = relations @@ -334,17 +385,33 @@ class OrderService extends TransactionBaseService { const rels = query.relations delete query.relations + const raw = await orderRepo.findOneWithRelations(rels, query) + if (!raw) { + const selectorConstraints = Object.entries(selector) + .map((key, value) => `${key}: ${value}`) + .join(", ") throw new MedusaError( MedusaError.Types.NOT_FOUND, - `Order with ${orderId} was not found` + `Order with ${selectorConstraints} was not found` ) } return await this.decorateTotals(raw, totalsToSelect) } + async retrieveWithTotals( + orderId: string, + options: FindConfig = {}, + totalsConfig: TotalsConfig = {} + ): Promise { + const relations = this.getTotalsRelations(options) + const order = await this.retrieve(orderId, { ...options, relations }) + + return await this.decorateTotals(order, totalsConfig) + } + /** * Gets an order by cart id. * @param cartId - cart id to find order @@ -379,6 +446,10 @@ class OrderService extends TransactionBaseService { ) } + if (!totalsToSelect?.length) { + return raw + } + return await this.decorateTotals(raw, totalsToSelect) } @@ -404,6 +475,7 @@ class OrderService extends TransactionBaseService { if (relations && relations.length > 0) { query.relations = relations } + query.relations = this.getTotalsRelations({ relations: query.relations }) query.select = select?.length ? select : undefined @@ -420,16 +492,6 @@ class OrderService extends TransactionBaseService { return await this.decorateTotals(raw, totalsToSelect) } - /** - * Checks the existence of an order by cart id. - * @param cartId - cart id to find order - * @return the order document - */ - async existsByCartId(cartId: string): Promise { - const order = await this.retrieveByCartId(cartId).catch(() => undefined) - return !!order - } - /** * @param orderId - id of the order to complete * @return the result of the find operation @@ -459,27 +521,33 @@ class OrderService extends TransactionBaseService { /** * Creates an order from a cart - * @param cartId - id of the cart to create an order from * @return resolves to the creation result. + * @param cartOrId */ - async createFromCart(cartId: string): Promise { + async createFromCart(cartOrId: string | Cart): Promise { return await this.atomicPhase_(async (manager) => { const cartServiceTx = this.cartService_.withTransaction(manager) const inventoryServiceTx = this.inventoryService_.withTransaction(manager) - const cart = await cartServiceTx.retrieveWithTotals(cartId, { - relations: [ - "region", - "payment", - "items", - "discounts", - "discounts.rule", - "gift_cards", - "shipping_methods", - "items", - "items.adjustments", - ], - }) + const exists = !!(await this.retrieveByCartId( + isString(cartOrId) ? cartOrId : cartOrId?.id, + { + select: ["id"], + } + ).catch(() => void 0)) + + if (exists) { + throw new MedusaError( + MedusaError.Types.DUPLICATE_ERROR, + ORDER_CART_ALREADY_EXISTS_ERROR + ) + } + + const cart = isString(cartOrId) + ? await cartServiceTx.retrieveWithTotals(cartOrId, { + relations: ["region", "payment"], + }) + : cartOrId if (cart.items.length === 0) { throw new MedusaError( @@ -490,30 +558,22 @@ class OrderService extends TransactionBaseService { const { payment, region, total } = cart - for (const item of cart.items) { - try { - await inventoryServiceTx.confirmInventory( + await Promise.all( + cart.items.map(async (item) => { + return await inventoryServiceTx.confirmInventory( item.variant_id, item.quantity ) - } catch (err) { - if (payment) { - await this.paymentProviderService_ - .withTransaction(manager) - .cancelPayment(payment) - } - await cartServiceTx.update(cart.id, { payment_authorized_at: null }) - throw err + }) + ).catch(async (err) => { + if (payment) { + await this.paymentProviderService_ + .withTransaction(manager) + .cancelPayment(payment) } - } - - const exists = await this.existsByCartId(cart.id) - if (exists) { - throw new MedusaError( - MedusaError.Types.DUPLICATE_ERROR, - "Order from cart already exists" - ) - } + await cartServiceTx.update(cart.id, { payment_authorized_at: null }) + throw err + }) // Would be the case if a discount code is applied that covers the item // total @@ -539,11 +599,19 @@ class OrderService extends TransactionBaseService { const orderRepo = manager.getCustomRepository(this.orderRepository_) + // TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals. + // Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating? + // We normally should only pass what is needed? + const shippingMethods = cart.shipping_methods.map((method) => { + ;(method.tax_lines as any) = undefined + return method + }) + const toCreate = { payment_status: "awaiting", discounts: cart.discounts, gift_cards: cart.gift_cards, - shipping_methods: cart.shipping_methods, + shipping_methods: shippingMethods, shipping_address_id: cart.shipping_address_id, billing_address_id: cart.billing_address_id, region_id: cart.region_id, @@ -570,18 +638,28 @@ class OrderService extends TransactionBaseService { toCreate.no_notification = draft.no_notification_order } - const o = orderRepo.create(toCreate) - const result = await orderRepo.save(o) + const rawOrder = orderRepo.create(toCreate) + const order = await orderRepo.save(rawOrder) if (total !== 0 && payment) { await this.paymentProviderService_ .withTransaction(manager) .updatePayment(payment.id, { - order_id: result.id, + order_id: order.id, }) } - let gcBalance = await this.totalsService_.getGiftCardableAmount(cart) + if (!isDefined(cart.subtotal) || !isDefined(cart.discount_total)) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Unable to compute gift cardable amount during order creation from cart. The cart is missing the subtotal and/or discount_total" + ) + } + + let gcBalance = + (cart.region?.gift_cards_taxable + ? cart.subtotal! - cart.discount_total! + : cart.total! + cart.gift_card_total!) || 0 const gcService = this.giftCardService_.withTransaction(manager) for (const g of cart.gift_cards) { @@ -594,7 +672,7 @@ class OrderService extends TransactionBaseService { await gcService.createTransaction({ gift_card_id: g.id, - order_id: result.id, + order_id: order.id, amount: usage, is_taxable: cart.region.gift_cards_taxable, tax_rate: cart.region.gift_cards_taxable @@ -605,34 +683,43 @@ class OrderService extends TransactionBaseService { gcBalance = gcBalance - usage } - for (const method of cart.shipping_methods) { - await this.shippingOptionService_ - .withTransaction(manager) - .updateShippingMethod(method.id, { order_id: result.id }) - } - + const shippingOptionServiceTx = + this.shippingOptionService_.withTransaction(manager) const lineItemServiceTx = this.lineItemService_.withTransaction(manager) - for (const item of cart.items) { - await lineItemServiceTx.update(item.id, { order_id: result.id }) - } - for (const item of cart.items) { - await inventoryServiceTx.adjustInventory( - item.variant_id, - -item.quantity - ) - } + await Promise.all( + [ + cart.items.map((item) => { + return [ + lineItemServiceTx.update(item.id, { order_id: order.id }), + inventoryServiceTx.adjustInventory( + item.variant_id, + -item.quantity + ), + ] + }), + cart.shipping_methods.map((method) => { + // TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals. + // Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating? + // We normally should only pass what is needed? + ;(method.tax_lines as any) = undefined + return shippingOptionServiceTx.updateShippingMethod(method.id, { + order_id: order.id, + }) + }), + ].flat(Infinity) + ) await this.eventBus_ .withTransaction(manager) .emit(OrderService.Events.PLACED, { - id: result.id, - no_notification: result.no_notification, + id: order.id, + no_notification: order.no_notification, }) await cartServiceTx.update(cart.id, { completed_at: new Date() }) - return result + return order }) } @@ -814,8 +901,7 @@ class OrderService extends TransactionBaseService { config: CreateShippingMethodDto = {} ): Promise { return await this.atomicPhase_(async (manager) => { - const order = await this.retrieve(orderId, { - select: ["subtotal"], + const order = await this.retrieveWithTotals(orderId, { relations: [ "shipping_methods", "shipping_methods.shipping_option", @@ -1408,31 +1494,10 @@ class OrderService extends TransactionBaseService { }) } - protected async decorateTotals( + protected async decorateTotalsLegacy( order: Order, totalsFields: string[] = [] ): Promise { - if (totalsFields.some((field) => ["subtotal", "total"].includes(field))) { - const calculationContext = - await this.totalsService_.getCalculationContext(order, { - exclude_shipping: true, - }) - order.items = await Promise.all( - (order.items || []).map(async (item) => { - const itemTotals = await this.totalsService_.getLineItemTotals( - item, - order, - { - include_tax: true, - calculation_context: calculationContext, - } - ) - - return Object.assign(item, itemTotals) - }) - ) - } - for (const totalField of totalsFields) { switch (totalField) { case "shipping_total": { @@ -1537,6 +1602,149 @@ class OrderService extends TransactionBaseService { return order } + /** + * @param order + * @param totalsFieldsOrConfig + * @protected + */ + async decorateTotals( + order: Order, + totalsFieldsOrConfig?: string[] | TotalsConfig + ): Promise { + if (Array.isArray(totalsFieldsOrConfig)) { + return await this.decorateTotalsLegacy(order, totalsFieldsOrConfig) + } + + const manager = this.transactionManager_ ?? this.manager_ + const newTotalsServiceTx = this.newTotalsService_.withTransaction(manager) + + const calculationContext = await this.totalsService_.getCalculationContext( + order + ) + const orderItems = [...(order.items ?? [])] + const orderShippingMethods = [...(order.shipping_methods ?? [])] + + const itemsTotals = await newTotalsServiceTx.getLineItemTotals(orderItems, { + taxRate: order.tax_rate, + includeTax: true, + calculationContext, + }) + const shippingTotals = await newTotalsServiceTx.getShippingMethodTotals( + orderShippingMethods, + { + taxRate: order.tax_rate, + discounts: order.discounts, + includeTax: true, + calculationContext, + } + ) + + order.subtotal = 0 + order.discount_total = 0 + order.shipping_total = 0 + order.refunded_total = + Math.round(order.refunds?.reduce((acc, next) => acc + next.amount, 0)) || + 0 + order.paid_total = + order.payments?.reduce((acc, next) => (acc += next.amount), 0) || 0 + order.refundable_amount = order.paid_total - order.refunded_total || 0 + let item_tax_total = 0 + let shipping_tax_total = 0 + + order.items = (order.items || []).map((item) => { + const refundable = newTotalsServiceTx.getLineItemRefund( + { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + }, + { + calculationContext, + taxRate: order.tax_rate, + } + ) + + const itemWithTotals = { + ...item, + ...(itemsTotals[item.id] ?? {}), + refundable, + } + + order.subtotal += itemWithTotals.subtotal ?? 0 + order.discount_total += itemWithTotals.discount_total ?? 0 + item_tax_total += itemWithTotals.tax_total ?? 0 + + return itemWithTotals as LineItem + }) + + order.shipping_methods = (order.shipping_methods || []).map( + (shippingMethod) => { + const methodWithTotals = Object.assign( + shippingMethod, + shippingTotals[shippingMethod.id] ?? {} + ) + + order.shipping_total += methodWithTotals.subtotal ?? 0 + shipping_tax_total += methodWithTotals.tax_total ?? 0 + + return methodWithTotals + } + ) + + const giftCardTotal = await this.newTotalsService_.getGiftCardTotals( + order.subtotal - order.discount_total, + { + region: order.region, + giftCards: order.gift_cards, + giftCardTransactions: order.gift_card_transactions ?? [], + } + ) + order.gift_card_total = giftCardTotal.total || 0 + order.gift_card_tax_total = giftCardTotal.tax_total || 0 + + order.tax_total = + item_tax_total + shipping_tax_total - order.gift_card_tax_total + + for (const swap of order.swaps ?? []) { + swap.additional_items = swap.additional_items.map((item) => { + item.refundable = newTotalsServiceTx.getLineItemRefund( + { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + }, + { + calculationContext, + taxRate: order.tax_rate, + } + ) + return item + }) + } + + for (const claim of order.claims ?? []) { + claim.additional_items = claim.additional_items.map((item) => { + item.refundable = newTotalsServiceTx.getLineItemRefund( + { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + }, + { + calculationContext, + taxRate: order.tax_rate, + } + ) + return item + }) + } + + order.total = + order.subtotal + + order.shipping_total + + order.tax_total - + (order.gift_card_total + order.discount_total) + + return order + } + /** * Handles receiving a return. This will create a * refund to the customer. If the returned items don't match the requested @@ -1624,6 +1832,32 @@ class OrderService extends TransactionBaseService { return result }) } + + private getTotalsRelations(config: FindConfig): string[] { + const relationSet = new Set(config.relations) + + relationSet.add("items") + relationSet.add("items.tax_lines") + relationSet.add("items.adjustments") + relationSet.add("swaps") + relationSet.add("swaps.additional_items") + relationSet.add("swaps.additional_items.tax_lines") + relationSet.add("swaps.additional_items.adjustments") + relationSet.add("claims") + relationSet.add("claims.additional_items") + relationSet.add("claims.additional_items.tax_lines") + relationSet.add("claims.additional_items.adjustments") + relationSet.add("discounts") + relationSet.add("discounts.rule") + relationSet.add("gift_cards") + relationSet.add("gift_card_transactions") + relationSet.add("refunds") + relationSet.add("shipping_methods") + relationSet.add("shipping_methods.tax_lines") + relationSet.add("region") + + return Array.from(relationSet.values()) + } } export default OrderService diff --git a/packages/medusa/src/services/payment-provider.ts b/packages/medusa/src/services/payment-provider.ts index deb1deb532eaa..34133874bcb8d 100644 --- a/packages/medusa/src/services/payment-provider.ts +++ b/packages/medusa/src/services/payment-provider.ts @@ -374,11 +374,14 @@ export default class PaymentProviderService extends TransactionBaseService { } } - async createPayment( - cart: Cart & { payment_session: PaymentSession } - ): Promise { + async createPayment(data: { + cart_id: string + amount: number + currency_code: string + payment_session: PaymentSession + }): Promise { return await this.atomicPhase_(async (transactionManager) => { - const { payment_session: paymentSession, region, total } = cart + const { payment_session: paymentSession, currency_code, amount } = data const provider = this.retrieveProvider(paymentSession.provider_id) const paymentData = await provider @@ -391,10 +394,10 @@ export default class PaymentProviderService extends TransactionBaseService { const created = paymentRepo.create({ provider_id: paymentSession.provider_id, - amount: total, - currency_code: region.currency_code, + amount, + currency_code, data: paymentData, - cart_id: cart.id, + cart_id: data.cart_id, }) return await paymentRepo.save(created) diff --git a/packages/medusa/src/services/swap.ts b/packages/medusa/src/services/swap.ts index ed7f0fe71862e..ff312d41d7d5d 100644 --- a/packages/medusa/src/services/swap.ts +++ b/packages/medusa/src/services/swap.ts @@ -1,5 +1,5 @@ import { MedusaError } from "medusa-core-utils" -import { EntityManager, In } from "typeorm" +import { EntityManager } from "typeorm" import { buildQuery, isDefined, setMetadata, validateId } from "../utils" import { TransactionBaseService } from "../interfaces" @@ -719,14 +719,8 @@ class SwapService extends TransactionBaseService { const cart = await this.cartService_ .withTransaction(manager) - .retrieve(swap.cart_id, { - select: ["total"], - relations: [ - "payment", - "shipping_methods", - "items", - "items.adjustments", - ], + .retrieveWithTotals(swap.cart_id, { + relations: ["payment"], }) const { payment } = cart @@ -802,7 +796,13 @@ class SwapService extends TransactionBaseService { swap.difference_due = total swap.shipping_address_id = cart.shipping_address_id - swap.shipping_methods = cart.shipping_methods + // TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals. + // Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating? + // We normally should only pass what is needed? + swap.shipping_methods = cart.shipping_methods.map((method) => { + ;(method.tax_lines as any) = undefined + return method + }) swap.confirmed_at = new Date() swap.payment_status = total === 0 ? SwapPaymentStatus.CONFIRMED : SwapPaymentStatus.AWAITING diff --git a/packages/medusa/src/services/tax-provider.ts b/packages/medusa/src/services/tax-provider.ts index 9cd36b9ff3eac..935bc1c076e07 100644 --- a/packages/medusa/src/services/tax-provider.ts +++ b/packages/medusa/src/services/tax-provider.ts @@ -23,7 +23,7 @@ import { TransactionBaseService, } from "../interfaces" -import { TaxServiceRate } from "../types/tax-service" +import { TaxLinesMaps, TaxServiceRate } from "../types/tax-service" import TaxRateService from "./tax-rate" import EventBusService from "./event-bus" @@ -333,6 +333,42 @@ class TaxProviderService extends TransactionBaseService { }) } + /** + * Return a map of tax lines for line items and shipping methods + * @param items + * @param calculationContext + * @protected + */ + async getTaxLinesMap( + items: LineItem[], + calculationContext: TaxCalculationContext + ): Promise { + const lineItemsTaxLinesMap = {} + const shippingMethodsTaxLinesMap = {} + + const taxLines = await this.getTaxLines(items, calculationContext) + + taxLines.forEach((taxLine) => { + if ("item_id" in taxLine) { + const itemTaxLines = lineItemsTaxLinesMap[taxLine.item_id] ?? [] + itemTaxLines.push(taxLine) + lineItemsTaxLinesMap[taxLine.item_id] = itemTaxLines + } + if ("shipping_method_id" in taxLine) { + const shippingMethodTaxLines = + shippingMethodsTaxLinesMap[taxLine.shipping_method_id] ?? [] + shippingMethodTaxLines.push(taxLine) + shippingMethodsTaxLinesMap[taxLine.shipping_method_id] = + shippingMethodTaxLines + } + }) + + return { + lineItemsTaxLines: lineItemsTaxLinesMap, + shippingMethodsTaxLines: shippingMethodsTaxLinesMap, + } + } + /** * Gets the tax rates configured for a shipping option. The rates are cached * between calls. diff --git a/packages/medusa/src/strategies/__tests__/cart-completion.js b/packages/medusa/src/strategies/__tests__/cart-completion.js index deb7cdac3e01e..46b6b402cbf25 100644 --- a/packages/medusa/src/strategies/__tests__/cart-completion.js +++ b/packages/medusa/src/strategies/__tests__/cart-completion.js @@ -1,5 +1,6 @@ import { MockManager } from "medusa-test-utils" import CartCompletionStrategy from "../cart-completion" +import { newTotalsServiceMock } from "../../services/__mocks__/new-totals" const IdempotencyKeyServiceMock = { withTransaction: function () { @@ -57,32 +58,34 @@ const toTest = [ }) expect(cartServiceMock.createTaxLines).toHaveBeenCalledTimes(1) - expect(cartServiceMock.createTaxLines).toHaveBeenCalledWith("test-cart") + expect(cartServiceMock.createTaxLines).toHaveBeenCalledWith( + expect.objectContaining({ id: "test-cart" }) + ) expect(cartServiceMock.authorizePayment).toHaveBeenCalledTimes(1) expect(cartServiceMock.authorizePayment).toHaveBeenCalledWith( "test-cart", { - idempotency_key: "ikey", + cart_id: "test-cart", + idempotency_key: { + idempotency_key: "ikey", + recovery_point: "tax_lines_created", + }, } ) expect(orderServiceMock.createFromCart).toHaveBeenCalledTimes(1) expect(orderServiceMock.createFromCart).toHaveBeenCalledWith( - "test-cart" + expect.objectContaining({ id: "test-cart" }) ) - expect(orderServiceMock.retrieve).toHaveBeenCalledTimes(1) - expect(orderServiceMock.retrieve).toHaveBeenCalledWith("test-cart", { - select: [ - "subtotal", - "tax_total", - "shipping_total", - "discount_total", - "total", - ], - relations: ["shipping_address", "items", "payments"], - }) + expect(orderServiceMock.retrieveWithTotals).toHaveBeenCalledTimes(1) + expect(orderServiceMock.retrieveWithTotals).toHaveBeenCalledWith( + "test-cart", + { + relations: ["shipping_address", "items", "payments"], + } + ) }, }, ], @@ -187,6 +190,7 @@ describe("CartCompletionStrategy", () => { authorizePayment: jest.fn(() => Promise.resolve(cart)), retrieve: jest.fn(() => Promise.resolve(cart)), retrieveWithTotals: jest.fn(() => Promise.resolve(cart)), + newTotalsService: newTotalsServiceMock, } const orderServiceMock = { withTransaction: function () { @@ -194,6 +198,8 @@ describe("CartCompletionStrategy", () => { }, createFromCart: jest.fn(() => Promise.resolve(cart)), retrieve: jest.fn(() => Promise.resolve({})), + retrieveWithTotals: jest.fn(() => Promise.resolve({})), + newTotalsService: newTotalsServiceMock, } const swapServiceMock = { withTransaction: function () { diff --git a/packages/medusa/src/strategies/cart-completion.ts b/packages/medusa/src/strategies/cart-completion.ts index e8d2611f79919..6b31335840af9 100644 --- a/packages/medusa/src/strategies/cart-completion.ts +++ b/packages/medusa/src/strategies/cart-completion.ts @@ -4,7 +4,9 @@ import { EntityManager } from "typeorm" import { IdempotencyKey, Order } from "../models" import CartService from "../services/cart" import IdempotencyKeyService from "../services/idempotency-key" -import OrderService from "../services/order" +import OrderService, { + ORDER_CART_ALREADY_EXISTS_ERROR, +} from "../services/order" import SwapService from "../services/swap" import { RequestContext } from "../types/request" @@ -52,11 +54,6 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { ): Promise { let idempotencyKey: IdempotencyKey = ikey - const idempotencyKeyService = this.idempotencyKeyService_ - const cartService = this.cartService_ - const orderService = this.orderService_ - const swapService = this.swapService_ - let inProgress = true let err: unknown = false @@ -65,30 +62,12 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { case "started": { await this.manager_ .transaction("SERIALIZABLE", async (transactionManager) => { - idempotencyKey = await idempotencyKeyService + idempotencyKey = await this.idempotencyKeyService_ .withTransaction(transactionManager) - .workStage(idempotencyKey.idempotency_key, async (manager) => { - const cart = await cartService - .withTransaction(manager) - .retrieve(id) - - if (cart.completed_at) { - return { - response_code: 409, - response_body: { - code: MedusaError.Codes.CART_INCOMPATIBLE_STATE, - message: "Cart has already been completed", - type: MedusaError.Types.NOT_ALLOWED, - }, - } - } - - await cartService.withTransaction(manager).createTaxLines(id) - - return { - recovery_point: "tax_lines_created", - } - }) + .workStage( + idempotencyKey.idempotency_key, + async (manager) => await this.handleStarted(id, { manager }) + ) }) .catch((e) => { inProgress = false @@ -99,40 +78,16 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { case "tax_lines_created": { await this.manager_ .transaction("SERIALIZABLE", async (transactionManager) => { - idempotencyKey = await idempotencyKeyService + idempotencyKey = await this.idempotencyKeyService_ .withTransaction(transactionManager) - .workStage(idempotencyKey.idempotency_key, async (manager) => { - const cart = await cartService - .withTransaction(manager) - .authorizePayment(id, { - ...context, - idempotency_key: idempotencyKey.idempotency_key, + .workStage( + idempotencyKey.idempotency_key, + async (manager) => + await this.handleTaxLineCreated(id, idempotencyKey, { + context, + manager, }) - - if (cart.payment_session) { - if ( - cart.payment_session.status === "requires_more" || - cart.payment_session.status === "pending" - ) { - await cartService - .withTransaction(transactionManager) - .deleteTaxLines(id) - - return { - response_code: 200, - response_body: { - data: cart, - payment_status: cart.payment_session.status, - type: "cart", - }, - } - } - } - - return { - recovery_point: "payment_authorized", - } - }) + ) }) .catch((e) => { inProgress = false @@ -144,139 +99,13 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { case "payment_authorized": { await this.manager_ .transaction("SERIALIZABLE", async (transactionManager) => { - idempotencyKey = await idempotencyKeyService + idempotencyKey = await this.idempotencyKeyService_ .withTransaction(transactionManager) - .workStage(idempotencyKey.idempotency_key, async (manager) => { - const cart = await cartService - .withTransaction(manager) - .retrieveWithTotals(id, { - relations: ["payment", "payment_sessions"], - }) - - // If cart is part of swap, we register swap as complete - switch (cart.type) { - case "swap": { - try { - const swapId = cart.metadata?.swap_id - let swap = await swapService - .withTransaction(manager) - .registerCartCompletion(swapId as string) - - swap = await swapService - .withTransaction(manager) - .retrieve(swap.id, { - relations: ["shipping_address"], - }) - - return { - response_code: 200, - response_body: { data: swap, type: "swap" }, - } - } catch (error) { - if ( - error && - error.code === - MedusaError.Codes.INSUFFICIENT_INVENTORY - ) { - return { - response_code: 409, - response_body: { - message: error.message, - type: error.type, - code: error.code, - }, - } - } else { - throw error - } - } - } - default: { - if (typeof cart.total === "undefined") { - return { - response_code: 500, - response_body: { - message: "Unexpected state", - }, - } - } - - if (!cart.payment && cart.total > 0) { - throw new MedusaError( - MedusaError.Types.INVALID_DATA, - `Cart payment not authorized` - ) - } - - let order: Order - try { - order = await orderService - .withTransaction(manager) - .createFromCart(cart.id) - } catch (error) { - if ( - error && - error.message === "Order from cart already exists" - ) { - order = await orderService - .withTransaction(manager) - .retrieveByCartId(id, { - select: [ - "subtotal", - "tax_total", - "shipping_total", - "discount_total", - "total", - ], - relations: [ - "shipping_address", - "items", - "payments", - ], - }) - - return { - response_code: 200, - response_body: { data: order, type: "order" }, - } - } else if ( - error && - error.code === - MedusaError.Codes.INSUFFICIENT_INVENTORY - ) { - return { - response_code: 409, - response_body: { - message: error.message, - type: error.type, - code: error.code, - }, - } - } else { - throw error - } - } - - order = await orderService - .withTransaction(manager) - .retrieve(order.id, { - select: [ - "subtotal", - "tax_total", - "shipping_total", - "discount_total", - "total", - ], - relations: ["shipping_address", "items", "payments"], - }) - - return { - response_code: 200, - response_body: { data: order, type: "order" }, - } - } - } - }) + .workStage( + idempotencyKey.idempotency_key, + async (manager) => + await this.handlePaymentAuthorized(id, { manager }) + ) }) .catch((e) => { inProgress = false @@ -292,7 +121,7 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { default: await this.manager_.transaction(async (transactionManager) => { - idempotencyKey = await idempotencyKeyService + idempotencyKey = await this.idempotencyKeyService_ .withTransaction(transactionManager) .update(idempotencyKey.idempotency_key, { recovery_point: "finished", @@ -308,11 +137,11 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { if (idempotencyKey.recovery_point !== "started") { await this.manager_.transaction(async (transactionManager) => { try { - await orderService + await this.orderService_ .withTransaction(transactionManager) .retrieveByCartId(id) } catch (error) { - await cartService + await this.cartService_ .withTransaction(transactionManager) .deleteTaxLines(id) } @@ -326,6 +155,172 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy { response_code: idempotencyKey.response_code, } } + + protected async handleStarted( + id: string, + { manager }: { manager: EntityManager } + ) { + const cart = await this.cartService_.withTransaction(manager).retrieve(id, { + relations: [ + "customer", + "discounts", + "discounts.rule", + "gift_cards", + "items", + "items.adjustments", + "region", + "region.tax_rates", + "shipping_address", + "shipping_methods", + ], + }) + + if (cart.completed_at) { + return { + response_code: 409, + response_body: { + code: MedusaError.Codes.CART_INCOMPATIBLE_STATE, + message: "Cart has already been completed", + type: MedusaError.Types.NOT_ALLOWED, + }, + } + } + + await this.cartService_.withTransaction(manager).createTaxLines(cart) + + return { + recovery_point: "tax_lines_created", + } + } + + protected async handleTaxLineCreated( + id: string, + idempotencyKey: IdempotencyKey, + { context, manager }: { context: any; manager: EntityManager } + ) { + const cart = await this.cartService_ + .withTransaction(manager) + .authorizePayment(id, { + ...context, + cart_id: id, + idempotency_key: idempotencyKey, + }) + + if (cart.payment_session) { + if ( + cart.payment_session.status === "requires_more" || + cart.payment_session.status === "pending" + ) { + await this.cartService_.withTransaction(manager).deleteTaxLines(id) + + return { + response_code: 200, + response_body: { + data: cart, + payment_status: cart.payment_session.status, + type: "cart", + }, + } + } + } + + return { + recovery_point: "payment_authorized", + } + } + + protected async handlePaymentAuthorized( + id: string, + { manager }: { manager: EntityManager } + ) { + const orderServiceTx = this.orderService_.withTransaction(manager) + + const cart = await this.cartService_ + .withTransaction(manager) + .retrieveWithTotals(id, { + relations: ["region", "payment", "payment_sessions"], + }) + + // If cart is part of swap, we register swap as complete + if (cart.type === "swap") { + try { + const swapId = cart.metadata?.swap_id + let swap = await this.swapService_ + .withTransaction(manager) + .registerCartCompletion(swapId as string) + + swap = await this.swapService_ + .withTransaction(manager) + .retrieve(swap.id, { + relations: ["shipping_address"], + }) + + return { + response_code: 200, + response_body: { data: swap, type: "swap" }, + } + } catch (error) { + if (error && error.code === MedusaError.Codes.INSUFFICIENT_INVENTORY) { + return { + response_code: 409, + response_body: { + message: error.message, + type: error.type, + code: error.code, + }, + } + } else { + throw error + } + } + } + + if (!cart.payment && cart.total! > 0) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `Cart payment not authorized` + ) + } + + let order: Order + try { + order = await orderServiceTx.createFromCart(cart) + } catch (error) { + if (error && error.message === ORDER_CART_ALREADY_EXISTS_ERROR) { + order = await orderServiceTx.retrieveByCartId(id, { + relations: ["shipping_address", "payments"], + }) + + return { + response_code: 200, + response_body: { data: order, type: "order" }, + } + } else if ( + error && + error.code === MedusaError.Codes.INSUFFICIENT_INVENTORY + ) { + return { + response_code: 409, + response_body: { + message: error.message, + type: error.type, + code: error.code, + }, + } + } else { + throw error + } + } + + order = await orderServiceTx.retrieveWithTotals(order.id, { + relations: ["shipping_address", "items", "payments"], + }) + + return { + response_code: 200, + response_body: { data: order, type: "order" }, + } + } } export default CartCompletionStrategy diff --git a/packages/medusa/src/types/tax-service.ts b/packages/medusa/src/types/tax-service.ts index f0866c3e4c06d..1c70fe756a6e5 100644 --- a/packages/medusa/src/types/tax-service.ts +++ b/packages/medusa/src/types/tax-service.ts @@ -1,3 +1,12 @@ +import { LineItemTaxLine, ShippingMethodTaxLine } from "../models" + +export type TaxLinesMaps = { + lineItemsTaxLines: { [lineItemId: string]: LineItemTaxLine[] } + shippingMethodsTaxLines: { + [shippingMethodId: string]: ShippingMethodTaxLine[] + } +} + /** * The tax rate object as configured in Medusa. These may have an unspecified * numerical rate as they may be used for lookup purposes in the tax provider diff --git a/packages/medusa/src/types/totals.ts b/packages/medusa/src/types/totals.ts index ea0378a3ca804..f818e06120ca7 100644 --- a/packages/medusa/src/types/totals.ts +++ b/packages/medusa/src/types/totals.ts @@ -1,4 +1,4 @@ -import { LineItem } from "../models/line-item" +import { LineItem } from "../models" /** The amount of a gift card allocated to a line item */ export type GiftCardAllocation = { diff --git a/packages/medusa/tsconfig.json b/packages/medusa/tsconfig.json index 0fc6130e7818a..b73ee183580bb 100644 --- a/packages/medusa/tsconfig.json +++ b/packages/medusa/tsconfig.json @@ -27,6 +27,7 @@ "./dist/**/*", "./src/**/__tests__", "./src/**/__mocks__", + "./src/**/__fixtures__", "node_modules" ] } \ No newline at end of file