Knife4j/Swagger2 忽略实体类或 List 属性,含接收 MyBatis Plus 分页 Page 参数、返回 Page 对象写法

x33g5p2x  于2022-05-13 转载在 其他  
字(15.5k)|赞(0)|评价(0)|浏览(386)

分页接口使用继承了 Page 的对象接收参数:public class EntityPageQO extends Page<Entity>,因为不仅要接收分页参数,还要能把查询对象传到 entityService.page(Page, Wrapper) 方法里。

可能你只需要 Swagger 页面显示 currentsize,但是页面上却展示了一堆入参,可以实现一个自己的 Page 查询对象,也不需要 Getter/Setter,Swagger 可以直接识别出来,再使用 @ApiModelProperty(hidden = true) 来隐藏不需要的入参:

import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.irms.config.swagger.MyApiIgnore;
import io.swagger.annotations.ApiModelProperty;

import java.util.List;

public class PageQO<T> extends Page<T> {
  @ApiModelProperty(hidden = true)
  protected List<T> records;
  @ApiModelProperty(hidden = true)
  protected long total;
  @ApiModelProperty("每页数量")
  protected long size;
  @ApiModelProperty("当前页码")
  protected long current;
  @ApiModelProperty(hidden = true)
  protected List<OrderItem> orders;
  @ApiModelProperty(hidden = true)
  protected boolean optimizeCountSql;
  @ApiModelProperty(hidden = true)
  protected boolean searchCount;
  @ApiModelProperty(hidden = true)
  protected boolean optimizeJoinOfCountSql;
  @ApiModelProperty(hidden = true)
  protected String countId;
  @ApiModelProperty(hidden = true)
  protected Long maxLimit;
  @ApiModelProperty(hidden = true)
  protected Long pages;
}

此时发现其他属性都正常隐藏了,但是 recordsorders 无效,查询一番之后发现 Swagger 隐藏不了非基础类型的入参。解决方案是覆盖源码,使用自定义注解,Swagger 2.8.0 和 2.9.2 的写法看最后的参考来源,我这里的版本是 2.10.5,也是 Knife4j 2.0.9 采用的版本。

MyModelAttributeParameterExpander

import cn.hutool.core.util.ReflectUtil;
import com.fasterxml.classmate.ResolvedType;
import com.fasterxml.classmate.members.ResolvedField;
import com.fasterxml.classmate.members.ResolvedMember;
import com.fasterxml.classmate.members.ResolvedMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;
import springfox.documentation.builders.ParameterBuilder;
import springfox.documentation.schema.Maps;
import springfox.documentation.schema.Types;
import springfox.documentation.schema.property.bean.AccessorsProvider;
import springfox.documentation.schema.property.field.FieldProvider;
import springfox.documentation.service.Parameter;
import springfox.documentation.spi.schema.AlternateTypeProvider;
import springfox.documentation.spi.schema.EnumTypeDeterminer;
import springfox.documentation.spi.service.contexts.ParameterExpansionContext;
import springfox.documentation.spring.web.plugins.DocumentationPluginsManager;
import springfox.documentation.spring.web.readers.parameter.ExpansionContext;
import springfox.documentation.spring.web.readers.parameter.ModelAttributeField;
import springfox.documentation.spring.web.readers.parameter.ModelAttributeParameterExpander;
import springfox.documentation.spring.web.readers.parameter.ModelAttributeParameterMetadataAccessor;

import java.beans.BeanInfo;
import java.beans.IntrospectionException;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static com.irms.config.swagger.ParameterTypeDeterminer.determineScalarParameterType;
import static java.util.Collections.emptySet;
import static java.util.Optional.ofNullable;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.springframework.util.StringUtils.isEmpty;
import static springfox.documentation.schema.Collections.collectionElementType;
import static springfox.documentation.schema.Collections.isContainerType;
import static springfox.documentation.schema.Types.isVoid;
import static springfox.documentation.schema.Types.typeNameFor;

/**
 * 用于覆盖 ModelAttributeParameterExpander,大部分为源码,切记升级 swagger 版本后需重新修改源码
 */
@Primary
@Component
public class MyModelAttributeParameterExpander extends ModelAttributeParameterExpander {

  private static final Logger LOG = LoggerFactory.getLogger(ModelAttributeParameterExpander.class);
  private final FieldProvider fields;
  private final AccessorsProvider accessors;
  private final EnumTypeDeterminer enumTypeDeterminer;

