在mybatis执行SQL语句之前进行拦击处理实例
比较适用于在分页时候进行拦截。对分页的SQL语句通过封装处理,处理成不同的分页sql。
实用性比较强。
importjava.sql.Connection; importjava.sql.PreparedStatement; importjava.sql.ResultSet; importjava.sql.SQLException; importjava.util.List; importjava.util.Properties; importorg.apache.ibatis.executor.parameter.ParameterHandler; importorg.apache.ibatis.executor.statement.RoutingStatementHandler; importorg.apache.ibatis.executor.statement.StatementHandler; importorg.apache.ibatis.mapping.BoundSql; importorg.apache.ibatis.mapping.MappedStatement; importorg.apache.ibatis.mapping.ParameterMapping; importorg.apache.ibatis.plugin.Interceptor; importorg.apache.ibatis.plugin.Intercepts; importorg.apache.ibatis.plugin.Invocation; importorg.apache.ibatis.plugin.Plugin; importorg.apache.ibatis.plugin.Signature; importorg.apache.ibatis.scripting.defaults.DefaultParameterHandler; importcom.yidao.utils.Page; importcom.yidao.utils.ReflectHelper; /** * *分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。 *利用拦截器实现Mybatis分页的原理: *要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句 *是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的 *prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用 *StatementHandler对象的prepare方法,即调用invocation.proceed()。 *对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设 *置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。 * */ @Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})}) publicclassPageInterceptorimplementsInterceptor{ privateStringdialect="";//数据库方言 privateStringpageSqlId="";//mapper.xml中需要拦截的ID(正则匹配) publicObjectintercept(Invocationinvocation)throwsThrowable{ //对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler, //BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler, //SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是 //处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个 //StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、 //PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。 //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候 //是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。 if(invocation.getTarget()instanceofRoutingStatementHandler){ RoutingStatementHandlerstatementHandler=(RoutingStatementHandler)invocation.getTarget(); StatementHandlerdelegate=(StatementHandler)ReflectHelper.getFieldValue(statementHandler,"delegate"); BoundSqlboundSql=delegate.getBoundSql(); Objectobj=boundSql.getParameterObject(); if(objinstanceofPage>){ Page>page=(Page>)obj; //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性 MappedStatementmappedStatement=(MappedStatement)ReflectHelper.getFieldValue(delegate,"mappedStatement"); //拦截到的prepare方法参数是一个Connection对象 Connectionconnection=(Connection)invocation.getArgs()[0]; //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句 Stringsql=boundSql.getSql(); //给当前的page参数对象设置总记录数 this.setTotalRecord(page, mappedStatement,connection); //获取分页Sql语句 StringpageSql=this.getPageSql(page,sql); //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句 ReflectHelper.setFieldValue(boundSql,"sql",pageSql); } } returninvocation.proceed(); } /** *给当前的参数对象page设置总记录数 * *@parampageMapper映射语句对应的参数对象 *@parammappedStatementMapper映射语句 *@paramconnection当前的数据库连接 */ privatevoidsetTotalRecord(Page>page, MappedStatementmappedStatement,Connectionconnection){ //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。 //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。 BoundSqlboundSql=mappedStatement.getBoundSql(page); //获取到我们自己写在Mapper映射语句中对应的Sql语句 Stringsql=boundSql.getSql(); //通过查询Sql语句获取到对应的计算总记录数的sql语句 StringcountSql=this.getCountSql(sql); //通过BoundSql获取对应的参数映射 ListparameterMappings=boundSql.getParameterMappings(); //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。 BoundSqlcountBoundSql=newBoundSql(mappedStatement.getConfiguration(),countSql,parameterMappings,page); //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象 ParameterHandlerparameterHandler=newDefaultParameterHandler(mappedStatement,page,countBoundSql); //通过connection建立一个countSql对应的PreparedStatement对象。 PreparedStatementpstmt=null; ResultSetrs=null; try{ pstmt=connection.prepareStatement(countSql); //通过parameterHandler给PreparedStatement对象设置参数 parameterHandler.setParameters(pstmt); //之后就是执行获取总记录数的Sql语句和获取结果了。 rs=pstmt.executeQuery(); if(rs.next()){ inttotalRecord=rs.getInt(1); //给当前的参数page对象设置总记录数 page.setTotalRecord(totalRecord); } }catch(SQLExceptione){ e.printStackTrace(); }finally{ try{ if(rs!=null) rs.close(); if(pstmt!=null) pstmt.close(); }catch(SQLExceptione){ e.printStackTrace(); } } } /** *根据原Sql语句获取对应的查询总记录数的Sql语句 *@paramsql *@return */ privateStringgetCountSql(Stringsql){ intindex=sql.indexOf("from"); return"selectcount(*)"+sql.substring(index); } /** *根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle *其它的数据库都没有进行分页 * *@parampage分页对象 *@paramsql原sql语句 *@return */ privateStringgetPageSql(Page>page,Stringsql){ StringBuffersqlBuffer=newStringBuffer(sql); if("mysql".equalsIgnoreCase(dialect)){ returngetMysqlPageSql(page,sqlBuffer); }elseif("oracle".equalsIgnoreCase(dialect)){ returngetOraclePageSql(page,sqlBuffer); } returnsqlBuffer.toString(); } /** *获取Mysql数据库的分页查询语句 *@parampage分页对象 *@paramsqlBuffer包含原sql语句的StringBuffer对象 *@returnMysql数据库分页语句 */ privateStringgetMysqlPageSql(Page>page,StringBuffersqlBuffer){ //计算第一条记录的位置,Mysql中记录的位置是从0开始的。 //System.out.println("page:"+page.getPage()+"-------"+page.getRows()); intoffset=(page.getPage()-1)*page.getRows(); sqlBuffer.append("limit").append(offset).append(",").append(page.getRows()); returnsqlBuffer.toString(); } /** *获取Oracle数据库的分页查询语句 *@parampage分页对象 *@paramsqlBuffer包含原sql语句的StringBuffer对象 *@returnOracle数据库的分页查询语句 */ privateStringgetOraclePageSql(Page>page,StringBuffersqlBuffer){ //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的 intoffset=(page.getPage()-1)*page.getRows()+1; sqlBuffer.insert(0,"selectu.*,rownumrfrom(").append(")uwhererownum<").append(offset+page.getRows()); sqlBuffer.insert(0,"select*from(").append(")wherer>=").append(offset); //上面的Sql语句拼接之后大概是这个样子: //select*from(selectu.*,rownumrfrom(select*fromt_user)uwhererownum<31)wherer>=16 returnsqlBuffer.toString(); } /** *拦截器对应的封装原始对象的方法 */ publicObjectplugin(Objectarg0){ //TODOAuto-generatedmethodstub if(arg0instanceofStatementHandler){ returnPlugin.wrap(arg0,this); }else{ returnarg0; } } /** *设置注册拦截器时设定的属性 */ publicvoidsetProperties(Propertiesp){ } publicStringgetDialect(){ returndialect; } publicvoidsetDialect(Stringdialect){ this.dialect=dialect; } publicStringgetPageSqlId(){ returnpageSqlId; } publicvoidsetPageSqlId(StringpageSqlId){ this.pageSqlId=pageSqlId; } }
xml配置:
Page类
packagecom.yidao.utils; /**自己看看,需要什么字段加什么字段吧*/ publicclassPage{ privateIntegerrows; privateIntegerpage=1; privateIntegertotalRecord; publicIntegergetRows(){ returnrows; } publicvoidsetRows(Integerrows){ this.rows=rows; } publicIntegergetPage(){ returnpage; } publicvoidsetPage(Integerpage){ this.page=page; } publicIntegergetTotalRecord(){ returntotalRecord; } publicvoidsetTotalRecord(IntegertotalRecord){ this.totalRecord=totalRecord; } }
ReflectHelper类
packagecom.yidao.utils; importjava.lang.reflect.Field; importorg.apache.commons.lang3.reflect.FieldUtils; publicclassReflectHelper{ publicstaticObjectgetFieldValue(Objectobj,StringfieldName){ if(obj==null){ returnnull; } FieldtargetField=getTargetField(obj.getClass(),fieldName); try{ returnFieldUtils.readField(targetField,obj,true); }catch(IllegalAccessExceptione){ e.printStackTrace(); } returnnull; } publicstaticFieldgetTargetField(Class>targetClass,StringfieldName){ Fieldfield=null; try{ if(targetClass==null){ returnfield; } if(Object.class.equals(targetClass)){ returnfield; } field=FieldUtils.getDeclaredField(targetClass,fieldName,true); if(field==null){ field=getTargetField(targetClass.getSuperclass(),fieldName); } }catch(Exceptione){ } returnfield; } publicstaticvoidsetFieldValue(Objectobj,StringfieldName,Objectvalue){ if(null==obj){return;} FieldtargetField=getTargetField(obj.getClass(),fieldName); try{ FieldUtils.writeField(targetField,obj,value); }catch(IllegalAccessExceptione){ e.printStackTrace(); } } }
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。