Răsfoiți Sursa

feat(cli): Add service command

Michael Bromley 1 an în urmă
părinte
comite
e29accc283

+ 5 - 0
packages/cli/src/commands/add/add.ts

@@ -4,6 +4,7 @@ import { Command } from 'commander';
 import { addCodegen } from './codegen/add-codegen';
 import { addEntity } from './entity/add-entity';
 import { createNewPlugin } from './plugin/create-new-plugin';
+import { addService } from './service/add-service';
 import { addUiExtensions } from './ui-extensions/add-ui-extensions';
 
 const cancelledMessage = 'Add feature cancelled.';
@@ -18,6 +19,7 @@ export function registerAddCommand(program: Command) {
                 options: [
                     { value: 'plugin', label: '[Plugin] Add a new plugin' },
                     { value: 'entity', label: '[Plugin: Entity] Add a new entity to a plugin' },
+                    { value: 'service', label: '[Plugin: Service] Add a new service to a plugin' },
                     { value: 'uiExtensions', label: '[Plugin: UI] Set up Admin UI extensions' },
                     { value: 'codegen', label: '[Project: Codegen] Set up GraphQL code generation' },
                 ],
@@ -39,6 +41,9 @@ export function registerAddCommand(program: Command) {
                 if (featureType === 'codegen') {
                     await addCodegen();
                 }
+                if (featureType === 'service') {
+                    await addService();
+                }
             } catch (e: any) {
                 log.error(e.message as string);
                 if (e.stack) {

+ 1 - 1
packages/cli/src/commands/add/entity/add-entity.ts

@@ -3,6 +3,7 @@ import { paramCase, pascalCase } from 'change-case';
 import path from 'path';
 import { ClassDeclaration } from 'ts-morph';
 
+import { pascalCaseRegex } from '../../../constants';
 import { analyzeProject, selectPlugin } from '../../../shared/shared-prompts';
 import { VendurePluginRef } from '../../../shared/vendure-plugin-ref';
 import { createFile } from '../../../utilities/ast-utils';
@@ -138,7 +139,6 @@ export async function getCustomEntityName(_cancelledMessage: string) {
             if (!input) {
                 return 'The custom entity name cannot be empty';
             }
-            const pascalCaseRegex = /^[A-Z][a-zA-Z0-9]*$/;
             if (!pascalCaseRegex.test(input)) {
                 return 'The custom entity name must be in PascalCase, e.g. "ProductReview"';
             }

+ 4 - 0
packages/cli/src/commands/add/plugin/create-new-plugin.ts

@@ -8,6 +8,7 @@ import { VendurePluginRef } from '../../../shared/vendure-plugin-ref';
 import { addImportsToFile, createFile, getTsMorphProject } from '../../../utilities/ast-utils';
 import { addCodegen } from '../codegen/add-codegen';
 import { addEntity } from '../entity/add-entity';
+import { addService } from '../service/add-service';
 import { addUiExtensions } from '../ui-extensions/add-ui-extensions';
 
 import { GeneratePluginOptions, NewPluginTemplateContext } from './types';
@@ -74,6 +75,7 @@ export async function createNewPlugin() {
             options: [
                 { value: 'no', label: "[Finish] No, I'm done!" },
                 { value: 'entity', label: '[Plugin: Entity] Add a new entity to the plugin' },
+                { value: 'service', label: '[Plugin: Service] Add a new service to the plugin' },
                 { value: 'uiExtensions', label: '[Plugin: UI] Set up Admin UI extensions' },
                 {
                     value: 'codegen',
@@ -92,6 +94,8 @@ export async function createNewPlugin() {
             await addUiExtensions(plugin);
         } else if (featureType === 'codegen') {
             await addCodegen(plugin);
+        } else if (featureType === 'service') {
+            await addService(plugin);
         }
     }
 

+ 291 - 0
packages/cli/src/commands/add/service/add-service.ts

@@ -0,0 +1,291 @@
+import { cancel, isCancel, outro, select, text } from '@clack/prompts';
+import path from 'path';
+import { ClassDeclaration, SourceFile } from 'ts-morph';
+
+import { pascalCaseRegex } from '../../../constants';
+import { EntityRef } from '../../../shared/entity-ref';
+import { analyzeProject, selectEntity, selectPlugin } from '../../../shared/shared-prompts';
+import { VendurePluginRef } from '../../../shared/vendure-plugin-ref';
+import { addImportsToFile, createFile, kebabize } from '../../../utilities/ast-utils';
+
+const cancelledMessage = 'Add service cancelled';
+
+interface AddServiceTemplateContext {
+    type: 'basic' | 'entity';
+    serviceName: string;
+    entityRef?: EntityRef;
+}
+
+export async function addService(providedVendurePlugin?: VendurePluginRef) {
+    const project = await analyzeProject({ providedVendurePlugin, cancelledMessage });
+    const vendurePlugin = providedVendurePlugin ?? (await selectPlugin(project, cancelledMessage));
+
+    const type = await select({
+        message: 'What type of service would you like to add?',
+        options: [
+            { value: 'basic', label: 'Basic empty service' },
+            { value: 'entity', label: 'Service to perform CRUD operations on an entity' },
+        ],
+        maxItems: 10,
+    });
+    if (isCancel(type)) {
+        cancel('Cancelled');
+        process.exit(0);
+    }
+    const context: AddServiceTemplateContext = {
+        type: type as AddServiceTemplateContext['type'],
+        serviceName: 'MyService',
+    };
+    if (type === 'entity') {
+        const entityRef = await selectEntity(vendurePlugin);
+        context.entityRef = entityRef;
+        context.serviceName = `${entityRef.name}Service`;
+    }
+
+    let serviceSourceFile: SourceFile;
+    if (context.type === 'basic') {
+        serviceSourceFile = createFile(project, path.join(__dirname, 'templates/basic-service.template.ts'));
+        const name = await text({
+            message: 'What is the name of the new service?',
+            initialValue: 'MyService',
+            validate: input => {
+                if (!input) {
+                    return 'The service name cannot be empty';
+                }
+                if (!pascalCaseRegex.test(input)) {
+                    return 'The service name must be in PascalCase, e.g. "MyService"';
+                }
+            },
+        });
+        if (isCancel(name)) {
+            cancel(cancelledMessage);
+            process.exit(0);
+        }
+        context.serviceName = name;
+        serviceSourceFile.getClass('BasicServiceTemplate')?.rename(context.serviceName);
+    } else {
+        serviceSourceFile = createFile(project, path.join(__dirname, 'templates/entity-service.template.ts'));
+        const serviceClassDeclaration = serviceSourceFile
+            .getClass('EntityServiceTemplate')
+            ?.rename(context.serviceName);
+        if (!serviceClassDeclaration) {
+            throw new Error('Could not find service class declaration');
+        }
+        const entityRef = context.entityRef;
+        if (!entityRef) {
+            throw new Error('Entity class not found');
+        }
+        const templateEntityClass = serviceSourceFile.getClass('TemplateEntity');
+        if (templateEntityClass) {
+            templateEntityClass.rename(entityRef.name);
+            templateEntityClass.remove();
+        }
+        addImportsToFile(serviceClassDeclaration.getSourceFile(), {
+            moduleSpecifier: entityRef.classDeclaration.getSourceFile(),
+            namedImports: [entityRef.name],
+        });
+        const templateTranslationEntityClass = serviceSourceFile.getClass('TemplateEntityTranslation');
+        if (entityRef.isTranslatable()) {
+            const translationEntityClass = entityRef.getTranslationClass();
+            if (translationEntityClass && templateTranslationEntityClass) {
+                templateTranslationEntityClass.rename(translationEntityClass?.getName() as string);
+                templateTranslationEntityClass.remove();
+
+                addImportsToFile(serviceClassDeclaration.getSourceFile(), {
+                    moduleSpecifier: translationEntityClass.getSourceFile(),
+                    namedImports: [translationEntityClass.getName() as string],
+                });
+            }
+        } else {
+            templateTranslationEntityClass?.remove();
+        }
+        customizeInputInterfaces(serviceSourceFile, entityRef);
+        customizeFindOneMethod(serviceClassDeclaration, entityRef);
+        customizeFindAllMethod(serviceClassDeclaration, entityRef);
+        customizeCreateMethod(serviceClassDeclaration, entityRef);
+        customizeUpdateMethod(serviceClassDeclaration, entityRef);
+        removedUnusedConstructorArgs(serviceClassDeclaration, entityRef);
+    }
+
+    const serviceFileName = kebabize(context.serviceName).replace(/-service$/, '.service');
+    serviceSourceFile?.move(
+        path.join(vendurePlugin.getPluginDir().getPath(), 'services', `${serviceFileName}.ts`),
+    );
+
+    vendurePlugin.addProvider(context.serviceName);
+    addImportsToFile(vendurePlugin.classDeclaration.getSourceFile(), {
+        moduleSpecifier: serviceSourceFile,
+        namedImports: [context.serviceName],
+    });
+
+    serviceSourceFile.organizeImports();
+    await project.save();
+
+    if (!providedVendurePlugin) {
+        outro('✅  Done!');
+    }
+}
+
+function customizeFindOneMethod(serviceClassDeclaration: ClassDeclaration, entityRef: EntityRef) {
+    // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
+    const findOneMethod = serviceClassDeclaration.getMethod('findOne')!;
+    findOneMethod
+        .setBodyText(writer => {
+            writer.write(` return this.connection
+            .getRepository(ctx, ${entityRef.name})
+            .findOne({
+                where: { id },
+                relations,
+            })`);
+            if (entityRef.isTranslatable()) {
+                writer.write(`.then(entity => entity && this.translator.translate(entity, ctx));`);
+            } else {
+                writer.write(`;`);
+            }
+        })
+        .formatText();
+    if (!entityRef.isTranslatable()) {
+        findOneMethod.setReturnType(`Promise<${entityRef.name} | null>`);
+    }
+}
+
+function customizeFindAllMethod(serviceClassDeclaration: ClassDeclaration, entityRef: EntityRef) {
+    // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
+    const findAllMethod = serviceClassDeclaration.getMethod('findAll')!;
+    findAllMethod
+        .setBodyText(writer => {
+            writer.writeLine(`return this.listQueryBuilder`);
+            writer.write(`.build(${entityRef.name}, options,`).block(() => {
+                writer.writeLine('relations,');
+                writer.writeLine('ctx,');
+                writer.writeLine('channelId: ctx.channelId,');
+            });
+            writer.write(')');
+            writer.write('.getManyAndCount()');
+            writer.write('.then(([items, totalItems]) =>').block(() => {
+                writer.write('return').block(() => {
+                    if (entityRef.isTranslatable()) {
+                        writer.writeLine('items: items.map(item => this.translator.translate(item, ctx)),');
+                    } else {
+                        writer.writeLine('items,');
+                    }
+                    writer.writeLine('totalItems,');
+                });
+            });
+            writer.write(');');
+        })
+        .formatText();
+    if (!entityRef.isTranslatable()) {
+        findAllMethod.setReturnType(`Promise<PaginatedList<${entityRef.name}>>`);
+    }
+}
+
+function customizeCreateMethod(serviceClassDeclaration: ClassDeclaration, entityRef: EntityRef) {
+    // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
+    const createMethod = serviceClassDeclaration.getMethod('create')!;
+    createMethod
+        .setBodyText(writer => {
+            if (entityRef.isTranslatable()) {
+                writer.write(`const newEntity = await this.translatableSaver.create({
+                                ctx,
+                                input,
+                                entityType: ${entityRef.name},
+                                translationType: ${entityRef.getTranslationClass()?.getName() as string},
+                                beforeSave: async f => {
+                                    // Any pre-save logic can go here
+                                },
+                            });`);
+            } else {
+                writer.writeLine(
+                    `const newEntity = await this.connection.getRepository(ctx, ${entityRef.name}).save(input);`,
+                );
+            }
+            if (entityRef.hasCustomFields()) {
+                writer.writeLine(
+                    `await this.customFieldRelationService.updateRelations(ctx, ${entityRef.name}, input, newEntity);`,
+                );
+            }
+            writer.writeLine(`return assertFound(this.findOne(ctx, newEntity.id));`);
+        })
+        .formatText();
+    if (!entityRef.isTranslatable()) {
+        createMethod.setReturnType(`Promise<${entityRef.name} | null>`);
+    }
+}
+
+function customizeUpdateMethod(serviceClassDeclaration: ClassDeclaration, entityRef: EntityRef) {
+    // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
+    const updateMethod = serviceClassDeclaration.getMethod('update')!;
+    updateMethod
+        .setBodyText(writer => {
+            if (entityRef.isTranslatable()) {
+                writer.write(`const updatedEntity = await this.translatableSaver.update({
+                                ctx,
+                                input,
+                                entityType: ${entityRef.name},
+                                translationType: ${entityRef.getTranslationClass()?.getName() as string},
+                                beforeSave: async f => {
+                                    // Any pre-save logic can go here
+                                },
+                            });`);
+            } else {
+                writer.writeLine(
+                    `const entity = await this.connection.getEntityOrThrow(ctx, ${entityRef.name}, input.id);`,
+                );
+                writer.writeLine(`const updatedEntity = patchEntity(entity, input);`);
+                writer.writeLine(
+                    `await this.connection.getRepository(ctx, ${entityRef.name}).save(updatedEntity, { reload: false });`,
+                );
+            }
+            if (entityRef.hasCustomFields()) {
+                writer.writeLine(
+                    `await this.customFieldRelationService.updateRelations(ctx, ${entityRef.name}, input, updatedEntity);`,
+                );
+            }
+            writer.writeLine(`return assertFound(this.findOne(ctx, updatedEntity.id));`);
+        })
+        .formatText();
+    if (!entityRef.isTranslatable()) {
+        updateMethod.setReturnType(`Promise<${entityRef.name} | null>`);
+    }
+}
+
+function customizeInputInterfaces(serviceSourceFile: SourceFile, entityRef: EntityRef) {
+    const createInputInterface = serviceSourceFile
+        .getInterface('CreateEntityInput')
+        ?.rename(`Create${entityRef.name}Input`);
+    const updateInputInterface = serviceSourceFile
+        .getInterface('UpdateEntityInput')
+        ?.rename(`Update${entityRef.name}Input`);
+    if (!entityRef.hasCustomFields()) {
+        createInputInterface?.getProperty('customFields')?.remove();
+        updateInputInterface?.getProperty('customFields')?.remove();
+    }
+    if (entityRef.isTranslatable()) {
+        createInputInterface
+            ?.getProperty('translations')
+            ?.setType(`Array<TranslationInput<${entityRef.name}>>`);
+        updateInputInterface
+            ?.getProperty('translations')
+            ?.setType(`Array<TranslationInput<${entityRef.name}>>`);
+    } else {
+        createInputInterface?.getProperty('translations')?.remove();
+        updateInputInterface?.getProperty('translations')?.remove();
+    }
+}
+
+function removedUnusedConstructorArgs(serviceClassDeclaration: ClassDeclaration, entityRef: EntityRef) {
+    const isTranslatable = entityRef.isTranslatable();
+    const hasCustomFields = entityRef.hasCustomFields();
+    serviceClassDeclaration.getConstructors().forEach(constructor => {
+        constructor.getParameters().forEach(param => {
+            const paramName = param.getName();
+            if ((paramName === 'translatableSaver' || paramName === 'translator') && !isTranslatable) {
+                param.remove();
+            }
+            if (paramName === 'customFieldRelationService' && !hasCustomFields) {
+                param.remove();
+            }
+        });
+    });
+}

+ 13 - 0
packages/cli/src/commands/add/service/templates/basic-service.template.ts

@@ -0,0 +1,13 @@
+import { Injectable } from '@nestjs/common';
+import { Ctx, Product, RequestContext, TransactionalConnection } from '@vendure/core';
+
+@Injectable()
+export class BasicServiceTemplate {
+    constructor(private connection: TransactionalConnection) {}
+
+    async exampleMethod(@Ctx() ctx: RequestContext) {
+        // Add your method logic here
+        const result = await this.connection.getRepository(ctx, Product).findOne({});
+        return result;
+    }
+}

+ 147 - 0
packages/cli/src/commands/add/service/templates/entity-service.template.ts

@@ -0,0 +1,147 @@
+import { Injectable } from '@nestjs/common';
+import { DeletionResponse, DeletionResult, LanguageCode } from '@vendure/common/lib/generated-types';
+import { CustomFieldsObject, ID, PaginatedList } from '@vendure/common/lib/shared-types';
+import {
+    assertFound,
+    CustomFieldRelationService,
+    HasCustomFields,
+    ListQueryBuilder,
+    ListQueryOptions,
+    RelationPaths,
+    RequestContext,
+    TransactionalConnection,
+    Translatable,
+    TranslatableSaver,
+    Translated,
+    Translation,
+    TranslationInput,
+    TranslatorService,
+    VendureEntity,
+    patchEntity,
+} from '@vendure/core';
+
+// These can be replaced by generated types if you set up code generation
+interface CreateEntityInput {
+    // Define the input fields here
+    customFields?: CustomFieldsObject;
+    translations: Array<TranslationInput<TemplateEntity>>;
+}
+interface UpdateEntityInput {
+    id: ID;
+    // Define the input fields here
+    customFields?: CustomFieldsObject;
+    translations: Array<TranslationInput<TemplateEntity>>;
+}
+
+class TemplateEntity extends VendureEntity implements Translatable, HasCustomFields {
+    constructor() {
+        super();
+    }
+
+    customFields: CustomFieldsObject;
+
+    translations: Array<Translation<TemplateEntity>>;
+}
+
+class TemplateEntityTranslation extends VendureEntity implements Translation<TemplateEntity> {
+    constructor() {
+        super();
+    }
+
+    id: ID;
+    languageCode: LanguageCode;
+    base: TemplateEntity;
+    customFields: CustomFieldsObject;
+}
+
+@Injectable()
+export class EntityServiceTemplate {
+    constructor(
+        private connection: TransactionalConnection,
+        private translatableSaver: TranslatableSaver,
+        private listQueryBuilder: ListQueryBuilder,
+        private customFieldRelationService: CustomFieldRelationService,
+        private translator: TranslatorService,
+    ) {}
+
+    findAll(
+        ctx: RequestContext,
+        options?: ListQueryOptions<TemplateEntity>,
+        relations?: RelationPaths<TemplateEntity>,
+    ): Promise<PaginatedList<Translated<TemplateEntity>>> {
+        return this.listQueryBuilder
+            .build(TemplateEntity, options, {
+                relations,
+                ctx,
+                channelId: ctx.channelId,
+            })
+            .getManyAndCount()
+            .then(([_items, totalItems]) => {
+                const items = _items.map(item => this.translator.translate(item, ctx));
+                return {
+                    items,
+                    totalItems,
+                };
+            });
+    }
+
+    findOne(
+        ctx: RequestContext,
+        id: ID,
+        relations?: RelationPaths<TemplateEntity>,
+    ): Promise<Translated<TemplateEntity> | null> {
+        return this.connection
+            .getRepository(ctx, TemplateEntity)
+            .findOne({
+                where: { id },
+                relations,
+            })
+            .then(entity => entity && this.translator.translate(entity, ctx));
+    }
+
+    async create(ctx: RequestContext, input: CreateEntityInput): Promise<Translated<TemplateEntity>> {
+        const newEntity = await this.translatableSaver.create({
+            ctx,
+            input,
+            entityType: TemplateEntity,
+            translationType: TemplateEntityTranslation,
+            beforeSave: async f => {
+                // Any pre-save logic can go here
+            },
+        });
+        // Ensure any custom field relations get saved
+        await this.customFieldRelationService.updateRelations(ctx, TemplateEntity, input, newEntity);
+        return assertFound(this.findOne(ctx, newEntity.id));
+    }
+
+    async update(ctx: RequestContext, input: UpdateEntityInput): Promise<Translated<TemplateEntity>> {
+        const updatedEntity = await this.translatableSaver.update({
+            ctx,
+            input,
+            entityType: TemplateEntity,
+            translationType: TemplateEntityTranslation,
+            beforeSave: async f => {
+                // Any pre-save logic can go here
+            },
+        });
+        // This is just here to stop the import being removed by the IDE
+        patchEntity(updatedEntity, {});
+        await this.customFieldRelationService.updateRelations(ctx, TemplateEntity, input, updatedEntity);
+        return assertFound(this.findOne(ctx, updatedEntity.id));
+    }
+
+    async delete(ctx: RequestContext, id: ID): Promise<DeletionResponse> {
+        const entity = await this.connection.getEntityOrThrow(ctx, TemplateEntity, id);
+        try {
+            await this.connection.getRepository(ctx, TemplateEntity).remove(entity);
+            return {
+                result: DeletionResult.DELETED,
+            };
+        } catch (e: any) {
+            return {
+                result: DeletionResult.NOT_DELETED,
+                message: e.toString(),
+            };
+        }
+    }
+}

+ 2 - 1
packages/cli/src/constants.ts

@@ -4,9 +4,10 @@ export const defaultManipulationSettings: Partial<ManipulationSettings> = {
     quoteKind: QuoteKind.Single,
     useTrailingCommas: true,
 };
-
+export const pascalCaseRegex = /^[A-Z][a-zA-Z0-9]*$/;
 export const AdminUiExtensionTypeName = 'AdminUiExtension';
 export const AdminUiAppConfigName = 'AdminUiAppConfig';
 export const Messages = {
     NoPluginsFound: `No plugins were found in this project. Create a plugin first by selecting "[Plugin] Add a new plugin"`,
+    NoEntitiesFound: `No entities were found in this plugin.`,
 };

+ 39 - 0
packages/cli/src/shared/entity-ref.ts

@@ -0,0 +1,39 @@
+import { ClassDeclaration, Node, SyntaxKind } from 'ts-morph';
+
+export class EntityRef {
+    constructor(public classDeclaration: ClassDeclaration) {}
+
+    get name(): string {
+        return this.classDeclaration.getName() as string;
+    }
+
+    isTranslatable() {
+        return this.classDeclaration.getImplements().some(i => i.getText() === 'Translatable');
+    }
+
+    isTranslation() {
+        return this.classDeclaration.getImplements().some(i => i.getText().includes('Translation<'));
+    }
+
+    hasCustomFields() {
+        return this.classDeclaration.getImplements().some(i => i.getText() === 'HasCustomFields');
+    }
+
+    getTranslationClass(): ClassDeclaration | undefined {
+        if (!this.isTranslatable()) {
+            return;
+        }
+        const translationsDecoratorArgs = this.classDeclaration
+            .getProperty('translations')
+            ?.getDecorator('OneToMany')
+            ?.getArguments();
+
+        if (translationsDecoratorArgs) {
+            const typeFn = translationsDecoratorArgs[0];
+            if (Node.isArrowFunction(typeFn)) {
+                const translationClass = typeFn.getReturnType().getSymbolOrThrow().getDeclarations()[0];
+                return translationClass as ClassDeclaration;
+            }
+        }
+    }
+}

+ 24 - 0
packages/cli/src/shared/shared-prompts.ts

@@ -4,6 +4,7 @@ import { ClassDeclaration, Project } from 'ts-morph';
 import { Messages } from '../constants';
 import { getPluginClasses, getTsMorphProject } from '../utilities/ast-utils';
 
+import { EntityRef } from './entity-ref';
 import { VendurePluginRef } from './vendure-plugin-ref';
 
 export async function analyzeProject(options: {
@@ -43,6 +44,29 @@ export async function selectPlugin(project: Project, cancelledMessage: string):
     return new VendurePluginRef(targetPlugin as ClassDeclaration);
 }
 
+export async function selectEntity(plugin: VendurePluginRef): Promise<EntityRef> {
+    const entities = plugin.getEntities();
+    if (entities.length === 0) {
+        cancel(Messages.NoEntitiesFound);
+        process.exit(0);
+    }
+    const targetEntity = await select({
+        message: 'Select an entity',
+        options: entities
+            .filter(e => !e.isTranslation())
+            .map(e => ({
+                value: e,
+                label: e.name,
+            })),
+        maxItems: 10,
+    });
+    if (isCancel(targetEntity)) {
+        cancel('Cancelled');
+        process.exit(0);
+    }
+    return targetEntity as EntityRef;
+}
+
 export async function selectMultiplePluginClasses(
     project: Project,
     cancelledMessage: string,

+ 51 - 1
packages/cli/src/shared/vendure-plugin-ref.ts

@@ -1,8 +1,9 @@
 import { ClassDeclaration, Node, SyntaxKind } from 'ts-morph';
-import { isLiteralExpression } from 'typescript';
 
 import { AdminUiExtensionTypeName } from '../constants';
 
+import { EntityRef } from './entity-ref';
+
 export class VendurePluginRef {
     constructor(public classDeclaration: ClassDeclaration) {}
 
@@ -46,6 +47,55 @@ export class VendurePluginRef {
         }
     }
 
+    addProvider(providerClassName: string) {
+        const pluginOptions = this.getMetadataOptions();
+        const providerProperty = pluginOptions.getProperty('providers');
+        if (providerProperty) {
+            const providersArray = providerProperty.getFirstChildByKind(SyntaxKind.ArrayLiteralExpression);
+            if (providersArray) {
+                providersArray.addElement(providerClassName);
+            }
+        } else {
+            pluginOptions.addPropertyAssignment({
+                name: 'providers',
+                initializer: `[${providerClassName}]`,
+            });
+        }
+    }
+
+    getEntities(): EntityRef[] {
+        const metadataOptions = this.getMetadataOptions();
+        const entitiesProperty = metadataOptions.getProperty('entities');
+        if (!entitiesProperty) {
+            return [];
+        }
+        const entitiesArray = entitiesProperty.getFirstChildByKind(SyntaxKind.ArrayLiteralExpression);
+        if (!entitiesArray) {
+            return [];
+        }
+        const entityNames = entitiesArray
+            .getElements()
+            .filter(Node.isIdentifier)
+            .map(e => e.getText());
+
+        const entitySourceFiles = this.getSourceFile()
+            .getImportDeclarations()
+            .filter(imp => {
+                for (const namedImport of imp.getNamedImports()) {
+                    if (entityNames.includes(namedImport.getName())) {
+                        return true;
+                    }
+                }
+            })
+            .map(imp => imp.getModuleSpecifierSourceFileOrThrow());
+        return entitySourceFiles
+            .map(sourceFile =>
+                sourceFile.getClasses().filter(c => c.getExtends()?.getText() === 'VendureEntity'),
+            )
+            .flat()
+            .map(classDeclaration => new EntityRef(classDeclaration));
+    }
+
     hasUiExtensions(): boolean {
         return !!this.classDeclaration
             .getStaticProperties()

+ 2 - 1
packages/cli/src/utilities/ast-utils.ts

@@ -97,8 +97,9 @@ export function getRelativeImportPath(locations: {
 
 export function createFile(project: Project, templatePath: string) {
     const template = fs.readFileSync(templatePath, 'utf-8');
+    const tempFilePath = path.join('/.vendure-cli-temp/', path.basename(templatePath));
     try {
-        return project.createSourceFile(path.join('/.vendure-cli-temp/', templatePath), template, {
+        return project.createSourceFile(path.join('/.vendure-cli-temp/', tempFilePath), template, {
             overwrite: true,
         });
     } catch (e: any) {