  @Autowired
  private DocumentationPluginsManager pluginsManager;

  @Autowired
  public MyModelAttributeParameterExpander(
    FieldProvider fields,
    AccessorsProvider accessors,
    EnumTypeDeterminer enumTypeDeterminer) {
    // 加了一行 super
    super(fields, accessors, enumTypeDeterminer);
    this.fields = fields;
    this.accessors = accessors;
    this.enumTypeDeterminer = enumTypeDeterminer;
  }

  public List<Parameter> expand(ExpansionContext context) {
    List<Parameter> parameters = new ArrayList<>();
    Set<PropertyDescriptor> propertyDescriptors = propertyDescriptors(context.getParamType().getErasedType());
    Map<Method, PropertyDescriptor> propertyLookupByGetter
      = propertyDescriptorsByMethod(context.getParamType().getErasedType(), propertyDescriptors);
    Iterable<ResolvedMethod> getters = accessors.in(context.getParamType()).stream()
      .filter(onlyValidGetters(propertyLookupByGetter.keySet())).collect(toList());

    Map<String, ResolvedField> fieldsByName =
      StreamSupport.stream(this.fields.in(context.getParamType()).spliterator(), false)
        .collect(toMap((ResolvedMember::getName), identity()));

    LOG.debug("Expanding parameter type: {}", context.getParamType());
    final AlternateTypeProvider alternateTypeProvider = context.getDocumentationContext().getAlternateTypeProvider();

    List<ModelAttributeField> attributes =
      allModelAttributes(
        propertyLookupByGetter,
        getters,
        fieldsByName,
        alternateTypeProvider);

    attributes.stream()
      .filter(simpleType().negate())
      .filter(recursiveType(context).negate())
      .forEach((each) -> {
        LOG.debug("Attempting to expand expandable property: {}", each.getName());
        parameters.addAll(
          expand(
            context.childContext(
              nestedParentName(context.getParentName(), each),
              each.getFieldType(),
              context.getOperationContext())));
      });

    Stream<ModelAttributeField> collectionTypes = attributes.stream()
      .filter(isCollection().and(recursiveCollectionItemType(context.getParamType()).negate()));
    collectionTypes.forEachOrdered((each) -> {
      LOG.debug("Attempting to expand collection/array field: {}", each.getName());

      ResolvedType itemType = collectionElementType(each.getFieldType());
      if (Types.isBaseType(itemType) || enumTypeDeterminer.isEnum(itemType.getErasedType())) {
        parameters.add(simpleFields(context.getParentName(), context, each));
      } else {
        ExpansionContext childContext = context.childContext(
          nestedParentName(context.getParentName(), each),
          itemType,
          context.getOperationContext());
        if (!context.hasSeenType(itemType)) {
          parameters.addAll(expand(childContext));
        }
      }
    });

    Stream<ModelAttributeField> simpleFields = attributes.stream().filter(simpleType());
    simpleFields.forEach((each) -> {
      parameters.add(simpleFields(context.getParentName(), context, each));
    });
    return parameters.stream()
      .filter(((Predicate<Parameter>) Parameter::isHidden).negate())
      .filter(voidParameters().negate())
      .collect(toList());
  }

  private List<ModelAttributeField> allModelAttributes(
    Map<Method, PropertyDescriptor> propertyLookupByGetter,
    Iterable<ResolvedMethod> getters,
    Map<String, ResolvedField> fieldsByName,
    AlternateTypeProvider alternateTypeProvider) {

    Stream<ModelAttributeField> modelAttributesFromGetters = StreamSupport.stream(getters.spliterator(), false)
      .map(toModelAttributeField(fieldsByName, propertyLookupByGetter, alternateTypeProvider));

    Stream<ModelAttributeField> modelAttributesFromFields = fieldsByName.values().stream()
      .filter(ResolvedMember::isPublic)
      .map(toModelAttributeField(alternateTypeProvider));

    return Stream.concat(
        modelAttributesFromFields,
        modelAttributesFromGetters)
      .collect(toList());
  }

  private Function<ResolvedField, ModelAttributeField> toModelAttributeField(
    final AlternateTypeProvider alternateTypeProvider) {

    return input -> new ModelAttributeField(
      alternateTypeProvider.alternateFor(input.getType()),
      input.getName(),
      input,
      input);
  }

