A Truly Generic Repository, Part 1

This is part of a series on using generics in C# to make code more resuable. Other articles in this series:

The Problem

The code here is based on the ASP.NET tutorial, Implementing the Repository and Unit of Work Patterns in an ASP.NET MVC Application. Probably like many new MVC developers, this was one of my first stops when learning the ropes, and I think it's probably done more disservice to the cause than any other article/tutorial out there. Sounds a bit harsh, I know, and mildly hypocritical since I opened by saying that I've based my code on it. Allow me to explain.

While there's some really good stuff there, a number of fatal flaws were also made. In short, the repository/unit of work patterns used in this way with Entity Framework is nothing short of idiotic. If you take a look at the implementations, the methods do little more than proxy to methods on DbContext. Then, the worst part is the unit of work itself, where you must write variations of the following bit of code for every single entity you want to work with.

private GenericRepository<Department> departmentRepository;

public GenericRepository<Department> DepartmentRepository
{
    get
    {
        if (this.departmentRepository == null)
        {
            this.departmentRepository = new GenericRepository<Department>(context);
        }
        return departmentRepository;
    }
}

That may not be too bad with the handful of entities they're working with in this sample project, but wait until you're dealing with 50, 100, 200 entities in a project. Suddenly, your UnitOfWork class is a rats' nest of brittle code that has to be maintained, and let's not even mention that virtually every bit of SOLID is violated in just this one seemingly innocent class.

So how is this different?

The following code diverges in a few major ways. First, and probably foremost, there's no UnitOfWork class. Instead, we're using a truly generic repository that can work with virtually any entity without having to new up multiple instances. Second, this implementation is provider agnostic. My example implementation is for Entity Framework, but this can be made to work with any store you have. Third, this solution is designed for dependency injection, which again, gives you the freedom to substitute implementations easily.

The Code

Without further ado...

Interfaces

First, we start off with two interfaces, IReadOnlyRepository and IRepository. As the names indicate, the former will contain only read methods while the latter will extend the first, adding write capability.

IReadOnlyRepository.cs

public interface IReadOnlyRepository
{
    IEnumerable<TEntity> GetAll<TEntity>(
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity;

    Task<IEnumerable<TEntity>> GetAllAsync<TEntity>(
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity;

    IEnumerable<TEntity> Get<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity;

    Task<IEnumerable<TEntity>> GetAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity;

    TEntity GetOne<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        string includeProperties = null)
        where TEntity : class, IEntity;

    Task<TEntity> GetOneAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        string includeProperties = null)
        where TEntity : class, IEntity;

    TEntity GetFirst<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null)
        where TEntity : class, IEntity;

    Task<TEntity> GetFirstAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null)
        where TEntity : class, IEntity;

    TEntity GetById<TEntity>(object id)
        where TEntity : class, IEntity;

    Task<TEntity> GetByIdAsync<TEntity>(object id)
        where TEntity : class, IEntity;

    int GetCount<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity;

    Task<int> GetCountAsync<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity;

    bool GetExists<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity;

    Task<bool> GetExistsAsync<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity;
}

IRepository.cs

public interface IRepository : IReadOnlyRepository
{
    void Create<TEntity>(TEntity entity, string createdBy = null)
        where TEntity : class, IEntity;

    void Update<TEntity>(TEntity entity, string modifiedBy = null)
        where TEntity : class, IEntity;

    void Delete<TEntity>(object id)
        where TEntity : class, IEntity;

    void Delete<TEntity>(TEntity entity)
        where TEntity : class, IEntity;

    void Save();

    Task SaveAsync();
}

Implementations

Now, we'll need classes that implement these interfaces. We're going to be creating specific implementations here based on Entity Framework, but you could conceivably have many similar implementations for things like a different ORM like NHibernate or even a Web API.

EntityFrameworkReadOnlyRepository.cs