  private Predicate<Parameter> voidParameters() {
    return input -> isVoid(input.getType().orElse(null));
  }

  private Predicate<ModelAttributeField> recursiveCollectionItemType(final ResolvedType paramType) {
    return input -> Objects.equals(collectionElementType(input.getFieldType()), paramType);
  }

  private Parameter simpleFields(
    String parentName,
    ExpansionContext context,
    ModelAttributeField each) {
    LOG.debug("Attempting to expand field: {}", each);
    String dataTypeName = ofNullable(typeNameFor(each.getFieldType().getErasedType()))
      .orElse(each.getFieldType().getErasedType().getSimpleName());
    LOG.debug("Building parameter for field: {}, with type: ", each, each.getFieldType());
    ParameterExpansionContext parameterExpansionContext = new ParameterExpansionContext(
      dataTypeName,
      parentName,
      // 'springfox.documentation.spring.web.readers.parameter.ParameterTypeDeterminer' 在 'springfox.documentation.spring.web.readers.parameter' 中不为 public。无法从外部软件包访问
      determineScalarParameterType(
        context.getOperationContext().consumes(),
        context.getOperationContext().httpMethod()),
      new ModelAttributeParameterMetadataAccessor(
        each.annotatedElements(),
        each.getFieldType(),
        each.getName()),
      context.getDocumentationContext().getDocumentationType(),
      new ParameterBuilder());
    return pluginsManager.expandParameter(parameterExpansionContext);
  }

  private Predicate<ModelAttributeField> recursiveType(final ExpansionContext context) {
    return input -> context.hasSeenType(input.getFieldType());
  }

  private Predicate<ModelAttributeField> simpleType() {
    return isCollection().negate().and(isMap().negate())
      .and(
        belongsToJavaPackage()
          .or(isBaseType())
          .or(isEnum()));
  }

  private Predicate<ModelAttributeField> isCollection() {
    return input -> isContainerType(input.getFieldType());
  }

  private Predicate<ModelAttributeField> isMap() {
    return input -> Maps.isMapType(input.getFieldType());
  }

  private Predicate<ModelAttributeField> isEnum() {
    return input -> enumTypeDeterminer.isEnum(input.getFieldType().getErasedType());
  }

  private Predicate<ModelAttributeField> belongsToJavaPackage() {
    return input -> ClassUtils.getPackageName(input.getFieldType().getErasedType()).startsWith("java.lang");
  }

  private Predicate<ModelAttributeField> isBaseType() {
    return input -> Types.isBaseType(input.getFieldType())
      || input.getFieldType().isPrimitive();
  }

  private Function<ResolvedMethod, ModelAttributeField> toModelAttributeField(
    final Map<String, ResolvedField> fieldsByName,
    final Map<Method, PropertyDescriptor> propertyLookupByGetter,
    final AlternateTypeProvider alternateTypeProvider) {
    return input -> {
      String name = propertyLookupByGetter.get(input.getRawMember()).getName();
      return new ModelAttributeField(
        fieldType(alternateTypeProvider, input),
        name,
        input,
        fieldsByName.get(name));
    };
  }

  private Predicate<ResolvedMethod> onlyValidGetters(final Set<Method> methods) {
    return input -> methods.contains(input.getRawMember());
  }

  private String nestedParentName(String parentName, ModelAttributeField attribute) {
    String name = attribute.getName();
    ResolvedType fieldType = attribute.getFieldType();
    if (isContainerType(fieldType) && !Types.isBaseType(collectionElementType(fieldType))) {
      name += "[0]";
    }

    if (isEmpty(parentName)) {
      return name;
    }
    return String.format("%s.%s", parentName, name);
  }

  private ResolvedType fieldType(AlternateTypeProvider alternateTypeProvider, ResolvedMethod method) {
    return alternateTypeProvider.alternateFor(method.getType());
  }

  private Set<PropertyDescriptor> propertyDescriptors(final Class<?> clazz) {
    try {
      // 跳过有 @MyApiIgnore 注解的属性
      List<PropertyDescriptor> propertyDescriptorList = new ArrayList<>();
      for (PropertyDescriptor descriptor : getBeanInfo(clazz).getPropertyDescriptors()) {
        // 此处获取属性可以自己写,主要是要能获取父类属性
        Field field = ReflectUtil.getField(clazz, descriptor.getName());;
        if (field != null && field.isAnnotationPresent(MyApiIgnore.class)) {
          continue;
        }
        propertyDescriptorList.add(descriptor);
      }
      return new HashSet<>(propertyDescriptorList);
    } catch (IntrospectionException e) {
      LOG.warn(String.format("Failed to get bean properties on (%s)", clazz), e);
    }
    return emptySet();
  }

  private Map<Method, PropertyDescriptor> propertyDescriptorsByMethod(
    final Class<?> clazz,
    Set<PropertyDescriptor> propertyDescriptors) {
    return propertyDescriptors.stream()
      .filter(input -> input.getReadMethod() != null
        && !clazz.isAssignableFrom(Collection.class)
        && !"isEmpty".equals(input.getReadMethod().getName()))
      .collect(toMap(PropertyDescriptor::getReadMethod, identity()));

  }

  BeanInfo getBeanInfo(Class<?> clazz) throws IntrospectionException {
    return Introspector.getBeanInfo(clazz);
  }

  public DocumentationPluginsManager getPluginsManager() {
    return pluginsManager;
  }

  public void setPluginsManager(DocumentationPluginsManager pluginsManager) {
    this.pluginsManager = pluginsManager;
  }
}

ParameterTypeDeterminer 和源码一模一样就行:

import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;

import java.util.Set;

class ParameterTypeDeterminer {
  private ParameterTypeDeterminer() {
    throw new UnsupportedOperationException();
  }

  public static String determineScalarParameterType(Set<? extends MediaType> consumes, HttpMethod method) {
    String parameterType = "query";

    if (consumes.contains(MediaType.APPLICATION_FORM_URLENCODED)
        && method == HttpMethod.POST) {
      parameterType = "form";
    } else if (consumes.contains(MediaType.MULTIPART_FORM_DATA)
        && method == HttpMethod.POST) {
      parameterType = "formData";
    }

    return parameterType;
  }
}

MyApiIgnore

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyApiIgnore {
}

最后再修改下 PageQO,给需要忽略的属性加上 @MyApiIgnore 注解即可,当然之前的 @ApiModelProperty(hidden = true) 也可以替换成这个注解。

public class PageQO<T> extends Page<T> {
  @MyApiIgnore
  protected List<T> records;
  @ApiModelProperty(hidden = true)
  protected long total;
  @ApiModelProperty("每页数量")
  protected long size;
  @ApiModelProperty("当前页码")
  protected long current;
  @MyApiIgnore
  protected List<OrderItem> orders;
  @ApiModelProperty(hidden = true)
  protected boolean optimizeCountSql;
  @ApiModelProperty(hidden = true)
  protected boolean searchCount;
  @ApiModelProperty(hidden = true)
  protected boolean optimizeJoinOfCountSql;
  @ApiModelProperty(hidden = true)
  protected String countId;
  @ApiModelProperty(hidden = true)
  protected Long maxLimit;
  @ApiModelProperty(hidden = true)
  protected Long pages;
}

至于返回对象中含 Page 对象时,同最开始的理由,用的也是 entityService.page(Page, Wrapper) 方法。想要在 Swagger 中显示也很简单,定义一个 PageVO,在返回方法中重新 set 即可。

PageVO

import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;

import java.util.List;

@EqualsAndHashCode(callSuper = false)
@Data
@ApiModel(value = "分页对象")
public class PageVO<T> {

  @ApiModelProperty("分页数据")
  private List<T> records;
  @ApiModelProperty("总记录数")
  private long total;
  @ApiModelProperty("总页数")
  private long size;
  @ApiModelProperty("当前页数")
  private long current;

  public PageVO(Page<T> page) {
    this.records = page.getRecords();
    this.total = page.getTotal();
    this.size = page.getSize();
    this.current = page.getCurrent();
  }
}

BaseController/Result/R

// ……
  protected <T> R<PageVO<T>> ok(Page<T> data) {
    R<PageVO<T>> r = new R<>();
    r.setCode(200);
    r.setData(new PageVO<>(data));
    return r;
  }
  // ……

EntityController

// ……
  @ApiOperation("XXX 列表")
  @GetMapping
  public R<PageVO<Entity>> list(@Validated EntityPageQO query) {
    return ok(entityService.page(query, new LambdaQueryWrapper<Entity>()));
  }
  // ……

最终效果:

参考来源

创作打卡挑战赛

赢取流量/现金/CSDN周边激励大奖

相关文章