public class EntityFrameworkReadOnlyRepository<TContext> : IReadOnlyRepository
    where TContext : DbContext
{
    protected readonly TContext context;

    public EntityFrameworkReadOnlyRepository(TContext context)
    {
        this.context = context;
    }

    protected virtual IQueryable<TEntity> GetQueryable<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity
    {
        includeProperties = includeProperties ?? string.Empty;
        IQueryable<TEntity> query = context.Set<TEntity>();

        if (filter != null)
        {
            query = query.Where(filter);
        }

        foreach (var includeProperty in includeProperties.Split
            (new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries))
        {
            query = query.Include(includeProperty);
        }

        if (orderBy != null)
        {
            query = orderBy(query);
        }

        if (skip.HasValue)
        {
            query = query.Skip(skip.Value);
        }

        if (take.HasValue)
        {
            query = query.Take(take.Value);
        }

        return query;
    }

    public virtual IEnumerable<TEntity> GetAll<TEntity>(
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(null, orderBy, includeProperties, skip, take).ToList();
    }

    public virtual async Task<IEnumerable<TEntity>> GetAllAsync<TEntity>(
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity
    {
        return await GetQueryable<TEntity>(null, orderBy, includeProperties, skip, take).ToListAsync();
    }

    public virtual IEnumerable<TEntity> Get<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter, orderBy, includeProperties, skip, take).ToList();
    }

    public virtual async Task<IEnumerable<TEntity>> GetAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null,
        int? skip = null,
        int? take = null)
        where TEntity : class, IEntity
    {
        return await GetQueryable<TEntity>(filter, orderBy, includeProperties, skip, take).ToListAsync();
    }

    public virtual TEntity GetOne<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        string includeProperties = "")
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter, null, includeProperties).SingleOrDefault();
    }

    public virtual async Task<TEntity> GetOneAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        string includeProperties = null)
        where TEntity : class, IEntity
    {
        return await GetQueryable<TEntity>(filter, null, includeProperties).SingleOrDefaultAsync();
    }

    public virtual TEntity GetFirst<TEntity>(
       Expression<Func<TEntity, bool>> filter = null,
       Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
       string includeProperties = "")
       where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter, orderBy, includeProperties).FirstOrDefault();
    }

    public virtual async Task<TEntity> GetFirstAsync<TEntity>(
        Expression<Func<TEntity, bool>> filter = null,
        Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
        string includeProperties = null)
        where TEntity : class, IEntity
    {
        return await GetQueryable<TEntity>(filter, orderBy, includeProperties).FirstOrDefaultAsync();
    }

    public virtual TEntity GetById<TEntity>(object id)
        where TEntity : class, IEntity
    {
        return context.Set<TEntity>().Find(id);
    }

    public virtual Task<TEntity> GetByIdAsync<TEntity>(object id)
        where TEntity : class, IEntity
    {
        return context.Set<TEntity>().FindAsync(id);
    }

    public virtual int GetCount<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter).Count();
    }

    public virtual Task<int> GetCountAsync<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter).CountAsync();
    }

    public virtual bool GetExists<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter).Any();
    }

    public virtual Task<bool> GetExistsAsync<TEntity>(Expression<Func<TEntity, bool>> filter = null)
        where TEntity : class, IEntity
    {
        return GetQueryable<TEntity>(filter).AnyAsync();
    }
}

EntityFrameworkRepository.cs

public class EntityFrameworkRepository<TContext> : EntityFrameworkReadOnlyRepository<TContext>, IRepository
    where TContext : DbContext
{
    public EntityFrameworkRepository(TContext context)
        : base(context)
    {
    }

    public virtual void Create<TEntity>(TEntity entity, string createdBy = null)
        where TEntity : class, IEntity
    {
        entity.CreatedDate = DateTime.UtcNow;
        entity.CreatedBy = createdBy;
        context.Set<TEntity>().Add(entity);
    }

    public virtual void Update<TEntity>(TEntity entity, string modifiedBy = null)
        where TEntity : class, IEntity
    {
        entity.ModifiedDate = DateTime.UtcNow;
        entity.ModifiedBy = modifiedBy;
        context.Set<TEntity>().Attach(entity);
        context.Entry(entity).State = EntityState.Modified;
    }

    public virtual void Delete<TEntity>(object id)
        where TEntity : class, IEntity
    {
        TEntity entity = context.Set<TEntity>().Find(id);
        Delete(entity);
    }

    public virtual void Delete<TEntity>(TEntity entity)
        where TEntity : class, IEntity
    {
        var dbSet = context.Set<TEntity>();
        if (context.Entry(entity).State == EntityState.Detached)
        {
            dbSet.Attach(entity);
        }
        dbSet.Remove(entity);
    }

    public virtual void Save()
    {
        try
        {
            context.SaveChanges();
        }
        catch (DbEntityValidationException e)
        {
            ThrowEnhancedValidationException(e);
        }
    }

    public virtual Task SaveAsync()
    {
        try
        {
            return context.SaveChangesAsync();
        }
        catch (DbEntityValidationException e)
        {
            ThrowEnhancedValidationException(e);
        }

        return Task.FromResult(0);
    }

    protected virtual void ThrowEnhancedValidationException(DbEntityValidationException e)
    {
        var errorMessages = e.EntityValidationErrors
                .SelectMany(x => x.ValidationErrors)
                .Select(x => x.ErrorMessage);

        var fullErrorMessage = string.Join("; ", errorMessages);
        var exceptionMessage = string.Concat(e.Message, " The validation errors are: ", fullErrorMessage);
        throw new DbEntityValidationException(exceptionMessage, e.EntityValidationErrors);
    }
}

Usage

Here, I'm providing sample code based on using Ninject as a DI container for two reasons: 1) it's a very nice dependency injection solution that combines both flexibility and simplicity and 2) it's the DI container I'm familiar with. If you have a different preference it should be easy enough to adapt the code appropriately.

FooController.cs

public class FooController : Controller
{
    protected readonly IRepository repo;

    public FooController(IRepository repo)
    {
        this.repo = repo;
    }

    ...
}

NinjectWebCommon.cs

private static void RegisterServices(IKernel kernel)
{
    kernel.Bind<ApplicationDbContext>()
          .ToSelf()
          .InRequestScope();
    kernel.Bind<IRepository>()
          .To<EntityFrameworkRepository<ApplicationDbContext>>()
          .InRequestScope();
}
comments powered by Disqus