diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index 922b2901cc..fd0a211b19 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -179,6 +179,39 @@ python llmexport.py \ cmake .. -DMNN_BUILD_CONVERTER=ON && make -j16 ``` 编译完成后 `build/` 目录下会生成 `MNNConvert` 可执行文件,`llmexport.py` 默认会在 `../../../build/` 下查找该工具;也可以通过 `--mnnconvert` 选项显式指定 MNNConvert 路径。若未提供本地 MNNConvert,脚本会回退到 pymnn(需先安装 `pip install MNN`)。此方案目前支持导出4bit和8bit模型。 +- 导出 segment 形式的 MNN LLM,使用 `--export mnn --segment`。该模式从 safetensors 权重和 workflow JSON 直接生成多个 MNN 子图,跳过 ONNX 中间文件,适合在 Metal 等后端上复用 decoder、logit、embedding 等 segment 模型。默认会在 `resource/*.json` 中查找匹配的 workflow,也可以通过 `--workflow /path/to/workflow.json` 显式指定。 + + ```bash + cd transformers/llm/export + python3 llmexport.py \ + --path /path/to/Qwen3-0.6B \ + --export mnn \ + --segment \ + --dst_path ./model + ``` + + segment 导出目录包含: + + ```text + model/ + ├── config.json # llm_demo 入口配置,包含 "mnn_llm_version": "segment" + ├── llm_config.json # 模型结构和模板配置 + ├── tokenizer.mtok + ├── embed.mnn + ├── decoder.mnn + ├── decoder.mnn.weight + ├── logit.mnn + ├── logit.mnn.weight + └── logit_topkv_1.mnn + ``` + + 运行 segment 模型时需要使用生成的 `config.json`: + + ```bash + ./llm_demo transformers/llm/export/model/config.json /path/to/prompt.txt + ``` + + C++ 运行时需启用 `MNN_BUILD_LLM=ON`,并打开 `MNN_LLM_SUPPORT_SEGMENT`(默认开启)。segment 路径当前仅支持 `--export mnn`,不支持 `--export onnx`。 - 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型: ``` @@ -197,7 +230,7 @@ usage: llmexport.py [-h] --path PATH [--type TYPE] [--tokenizer_path TOKENIZER_P [--gptq_path GPTQ_PATH] [--dst_path DST_PATH] [--verbose] [--test TEST] [--export EXPORT] [--onnx_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT] [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--omni] [--sym] [--seperate_embed] - [--lora_split] + [--lora_split] [--segment] [--workflow WORKFLOW] llm_exporter @@ -219,6 +252,8 @@ optional arguments: --verbose Whether or not to print verbose. --test TEST test model inference with query `TEST`. --export EXPORT export model to an onnx/mnn model. + --segment export segment MNN LLM from safetensors workflow directly, without ONNX export. + --workflow WORKFLOW workflow json for --segment safetensors conversion. If absent, search resource/*.json. --onnx_slim Whether or not to use onnx-slim. --quant_bit QUANT_BIT mnn quant bit, 4 or 8, default is 4. @@ -1158,4 +1193,4 @@ adb push model /data/local/tmp/MNN/model ``` cd ${MNN_ROOT} project/android/testCommon.sh ./llm_demo model/config_mlda.json -``` \ No newline at end of file +``` diff --git a/express/MathOp.cpp b/express/MathOp.cpp index ac04f27ee0..fa662d9127 100644 --- a/express/MathOp.cpp +++ b/express/MathOp.cpp @@ -584,6 +584,10 @@ VARP _Multiply(VARP x, VARP y) { return _Binary(x, y, BinaryOpOperation_MUL); } +VARP _MulSilu(VARP x, VARP y) { + return _Binary(x, y, BinaryOpOperation_MUL_SILU); +} + /*Computes Python style division of x by y. Args: x: A variable. Must be one of the following types: diff --git a/express/module/StaticModule.cpp b/express/module/StaticModule.cpp index bbc9f9b394..1eae212ef6 100644 --- a/express/module/StaticModule.cpp +++ b/express/module/StaticModule.cpp @@ -22,35 +22,81 @@ namespace MNN { namespace Express { -static const StaticModule* getStaticModule(const Module* m) { - if (m->type() == "StaticModule") { - return static_cast(m); +using ExecutionCacheKey = std::tuple; +using ExecutionCacheMap = std::map>; + +static ExecutionCacheKey makeExecutionCacheKey(const Op* op) { + return std::make_tuple(op->name()->str(), static_cast(op->type()), static_cast(op->main_type())); +} + +static bool supportPrearrangeClone(const Op* op) { + return op->main_type() == OpParameter_Convolution2D || op->main_type() == OpParameter_LayerNorm || + op->type() == OpType_Attention || op->type() == OpType_Scale || op->type() == OpType_RoPE || + op->type() == OpType_GatherV2; +} + +static void collectStaticModuleExecutions(const StaticModule* module, ExecutionCacheMap& executeMap) { + auto session = module->getSession(); + std::vector opCaches = session->getPipelineInfo(0).second; + for (auto& opCache : opCaches) { + const auto& exeCache = opCache.executionCache; + for (const auto& exeItem : exeCache) { + if (supportPrearrangeClone(exeItem.first) && exeItem.first->name()) { + executeMap.insert(std::make_pair(makeExecutionCacheKey(exeItem.first), exeItem.second)); + } + } + } +} + +static void collectBaseExecutions(const Module* base, ExecutionCacheMap& executeMap) { + if (base == nullptr) { + return; + } + if (base->type() == "StaticModule") { + collectStaticModuleExecutions(static_cast(base), executeMap); + return; } - if (m->getChildren().empty()) { - return nullptr; + for (const auto& child : base->getChildren()) { + collectBaseExecutions(child.get(), executeMap); } - return getStaticModule(m->getChildren()[0].get()); } -static std::vector> preRearrangeWeights( // NOLINT - Schedule::ScheduleInfo& scheduleInfo, Backend* firstbackend, Backend* backupBackend, const Module* base = nullptr) { - std::map> base_executions; - if (base != nullptr) { - // has base module - auto static_module = getStaticModule(base); - if (static_module) { - auto session = static_module->getSession(); - std::vector op_caches = session->getPipelineInfo(0).second; - for (auto& op_cache : op_caches) { - const auto& exe_cache = op_cache.executionCache; - for (const auto& exe_item : exe_cache) { - if (exe_item.first->name()) { - base_executions.insert(std::make_pair(exe_item.first->name()->str(), exe_item.second)); - } - } +static bool cloneBaseExecution(std::shared_ptr& exe, const ExecutionCacheMap& baseExecutions, const Op* op, + Backend* backend, Backend* backupBackend) { + if (baseExecutions.empty() || !op->name()) { + return false; + } + auto iter = baseExecutions.find(makeExecutionCacheKey(op)); + if (iter == baseExecutions.end() && op->type() == OpType_GatherV2) { + for (auto candidate = baseExecutions.begin(); candidate != baseExecutions.end(); ++candidate) { + if (std::get<0>(candidate->first) == op->name()->str() && + std::get<2>(candidate->first) == OpParameter_Convolution2D) { + iter = candidate; + break; } } } + if (iter == baseExecutions.end()) { + return false; + } + Execution* copyExecution = nullptr; + auto baseExe = iter->second.get(); + baseExe->onClone(backend, op, ©Execution); + if (copyExecution == nullptr) { + baseExe->onClone(backupBackend, op, ©Execution); + } + std::unique_ptr cloned(copyExecution); + if (cloned == nullptr || !cloned->onClone(nullptr, op, nullptr)) { + return false; + } + exe.reset(cloned.release()); + return true; +} + +static std::vector> preRearrangeWeights( // NOLINT + Schedule::ScheduleInfo& scheduleInfo, Backend* firstbackend, Backend* backupBackend, const Module::Config& config) { + ExecutionCacheMap base_executions; + collectBaseExecutions(config.base, base_executions); FileLoader loader(scheduleInfo.externalWeightPath.c_str()); auto&& pipelineInfo = scheduleInfo.pipelineInfo[0].second; std::vector> splitOps(pipelineInfo.size()); @@ -58,32 +104,22 @@ static std::vector> preRearrangeWeights( // NOLIN std::map> kvAttentionRegistry; for (int i = 0; i < pipelineInfo.size(); ++i) { auto& info = pipelineInfo[i]; - auto op = pipelineInfo[i].op; + auto op = pipelineInfo[i].op; std::unique_ptr op_table(op->UnPack()); std::shared_ptr exe; Backend* backend = firstbackend; if (info.type == Schedule::CONSTANT) { backend = backupBackend; } + if (op->type() == MNN::OpType_GatherV2) { + cloneBaseExecution(exe, base_executions, op, backend, backupBackend); + } switch (op->type()) { case MNN::OpType_DepthwiseConvInt8: case MNN::OpType_ConvInt8: case MNN::OpType_ConvolutionDepthwise: case MNN::OpType_Convolution: { - if (!base_executions.empty() && op->name()) { - auto iter = base_executions.find(op->name()->str()); - if (iter != base_executions.end()) { - auto base_exe = iter->second.get(); - Execution* copyExecution = nullptr; - base_exe->onClone(backend, op, ©Execution); - if (copyExecution == nullptr) { - base_exe->onClone(backupBackend, op, ©Execution); - } - if (copyExecution != nullptr && copyExecution->onClone(nullptr, op, nullptr)) { - exe.reset(copyExecution); - } - } - } + cloneBaseExecution(exe, base_executions, op, backend, backupBackend); if (exe == nullptr) { DataType type = DataType_DT_FLOAT; auto conv2d = op->main_as_Convolution2D(); @@ -96,21 +132,23 @@ static std::vector> preRearrangeWeights( // NOLIN int ow = 2, oh = 2; int iw = (common->kernelX() - 1) * common->dilateX() + common->strideX() * (ow - 1) + 1; int ih = (common->kernelY() - 1) * common->dilateY() + common->strideY() * (oh - 1) + 1; - TensorUtils::getDescribe(tempInput)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;; + TensorUtils::getDescribe(tempInput)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4; tempInput->setLength(0, 1); tempInput->setLength(1, conv2d->common()->inputCount()); tempInput->setLength(2, ih); tempInput->setLength(3, iw); - TensorUtils::getDescribe(tempOutput)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;; + TensorUtils::getDescribe(tempOutput)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4; tempOutput->setLength(0, 1); tempOutput->setLength(1, conv2d->common()->outputCount()); tempOutput->setLength(2, oh); tempOutput->setLength(3, ow); } std::shared_ptr tmpstorage; - exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, &loader, tmpstorage)); + exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, + &loader, tmpstorage)); if (exe.get() == nullptr) { - exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, op, &loader, tmpstorage)); + exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, + op, &loader, tmpstorage)); } if (nullptr == exe) { break; @@ -136,8 +174,7 @@ static std::vector> preRearrangeWeights( // NOLIN break; } case MNN::OpType_Attention: - case MNN::OpType_LinearAttention: - { + case MNN::OpType_LinearAttention: { // KV Cache sharing: clone from source Attention's execution instead of creating new if (op->type() == OpType_Attention && op->main_type() == OpParameter_AttentionParam) { auto param = op->main_as_AttentionParam(); @@ -176,26 +213,17 @@ static std::vector> preRearrangeWeights( // NOLIN } break; } - case MNN::OpType_LayerNorm: { - if (!base_executions.empty() && op->name()) { - auto iter = base_executions.find(op->name()->str()); - if (iter != base_executions.end()) { - auto base_exe = iter->second.get(); - Execution* copyExecution = nullptr; - base_exe->onClone(backend, op, ©Execution); - if (copyExecution == nullptr) { - base_exe->onClone(backupBackend, op, ©Execution); - } - if (copyExecution != nullptr && copyExecution->onClone(nullptr, op, nullptr)) { - exe.reset(copyExecution); - } - } - } + case MNN::OpType_LayerNorm: + case MNN::OpType_Scale: + case MNN::OpType_RoPE: { + cloneBaseExecution(exe, base_executions, op, backend, backupBackend); if (exe == nullptr) { std::shared_ptr tmpstorage; - exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, &loader, tmpstorage)); + exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, + &loader, tmpstorage)); if (exe.get() == nullptr) { - exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, op, &loader, tmpstorage)); + exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, + op, &loader, tmpstorage)); } if (nullptr == exe) { break; @@ -279,7 +307,8 @@ void StaticModule::resetInputOutputs() { if (des->usage != Tensor::InsideDescribe::CONSTANT && des->usage != Tensor::InsideDescribe::TRAINABLE) { des->usage = Tensor::InsideDescribe::INPUT; } - pipelineInfo.first.inputTensorCopyCache.insert(std::make_pair(mInputTensors[i], std::make_tuple(nullptr, nullptr, true, true))); + pipelineInfo.first.inputTensorCopyCache.insert( + std::make_pair(mInputTensors[i], std::make_tuple(nullptr, nullptr, true, true))); mPrevInputTensor[i].first = nullptr; mPrevInputTensor[i].second = MNN_FORWARD_CPU; } @@ -317,14 +346,14 @@ void StaticModule::resetInputOutputs() { mOutputTensors[i] = mSession->getTensor(mResource->mOutputs[mResource->mOutputFromTensor[i]]); auto des = TensorUtils::getDescribe(mOutputTensors[i]); if (des->usage == Tensor::InsideDescribe::CONSTANT && des->isMutable) { - des->useCount ++; + des->useCount++; } } for (auto& info : infos.second) { if (info.type != Schedule::Type::CONSTANT) { continue; } - for (int v=0; vusage == Tensor::InsideDescribe::CONSTANT && des->isMutable) { des->useCount--; @@ -336,15 +365,10 @@ void StaticModule::resetInputOutputs() { } } -StaticModule::StaticModule(std::vector inputs, - std::vector outputs, - std::vector>&& buffer, - Schedule::ScheduleInfo&& scheduleInfo, - std::shared_ptr sharedConst, - Session::ModeGroup&& mode, - std::shared_ptr rtm, - const Module::Config& config - ) { +StaticModule::StaticModule(std::vector inputs, std::vector outputs, + std::vector>&& buffer, Schedule::ScheduleInfo&& scheduleInfo, + std::shared_ptr sharedConst, Session::ModeGroup&& mode, + std::shared_ptr rtm, const Module::Config& config) { setType("StaticModule"); mResource.reset(new Resource); mRuntimeManager = rtm; @@ -355,7 +379,8 @@ StaticModule::StaticModule(std::vector inputs, mResource->mSharedConst = sharedConst; mResource->mModes = std::move(mode); mResource->mBnInfo.user = &mResource->mBnConfig; - mResource->mModes.inputMode = config.shapeMutable ? Interpreter::Session_Input_User : Interpreter::Session_Input_Inside; + mResource->mModes.inputMode = + config.shapeMutable ? Interpreter::Session_Input_User : Interpreter::Session_Input_Inside; mResource->mModes.outputMode = Interpreter::Session_Output_User; std::shared_ptr net_storage; std::map, DataType>> exeCache; @@ -370,7 +395,8 @@ StaticModule::StaticModule(std::vector inputs, bnCache.cache.first->pNPUModelDirPath = rtm->getInside()->mContent->mNpuDir; bnCache.cache.second->pNPUModelDirPath = rtm->getInside()->mContent->mNpuDir; if (config.rearrange) { - mResource->mBuffer = preRearrangeWeights(scheduleInfo, bnCache.cache.first.get(), bnCache.cache.second.get(), config.base); + mResource->mBuffer = + preRearrangeWeights(scheduleInfo, bnCache.cache.first.get(), bnCache.cache.second.get(), config); } else { mResource->mBuffer = std::move(buffer); } @@ -380,7 +406,7 @@ StaticModule::StaticModule(std::vector inputs, std::vector mOutputFromInput; */ for (int i = 0; i < outputs.size(); ++i) { - auto& t = outputs[i]; + auto& t = outputs[i]; bool fromInput = false; for (int j = 0; j < inputs.size(); ++j) { if (inputs[j] == t) { @@ -403,11 +429,11 @@ StaticModule::StaticModule(std::vector inputs, } mResource->mInputs = std::move(inputs); mResource->mInputNeedCPU.resize(mResource->mInputs.size()); - for (int i=0; imInputs.size(); ++i) { + for (int i = 0; i < mResource->mInputs.size(); ++i) { mResource->mInputNeedCPU[i] = false; } if (mResource->mUseContentInputs) { - for (int i=0; imInputs.size(); ++i) { + for (int i = 0; i < mResource->mInputs.size(); ++i) { auto subT = scheduleInfo.allTensors[mResource->mInputs[i]].get(); if (TensorUtils::getDescribe(subT)->usage == Tensor::InsideDescribe::CONSTANT) { mResource->mInputNeedCPU[i] = true; @@ -424,11 +450,11 @@ StaticModule::StaticModule(std::vector inputs, } } StaticModule::~StaticModule() { - mSession = nullptr; + mSession = nullptr; } void StaticModule::onClearCache() { if (nullptr != mSession) { - for (int i=0; igetPipelineInfo(0).first.inputTensorCopyCache) { @@ -465,7 +491,8 @@ ErrorCode StaticModule::_resize(const std::vector& inputs) { std::get<3>(cacheIter->second) = true; mPrevInputTensor[i] = std::make_pair(inputTensor, newType); if (std::get<1>(*cacheTensor) != nullptr) { - if (!WrapExecution::needWrap(inputTensor, TensorUtils::getDescribeOrigin(std::get<0>(*cacheTensor))->getBackend())) { + if (!WrapExecution::needWrap( + inputTensor, TensorUtils::getDescribeOrigin(std::get<0>(*cacheTensor))->getBackend())) { // No need copy now, reset it cacheIter->second = std::make_tuple(nullptr, nullptr, true, true); } @@ -523,8 +550,8 @@ ErrorCode StaticModule::_resize(const std::vector& inputs) { mSession->setNeedResize(); } if (!needResize) { - // Check if output is used by other vars. If used, must realloc output to avoid the content dirty for output vars - // If resized, the output's memory will be all released in Session::resize, don't need clear here + // Check if output is used by other vars. If used, must realloc output to avoid the content dirty for output + // vars If resized, the output's memory will be all released in Session::resize, don't need clear here for (auto& output : mOutputTensors) { auto desOrigin = TensorUtils::getDescribeOrigin(output); if ((!desOrigin->mContent->isMutable) || nullptr == desOrigin->mem.get()) { @@ -534,7 +561,8 @@ ErrorCode StaticModule::_resize(const std::vector& inputs) { if (nullptr == bn) { continue; } - if (desOrigin->mContent.use_count() > 1 && desOrigin->mContent->usage != Tensor::InsideDescribe::CONSTANT) { + if (desOrigin->mContent.use_count() > 1 && + desOrigin->mContent->usage != Tensor::InsideDescribe::CONSTANT) { desOrigin->mem = nullptr; auto res = bn->onAcquireBuffer(output, Backend::STATIC); if (!res) { @@ -568,7 +596,7 @@ ErrorCode StaticModule::_resize(const std::vector& inputs) { if (nullptr == mInputTensors[i]) { continue; } - auto exprInfo = inputs[i]->expr(); + auto exprInfo = inputs[i]->expr(); auto inputTensor = Utils::getTensor(inputs[i]); mInputTensors[i]->copyFromHostTensor(inputTensor); } @@ -594,7 +622,6 @@ ErrorCode StaticModule::_execute() { } std::vector StaticModule::onForward(const std::vector& inputs) { - AUTOTIME; // Apply before resize/clone may construct new Backends (e.g. onClone path). if (mRuntimeManager) { @@ -651,13 +678,17 @@ std::vector StaticModule::onForward(const std::vectormOutputFromTensor[i]] = Express::Variable::create(Express::Expr::create(tensor, true)); auto backend = TensorUtils::getDescribeOrigin(tensor)->getBackend(); if (backend == pipelineInfo.first.cache.first.get()) { - outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = pipelineInfo.first.cache.first; + outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = + pipelineInfo.first.cache.first; } else if (backend == pipelineInfo.first.cache.second.get()) { - outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = pipelineInfo.first.cache.second; + outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = + pipelineInfo.first.cache.second; } else if (backend == mResource->mSharedConst->defaultBackend.get()) { - outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = mResource->mSharedConst->defaultBackend; + outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = + mResource->mSharedConst->defaultBackend; } else if (backend == mResource->mSharedConst->constReplaceBackend.get()) { - outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = mResource->mSharedConst->constReplaceBackend; + outputs[mResource->mOutputFromTensor[i]]->expr().first->inside()->mHoldBackend = + mResource->mSharedConst->constReplaceBackend; } } if (mShapeInferSeperate && runResize) { @@ -697,7 +728,8 @@ int StaticModule::onOptimize(Interpreter::SessionMode stage) { mSession->fixResizeCache(); break; case MNN::Interpreter::Module_Forward_Separate: - if (mResource->mUseContentInputs || mResource->mModes.inputMode != Interpreter::Session_Input_User || mResource->mOutputFromTensor.empty()) { + if (mResource->mUseContentInputs || mResource->mModes.inputMode != Interpreter::Session_Input_User || + mResource->mOutputFromTensor.empty()) { res = NOT_SUPPORT; break; } diff --git a/include/MNN/expr/MathOp.hpp b/include/MNN/expr/MathOp.hpp index 881003c536..4135ce5e18 100644 --- a/include/MNN/expr/MathOp.hpp +++ b/include/MNN/expr/MathOp.hpp @@ -15,6 +15,7 @@ namespace Express { MNN_PUBLIC VARP _Add(VARP x, VARP y); MNN_PUBLIC VARP _Subtract(VARP x, VARP y); MNN_PUBLIC VARP _Multiply(VARP x, VARP y); +MNN_PUBLIC VARP _MulSilu(VARP x, VARP y); MNN_PUBLIC VARP _Divide(VARP x, VARP y); MNN_PUBLIC VARP _Pow(VARP x, VARP y); MNN_PUBLIC VARP _Minimum(VARP x, VARP y); diff --git a/resource/qwen3_hf_0.6b.json b/resource/qwen3_hf_0.6b.json new file mode 100644 index 0000000000..33c5f1c3eb --- /dev/null +++ b/resource/qwen3_hf_0.6b.json @@ -0,0 +1,34 @@ +{ + "models": [ + { + "name": "hf_decoder", + "blocks": [ + { + "type": "QwenTransformer", + "hiddenSize": 1024, + "headDim": 128, + "numHead": 16, + "kvNumHead": 8, + "number": 28, + "max_position_embeddings": 40960 + } + ] + }, + { + "name": "logit", + "blocks": [ + { + "type": "InnerProduct", + "prefix": "lm_head" + }, + { + "type": "TieEmbedding" + }, + { + "type": "TopKV", + "K": [1, 5] + } + ] + } + ] +} diff --git a/schema/current/MNN_generated.h b/schema/current/MNN_generated.h index 60c764f07f..035efd4ca1 100644 --- a/schema/current/MNN_generated.h +++ b/schema/current/MNN_generated.h @@ -289,6 +289,7 @@ enum OpType { OpType_SplitGeLU = 303, OpType_GroupNorm = 304, OpType_LinearAttention = 305, + OpType_RoPE = 306, OpType_Extra = 512, OpType_ConvInt8 = 513, OpType_Int8ToFloat = 514, @@ -303,7 +304,7 @@ enum OpType { OpType_MAX = OpType_GridSample }; -inline const OpType (&EnumValuesOpType())[183] { +inline const OpType (&EnumValuesOpType())[184] { static const OpType values[] = { OpType_AbsVal, OpType_QuantizedAdd, @@ -478,6 +479,7 @@ inline const OpType (&EnumValuesOpType())[183] { OpType_SplitGeLU, OpType_GroupNorm, OpType_LinearAttention, + OpType_RoPE, OpType_Extra, OpType_ConvInt8, OpType_Int8ToFloat, @@ -800,7 +802,7 @@ inline const char * const *EnumNamesOpType() { "SplitGeLU", "GroupNorm", "LinearAttention", - "", + "RoPE", "", "", "", @@ -2999,10 +3001,15 @@ struct AttentionParamT : public flatbuffers::NativeTable { std::string kv_shared_layer; int32_t layer_index; int32_t kv_shared_layer_index; + std::vector> mhq_quant; + bool output_c4; + float attnScale; AttentionParamT() : kv_cache(true), layer_index(-1), - kv_shared_layer_index(-1) { + kv_shared_layer_index(-1), + output_c4(false), + attnScale(0.0f) { } }; @@ -3023,6 +3030,15 @@ struct AttentionParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t kv_shared_layer_index() const { return GetField(10, -1); } + const flatbuffers::Vector> *mhq_quant() const { + return GetPointer> *>(12); + } + bool output_c4() const { + return GetField(14, 0) != 0; + } + float attnScale() const { + return GetField(16, 0.0f); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, 4) && @@ -3030,6 +3046,11 @@ struct AttentionParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(kv_shared_layer()) && VerifyField(verifier, 8) && VerifyField(verifier, 10) && + VerifyOffset(verifier, 12) && + verifier.VerifyVector(mhq_quant()) && + verifier.VerifyVectorOfTables(mhq_quant()) && + VerifyField(verifier, 14) && + VerifyField(verifier, 16) && verifier.EndTable(); } AttentionParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3052,6 +3073,15 @@ struct AttentionParamBuilder { void add_kv_shared_layer_index(int32_t kv_shared_layer_index) { fbb_.AddElement(10, kv_shared_layer_index, -1); } + void add_mhq_quant(flatbuffers::Offset>> mhq_quant) { + fbb_.AddOffset(12, mhq_quant); + } + void add_output_c4(bool output_c4) { + fbb_.AddElement(14, static_cast(output_c4), 0); + } + void add_attnScale(float attnScale) { + fbb_.AddElement(16, attnScale, 0.0f); + } explicit AttentionParamBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3069,11 +3099,17 @@ inline flatbuffers::Offset CreateAttentionParam( bool kv_cache = true, flatbuffers::Offset kv_shared_layer = 0, int32_t layer_index = -1, - int32_t kv_shared_layer_index = -1) { + int32_t kv_shared_layer_index = -1, + flatbuffers::Offset>> mhq_quant = 0, + bool output_c4 = false, + float attnScale = 0.0f) { AttentionParamBuilder builder_(_fbb); + builder_.add_attnScale(attnScale); + builder_.add_mhq_quant(mhq_quant); builder_.add_kv_shared_layer_index(kv_shared_layer_index); builder_.add_layer_index(layer_index); builder_.add_kv_shared_layer(kv_shared_layer); + builder_.add_output_c4(output_c4); builder_.add_kv_cache(kv_cache); return builder_.Finish(); } @@ -5499,6 +5535,9 @@ inline void AttentionParam::UnPackTo(AttentionParamT *_o, const flatbuffers::res { auto _e = kv_shared_layer(); if (_e) _o->kv_shared_layer = _e->str(); }; { auto _e = layer_index(); _o->layer_index = _e; }; { auto _e = kv_shared_layer_index(); _o->kv_shared_layer_index = _e; }; + { auto _e = mhq_quant(); if (_e) { _o->mhq_quant.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mhq_quant[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = output_c4(); _o->output_c4 = _e; }; + { auto _e = attnScale(); _o->attnScale = _e; }; } inline flatbuffers::Offset AttentionParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5513,12 +5552,18 @@ inline flatbuffers::Offset CreateAttentionParam(flatbuffers::Fla auto _kv_shared_layer = _o->kv_shared_layer.empty() ? 0 : _fbb.CreateString(_o->kv_shared_layer); auto _layer_index = _o->layer_index; auto _kv_shared_layer_index = _o->kv_shared_layer_index; + auto _mhq_quant = _o->mhq_quant.size() ? _fbb.CreateVector> (_o->mhq_quant.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorQuantInfo(*__va->__fbb, __va->__o->mhq_quant[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _output_c4 = _o->output_c4; + auto _attnScale = _o->attnScale; return MNN::CreateAttentionParam( _fbb, _kv_cache, _kv_shared_layer, _layer_index, - _kv_shared_layer_index); + _kv_shared_layer_index, + _mhq_quant, + _output_c4, + _attnScale); } inline LinearAttentionParamT *LinearAttentionParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7768,7 +7813,7 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS break; } case OpParameter_AttentionParam: { - value = new AttentionParamT(*reinterpret_cast(u.value)); + FLATBUFFERS_ASSERT(false); // AttentionParamT not copyable. break; } case OpParameter_StftParam: { @@ -8485,12 +8530,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, + { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 } }; static const flatbuffers::TypeFunction type_refs[] = { OpTypeTypeTable }; - static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 305, 512, 513, 514, 515, 517, 518, 600, 601, 603, 604 }; + static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 305, 306, 512, 513, 514, 515, 517, 518, 600, 601, 603, 604 }; static const char * const names[] = { "AbsVal", "QuantizedAdd", @@ -8665,6 +8711,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { "SplitGeLU", "GroupNorm", "LinearAttention", + "RoPE", "Extra", "ConvInt8", "Int8ToFloat", @@ -8677,7 +8724,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { "GridSample" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_ENUM, 183, type_codes, type_refs, values, names + flatbuffers::ST_ENUM, 184, type_codes, type_refs, values, names }; return &tt; } @@ -9110,16 +9157,25 @@ inline const flatbuffers::TypeTable *AttentionParamTypeTable() { { flatbuffers::ET_BOOL, 0, -1 }, { flatbuffers::ET_STRING, 0, -1 }, { flatbuffers::ET_INT, 0, -1 }, - { flatbuffers::ET_INT, 0, -1 } + { flatbuffers::ET_INT, 0, -1 }, + { flatbuffers::ET_SEQUENCE, 1, 0 }, + { flatbuffers::ET_BOOL, 0, -1 }, + { flatbuffers::ET_FLOAT, 0, -1 } + }; + static const flatbuffers::TypeFunction type_refs[] = { + TensorQuantInfoTypeTable }; static const char * const names[] = { "kv_cache", "kv_shared_layer", "layer_index", - "kv_shared_layer_index" + "kv_shared_layer_index", + "mhq_quant", + "output_c4", + "attnScale" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names + flatbuffers::ST_TABLE, 7, type_codes, type_refs, nullptr, names }; return &tt; } diff --git a/schema/current/TensorflowOp_generated.h b/schema/current/TensorflowOp_generated.h index c567eaf96f..00edade18a 100644 --- a/schema/current/TensorflowOp_generated.h +++ b/schema/current/TensorflowOp_generated.h @@ -234,11 +234,12 @@ enum BinaryOpOperation { BinaryOpOperation_LOGICALXOR = 26, BinaryOpOperation_LEFTSHIFT = 27, BinaryOpOperation_RIGHTSHIFT = 28, + BinaryOpOperation_MUL_SILU = 29, BinaryOpOperation_MIN = BinaryOpOperation_ADD, - BinaryOpOperation_MAX = BinaryOpOperation_RIGHTSHIFT + BinaryOpOperation_MAX = BinaryOpOperation_MUL_SILU }; -inline const BinaryOpOperation (&EnumValuesBinaryOpOperation())[28] { +inline const BinaryOpOperation (&EnumValuesBinaryOpOperation())[29] { static const BinaryOpOperation values[] = { BinaryOpOperation_ADD, BinaryOpOperation_SUB, @@ -267,7 +268,8 @@ inline const BinaryOpOperation (&EnumValuesBinaryOpOperation())[28] { BinaryOpOperation_BITWISE_XOR, BinaryOpOperation_LOGICALXOR, BinaryOpOperation_LEFTSHIFT, - BinaryOpOperation_RIGHTSHIFT + BinaryOpOperation_RIGHTSHIFT, + BinaryOpOperation_MUL_SILU }; return values; } @@ -303,13 +305,14 @@ inline const char * const *EnumNamesBinaryOpOperation() { "LOGICALXOR", "LEFTSHIFT", "RIGHTSHIFT", + "MUL_SILU", nullptr }; return names; } inline const char *EnumNameBinaryOpOperation(BinaryOpOperation e) { - if (e < BinaryOpOperation_ADD || e > BinaryOpOperation_RIGHTSHIFT) return ""; + if (e < BinaryOpOperation_ADD || e > BinaryOpOperation_MUL_SILU) return ""; const size_t index = static_cast(e); return EnumNamesBinaryOpOperation()[index]; } @@ -4957,12 +4960,13 @@ inline const flatbuffers::TypeTable *BinaryOpOperationTypeTable() { { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, + { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 } }; static const flatbuffers::TypeFunction type_refs[] = { BinaryOpOperationTypeTable }; - static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28 }; + static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29 }; static const char * const names[] = { "ADD", "SUB", @@ -4991,10 +4995,11 @@ inline const flatbuffers::TypeTable *BinaryOpOperationTypeTable() { "BITWISE_XOR", "LOGICALXOR", "LEFTSHIFT", - "RIGHTSHIFT" + "RIGHTSHIFT", + "MUL_SILU" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_ENUM, 28, type_codes, type_refs, values, names + flatbuffers::ST_ENUM, 29, type_codes, type_refs, values, names }; return &tt; } diff --git a/schema/default/MNN.fbs b/schema/default/MNN.fbs index 73f1c5854f..77f777d704 100644 --- a/schema/default/MNN.fbs +++ b/schema/default/MNN.fbs @@ -195,6 +195,7 @@ enum OpType : int { SplitGeLU = 303, GroupNorm = 304, LinearAttention = 305, + RoPE = 306, Extra = 512, // quantization @@ -233,6 +234,9 @@ table AttentionParam { kv_shared_layer: string; layer_index: int = -1; kv_shared_layer_index: int = -1; + mhq_quant:[TensorQuantInfo]; // qk_scale_q, qk_scale_k, sv_scale_s, sv_scale_v + output_c4:bool = false; + attnScale: float = 0.0; } table LinearAttentionParam { diff --git a/schema/default/TensorflowOp.fbs b/schema/default/TensorflowOp.fbs index 89cd0c79b7..43826e90f9 100644 --- a/schema/default/TensorflowOp.fbs +++ b/schema/default/TensorflowOp.fbs @@ -30,6 +30,7 @@ enum BinaryOpOperation : int { LOGICALXOR = 26, LEFTSHIFT = 27, RIGHTSHIFT = 28, + MUL_SILU = 29, } table BinaryOp { diff --git a/skills/support-new-llm/SKILL.md b/skills/support-new-llm/SKILL.md index 0ea0629217..5620553093 100644 --- a/skills/support-new-llm/SKILL.md +++ b/skills/support-new-llm/SKILL.md @@ -1,11 +1,11 @@ --- name: support-new-llm -description: 为 MNN 框架添加新的 LLM 模型支持。支持从 HuggingFace/ModelScope 下载模型,分析架构,添加映射,Hook 对齐测试,导出 MNN 模型。采用 TDD 模式,分 6 步执行,每步有独立测试标准。 +description: 为 MNN 框架添加新的 LLM 模型支持。支持从 HuggingFace/ModelScope 下载模型,分析架构,添加映射,Hook 对齐测试,导出 MNN 模型;当用户明确要求 safetensors/segment/workflow/MNNConvert -f ST 时,走 safetensors segment 补充分支。采用 TDD 模式,分 6 步执行,每步有独立测试标准。 --- # MNN LLM 新模型支持 SKILL -> **触发条件**:当用户请求支持/添加/适配一个新的 LLM 模型时触发。常见表述包括:"支持xxx模型"、"添加xxx模型支持"、"适配xxx"、"导出xxx模型"等。 +> **触发条件**:当用户请求支持/添加/适配一个新的 LLM 模型时触发。常见表述包括:"支持xxx模型"、"添加xxx模型支持"、"适配xxx"、"导出xxx模型"等。若用户明确提到 `safetensors`、`--segment`、`workflow.json`、`MNNConvert -f ST` 或“绕过 ONNX 直接转换”,先读 `safetensors-segment.md`。 ## 概述 @@ -20,6 +20,8 @@ MNN 的模型导出本质上是**对照 HuggingFace transformers 库中原始模 3. 用 Python `--test` 验证映射正确性 4. 导出 MNN 模型并用 C++ 引擎验证 +默认导出链路是 `llmexport.py --export mnn`。如果目标是 safetensors segment 格式,则使用 `llmexport.py --export mnn --segment`,按 `safetensors-segment.md` 先校验 workflow、safetensors key 和 builder 约定。 + ### 注意事项 > **🚨 严禁将输出错误归因于"量化精度不够"**:4bit 量化的 0.5B 小模型都能正确输出。如果 C++ 输出完全不对(如图片识别不出、输出乱码),**一定是实现细节没有与 HF 对齐**,必须逐步 dump 数据对比定位,不要靠猜。 @@ -44,6 +46,9 @@ MNN 的模型导出本质上是**对照 HuggingFace transformers 库中原始模 | `transformers/llm/export/utils/audio.py` | Audio Encoder 实现 | 音频模型 | | `transformers/llm/export/utils/custom_op.py` | 自定义算子导出 | 新算子时 | | `transformers/llm/export/llmexport.py` | 导出主流程入口 | 偶尔 | +| `transformers/llm/export/segment.py` | safetensors segment 导出入口 | segment 分支 | +| `resource/*.json` | safetensors workflow 模板 | segment 分支 | +| `tools/converter/source/safetensors/*.cpp` | safetensors converter / builder 实现 | segment 分支 | --- @@ -98,6 +103,7 @@ MNN 的模型导出本质上是**对照 HuggingFace transformers 库中原始模 | Tier 4 (音频模型) | 1 → 2 → 3 → 5 → 4 | 需要 audio.py | | Tier 5 (视觉模型) | 1 → 2 → 3 → 5 → 4 | 需要 vision.py | | Tier 6 (全新架构) | 1 → 2 → 6 → 3 → 4 | 需要新算子(如叠加 Tier 4/5 则加入 step5) | +| Safetensors segment | 1 → S1/S2/S3 → S4/S5 | 明确要求 `--segment` / workflow / `MNNConvert -f ST` 时,参见 `safetensors-segment.md` | --- @@ -185,8 +191,10 @@ modeling_*.py 中是否有全新的 Attention 类型(非标准 SDPA)? **在开始之前,建议先浏览 `common-pitfalls.md`**,了解已知的常见问题和解决方案(RoPE 变体、dtype 级联、Jinja 限制、stop token、残差模式、MoE 支持要点、FakeLinear axis 陷阱、**do_map 静默失败与 rope_theta 间接存储**、非标准模型加载等)。 +**Safetensors segment 分支的常见问题**:workflow 超参与权重 shape 不匹配、builder 预期 key 前缀不存在、自动 workflow 匹配选错、segment runtime 没启用。详见 `safetensors-segment.md`。 + --- ## 开始执行 -**现在请打开 `skills/support-new-llm/step1-analyze.md`,开始步骤 1。** +**现在请打开 `skills/support-new-llm/step1-analyze.md`,开始步骤 1。若用户明确要求 safetensors segment 导出,同时打开 `skills/support-new-llm/safetensors-segment.md`。** diff --git a/skills/support-new-llm/safetensors-segment.md b/skills/support-new-llm/safetensors-segment.md new file mode 100644 index 0000000000..cead183cf3 --- /dev/null +++ b/skills/support-new-llm/safetensors-segment.md @@ -0,0 +1,280 @@ +# Safetensors Segment 导出补充 + +> **适用场景**:用户明确提到 `safetensors`、`--segment`、`workflow.json`、`MNNConvert -f ST`,或要求绕过 ONNX、直接从 safetensors 生成 segment 格式 MNN LLM。 + +本补充从 `skills/segment-new-llm` 的 safetensors 流程提取而来,并按当前 MNN 仓库路径修正。默认 `support-new-llm` 流程仍是 `llmexport.py --export mnn` 的标准导出;只有命中上述场景时才切到本分支。 + +不要照搬其他仓库/旧 skill 中的 PantherLLM 路径(如 `converter/resource/*.json`、`converter/mnn_safetensors_plugin`、`--customOpLibs libpantherllm_safetensors_plugin`)。当前 MNN 仓库的 segment 分支以 `transformers/llm/export/segment.py`、`resource/*.json` 和 `tools/converter/source/safetensors` 为准。 + +--- + +## 入口与核心文件 + +Segment 分支的主路径是: + +```text +HF / ModelScope model dir + | + v +safetensors weights + workflow JSON + | + v +llmexport.py --export mnn --segment + | + v +MNNConvert -f ST + | + v +segment model dir + | + v +llm_demo /config.json prompt.txt +``` + +核心文件: + +| 文件路径 | 作用 | +|---------|------| +| `transformers/llm/export/segment.py` | segment 导出入口;解析 workflow、safetensors 和导出配置 | +| `transformers/llm/export/llmexport.py` | `--segment` / `--workflow` 参数入口 | +| `resource/*.json` | workflow 模板;当前典型样例是 `resource/qwen3_hf_0.6b.json` | +| `tools/converter/source/safetensors/*.cpp` | safetensors builder / converter 实现 | +| `tools/converter/source/safetensors/SafetensorModelRegistry.hpp` | `REGISTER_SAFETENSOR_MODEL_BUILDER` 注册机制 | +| `transformers/llm/engine/src/segment.cpp` | C++ runtime 的 segment 加载路径 | + +--- + +## 步骤 S1:确认输入形态 + +先判断用户给的是哪种输入: + +| 输入形态 | 处理方式 | +|---------|---------| +| 模型目录,包含 `*.safetensors` 和 `config.json` | 可直接作为 `--path` | +| 单个 `.safetensors` 文件 | 可直接作为 `--path`,但仍需要 tokenizer/config 来源 | +| sharded safetensors + `*.safetensors.index.json` | `segment.py` 会按 index 中的 `weight_map` 顺序传给 `MNNConvert` | +| 只有 PyTorch `state_dict` / `.bin` | 先转换为 safetensors,再进入本流程 | +| 已有 workflow JSON | 显式传 `--workflow /path/to/workflow.json` | +| 没有 workflow JSON | 先从 `resource/*.json` 找最接近模板;不要盲目依赖自动匹配 | + +必须记录: + +- `model_type` +- `hidden_size` +- `num_hidden_layers` +- `num_attention_heads` +- `num_key_value_heads` +- `head_dim` +- `max_position_embeddings` +- embedding / blocks / norm / lm_head 的实际 safetensors key +- tokenizer 文件是否完整 + +--- + +## 步骤 S2:选择 workflow 与 builder + +Segment 分支不是在 `model_mapper.py` 中添加 Python 映射,而是靠 **workflow + safetensors builder**。 + +Workflow 关键点: + +- 顶层 `models[].name` 决定调用哪个 builder。 +- `blocks[]` 描述结构和超参,例如 `hiddenSize`、`headDim`、`numHead`、`kvNumHead`、`number`。 +- 当前可优先参考 `resource/qwen3_hf_0.6b.json`。 + +Builder 关键点: + +- 注册点使用 `REGISTER_SAFETENSOR_MODEL_BUILDER("name", builderFunc)`。 +- 当前文本 decoder 典型实现是 `tools/converter/source/safetensors/HuggingFaceQwen3.cpp`。 +- `logit` 典型实现是 `tools/converter/source/safetensors/Logit.cpp`。 + +判断是否能复用 workflow: + +- 权重命名与现有 builder 预期一致。 +- block 类型一致。 +- 只需要修改层数、hidden、head、kv head、head dim、max position。 +- 输出仍是 segment runtime 需要的 `embed.mnn`、`decoder.mnn`、`logit.mnn`、`logit_topkv_*.mnn` 等文件。 + +需要新增或修改 builder 的信号: + +- 权重前缀不同,现有 builder 找不到关键 tensor。 +- Attention / MLP / norm / residual 结构不同。 +- 需要新增 workflow block 字段才能表达模型结构。 +- 输出文件结构不能被现有 segment runtime 加载。 + +--- + +## 步骤 S3:转换前静态校验 + +在执行 `MNNConvert -f ST` 前,先验证 key、shape 和 workflow 超参。 + +### safetensors key 检查 + +```python +from safetensors import safe_open + +st_path = "/path/to/model.safetensors" +required_keys = [ + "model.embed_tokens.weight", + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.up_proj.weight", + "model.layers.0.mlp.down_proj.weight", + "model.norm.weight", + "lm_head.weight", +] + +with safe_open(st_path, framework="pt", device="cpu") as f: + key_set = set(f.keys()) + print("tensor_count:", len(key_set)) + for key in required_keys: + if key in key_set: + print("OK ", key, f.get_tensor(key).shape) + else: + print("MISS", key) + +missing = [key for key in required_keys if key not in key_set] +if missing: + raise SystemExit(f"missing keys: {missing}") +``` + +### workflow 超参与权重 shape 检查 + +```python +import json +from safetensors import safe_open + +workflow_path = "/path/to/workflow.json" +st_path = "/path/to/model.safetensors" + +with open(workflow_path, "r", encoding="utf-8") as f: + workflow = json.load(f) + +with safe_open(st_path, framework="pt", device="cpu") as st: + q = st.get_tensor("model.layers.0.self_attn.q_proj.weight") + +for model in workflow.get("models", []): + print("model:", model.get("name")) + for block in model.get("blocks", []): + if block.get("type") in {"QwenTransformer", "GPT2Transformer"}: + hidden = block.get("hiddenSize") + head_dim = block.get("headDim") + num_head = block.get("numHead") + if hidden is not None: + assert q.shape[1] == hidden, (q.shape, hidden) + if head_dim is not None and num_head is not None: + assert head_dim * num_head == q.shape[0], (head_dim, num_head, q.shape) + +print("workflow contract looks OK") +``` + +通过标准: + +- builder 依赖的关键 key 全部存在。 +- workflow 中的层数、hidden、head 维度与权重 shape 对齐。 +- tokenizer/config 资源来源明确。 +- 已保存参考模型的输入 prompt、token ids 和输出,用于导出后对比。 + +--- + +## 步骤 S4:执行 segment 导出 + +### 构建要求 + +```bash +mkdir -p build +cd build +cmake .. -DMNN_BUILD_LLM=ON -DMNN_BUILD_CONVERTER=ON +make -j$(nproc) +``` + +`MNN_LLM_SUPPORT_SEGMENT` 默认开启;如果被关闭,segment runtime 不能加载 `"mnn_llm_version": "segment"` 的模型。 + +### 推荐命令:通过 llmexport.py + +```bash +cd transformers/llm/export +python3 llmexport.py \ + --path /path/to/model_dir_or_safetensors \ + --export mnn \ + --segment \ + --workflow /path/to/workflow.json \ + --dst_path ./MODEL \ + --quant_bit 4 \ + --quant_block 64 +``` + +如果省略 `--workflow`,`segment.py` 会在 `resource/` 和 `transformers/llm/resource/` 下搜索可匹配的 JSON。命中多个或找不到时,应显式传入 workflow。 + +### 调试命令:直接调用 MNNConvert + +```bash +build/MNNConvert \ + -f ST \ + -i /path/to/workflow.json \ + -i /path/to/model.safetensors \ + -o /path/to/out_dir \ + --allowCustomOp \ + --saveExternalData \ + --weightQuantBits 4 \ + --weightQuantBlock 64 +``` + +多 shard safetensors 时,对每个 shard 追加一个 `-i /path/to/shard.safetensors`,顺序应与 index 中的 `weight_map` 一致。 + +--- + +## 步骤 S5:检查产物并验证 + +典型输出: + +```text +MODEL/ +├── config.json # 包含 "mnn_llm_version": "segment" +├── llm_config.json +├── tokenizer.mtok +├── embed.mnn +├── decoder.mnn +├── decoder.mnn.weight +├── logit.mnn +├── logit.mnn.weight +└── logit_topkv_1.mnn +``` + +检查: + +```bash +ls -la /path/to/MODEL +cat /path/to/MODEL/config.json +``` + +运行: + +```bash +echo "你好" > /tmp/prompt.txt +build/llm_demo /path/to/MODEL/config.json /tmp/prompt.txt +``` + +通过标准: + +- `config.json` 存在且包含 `"mnn_llm_version": "segment"`。 +- `embed.mnn`、`decoder.mnn`、`logit.mnn` 等关键文件存在且大小 > 0。 +- `llm_demo` 能加载并生成合理文本。 +- 输出与步骤 S3 保留的参考输出方向一致。 + +--- + +## 常见失败 + +| 现象 | 排查顺序 | +|------|---------| +| `no suitable workflow json` | 显式传 `--workflow`;检查 workflow 超参是否与 config 匹配 | +| `multiple suitable workflow json files` | 显式传 `--workflow`,不要让自动匹配猜 | +| `missing tensor` | 回到步骤 S3,核对 safetensors key 和 builder 硬编码前缀 | +| `unknown builder` | 检查 `models[].name` 是否已在 `tools/converter/source/safetensors` 注册 | +| 转换成功但加载失败 | 检查 `config.json`、`llm_config.json`、`tokenizer.mtok` 和输出文件名 | +| 输出完全不对 | 先查 workflow 超参、权重前缀、builder 读权重/reshape/transpose,再考虑量化 | + +不要在没有 key/shape 证据的情况下把问题归因于量化精度。 diff --git a/skills/support-new-llm/step1-analyze.md b/skills/support-new-llm/step1-analyze.md index cf041ebd38..6e6f0dc581 100644 --- a/skills/support-new-llm/step1-analyze.md +++ b/skills/support-new-llm/step1-analyze.md @@ -11,6 +11,8 @@ 根据用户提供的输入,选择对应的方式: +> **Safetensors segment 分支**:如果用户明确要求 `safetensors`、`--segment`、`workflow.json` 或 `MNNConvert -f ST`,本步骤仍需下载/确认模型目录和参考推理,但后续映射与导出要按 `safetensors-segment.md` 执行,而不是默认 ONNX 导出路径。 + ### 情况 A:用户提供本地路径 ``` @@ -72,6 +74,17 @@ snapshot_download( - `*.safetensors` 或 `pytorch_model*.bin`(模型权重) - `tokenizer.json` 或 `tokenizer.model`(tokenizer 文件) +### Safetensors segment 输入补充 + +命中 segment 分支时,还需要记录: + +- safetensors 是单文件、`model.safetensors`,还是 sharded safetensors + `*.safetensors.index.json` +- 是否已有 workflow JSON;没有则先从 `resource/*.json` 中找最接近模板 +- embedding / blocks / norm / lm_head 的实际 safetensors key +- 是否需要显式传 `--workflow`,避免自动匹配选错 + +具体 key/shape 校验脚本见 `safetensors-segment.md`。 + --- ## 1.2 阅读模型 README 和 config.json diff --git a/skills/support-new-llm/step2-mapping.md b/skills/support-new-llm/step2-mapping.md index f29c9b74ca..02ea04c8f3 100644 --- a/skills/support-new-llm/step2-mapping.md +++ b/skills/support-new-llm/step2-mapping.md @@ -10,6 +10,8 @@ MNN 使用 4 层映射将 HuggingFace 模型结构转换为统一接口: +> **Safetensors segment 分支**:如果本次目标是 `--segment` 或 `MNNConvert -f ST`,不要把主要工作放在 `model_mapper.py`。segment 分支的映射单位是 `resource/*.json` workflow 和 `tools/converter/source/safetensors` builder,流程见 `safetensors-segment.md` 的步骤 S2。 + | 映射键 | 作用 | 说明 | |--------|------|------| | `config` | HF config.json 字段 → LlmConfig 属性 | 把模型配置正确读入 | diff --git a/skills/support-new-llm/step3-test-python.md b/skills/support-new-llm/step3-test-python.md index f70d12b548..7bc8b13f86 100644 --- a/skills/support-new-llm/step3-test-python.md +++ b/skills/support-new-llm/step3-test-python.md @@ -10,6 +10,8 @@ 仅看最终输出文本是否"合理"是不够的。本步骤通过 **hook 机制**在两个模型的关键位置截取中间结果,逐层对比,精确定位映射或实现中的错误。 +> **Safetensors segment 分支**:如果目标是 `--segment` / `MNNConvert -f ST`,本步骤需要先做 `safetensors-segment.md` 中的 S3 静态校验(key、shape、workflow 超参、tokenizer/config 资源)。只有当 segment 分支仍修改了 Python LlmModel 或 transformers 逻辑时,才继续执行本文的 hook 对齐。 + **对比的两套模型**: 1. **原始 transformers 模型**:步骤 1 中加载的 `AutoModelForCausalLM`(标准答案) 2. **MNN LlmModel**:步骤 2 中映射转换后的 `LlmModel`(需要验证) diff --git a/skills/support-new-llm/step4-export.md b/skills/support-new-llm/step4-export.md index 3ef1a56ca9..6246944d68 100644 --- a/skills/support-new-llm/step4-export.md +++ b/skills/support-new-llm/step4-export.md @@ -8,6 +8,8 @@ ## 4.1 导出 MNN 模型 +> **Safetensors segment 分支**:如果目标是 `--segment` / `MNNConvert -f ST`,不要使用本节默认 ONNX 导出路径,改用 `safetensors-segment.md` 的 S4/S5:`llmexport.py --export mnn --segment --workflow ...`,并用 `llm_demo /config.json prompt.txt` 验证 segment runtime。 + ```bash cd transformers/llm/export python3 llmexport.py \ diff --git a/skills/test-ci/SKILL.md b/skills/test-ci/SKILL.md index f01873fea7..411cfa8004 100644 --- a/skills/test-ci/SKILL.md +++ b/skills/test-ci/SKILL.md @@ -65,6 +65,9 @@ Valid filters: `all` (default) · `cpu` · `opencl` · `opencl-image` · * Combined stdout/stderr for every stage is saved under `logs/test_ci-/.log` — read the named log of a failing stage for the trailing output. `rc=137` ≈ OOM-kill, `rc=139` ≈ SIGSEGV. +* For GPU/OpenCL smoke tests, verify that the intended backend actually loaded + (for example, OpenCL tuning/backend logs are present). A correct model output + alone is not sufficient when CPU fallback is possible. ## Environment variables @@ -118,6 +121,10 @@ file explains every field and every `skip` entry's rationale. [`TESTING.md`](../../TESTING.md) § "How to add a new operator test". 2. If its name prefix matches an existing stage (e.g. `op/*`), it is picked up automatically — no JSON change needed. Otherwise add a dedicated stage. +3. Do not add backend-specific skips inside an operator test. If a configured + backend fails, fix the backend implementation or, for a confirmed driver + issue, put the exact test name in the stage `skip` list with a documented + rationale in `test_stages.json`. For deeper work on operators themselves, see the [`add-new-op`](../add-new-op/SKILL.md) skill. diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index 7195ecb625..d18d141c88 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -21,18 +21,18 @@ #include "core/BufferAllocator.hpp" #include "compute/ConvolutionTiledExecutor.hpp" -#if defined (__aarch64__) +#if defined(__aarch64__) #define FLOAT16_T __fp16 #else #define FLOAT16_T float #endif - - namespace MNN { template -static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t processedKvSeq, int pack, int kvSeqLen, int kvoffset, int padKvSeqLen, const float* sinksPtr, const Tensor* mask, bool quantKey, bool isLowerTriangular) { +static void _maskQK(float* qkPacked, const float* scale, size_t seqLen, size_t processedKvSeq, int pack, int kvSeqLen, + int kvoffset, int padKvSeqLen, const float* sinksPtr, const Tensor* mask, bool quantKey, + bool isLowerTriangular) { /* * FIGURE 1: mask->elementSize() == seqLen * maskStride * Context: Cross Attention or Prefill stage (Full Context). @@ -53,7 +53,6 @@ static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t * 'X' : Masked (Value = -inf) */ - /* * FIGURE 2: mask->elementSize() != seqLen * maskStride * Context: Self-Attention Inference (Decoding stage). @@ -108,9 +107,12 @@ static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t return; } - int gapLen = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) ? 0 : static_cast(kvSeqLen - seqLen); + int gapLen = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) + ? 0 + : static_cast(kvSeqLen - seqLen); auto maskPtr = mask->host(); - auto maskCols = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) ? kvSeqLen + padKvSeqLen : seqLen + padKvSeqLen; + auto maskCols = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) ? kvSeqLen + padKvSeqLen + : seqLen + padKvSeqLen; for (int i = 0; i < kvBlockCount; ++i) { T* blockDataPtr = source + (i * seqLen * pack); @@ -135,20 +137,19 @@ static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t val += (float)currentMaskRow[currentKvSeqIndx - gapLen]; dataPtr[k] = (T)val; - } } } } ErrorCode CPUAttention::onResize(const std::vector& inputs, const std::vector& outputs) { - auto gcore = static_cast(backend())->functions(); + auto gcore = static_cast(backend())->functions(); auto core = static_cast(backend())->int8Functions(); gcore->MNNGetMatMulPackMode(&eP, &lP, &hP); - mThreadNum = ((CPUBackend *)backend())->threadNumber(); - mPack = gcore->pack; + mThreadNum = ((CPUBackend*)backend())->threadNumber(); + mPack = gcore->pack; mBytes = gcore->bytes; - int attentionOption = static_cast(backend())->getRuntime()->hint().attentionOption; + int attentionOption = static_cast(backend())->getRuntime()->hint().attentionOption; mUseFlashAttention = (attentionOption / 8 == 1); // attentionOption % 8: @@ -189,7 +190,7 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); auto query = inputs[0]; - auto key = inputs[1]; + auto key = inputs[1]; int seqLen = query->length(1); int mBlockNum = 1; mNumHead = query->length(2); @@ -260,7 +261,8 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: } } - if (mSumQ.invalid() || mQueryScale.invalid() || mQueryQuantZero.invalid() || mQueryZeroPoint.invalid() || mQueryQuantScale.invalid() || mQuantQuery.invalid()) { + if (mSumQ.invalid() || mQueryScale.invalid() || mQueryQuantZero.invalid() || mQueryZeroPoint.invalid() || + mQueryQuantScale.invalid() || mQuantQuery.invalid()) { return OUT_OF_MEMORY; } @@ -294,7 +296,8 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: } } } else { - mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seqLen, eP), ROUND_UP(mHeadDim, lP), eP * mBytes})); + mPackQ.reset( + Tensor::createDevice({mThreadNum, UP_DIV(seqLen, eP), ROUND_UP(mHeadDim, lP), eP * mBytes})); backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); } @@ -347,10 +350,11 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: } ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std::vector& outputs) { - auto gcore = static_cast(backend())->functions(); - auto core = static_cast(backend())->int8Functions(); + auto gcore = static_cast(backend())->functions(); + auto core = static_cast(backend())->int8Functions(); + bool outputC4 = TensorUtils::getDescribe(outputs[0])->dimensionFormat == MNN_DATA_FORMAT_NC4HW4; auto query = inputs[0]; - auto key = inputs[1]; + auto key = inputs[1]; auto value = inputs[2]; int seqLen = query->length(1); const Tensor* mask = nullptr; @@ -372,7 +376,8 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow FLOAT16_T minValue; FLOAT16_T maxValue; - gcore->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), query->elementSize()); + gcore->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), + query->elementSize()); float maxV = maxValue; float minV = minValue; float absMax = ALIMAX(fabsf(maxV), fabsf(minV)); @@ -413,7 +418,7 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // Constant Initialization auto padSeqLength = seqLen - insertLen; seqLen = insertLen; - int kvSeqLen = mKVCacheManager->kvLength(); + int kvSeqLen = mKVCacheManager->kvLength(); int maxLen = mKVCacheManager->maxLength(); int32_t units[2] = {eP, lP}; const float* sinksPtr = sinks ? sinks->host() : nullptr; @@ -421,14 +426,17 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // Temporary tensors for intermediate results std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seqLen, mBlockKV})); - std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seqLen, ROUND_UP(mBlockKV, mPack)})); // [mBlockKV/mPack, seqLen, mPack ] + std::shared_ptr softmMaxQ(Tensor::createDevice( + {mThreadNum, seqLen, ROUND_UP(mBlockKV, mPack)})); // [mBlockKV/mPack, seqLen, mPack ] std::shared_ptr newPackQK; if (mValueQuantMode != KVQuantMode::Int8) { newPackQK.reset(Tensor::createDevice({mThreadNum, eP * ROUND_UP(mBlockKV, lP) * mBytes})); } else { - newPackQK.reset(Tensor::createDevice({mThreadNum, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8)})); + newPackQK.reset( + Tensor::createDevice({mThreadNum, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8)})); } - std::shared_ptr mTempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seqLen, mPack * mBytes})); + std::shared_ptr mTempQKBlock( + Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seqLen, mPack * mBytes})); backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); @@ -449,13 +457,12 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // quantQ: [seqLen,numHead,headDim] auto queryPtr = query->host(); int divPart = UP_DIV(seqLen * mNumHead, mThreadNum); - MNN_CONCURRENCY_BEGIN (tId, mThreadNum) { + MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { size_t info[9] = {1, (size_t)mHeadDim, 1, 1, 1, 1, 1, 1, 0}; auto remainLu = seqLen * mNumHead - tId * divPart; if (remainLu > 0) { remainLu = ALIMIN(divPart, remainLu); for (int i = tId * divPart; i < tId * divPart + remainLu; ++i) { - // address auto srcFloatPtr = (float*)(queryPtr + i * mHeadDim * mBytes); auto dstInt8Ptr = (int8_t*)(mQuantQuery.ptr() + i * mHeadDim); @@ -467,17 +474,19 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: auto scalePtr = (float*)(mQueryScale.ptr() + indexQ * QUANT_INFO_BYTES); auto zeroPtr = (float*)(mQueryZeroPoint.ptr() + indexQ * QUANT_INFO_BYTES); - // compute the quant/dequant scale/bias - gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, quantScalePtr, quantZeroPtr, nullptr, nullptr, srcFloatPtr, info); + gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, quantScalePtr, quantZeroPtr, nullptr, nullptr, + srcFloatPtr, info); scalePtr[0] *= mScale; zeroPtr[0] *= mScale; // quantize the float query to int8_t query - mQuantFunc(srcFloatPtr, dstInt8Ptr, UP_DIV(mHeadDim, gcore->pack), quantScalePtr, -128, 127, quantZeroPtr, 0); + mQuantFunc(srcFloatPtr, dstInt8Ptr, UP_DIV(mHeadDim, gcore->pack), quantScalePtr, -128, 127, + quantZeroPtr, 0); } } - } MNN_CONCURRENCY_END(); + } + MNN_CONCURRENCY_END(); // source int8_t query: [seqLen,numHead,headDim] // dest int8_t query: [numHead,seqLen/eP,headDim/lP,eP,lP] @@ -504,19 +513,15 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: continue; } - int8_t* dst_block_ptr = dst_base_ptr + - outputOffset + - (size_t)dimBlock * (eunit * lP8); + int8_t* dst_block_ptr = dst_base_ptr + outputOffset + (size_t)dimBlock * (eunit * lP8); const size_t src_row_stride = (size_t)mNumHead * mHeadDim; for (int seqLocal = 0; seqLocal < eunit; ++seqLocal) { int innerSeq = seqBase + seqLocal; - const int8_t* src_row_ptr = src_base_ptr + - (size_t)innerSeq * src_row_stride + - (size_t)h * mHeadDim + - dimBase; + const int8_t* src_row_ptr = + src_base_ptr + (size_t)innerSeq * src_row_stride + (size_t)h * mHeadDim + dimBase; int8_t* dst_row_ptr = dst_block_ptr + seqLocal * lP8; @@ -539,7 +544,7 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: for (int k = 0; k < eP8; ++k) { scalePtr[k] = 1.f / 255.f; #ifdef MNN_USE_SSE - zeroPtr[k] =0; + zeroPtr[k] = 0; #else zeroPtr[k] = 128.f / 255.f; #endif @@ -549,21 +554,22 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: std::function mCompute = [=](int tId) { int8_t* qReordered = nullptr; - auto qkPacked = mTempQKBlock->host() + tId * mTempQKBlock->stride(0); - auto qkFlatten = unpackQK->host() + tId * unpackQK->stride(0); - auto qkSoftmax = softmMaxQ->host() + tId * softmMaxQ->stride(0); + auto qkPacked = mTempQKBlock->host() + tId * mTempQKBlock->stride(0); + auto qkFlatten = unpackQK->host() + tId * unpackQK->stride(0); + auto qkSoftmax = softmMaxQ->host() + tId * softmMaxQ->stride(0); auto qkReordered = newPackQK->host() + tId * newPackQK->stride(0); - auto qkvPacked = mPackQKV->host() + tId * mPackQKV->stride(0); - int headIndex = tId * numHeadDiv; - int headsToCompute = ALIMIN(numHeadDiv, mNumHead - headIndex); + auto qkvPacked = mPackQKV->host() + tId * mPackQKV->stride(0); + int headIndex = tId * numHeadDiv; + int headsToCompute = ALIMIN(numHeadDiv, mNumHead - headIndex); // Flash Attention auto runningMax = mRunningMax ? (float*)(mRunningMax->host() + tId * mRunningMax->stride(0)) : nullptr; auto runningSum = mRunningSum ? (float*)(mRunningSum->host() + tId * mRunningSum->stride(0)) : nullptr; - auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host() + tId * mExpfDiffMax->stride(0)) : nullptr; + auto diffScale = + mExpfDiffMax ? (float*)(mExpfDiffMax->host() + tId * mExpfDiffMax->stride(0)) : nullptr; auto outputPacked = mTempOut ? mTempOut->host() + tId * mTempOut->stride(0) : qkvPacked; - - int kvBlocks = UP_DIV(kvSeqLen, mBlockKV); + + int kvBlocks = UP_DIV(kvSeqLen, mBlockKV); bool isLowerTriangular = (mask == nullptr); if (mask != nullptr && mask->shape().empty()) { @@ -607,7 +613,6 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: sumParams4QxK.kernelCountUnitDouble = UP_DIV(mHeadDim, lP8); sumParams4QxK.valid = mHeadDim % lP8; - if (mBlockNum > 1) { accumbuff = (float*)(mAccumBuffer.ptr() + tId * eP8 * hP8 * QUANT_INFO_BYTES); } @@ -635,7 +640,8 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: size_t vstride0 = ROUND_UP(mHeadDim, hP) * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP); if (mValueQuantMode == KVQuantMode::Int8) { - vstride0 = (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP8) + 2 * QUANT_INFO_BYTES * mBlockNum * ROUND_UP(mHeadDim, hP8)); + vstride0 = (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP8) + + 2 * QUANT_INFO_BYTES * mBlockNum * ROUND_UP(mHeadDim, hP8)); } // use for V @@ -644,9 +650,9 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: float vQuantScale[1] = {255.f}; float vQuantBias[1] = {-128.f}; int32_t infoInt8V[5]; - infoInt8V[0] = 1; // number + infoInt8V[0] = 1; // number infoInt8V[2] = static_cast(sumParams4QKxV.unitColBufferSize); - infoInt8V[3] = 1; // stride + infoInt8V[3] = 1; // stride int32_t elInt8V[4] = {eP8, ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), 0, 0}; // only used for float V @@ -659,6 +665,13 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: int offset[2] = {seqLen, mNumHead * mHeadDim}; for (int h = headIndex; h < headIndex + headsToCompute; h++) { + auto dstStep = mBytes * seqLen * mPack; + if (outputC4) { + outputPacked = outputs[0]->host() + h * mHeadDim * seqLen * mBytes; + if (!mUseFlashAttention) { + qkvPacked = outputPacked; + } + } // Prepare for flash attention if (runningSum && runningMax) { if (sinksPtr == nullptr) { @@ -683,16 +696,18 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: } // Compute the current addresses - int kvHeadIndex = h / group_size; - int8_t * keyAddr = mKVCacheManager->addrOfKey(kvHeadIndex); - int8_t * keySum = mKVCacheManager->addrOfKeySum(kvHeadIndex); - int8_t * valueAddr = mKVCacheManager->addrOfValue(kvHeadIndex); - float* valueSum = (float*)mKVCacheManager->addrOfValueSum(kvHeadIndex); + int kvHeadIndex = h / group_size; + int8_t* keyAddr = mKVCacheManager->addrOfKey(kvHeadIndex); + int8_t* keySum = mKVCacheManager->addrOfKeySum(kvHeadIndex); + int8_t* valueAddr = mKVCacheManager->addrOfValue(kvHeadIndex); + float* valueSum = (float*)mKVCacheManager->addrOfValueSum(kvHeadIndex); // Get packed Q if (mKeyQuantMode != KVQuantMode::Int8) { - qReordered = mPackQ->host() + tId * mPackQ->stride(0); - gcore->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * mBytes), mHeadDim * mNumHead, &q_scale, units, seqLen, mHeadDim); + qReordered = mPackQ->host() + tId * mPackQ->stride(0); + gcore->MNNAttenPackAndScaleSingleHead((float*)qReordered, + (float*)(query->host() + h * mHeadDim * mBytes), + mHeadDim * mNumHead, &q_scale, units, seqLen, mHeadDim); } else { qReordered = mPackQ->host() + h * mPackQ->stride(0); qSumAddr = (float*)(mSumQ.ptr() + tId * ROUND_UP(seqLen, eP8) * mBlockNum * QUANT_INFO_BYTES); @@ -809,25 +824,37 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: int loop_e = seqLen / eP; int remain = seqLen % eP; auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * mBytes; - size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seqLen * mPack * mBytes, 0, 0, 0}; - for (int ei = 0 ; ei < loop_e; ei++) { - gcore->MNNPackedMatMul((float*)(qkPacked + (ei * eP * mPack) * mBytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); + size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, + ROUND_UP((size_t)mHeadDim, lP), + (size_t)subKvSeqLen, + (size_t)seqLen * mPack * mBytes, + 0, + 0, + 0}; + for (int ei = 0; ei < loop_e; ei++) { + gcore->MNNPackedMatMul((float*)(qkPacked + (ei * eP * mPack) * mBytes), + (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, + nullptr, nullptr, nullptr, nullptr); } if (remain > 0) { - gcore->MNNPackedMatMulRemain((float*)(qkPacked + (loop_e * eP * mPack) * mBytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + gcore->MNNPackedMatMulRemain((float*)(qkPacked + (loop_e * eP * mPack) * mBytes), + (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, + shapeParameters, nullptr, nullptr, nullptr, nullptr); } } else { auto eRemain = seqLen; auto srcInt8 = qReordered; auto dstInt8 = qkPacked; - auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP8) * (ROUND_UP(mHeadDim, lP8) * hP8 + 2 * hP8 * QUANT_INFO_BYTES); + auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP8) * + (ROUND_UP(mHeadDim, lP8) * hP8 + 2 * hP8 * QUANT_INFO_BYTES); gemmParam4QxK.weightKernelSum = (float*)(keySum + i * mBlockKV * QUANT_INFO_BYTES); - gemmParam4QxK.inputScale = qScale; - gemmParam4QxK.inputBias = qBias; + gemmParam4QxK.inputScale = qScale; + gemmParam4QxK.inputBias = qBias; gemmParam4QxK.srcKernelSum = qSumAddr; while (eRemain > 0) { auto eSize = ALIMIN(eP8, eRemain); - mInt8GemmKernel(dstInt8, srcInt8, keyPtr, UP_DIV(mHeadDim, lP8), mBytes * seqLen * mPack, UP_DIV(subKvSeqLen, mPack), &gemmParam4QxK, eSize); + mInt8GemmKernel(dstInt8, srcInt8, keyPtr, UP_DIV(mHeadDim, lP8), mBytes * seqLen * mPack, + UP_DIV(subKvSeqLen, mPack), &gemmParam4QxK, eSize); eRemain -= eP8; gemmParam4QxK.inputScale += eP8; gemmParam4QxK.inputBias += eP8; @@ -853,11 +880,13 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: (mKeyQuantMode == KVQuantMode::Int8), isLowerTriangular); } } - gcore->MNNSoftmax(qkSoftmax, (float*)qkPacked, runningMax, runningSum, diffScale, seqLen, subKvSeqLen, i * mBlockKV, kvValidOffset, mPack, useMaskInSoftmax); + gcore->MNNSoftmax(qkSoftmax, (float*)qkPacked, runningMax, runningSum, diffScale, seqLen, + subKvSeqLen, i * mBlockKV, kvValidOffset, mPack, useMaskInSoftmax); } // 3. qk @ v auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * mBytes; - auto rowStart = (!isLowerTriangular || i * mBlockKV < kvValidOffset)? 0 : (i * mBlockKV - kvValidOffset); + auto rowStart = + (!isLowerTriangular || i * mBlockKV < kvValidOffset) ? 0 : (i * mBlockKV - kvValidOffset); if (mValueQuantMode == KVQuantMode::TQ3) { // Vec_dot Value fusion: accumulate in rotated domain, WHT_inverse once @@ -953,8 +982,18 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: } } else if (mValueQuantMode != KVQuantMode::Int8) { auto valuePtr = valueAddr + i * vstride0 * mBytes; - size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seqLen * mPack * mBytes, 0, 0, 0}; - size_t bExtraStride = (i < kvBlocks - 1) ? 0 : (ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP) - ROUND_UP(subKvSeqLen, lP)) * hP * mBytes; + size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, + ROUND_UP((size_t)subKvSeqLen, lP), + (size_t)mHeadDim, + (size_t)dstStep, + 0, + 0, + 0}; + size_t bExtraStride = + (i < kvBlocks - 1) + ? 0 + : (ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP) - ROUND_UP(subKvSeqLen, lP)) * + hP * mBytes; shapeParameters[5] = bExtraStride; int loop_e = (seqLen - rowStart) / eP; @@ -964,10 +1003,12 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: elFloatV[0] = eP; elFloatV[1] = ROUND_UP(subKvSeqLen, lP); infoFloatV[2] = eP; - for ( ; ei < loop_e; ei++) { + for (; ei < loop_e; ei++) { srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (ei * eP + rowStart) * mPack * mBytes); gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV); - gcore->MNNPackedMatMul((float*)(qkvPacked + (ei * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); + gcore->MNNPackedMatMul((float*)(qkvPacked + (ei * eP + rowStart) * mPack * mBytes), + (float*)qkReordered, (float*)valuePtr, shapeParameters, nullptr, nullptr, + nullptr, nullptr); } if (remain > 0) { elFloatV[0] = remain; @@ -975,18 +1016,22 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (loop_e * eP + rowStart) * mPack * mBytes); shapeParameters[0] = remain * lP * mBytes; gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV); - gcore->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + gcore->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP + rowStart) * mPack * mBytes), + (float*)qkReordered, (float*)valuePtr, remain, shapeParameters, + nullptr, nullptr, nullptr, nullptr); } } else { // use int8 kernel to compute qk@ v auto valuePtr = valueAddr + i * vstride0; auto eRemain = seqLen - rowStart; - auto qkPtr = (int8_t*)(qkSoftmax) + rowStart * mPack * mBytes; // [UP_DIV(subKvSeqLen,pack),seqLen,pack] + auto qkPtr = + (int8_t*)(qkSoftmax) + rowStart * mPack * mBytes; // [UP_DIV(subKvSeqLen,pack),seqLen,pack] auto qkvFloat = qkvPacked + rowStart * mPack * mBytes; gemmParam4QKxV.weightKernelSum = valueSum + i * ROUND_UP(mHeadDim, hP8); sumParams4QKxV.valid = subKvSeqLen % lP8; sumParams4QKxV.LU = UP_DIV(subKvSeqLen, lP8); - auto dstInt8Ptr = (int8_t*)mQuantQK.ptr() + tId * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack); + auto dstInt8Ptr = + (int8_t*)mQuantQK.ptr() + tId * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack); srcPtr[0] = (const float*)(dstInt8Ptr); while (eRemain > 0) { @@ -998,14 +1043,16 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: infoInt8V[4] = eSize; // e to process elInt8V[0] = eSize; // e to process - for (int qi = 0; qi < UP_DIV(subKvSeqLen, mPack); ++qi) { - mQuantFunc((float*)(qkPtr + qi * seqLen * mPack * mBytes), dstInt8Ptr + qi * eSize * mPack, eSize, vQuantScale, -128, 127, vQuantBias, 0); + mQuantFunc((float*)(qkPtr + qi * seqLen * mPack * mBytes), dstInt8Ptr + qi * eSize * mPack, + eSize, vQuantScale, -128, 127, vQuantBias, 0); } - core->MNNPackC4Int8ForMatMul_A(qkReordered, (int8_t const **)srcPtr, infoInt8V, elInt8V); + core->MNNPackC4Int8ForMatMul_A(qkReordered, (int8_t const**)srcPtr, infoInt8V, elInt8V); // mSumQK - gcore->MNNSumByAxisLForMatmul_A(gemmParam4QKxV.srcKernelSum, qkReordered, (float*)mQKScale.ptr(), eSize, sumParams4QKxV); - mInt8GemmKernel(qkvFloat, qkReordered, valuePtr, UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), mBytes * seqLen * mPack, UP_DIV(mHeadDim, mPack), &gemmParam4QKxV, eSize); + gcore->MNNSumByAxisLForMatmul_A(gemmParam4QKxV.srcKernelSum, qkReordered, + (float*)mQKScale.ptr(), eSize, sumParams4QKxV); + mInt8GemmKernel(qkvFloat, qkReordered, valuePtr, UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), + dstStep, UP_DIV(mHeadDim, mPack), &gemmParam4QKxV, eSize); eRemain -= eSize; qkPtr += (eSize * mPack * mBytes); @@ -1015,14 +1062,18 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // 4. flash attention, update each sub kvSeq's final results if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) { - gcore->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seqLen, mPack, i, kvBlocks, mPackQKV->stride(0) / mBytes, mBytes, rowStart); + gcore->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, + runningSum, UP_DIV(mHeadDim, mPack), seqLen, mPack, i, + kvBlocks, mPackQKV->stride(0) / mBytes, mBytes, rowStart); } } // Final results writing: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim] - auto dstPtr = outputs[0]->host() + h * mHeadDim * mBytes; - // offset = {seqLen, mNumHead * mHeadDim}; - gcore->MNNUnpackCUnitTranspose((float*)dstPtr, (float*)outputPacked, seqLen, mHeadDim, offset); + if (!outputC4) { + auto dstPtr = outputs[0]->host() + h * mHeadDim * mBytes; + // offset = {seqLen, mNumHead * mHeadDim}; + gcore->MNNUnpackCUnitTranspose((float*)dstPtr, (float*)outputPacked, seqLen, mHeadDim, offset); + } } }; @@ -1039,9 +1090,12 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: if (!mKVCache) { mKVCacheManager->onClear(); } - auto ptr = outputs[0]->host(); - if (seqLen < outputs[0]->length(1)) { - ::memset(outputs[0]->host() + seqLen * mHeadDim * mNumHead * mBytes, 0, (outputs[0]->length(1)-seqLen) * mHeadDim * mNumHead * mBytes); + if (!outputC4) { + auto ptr = outputs[0]->host(); + if (seqLen < outputs[0]->length(1)) { + ::memset(outputs[0]->host() + seqLen * mHeadDim * mNumHead * mBytes, 0, + (outputs[0]->length(1) - seqLen) * mHeadDim * mNumHead * mBytes); + } } return NO_ERROR; } @@ -1082,8 +1136,8 @@ CPUAttention::CPUAttention(Backend* backend, bool kv_cache) : Execution(backend) // attentionOption / 8: // 0: do not use flash attention // 1: use flash attention - kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath; - kvconfig.mPrefixCacheDir = static_cast(backend)->getRuntime()->hint().prefixcacheDirPath; + kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath; + kvconfig.mPrefixCacheDir = static_cast(backend)->getRuntime()->hint().prefixcacheDirPath; kvconfig.mExpandChunk = 64; kvconfig.mBlockNum = 1; mKVCacheManager.reset(new CPUKVCacheManager(backend, kvconfig)); @@ -1103,4 +1157,3 @@ REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention); } // namespace MNN #endif // MNN_SUPPORT_TRANSFORMER_FUSE - diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 61ccf4fca3..a36447dbd5 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -189,11 +189,58 @@ MNNBinaryExecute CPUBinary::selectForInt(int type) { return nullptr; } +class MulSilu : public Execution { +public: + MulSilu(Backend *b) : Execution(b) { + auto func = static_cast(backend())->functions(); + auto precision = static_cast(backend())->precisionMode(); + mSilu = func->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SILU, precision); + mMul = func->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + } + virtual ~MulSilu() = default; + virtual ErrorCode onExecute(const std::vector &inputs, + const std::vector &outputs) override { + auto input0 = inputs[0]; + auto input1 = inputs[1]; + auto output = outputs[0]; + auto size = static_cast(backend())->getTensorSize(input0); + auto schedule = static_cast(backend())->multiThreadDivide(size); + auto bytes = static_cast(backend())->functions()->bytes; + auto i0 = input0->host(); + auto i1 = input1->host(); + auto o0 = output->host(); + + MNN_CONCURRENCY_BEGIN(tId, schedule.second) { + int start = schedule.first * (int)tId; + int realSize = schedule.first; + if (tId == schedule.second - 1) { + realSize = size - start; + } + if (realSize > 0) { + auto inp = i0 + start * bytes; + auto inp1 = i1 + start * bytes; + auto out = o0 + start * bytes; + mSilu((float *)out, (float *)inp1, realSize); + mMul((float *)out, (float *)out, (float *)inp, realSize, -1); + } + } + MNN_CONCURRENCY_END(); + return NO_ERROR; + } + +private: + MNNBinaryExecute mMul; + MNNUnaryExecute mSilu; +}; + class CPUBinaryCreator : public CPUBackend::Creator { public: virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, Backend* backend) const override { int32_t type = op->main_as_BinaryOp()->opType(); + if (BinaryOpOperation_MUL_SILU == type) { + return new MulSilu(backend); + } auto dataType = inputs[0]->getType(); auto core = static_cast(backend)->functions(); #ifdef MNN_SUPPORT_QUANT_EXTEND diff --git a/source/backend/cpu/CPUConvolution.hpp b/source/backend/cpu/CPUConvolution.hpp index 13abecef0f..169980861b 100644 --- a/source/backend/cpu/CPUConvolution.hpp +++ b/source/backend/cpu/CPUConvolution.hpp @@ -13,15 +13,18 @@ #include "CPUBackend.hpp" #include "core/ConvolutionCommon.hpp" namespace MNN { - class PerfConfig { +class PerfConfig { public: - PerfConfig() : isParallelInner{false}, eTile{1}, ePack{1}, hPack{1}, instructionCosts{.0f} { - } + PerfConfig() : isParallelInner{false}, eTile{1}, ePack{1}, hPack{1}, instructionCosts{.0f} {} PerfConfig(bool isParallelInner_, int eTile_, int ePack_, int hPack_, float instructionCosts_) - : isParallelInner{isParallelInner_}, eTile{eTile_}, ePack{ePack_}, hPack{hPack_}, instructionCosts{instructionCosts_} { - } + : isParallelInner{isParallelInner_}, + eTile{eTile_}, + ePack{ePack_}, + hPack{hPack_}, + instructionCosts{instructionCosts_} {} bool operator!=(const PerfConfig& other) { - return isParallelInner != other.isParallelInner || ePack != other.ePack || eTile != other.eTile || hPack != other.hPack; + return isParallelInner != other.isParallelInner || ePack != other.ePack || eTile != other.eTile || + hPack != other.hPack; } PerfConfig& operator=(const PerfConfig& other) { isParallelInner = other.isParallelInner; @@ -33,8 +36,8 @@ namespace MNN { } bool isParallelInner; // inner or outer parallel - int eTile; // L2 cache tiling - int ePack; // micro tile size along ow*oh dimension + int eTile; // L2 cache tiling + int ePack; // micro tile size along ow*oh dimension int hPack; float instructionCosts; }; @@ -58,15 +61,15 @@ class CPUConvolution : public Execution { std::vector mReluThreshold; }; struct ResourceInt8 { - std::vector mInt8WeightKernelSum; // PTQ's sum, DynamicQ not use - std::shared_ptr mWeightInt8; // PTQ's and DynamicQ's weight - std::shared_ptr mOriginBias; // PTQ's and DynamicQ's bias - std::shared_ptr mOriginScale; // PTQ's scale + bias, DynamicQ's alpha + zero; - std::shared_ptr mWeightKernelSum; // PTQ's and DynamicQ's weight kernel sum; + std::vector mInt8WeightKernelSum; // PTQ's sum, DynamicQ not use + std::shared_ptr mWeightInt8; // PTQ's and DynamicQ's weight + std::shared_ptr mOriginBias; // PTQ's and DynamicQ's bias + std::shared_ptr mOriginScale; // PTQ's scale + bias, DynamicQ's alpha + zero; + std::shared_ptr mWeightKernelSum; // PTQ's and DynamicQ's weight kernel sum; std::vector mReluThreshold; // relu or relu6 bool mRelu; - int mWeightBits; // quant bits + int mWeightBits; // quant bits bool mUseConvQuan = true; bool mWeightAsymmetricQuant = true; @@ -79,6 +82,10 @@ class CPUConvolution : public Execution { int8_t mClampMax; bool mDynamicQuant = false; int32_t mBlockNum = 1; + int32_t mHp = 0; + int32_t mLp = 0; + // For int4: 0: (x + half, x) -> int8, 1: (x, x + 1) -> int8. + int32_t mPackMode = 0; }; struct MutableResourceInt8 { MutableResourceInt8(std::shared_ptr res, Backend* backend, float* scalePtr = nullptr); @@ -96,22 +103,23 @@ class CPUConvolution : public Execution { int32_t mShiftBits = 14; bool mValid; }; - static std::shared_ptr makeResourceInt8(Backend *backend, const MNN::Op *op, int pack=4); - CPUConvolution(const Convolution2DCommon *convOp, Backend *b); + static std::shared_ptr makeResourceInt8(Backend* backend, const MNN::Op* op, int pack = 4); + CPUConvolution(const Convolution2DCommon* convOp, Backend* b); virtual ~CPUConvolution() = default; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; static int reorderWeightSize(int depth, int outputCount, int kernelSize, int unitDepth, int unitOC); std::vector getPostParameters() const; + public: PerfConfig mConvPerfconfig; + protected: - const Convolution2DCommon *mCommon; + const Convolution2DCommon* mCommon; // In execute, use pad from mPadX and mPadY, don't use mCommon's pad mutable int mPadX; mutable int mPadY; - }; } // namespace MNN diff --git a/source/backend/cpu/CPULayerNorm.cpp b/source/backend/cpu/CPULayerNorm.cpp index 0a2af995cb..e3deb016a6 100644 --- a/source/backend/cpu/CPULayerNorm.cpp +++ b/source/backend/cpu/CPULayerNorm.cpp @@ -38,7 +38,8 @@ std::shared_ptr CPULayerNorm::makeResource(const MNN::Op MNN_ASSERT(layer_norm_param->gamma()->size() == layer_norm_param->beta()->size()); gammasize = layer_norm_param->gamma()->size(); } - hasGammaBeta = hasGammaBeta || (layer_norm_param->external() && layer_norm_param->external()->size() > 1 && layer_norm_param->external()->data()[1] > 0); + hasGammaBeta = hasGammaBeta || (layer_norm_param->external() && layer_norm_param->external()->size() > 1 && + layer_norm_param->external()->data()[1] > 0); if (hasGammaBeta && gammasize == 0) { gammasize = layer_norm_param->external()->data()[1] / sizeof(float); } @@ -47,7 +48,8 @@ std::shared_ptr CPULayerNorm::makeResource(const MNN::Op // Use uint8_t to avoid lowp reduce float bytes res->mGamma.reset(Tensor::createDevice({gammasize * 4})); res->mBeta.reset(Tensor::createDevice({gammasize * 4})); - auto status = backend->onAcquireBuffer(res->mGamma.get(), Backend::STATIC) && backend->onAcquireBuffer(res->mBeta.get(), Backend::STATIC); + auto status = backend->onAcquireBuffer(res->mGamma.get(), Backend::STATIC) && + backend->onAcquireBuffer(res->mBeta.get(), Backend::STATIC); if (!status) { MNN_ERROR("Out of memory when gamma is acquired in CPULayerNorm.\n"); return nullptr; @@ -56,7 +58,7 @@ std::shared_ptr CPULayerNorm::makeResource(const MNN::Op if (useCachedMmap) { return res; } - + const float* gamma_data = layer_norm_param->gamma()->data(); memcpy(res->mGamma->host(), gamma_data, gammasize * sizeof(float)); const float* beta_data = layer_norm_param->beta()->data(); @@ -65,12 +67,9 @@ std::shared_ptr CPULayerNorm::makeResource(const MNN::Op return res; } -ErrorCode CPULayerNorm::onExecute(const std::vector &inputs, - const std::vector &outputs) { +ErrorCode CPULayerNorm::onExecute(const std::vector& inputs, const std::vector& outputs) { const float* gamma = mResource->mIniGammaBeta ? mResource->mGamma->host() : nullptr; const float* beta = mResource->mIniGammaBeta ? mResource->mBeta->host() : nullptr; - auto input = inputs[0]->host(); - auto output = outputs[0]->host(); auto bn = static_cast(backend()); auto core = bn->functions(); auto threadNumber = bn->threadNumber(); @@ -84,21 +83,83 @@ ErrorCode CPULayerNorm::onExecute(const std::vector &inputs, bytes = 1; } + if (mNeedUnpackC4 && core->MNNNormPacked != nullptr && bytes == 4) { + const int batch = inputs[0]->length(0); + const int channel = inputs[0]->length(1); + auto inputPtr = inputs[0]->host(); + auto outputPtr = outputs[0]->host(); + if (inputs.size() == 2 && outputs.size() == 2) { + auto input1Ptr = inputs[1]->host(); + auto output1Ptr = outputs[1]->host(); + int elementSize = static_cast(backend())->getTensorSize(inputs[0]); + int pack = core->pack; + core->MNNMatrixAdd(reinterpret_cast(outputPtr), reinterpret_cast(inputPtr), + reinterpret_cast(input1Ptr), elementSize / pack, 0, 0, 0, 1); + core->MNNNormPacked(reinterpret_cast(output1Ptr), reinterpret_cast(outputPtr), gamma, + beta, mResource->mEpsilon, batch, channel, mResource->mRMSNorm); + return NO_ERROR; + } + core->MNNNormPacked(reinterpret_cast(outputPtr), reinterpret_cast(inputPtr), gamma, beta, + mResource->mEpsilon, batch, channel, mResource->mRMSNorm); + return NO_ERROR; + } + if (mNeedUnpackC4 && bytes == 2) { + const int batch = inputs[0]->length(0); + const int channel = inputs[0]->length(1); + const int pack = core->pack; + auto inputPtr = reinterpret_cast(inputs[0]->host()); + auto outputPtr = reinterpret_cast(outputs[0]->host()); + const int16_t* input1Ptr = nullptr; + int16_t* output1Ptr = nullptr; + if (inputs.size() == 2 && outputs.size() == 2) { + input1Ptr = reinterpret_cast(inputs[1]->host()); + output1Ptr = reinterpret_cast(outputs[1]->host()); + } + MNN_CONCURRENCY_BEGIN(ttId, threadNumber) { + auto tmpInput = reinterpret_cast(mTmpInputFloat.ptr() + ttId * channel * sizeof(float)); + auto tmpOutput = reinterpret_cast(mTmpOutputFloat.ptr() + ttId * channel * sizeof(float)); + for (int n = ttId; n < batch; n += threadNumber) { + for (int c = 0; c < channel; ++c) { + const int index = ((c / pack) * batch + n) * pack + c % pack; + core->MNNLowpToFp32(inputPtr + index, tmpInput + c, 1); + if (input1Ptr != nullptr) { + float v1; + core->MNNLowpToFp32(input1Ptr + index, &v1, 1); + tmpInput[c] += v1; + core->MNNFp32ToLowp(tmpInput + c, outputPtr + index, 1); + } + } + MNNNorm(tmpOutput, tmpInput, gamma, beta, mResource->mEpsilon, channel, mResource->mRMSNorm); + auto normOutput = output1Ptr != nullptr ? output1Ptr : outputPtr; + for (int c = 0; c < channel; ++c) { + const int index = ((c / pack) * batch + n) * pack + c % pack; + core->MNNFp32ToLowp(tmpOutput + c, normOutput + index, 1); + } + } + } + MNN_CONCURRENCY_END(); + return NO_ERROR; + } + + auto input = inputs[0]->host(); + auto output = outputs[0]->host(); MNN_CONCURRENCY_BEGIN(ttId, threadNumber) { - for (int tId=ttId; tId < mOutterSize; tId += threadNumber) { + for (int tId = ttId; tId < mOutterSize; tId += threadNumber) { const float* inner_input = (const float*)(input + tId * mInnerSize * bytes); float* inner_output = (float*)(output + tId * mInnerSize * bytes); if (bytes != 4) { auto tmpInput = (float*)(mTmpInputFloat.ptr() + ttId * mInnerSize * sizeof(float)); auto tmpOutput = (float*)(mTmpOutputFloat.ptr() + ttId * mInnerSize * sizeof(float)); if (bytes == 1) { - CPUCastCreator::cast(inner_input, tmpInput, CPUCastCreator::INT8_TO_FlOAT, mInnerSize, inputQuan->scale, inputQuan->zero, inputQuan->min, inputQuan->max, bn); + CPUCastCreator::cast(inner_input, tmpInput, CPUCastCreator::INT8_TO_FlOAT, mInnerSize, + inputQuan->scale, inputQuan->zero, inputQuan->min, inputQuan->max, bn); } else { core->MNNLowpToFp32((const int16_t*)inner_input, tmpInput, mInnerSize); } MNNNorm(tmpOutput, tmpInput, gamma, beta, mResource->mEpsilon, mInnerSize, mResource->mRMSNorm); if (bytes == 1) { - CPUCastCreator::cast(tmpOutput, inner_output, CPUCastCreator::FlOAT_TO_INT8, mInnerSize, outputQuan->scale, outputQuan->zero, outputQuan->min, outputQuan->max, bn); + CPUCastCreator::cast(tmpOutput, inner_output, CPUCastCreator::FlOAT_TO_INT8, mInnerSize, + outputQuan->scale, outputQuan->zero, outputQuan->min, outputQuan->max, bn); } else { core->MNNFp32ToLowp(tmpOutput, (int16_t*)inner_output, mInnerSize); } @@ -111,10 +172,11 @@ ErrorCode CPULayerNorm::onExecute(const std::vector &inputs, return NO_ERROR; } -ErrorCode CPULayerNorm::onResize(const std::vector &inputs, - const std::vector &outputs) { +ErrorCode CPULayerNorm::onResize(const std::vector& inputs, const std::vector& outputs) { mOutterSize = 1; mInnerSize = 1; + const auto layout = TensorUtils::getDescribe(inputs[0])->dimensionFormat; + mNeedUnpackC4 = (layout == MNN_DATA_FORMAT_NC4HW4); do { // Compute outter and inner int rank = inputs.at(0)->dimensions(); @@ -135,7 +197,7 @@ ErrorCode CPULayerNorm::onResize(const std::vector &inputs, for (int i = rank - mResource->mAxis; i < rank; ++i) { mInnerSize *= inputs.at(0)->length(i); } - if (mResource->mIniGammaBeta) { + if (mResource->mIniGammaBeta && !mNeedUnpackC4) { MNN_ASSERT(mResource->mGamma->size() == mInnerSize * sizeof(float)); } } while (false); @@ -143,9 +205,12 @@ ErrorCode CPULayerNorm::onResize(const std::vector &inputs, auto threadNumber = ALIMIN(bn->threadNumber(), mOutterSize); auto buf = bn->getBufferAllocator(); - if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1 || bn->functions()->bytes != 4) { - mTmpInputFloat = buf->alloc(threadNumber * mInnerSize * sizeof(float)); - mTmpOutputFloat = buf->alloc(threadNumber * mInnerSize * sizeof(float)); + if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1 || + bn->functions()->bytes != 4) { + int tmpSize = mNeedUnpackC4 ? inputs[0]->length(1) : mInnerSize; + int tmpThreadNumber = mNeedUnpackC4 ? bn->threadNumber() : threadNumber; + mTmpInputFloat = buf->alloc(tmpThreadNumber * tmpSize * sizeof(float)); + mTmpOutputFloat = buf->alloc(tmpThreadNumber * tmpSize * sizeof(float)); buf->free(mTmpInputFloat); buf->free(mTmpOutputFloat); } @@ -165,7 +230,8 @@ bool CPULayerNorm::onClone(Backend* bn, const Op* op, Execution** dst) { class CPULayerNormCreator : public CPUBackend::Creator { public: - Execution* onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, Backend* backend) const override { + Execution* onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, + Backend* backend) const override { auto res = CPULayerNorm::makeResource(op, backend); if (nullptr == res.get()) { return nullptr; @@ -176,4 +242,4 @@ class CPULayerNormCreator : public CPUBackend::Creator { REGISTER_CPU_OP_CREATOR(CPULayerNormCreator, OpType_LayerNorm); -} // namespace MNN +} // namespace MNN diff --git a/source/backend/cpu/CPULayerNorm.hpp b/source/backend/cpu/CPULayerNorm.hpp index aa610aef88..c615a7c1f3 100644 --- a/source/backend/cpu/CPULayerNorm.hpp +++ b/source/backend/cpu/CPULayerNorm.hpp @@ -27,17 +27,19 @@ class CPULayerNorm : public Execution { CPULayerNorm(std::shared_ptr res, Backend* backend); virtual ~CPULayerNorm(); - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onExecute(const std::vector& inputs, const std::vector& outputs) override; + virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; static std::shared_ptr makeResource(const MNN::Op* op, Backend* backend); + private: std::shared_ptr mResource; int mInnerSize = 1; int mOutterSize = 1; MemChunk mTmpInputFloat; MemChunk mTmpOutputFloat; + bool mNeedUnpackC4 = false; }; } // namespace MNN #endif /* CPULayerNorm_hpp */ diff --git a/source/backend/cpu/CPUOPRegister.cpp b/source/backend/cpu/CPUOPRegister.cpp index 9c3fdb5404..5b633a1035 100644 --- a/source/backend/cpu/CPUOPRegister.cpp +++ b/source/backend/cpu/CPUOPRegister.cpp @@ -57,6 +57,7 @@ extern void ___CPURangeCreator__OpType_Range__(); extern void ___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__(); extern void ___CPUInterp3DCreator__OpType_Interp3D__(); extern void ___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__(); +extern void ___CPURoPECreator__OpType_RoPE__(); extern void ___ConvolutionFactory__OpType_Convolution__(); extern void ___CPUConvInt8Creator__OpType_ConvInt8__(); extern void ___CPURNNSequenceGRUCreator__OpType_RNNSequenceGRU__(); @@ -80,83 +81,84 @@ extern void ___CPUAttentionCreator__OpType_Attention__(); extern void ___CPULinearAttentionCreator__OpType_LinearAttention__(); #endif void registerCPUOps() { -___CPUCropAndResizeCreator__OpType_CropAndResize__(); -___CPUArgMaxCreator__OpType_ArgMax__(); -___CPUArgMaxCreator__OpType_ArgMin__(); -___CPUScaleCreator__OpType_Scale__(); -___CPUSelectCreator__OpType_Select__(); -___CPUSoftmaxCreator__OpType_Softmax__(); -___CPUDetectionPostProcessCreator__OpType_DetectionPostProcess__(); -___CPUCastCreator__OpType_Cast__(); -___CPUProposalCreator__OpType_Proposal__(); -___CPUInterpCreator__OpType_Interp__(); -___CPUGridSampleCreator__OpType_GridSample__(); -___CPUDetectionOutputCreator__OpType_DetectionOutput__(); -___CPUUnravelIndexCreator__OpType_UnravelIndex__(); -___CPUMatMulCreator__OpType_MatMul__(); -___CPUMomentsCreator__OpType_Moments__(); -___CPUSegmentMeanCreator__OpType_Segment__(); -___CPUInstanceNormCreator__OpType_InstanceNorm__(); -___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__(); -___CPUWhereCreator__OpType_Where__(); -___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__(); -___CPUDeconvolutionCreator__OpType_Deconvolution__(); -___CPUBinaryCreator__OpType_BinaryOp__(); -___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__(); -___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__(); -___CPUPoolCreator__OpType_Pooling__(); -___CPUDetCreator__OpType_Det__(); -___CPUHistogramCreator__OpType_Histogram__(); -___CPUPluginCreator__OpType_Plugin__(); -___CPUInt8ToFloatCreator__OpType_Int8ToFloat__(); -___CPUDynamicQuantCreator__OpType_DynamicQuant__(); -___CPUROIAlignCreator__OpType_ROIAlign__(); -___CPUROIPoolingCreator__OpType_ROIPooling__(); -___CPUTopKV2Creator__OpType_TopKV2__(); -___CPUUnaryCreator__OpType_UnaryOp__(); -___CPUStftCreator__OpType_Stft__(); -___CPUReductionCreator__OpType_Reduction__(); -___CPUReluCreator__OpType_ReLU__(); -___CPUReluCreator__OpType_PReLU__(); -___CPURelu6Creator__OpType_ReLU6__(); -___CPUUniqueCreator__OpType_Unique__(); -___CPUImageProcessCreator__OpType_ImageProcess__(); -___CPUDepthwiseConvInt8Creator__OpType_DepthwiseConvInt8__(); -___CPUOneHotCreator__OpType_OneHot__(); -___CPUMatrixBandPartCreator__OpType_MatrixBandPart__(); -___CPUQuantizedAddCreator__OpType_QuantizedAdd__(); -___CPUDeconvolutionDepthwiseCreator__OpType_DeconvolutionDepthwise__(); -___CPUFloatToInt8Creator__OpType_FloatToInt8__(); -___CPULinSpaceCreator__OpType_LinSpace__(); -___CPUNonMaxSuppressionV2Creator__OpType_NonMaxSuppressionV2__(); -___CPUDequantizeCreator__OpType_Dequantize__(); -___CPURasterFactory__OpType_Raster__(); -___CPURasterFactory__OpType_While__(); -___CPUConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__(); -___CPURangeCreator__OpType_Range__(); -___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__(); -___CPUInterp3DCreator__OpType_Interp3D__(); -___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__(); -___ConvolutionFactory__OpType_Convolution__(); -___CPUConvInt8Creator__OpType_ConvInt8__(); -___CPURNNSequenceGRUCreator__OpType_RNNSequenceGRU__(); -___CPUEltwiseCreator__OpType_Eltwise__(); -___CPURandomCreator__OpType_RandomUniform__(); -___CPURandomCreator__OpType_RandomNormal__(); -___CPUSetDiff1DCreator__OpType_SetDiff1D__(); -___CPUEltwiseInt8Creator__OpType_EltwiseInt8__(); -___CPUSvdCreator__OpType_Svd__(); -___CPULayerNormCreator__OpType_LayerNorm__(); -___CPUExternalConstCreator__OpType_Const__(); -___CPUExternalConstCreator__OpType_TrainableParam__(); + ___CPUCropAndResizeCreator__OpType_CropAndResize__(); + ___CPUArgMaxCreator__OpType_ArgMax__(); + ___CPUArgMaxCreator__OpType_ArgMin__(); + ___CPUScaleCreator__OpType_Scale__(); + ___CPUSelectCreator__OpType_Select__(); + ___CPUSoftmaxCreator__OpType_Softmax__(); + ___CPUDetectionPostProcessCreator__OpType_DetectionPostProcess__(); + ___CPUCastCreator__OpType_Cast__(); + ___CPUProposalCreator__OpType_Proposal__(); + ___CPUInterpCreator__OpType_Interp__(); + ___CPUGridSampleCreator__OpType_GridSample__(); + ___CPUDetectionOutputCreator__OpType_DetectionOutput__(); + ___CPUUnravelIndexCreator__OpType_UnravelIndex__(); + ___CPUMatMulCreator__OpType_MatMul__(); + ___CPUMomentsCreator__OpType_Moments__(); + ___CPUSegmentMeanCreator__OpType_Segment__(); + ___CPUInstanceNormCreator__OpType_InstanceNorm__(); + ___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__(); + ___CPUWhereCreator__OpType_Where__(); + ___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__(); + ___CPUDeconvolutionCreator__OpType_Deconvolution__(); + ___CPUBinaryCreator__OpType_BinaryOp__(); + ___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__(); + ___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__(); + ___CPUPoolCreator__OpType_Pooling__(); + ___CPUDetCreator__OpType_Det__(); + ___CPUHistogramCreator__OpType_Histogram__(); + ___CPUPluginCreator__OpType_Plugin__(); + ___CPUInt8ToFloatCreator__OpType_Int8ToFloat__(); + ___CPUDynamicQuantCreator__OpType_DynamicQuant__(); + ___CPUROIAlignCreator__OpType_ROIAlign__(); + ___CPUROIPoolingCreator__OpType_ROIPooling__(); + ___CPUTopKV2Creator__OpType_TopKV2__(); + ___CPUUnaryCreator__OpType_UnaryOp__(); + ___CPUStftCreator__OpType_Stft__(); + ___CPUReductionCreator__OpType_Reduction__(); + ___CPUReluCreator__OpType_ReLU__(); + ___CPUReluCreator__OpType_PReLU__(); + ___CPURelu6Creator__OpType_ReLU6__(); + ___CPUUniqueCreator__OpType_Unique__(); + ___CPUImageProcessCreator__OpType_ImageProcess__(); + ___CPUDepthwiseConvInt8Creator__OpType_DepthwiseConvInt8__(); + ___CPUOneHotCreator__OpType_OneHot__(); + ___CPUMatrixBandPartCreator__OpType_MatrixBandPart__(); + ___CPUQuantizedAddCreator__OpType_QuantizedAdd__(); + ___CPUDeconvolutionDepthwiseCreator__OpType_DeconvolutionDepthwise__(); + ___CPUFloatToInt8Creator__OpType_FloatToInt8__(); + ___CPULinSpaceCreator__OpType_LinSpace__(); + ___CPUNonMaxSuppressionV2Creator__OpType_NonMaxSuppressionV2__(); + ___CPUDequantizeCreator__OpType_Dequantize__(); + ___CPURasterFactory__OpType_Raster__(); + ___CPURasterFactory__OpType_While__(); + ___CPUConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__(); + ___CPURangeCreator__OpType_Range__(); + ___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__(); + ___CPUInterp3DCreator__OpType_Interp3D__(); + ___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__(); + ___CPURoPECreator__OpType_RoPE__(); + ___ConvolutionFactory__OpType_Convolution__(); + ___CPUConvInt8Creator__OpType_ConvInt8__(); + ___CPURNNSequenceGRUCreator__OpType_RNNSequenceGRU__(); + ___CPUEltwiseCreator__OpType_Eltwise__(); + ___CPURandomCreator__OpType_RandomUniform__(); + ___CPURandomCreator__OpType_RandomNormal__(); + ___CPUSetDiff1DCreator__OpType_SetDiff1D__(); + ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__(); + ___CPUSvdCreator__OpType_Svd__(); + ___CPULayerNormCreator__OpType_LayerNorm__(); + ___CPUExternalConstCreator__OpType_Const__(); + ___CPUExternalConstCreator__OpType_TrainableParam__(); #ifdef MNN_SUPPORT_RENDER -___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__(); -___CPURasterDiffCreator__OpType_RasterDiff__(); -___CPUTextureCreator__OpType_Texture__(); + ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__(); + ___CPURasterDiffCreator__OpType_RasterDiff__(); + ___CPUTextureCreator__OpType_Texture__(); #endif #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -___CPUAttentionCreator__OpType_Attention__(); -___CPULinearAttentionCreator__OpType_LinearAttention__(); + ___CPUAttentionCreator__OpType_Attention__(); + ___CPULinearAttentionCreator__OpType_LinearAttention__(); #endif } -} +} // namespace MNN diff --git a/source/backend/cpu/CPURoPE.cpp b/source/backend/cpu/CPURoPE.cpp new file mode 100644 index 0000000000..396e535431 --- /dev/null +++ b/source/backend/cpu/CPURoPE.cpp @@ -0,0 +1,192 @@ +// +// CPURoPE.cpp +// MNN +// +// Created by MNN on 2018/08/07. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#include "CPURoPE.hpp" +#include "CPUBackend.hpp" +#include "MNN_generated.h" +#include "backend/cpu/compute/CommonOptFunction.h" +#include "core/Concurrency.h" +#include "core/Macro.h" +#include "core/TensorUtils.hpp" + +namespace MNN { +CPURoPE::CPURoPE(const Op* op, Backend* bn) : MNN::Execution(bn) { + const Op* qLayernorm = nullptr; + const Op* kLayernorm = nullptr; + if (nullptr != op && OpParameter_Extra == op->main_type()) { + auto extra = op->main_as_Extra(); + if (nullptr != extra && nullptr != extra->attr()) { + for (int i = 0; i < extra->attr()->size(); ++i) { + auto attr = extra->attr()->GetAs(i); + if (nullptr == attr || nullptr == attr->key()) { + continue; + } + if (attr->key()->str() == "rope_cut_head_dim") { + mRopeCutHeadDim = attr->i(); + continue; + } + if (attr->key()->str() == "q_norm") { + qLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + mQNorm = CPULayerNorm::makeResource(qLayernorm, bn); + continue; + } + if (attr->key()->str() == "k_norm") { + kLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + mKNorm = CPULayerNorm::makeResource(kLayernorm, bn); + continue; + } + } + } + } +} + +CPURoPE::~CPURoPE() { + // Do nothing. +} + +CPURoPE::CPURoPE(Backend* bn) : Execution(bn) { + // Do nothing. +} + +ErrorCode CPURoPE::onResize(const std::vector& inputs, const std::vector& outputs) { + auto bn = static_cast(backend()); + auto threadNumber = bn->threadNumber(); + auto buf = bn->getBufferAllocator(); + if (bn->functions()->bytes != 4) { + if (mQNorm) { + auto Q = inputs[0]; + int numHead = Q->length(2); + int headDim = Q->length(3); + mTmpQFloat = buf->alloc(threadNumber * numHead * headDim * sizeof(float)); + buf->free(mTmpQFloat); + } + if (mKNorm) { + auto K = inputs[1]; + int kvnumHead = K->length(2); + int headDim = K->length(3); + mTmpKFloat = buf->alloc(threadNumber * kvnumHead * headDim * sizeof(float)); + buf->free(mTmpKFloat); + } + } + return NO_ERROR; +} + +bool CPURoPE::onClone(Backend* bn, const Op* op, Execution** dst) { + if (nullptr == dst) { + return true; + } + auto rope = new CPURoPE(bn); + rope->mRopeCutHeadDim = mRopeCutHeadDim; + rope->mQNorm = mQNorm; + rope->mKNorm = mKNorm; + *dst = rope; + return true; +} + +ErrorCode CPURoPE::onExecute(const std::vector& inputs, const std::vector& outputs) { + auto Q = inputs[0]; + auto K = inputs[1]; + auto cosEven = inputs[2]; + auto cosOdd = inputs[3]; + auto sinEven = inputs[4]; + auto sinOdd = inputs[5]; + + auto QOutput = outputs[0]; + auto KOutput = outputs[1]; + int batch = Q->length(0); + int seqLen = Q->length(1); + int numHead = Q->length(2); + int headDim = Q->length(3); + int kvnumHead = K->length(2); + auto halfHeadDim = headDim / 2; + int threadNum = static_cast(backend())->threadNumber(); + int totalWork = batch * seqLen; + auto core = static_cast(backend())->functions(); + MNN_ASSERT(core->MNNRoPECompute != nullptr); + + MNN_CONCURRENCY_BEGIN(tId, threadNum) { + int start = tId * totalWork / threadNum; + int end = (tId + 1) * totalWork / threadNum; + for (int i = start; i < end; ++i) { + auto cosEvenPtr = static_cast(cosEven->host()) + i * halfHeadDim * core->bytes; + auto cosOddPtr = static_cast(cosOdd->host()) + i * halfHeadDim * core->bytes; + auto sinEvenPtr = static_cast(sinEven->host()) + i * halfHeadDim * core->bytes; + auto sinOddPtr = static_cast(sinOdd->host()) + i * halfHeadDim * core->bytes; + auto qPtr = static_cast(Q->host()) + i * numHead * headDim * core->bytes; + auto qPtrOut = static_cast(QOutput->host()) + i * numHead * headDim * core->bytes; + + if (mQNorm) { + int size = headDim; + const float* gamma = mQNorm->mIniGammaBeta ? mQNorm->mGamma->host() : nullptr; + const float* beta = mQNorm->mIniGammaBeta ? mQNorm->mBeta->host() : nullptr; + if (core->bytes == 4) { + for (int h = 0; h < numHead; ++h) { + MNNNorm(reinterpret_cast(qPtrOut) + h * headDim, + reinterpret_cast(qPtr) + h * headDim, gamma, beta, mQNorm->mEpsilon, size, + mQNorm->mRMSNorm); + } + qPtr = qPtrOut; + } else { + int totalSize = numHead * headDim; + auto tmpQ = reinterpret_cast(mTmpQFloat.ptr() + tId * totalSize * sizeof(float)); + core->MNNLowpToFp32(reinterpret_cast(qPtr), tmpQ, totalSize); + for (int h = 0; h < numHead; ++h) { + MNNNorm(tmpQ + h * headDim, tmpQ + h * headDim, gamma, beta, mQNorm->mEpsilon, size, + mQNorm->mRMSNorm); + } + core->MNNFp32ToLowp(tmpQ, reinterpret_cast(qPtrOut), totalSize); + qPtr = qPtrOut; + } + } + core->MNNRoPECompute(qPtrOut, qPtr, cosEvenPtr, cosOddPtr, sinEvenPtr, sinOddPtr, numHead, headDim, + mRopeCutHeadDim); + + qPtr = static_cast(K->host()) + i * kvnumHead * headDim * core->bytes; + qPtrOut = static_cast(KOutput->host()) + i * kvnumHead * headDim * core->bytes; + + if (mKNorm) { + int size = headDim; + const float* gamma = mKNorm->mIniGammaBeta ? mKNorm->mGamma->host() : nullptr; + const float* beta = mKNorm->mIniGammaBeta ? mKNorm->mBeta->host() : nullptr; + if (core->bytes == 4) { + for (int h = 0; h < kvnumHead; ++h) { + MNNNorm(reinterpret_cast(qPtrOut) + h * headDim, + reinterpret_cast(qPtr) + h * headDim, gamma, beta, mKNorm->mEpsilon, size, + mKNorm->mRMSNorm); + } + qPtr = qPtrOut; + } else { + int totalSize = kvnumHead * headDim; + auto tmpK = reinterpret_cast(mTmpKFloat.ptr() + tId * totalSize * sizeof(float)); + core->MNNLowpToFp32(reinterpret_cast(qPtr), tmpK, totalSize); + for (int h = 0; h < kvnumHead; ++h) { + MNNNorm(tmpK + h * headDim, tmpK + h * headDim, gamma, beta, mKNorm->mEpsilon, size, + mKNorm->mRMSNorm); + } + core->MNNFp32ToLowp(tmpK, reinterpret_cast(qPtrOut), totalSize); + qPtr = qPtrOut; + } + } + core->MNNRoPECompute(qPtrOut, qPtr, cosEvenPtr, cosOddPtr, sinEvenPtr, sinOddPtr, kvnumHead, headDim, + mRopeCutHeadDim); + } + } + MNN_CONCURRENCY_END(); + return NO_ERROR; +} + +class CPURoPECreator : public CPUBackend::Creator { +public: + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const override { + return new CPURoPE(op, backend); + } +}; + +REGISTER_CPU_OP_CREATOR(CPURoPECreator, OpType_RoPE); +} // namespace MNN diff --git a/source/backend/cpu/CPURoPE.hpp b/source/backend/cpu/CPURoPE.hpp new file mode 100644 index 0000000000..e92d32c6f9 --- /dev/null +++ b/source/backend/cpu/CPURoPE.hpp @@ -0,0 +1,35 @@ +// +// CPURoPE.hpp +// MNN +// +// Created by MNN on 2018/08/07. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifndef CPURoPE_hpp +#define CPURoPE_hpp + +#include +#include "backend/cpu/CPULayerNorm.hpp" +#include "core/Execution.hpp" + +namespace MNN { +class CPURoPE : public Execution { +public: + CPURoPE(const Op* op, Backend* bn); + virtual ~CPURoPE(); + virtual ErrorCode onExecute(const std::vector& inputs, const std::vector& outputs) override; + virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; + +private: + CPURoPE(Backend* bn); + int mRopeCutHeadDim = 0; + std::shared_ptr mQNorm; + std::shared_ptr mKNorm; + MemChunk mTmpQFloat; + MemChunk mTmpKFloat; +}; + +} // namespace MNN +#endif /* CPURoPE_hpp */ diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index 9243428664..7594cef374 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -4635,6 +4635,96 @@ namespace MNN { static CoreFunctions* gCoreFunction = nullptr; +static void MNNRoPEComputeBasic(void* dst, const void* src, const void* cosEven, const void* cosOdd, + const void* sinEven, const void* sinOdd, int numHead, int headDim, int ropeCutHeadDim) { + const int halfHeadDim = headDim / 2; + int ropeDim = ropeCutHeadDim; + if (ropeDim <= 0 || ropeDim > headDim) { + ropeDim = headDim; + } + ropeDim = (ropeDim / 2) * 2; + const int ropeHalfHeadDim = ropeDim / 2; + + auto srcFloat = static_cast(src); + auto dstFloat = static_cast(dst); + auto cosEvenFloat = static_cast(cosEven); + auto cosOddFloat = static_cast(cosOdd); + auto sinEvenFloat = static_cast(sinEven); + auto sinOddFloat = static_cast(sinOdd); + for (int j = 0; j < numHead; ++j) { + auto src0 = srcFloat + j * headDim; + auto src1 = src0 + halfHeadDim; + auto dst0 = dstFloat + j * headDim; + auto dst1 = dst0 + halfHeadDim; + int k = 0; + for (; k <= ropeHalfHeadDim - 4; k += 4) { + auto q0 = Vec4::load(src0 + k); + auto q1 = Vec4::load(src1 + k); + auto c0 = Vec4::load(cosEvenFloat + k); + auto c1 = Vec4::load(cosOddFloat + k); + auto s0 = Vec4::load(sinEvenFloat + k); + auto s1 = Vec4::load(sinOddFloat + k); + Vec4::save(dst0 + k, Vec4::fms(q0 * c0, q1, s0)); + Vec4::save(dst1 + k, Vec4::fma(q1 * c1, q0, s1)); + } + for (; k < ropeHalfHeadDim; ++k) { + auto q0 = src0[k]; + auto q1 = src1[k]; + dst0[k] = q0 * cosEvenFloat[k] - q1 * sinEvenFloat[k]; + dst1[k] = q1 * cosOddFloat[k] + q0 * sinOddFloat[k]; + } + if (ropeHalfHeadDim < halfHeadDim) { + ::memcpy(dst0 + ropeHalfHeadDim, src0 + ropeHalfHeadDim, (halfHeadDim - ropeHalfHeadDim) * sizeof(float)); + ::memcpy(dst1 + ropeHalfHeadDim, src1 + ropeHalfHeadDim, (halfHeadDim - ropeHalfHeadDim) * sizeof(float)); + } + } +} + +template +static void MNNNormPackedFloat(float* dest, const float* source, const float* gamma, const float* beta, float epsilon, + size_t batch, size_t channels, bool RMSNorm) { + const size_t channelUnit = UP_DIV(channels, Pack); + for (size_t n = 0; n < batch; ++n) { + float mean = 0.0f; + if (!RMSNorm) { + float sum = 0.0f; + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + sum += source[(cu * batch + n) * Pack + cr]; + } + mean = sum / static_cast(channels); + } + + float squareSum = 0.0f; + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + float v = source[(cu * batch + n) * Pack + cr]; + float d = RMSNorm ? v : (v - mean); + squareSum += d * d; + } + + const float invStd = 1.0f / std::sqrt(squareSum / static_cast(channels) + epsilon); + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + const size_t index = (cu * batch + n) * Pack + cr; + float v = source[index]; + float norm = RMSNorm ? (v * invStd) : ((v - mean) * invStd); + if (gamma && beta) { + norm = norm * gamma[c] + beta[c]; + } + dest[index] = norm; + } + for (size_t c = channels; c < channelUnit * Pack; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + dest[(cu * batch + n) * Pack + cr] = 0.0f; + } + } +} + void MNNCoreFunctionInit() { gCoreFunction = new CoreFunctions; @@ -4645,6 +4735,7 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNPackedMatMul = MNNPackedMatMul; gCoreFunction->MNNPackedMatMulRemain = MNNPackedMatMulRemain; gCoreFunction->MNNCountMaxMinValue = MNNCountMaxMinValue; + gCoreFunction->MNNNormPacked = MNNNormPackedFloat<4>; #ifdef MNN_USE_SPARSE_COMPUTE gCoreFunction->MNNGetSparseMatMulPackMode = MNNGetSparseMatMulPackMode; gCoreFunction->MNNAdjustOptimalSparseKernel = _MNNAdjustOptimalSparseKernel; @@ -4721,6 +4812,7 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNQuantAttentionKey = MNNQuantAttentionKey; gCoreFunction->MNNQuantAttentionValue = MNNQuantAttentionValue; #endif // MNN_SUPPORT_TRANSFORMER_FUSE + gCoreFunction->MNNRoPECompute = MNNRoPEComputeBasic; gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel; gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg); @@ -4770,9 +4862,9 @@ void MNNCoreFunctionInit() { if (!gCoreFunction->supportSME2) { gCoreFunction->smeCoreNumber = 0; } - MNN_PRINT("MNN_CPU_TARGET=%d effective ARM features: fp16=%d, i8sdot=%d, i8mm=%d, sme2=%d\n", - target, gCoreFunction->supportFp16arith, gCoreFunction->supportSDot, - gCoreFunction->supportI8mm, gCoreFunction->supportSME2); + MNN_PRINT("MNN_CPU_TARGET=%d effective ARM features: fp16=%d, i8sdot=%d, i8mm=%d, sme2=%d\n", target, + gCoreFunction->supportFp16arith, gCoreFunction->supportSDot, gCoreFunction->supportI8mm, + gCoreFunction->supportSME2); } #endif gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A; diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 0abd5e5cf8..de16c6cf0d 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -25,26 +25,34 @@ extern "C" { #ifdef __aarch64__ #ifdef MNN_LOW_MEMORY -void MNNGeneralIm2col_Fp32Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack); -void MNNGeneralIm2col_Fp32Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack); -void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack); -void MNNLocalMinMaxFP32_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer); -void MNNLocalMinMaxFP32_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer); -void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack); -void MNNDynamicQuantFP32_Pack8(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack); +void MNNGeneralIm2col_Fp32Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, + int32_t LP, int32_t pack); +void MNNGeneralIm2col_Fp32Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, + int32_t LP, int32_t pack); +void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, + int32_t LP, int32_t pack); +void MNNLocalMinMaxFP32_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, + size_t EP, size_t LP, size_t loadDstBuffer); +void MNNLocalMinMaxFP32_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, + size_t EP, size_t LP, size_t loadDstBuffer); +void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, + size_t realSize, const float* bias, size_t pack); +void MNNDynamicQuantFP32_Pack8(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, + size_t realSize, const float* bias, size_t pack); void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); -void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad); +void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, + size_t ocQuad); #endif // MNN_LOW_MEMORY #ifdef MNN_SME2 -void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); #endif #endif // __aarch64__ - void MNNQuantAttentionKey(int8_t* dst, const float* source, float* sumKey, float* maxKey, int32_t* params); void MNNQuantAttentionValue(int8_t* dst, const float* source, float* valueQuantInfo, int32_t* params); @@ -66,9 +74,9 @@ void MNNPackC2Origin(double* dst, const double* src, size_t area, size_t depth, void MNNPackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); void MNNPackInt8C2Origin(float* dst, const float* src, size_t area, size_t depth, int areaOffset); -void MNNPackC4Int16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* areaOffset); +void MNNPackC4Int16(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); -void MNNPackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area,size_t depth, int* areaOffset); +void MNNPackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth, int* areaOffset); void MNNUnpackC4(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); void MNNUnpackC4Origin(float* dst, const float* src, size_t area, size_t depth, int areaOffset); @@ -80,9 +88,9 @@ void MNNUnpackC2Float(float* dst, const float* src, size_t area, size_t depth, i void MNNUnpackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); void MNNUnpackInt8C2Origin(float* dst, const float* src, size_t area, size_t depth, int areaOffset); -void MNNUnpackC4Int16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* areaOffset); +void MNNUnpackC4Int16(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); -void MNNUnpackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area,size_t depth, int* areaOffset); +void MNNUnpackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth, int* areaOffset); void MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber); @@ -90,12 +98,12 @@ void MNNScaleAndAddBiasScalar(float* dst, const float* src, float bias, float al // TODO: Swap the name for MNNUnpackTranspose and MNNPackTranspose void MNNUnpackTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); -void MNNUnpackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* areaOffset); -void MNNUnpackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area,size_t depth, int* areaOffset); +void MNNUnpackTransposeInt16(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); +void MNNUnpackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth, int* areaOffset); void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); -void MNNPackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* areaOffset); -void MNNPackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area,size_t depth, int* areaOffset); +void MNNPackTransposeInt16(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); +void MNNPackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth, int* areaOffset); void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); @@ -122,13 +130,15 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo void MNNHardSwishCommon(float* dst, const float* src, size_t size); void MNNGeluCommon(float* dst, const float* src, size_t size); void MNNGeluStandardCommon(float* dst, const float* src, size_t size); -void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false); -void MNNSoftmax(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack = 1, bool mask = false); +void MNNNorm(float* dest, const float* source, const float* gamma, const float* beta, float epsilon, size_t size, + bool RMSNorm = false); +void MNNSoftmax(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, + int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack = 1, bool mask = false); // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n -void MNNGetMatMulPackMode(int* eP, int *lP, int* hP); +void MNNGetMatMulPackMode(int* eP, int* lP, int* hP); -void MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP); +void MNNGetSparseMatMulPackMode(int* eP, int* lP, int* hP); /** int number = info[0]; @@ -147,27 +157,33 @@ void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const in void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose); // parameters: e, l, h, CStride, AStride, BStride -void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, + const float* bias, const float* k, const float* b); void MNNFunctionInit(); -void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); -void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose); -struct SparseMatMulParas -{ +void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, + const float* source, size_t h, size_t l, const int eP, bool transpose); +struct SparseMatMulParas { float* C; const float* A; const float* B; unsigned int* NNZMap; int* dataOffsetMap; }; -void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap); - -void MNNPackedSparseMatMulEpx4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap); +void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, unsigned int* NNZMap, + int* dataOffsetMap); +void MNNPackedSparseMatMulEpx4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, unsigned int* NNZMap, + int* dataOffsetMap); int MNNGetC4DivNumber(int hP); -void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters); +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, + size_t height, const float* parameters); // dim: 4-element, sizeDW, sizeDH, strideSW, strideDH void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim); // not C4 @@ -183,7 +199,8 @@ struct MatMulParam { bool ATranspose; bool BTranspose; }; -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, + size_t tId); void MNNCopyC4Int16WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count); void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); @@ -208,55 +225,90 @@ void MNNPermuteSumWeightInt4Arm82(uint8_t* dest, uint8_t* source, size_t outside void MNNSumWeightInt8Arm86(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); void MNNSumWeightInt8Arm82(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); #ifdef MNN_SME2 -void MNNSumWeightInt8Sme2_Hp32(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); -void MNNSumWeightInt8Sme2_Hp128(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); -void MNNPermuteSumWeightInt4Sme2_Hp32(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelsum, int32_t* table); -void MNNPermuteSumWeightInt4Sme2_Hp128(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelsum, int32_t* table); +void MNNSumWeightInt8Sme2_Hp32(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, + size_t lP); +void MNNSumWeightInt8Sme2_Hp128(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, + size_t lP); +void MNNPermuteSumWeightInt4Sme2_Hp32(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelsum, + int32_t* table); +void MNNPermuteSumWeightInt4Sme2_Hp128(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelsum, + int32_t* table); #endif #endif } -typedef void(*MNNBinaryExecute)(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex); -typedef void(*MNNUnaryExecute)(void* outputRaw, const void* inputRaw, int elementSize); -typedef void(*MNNUnaryExecuteInt8)(void* outputRaw, const void* inputRaw, int elementSize, QuanPrePostParameters* params); -typedef void(*MNNCopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); -typedef void(*MNNBinaryExecInt8)(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast); +typedef void (*MNNBinaryExecute)(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, + int broadcastIndex); +typedef void (*MNNUnaryExecute)(void* outputRaw, const void* inputRaw, int elementSize); +typedef void (*MNNUnaryExecuteInt8)(void* outputRaw, const void* inputRaw, int elementSize, + QuanPrePostParameters* params); +typedef void (*MNNCopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); +typedef void (*MNNBinaryExecInt8)(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, + ssize_t* inputScalesInt32, float* inputScalesFp32, + const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast); constexpr int InputTileMax = 14; // same value from DynamicGemm.h, cannot include from different backend code. namespace MNN { struct MatmulRelatedFunctions { // from coreFunctions - void (*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr; - void (*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr; - void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) = nullptr; - void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack) = nullptr; + void (*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, + size_t lP) = nullptr; + void (*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, + size_t lP) = nullptr; + void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, + float* kernelsum) = nullptr; + void (*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, + int32_t LP, int32_t pack) = nullptr; // from int8CoreFunctions - void(*Int8GemmKernel)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr; - void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr; - void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) = nullptr; - void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr; - void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; + void (*Int8GemmKernel)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, + size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr; + void (*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, + size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, + size_t realCount) = nullptr; + void (*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) = nullptr; + void (*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, + const int32_t* el) = nullptr; + void (*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; + void (*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void (*MNNGemmInt8AddBiasScale_w2_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void (*MNNGemmInt8AddBiasScale_w3_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; + void (*MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, + size_t realDstCount) = nullptr; + void (*MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, + size_t realDstCount) = nullptr; + void (*MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, + size_t dst_depth_quad, const QuanPostTreatParameters* post, + size_t realDstCount) = nullptr; + void (*MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, + size_t src_depth_quad, size_t dst_step, + size_t dst_depth_quad, const QuanPostTreatParameters* post, + size_t realDstCount) = nullptr; + void (*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, + size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, + size_t realDstCount) = nullptr; void (*Int8GemmKernel_W2)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void (*Int8GemmKernel_W3)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams) = nullptr; + void (*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, + SumByAxisParams sumParams) = nullptr; int eP; }; @@ -267,35 +319,46 @@ struct CoreFunctions { bool supportSDot = false; bool supportI8mm = false; bool supportSME2 = false; - bool supportRVV = false; - int smeCoreNumber = 0; + bool supportRVV = false; + int smeCoreNumber = 0; /**MatMul Pack and Functions*/ - void(*MNNGetMatMulPackMode)(int* eP, int *lP, int* hP); - void(*MNNGetSparseMatMulPackMode)(int* eP, int *lP, int* hP); - void(*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); - void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose); - void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack); + void (*MNNGetMatMulPackMode)(int* eP, int* lP, int* hP); + void (*MNNGetSparseMatMulPackMode)(int* eP, int* lP, int* hP); + void (*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); + void (*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, + bool transpose); + void (*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, + int32_t LP, int32_t pack); // parameters: e, l, h, CStride, AStride, BStride - void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); - void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void (*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); + void (*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); // int8 matmul related - void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); - void(*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum); - void(*MNNSumWeightInt8)(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); - void(*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); + void (*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, + SumByAxisParams sumParams); + void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum); + void (*MNNSumWeightInt8)(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); + void (*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, + size_t lP); // cpu dynamic quant - void(*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) = nullptr; - void(*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) = nullptr; - void(*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack, const float* bias) = nullptr; - void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); - void(*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); + void (*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) = nullptr; + void (*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, + size_t batch) = nullptr; + void (*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, + int pack, const float* bias) = nullptr; + void (*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, + const MatMulParam* param, size_t tId); + void (*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, + const MatMulParam* param, size_t tId); // Rank-1 update: S[dk, dv] += k[dk] * delta[dv] (outer product add) - void(*MNNRankOneUpdate)(float* S, const float* k, const float* delta, size_t dk, size_t dv); + void (*MNNRankOneUpdate)(float* S, const float* k, const float* delta, size_t dk, size_t dv); // Read-only dual MatVec: out_k = S^T @ k, out_q = S^T @ q (does NOT modify S) - void(*MNNDualMatVec)(const float* S, const float* k, const float* q, float* out_k, float* out_q, size_t dk, size_t dv); + void (*MNNDualMatVec)(const float* S, const float* k, const float* q, float* out_k, float* out_q, size_t dk, + size_t dv); // Fused decay + rank-1 update: S[i,j] = decay * S[i,j] + k[i] * delta[j] - void(*MNNDecayRankOneUpdate)(float* S, const float* k, const float* delta, float decay, size_t dk, size_t dv); + void (*MNNDecayRankOneUpdate)(float* S, const float* k, const float* delta, float decay, size_t dk, size_t dv); // Fused gated-delta-rule kernel. Computes (all in the backend's native // precision — fp32 in default backend, fp16 in arm82; pointer type is // float* by convention): @@ -307,143 +370,183 @@ struct CoreFunctions { // 'kq' must be precomputed as dot(k,q) by the caller. void (*MNNFusedGatedDelta)(float* S, const float* k, const float* q, const float* v, float* out, float decay, float beta, float kq, size_t dk, size_t dv); - void(*MNNCountMaxMinValue)(const float* source, float* minVal, float* maxVal, size_t size); - void(*MNNDynamicUpdateConvBiasScale)(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad); - void(*MNNAsyQuantInfo)(float* scale, float* bias, float* qscale, float* qbias, float* dstMin, float* dstMax, const float* src, const size_t* info); - void(*MNNAsyQuantFunc)(int8_t* dst, const float* src, float* qscale, float* qbias, const size_t* info); - typedef void(*MNNPackedMatMulKernel)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); + void (*MNNCountMaxMinValue)(const float* source, float* minVal, float* maxVal, size_t size); + void (*MNNNormPacked)(float* dest, const float* source, const float* gamma, const float* beta, float epsilon, + size_t batch, size_t channels, bool RMSNorm); + void (*MNNDynamicUpdateConvBiasScale)(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, + size_t ocQuad); + void (*MNNAsyQuantInfo)(float* scale, float* bias, float* qscale, float* qbias, float* dstMin, float* dstMax, + const float* src, const size_t* info); + void (*MNNAsyQuantFunc)(int8_t* dst, const float* src, float* qscale, float* qbias, const size_t* info); + typedef void (*MNNPackedMatMulKernel)(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias); MNNPackedMatMulKernel MNNPackedMatMulOC16Functions[InputTileMax] = {0}; MNNPackedMatMulKernel MNNPackedMatMulOC32Functions[InputTileMax] = {0}; MNNPackedMatMulKernel MNNPackedMatMulOC48Functions[InputTileMax] = {0}; // For Atomic Op - MNNBinaryExecute(*MNNSelectBinaryFunctionForFloat)(int opType); - MNNUnaryExecute(*MNNSelectUnaryFunctionForFloat)(int opType, int precisionMode); + MNNBinaryExecute (*MNNSelectBinaryFunctionForFloat)(int opType); + MNNUnaryExecute (*MNNSelectUnaryFunctionForFloat)(int opType, int precisionMode); #ifdef MNN_SUPPORT_QUANT_EXTEND - MNNUnaryExecuteInt8(*MNNSelectUnaryFunctionForInt8)(int opType) = nullptr; + MNNUnaryExecuteInt8 (*MNNSelectUnaryFunctionForInt8)(int opType) = nullptr; #endif // B matrix is sparsed - typedef void(*MNNPackedSparseMatMul)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap); - void(*MNNAdjustOptimalSparseKernel)(int& sparseBlockOC, MNNPackedSparseMatMul& packedSparseMatMul); + typedef void (*MNNPackedSparseMatMul)(float* C, const float* A, const float* B, size_t eSize, + const size_t* parameter, const float* postParameters, const float* bias, + unsigned int* NNZMap, int* dataOffsetMap); + void (*MNNAdjustOptimalSparseKernel)(int& sparseBlockOC, MNNPackedSparseMatMul& packedSparseMatMul); /**Lowp Backend Setting*/ - void(*MNNFp32ToLowp)(const float* src, int16_t* dst, size_t size); - void(*MNNLowpToFp32)(const int16_t* src, float* dst, size_t size); + void (*MNNFp32ToLowp)(const float* src, int16_t* dst, size_t size); + void (*MNNLowpToFp32)(const int16_t* src, float* dst, size_t size); int bytes; // Byte for float - int matmulBytes = 0; // Special bytes for dense matmul, C = A*B, A, B is matmulBytes, C is bytes. If 0, means the same as bytes + int matmulBytes = + 0; // Special bytes for dense matmul, C = A*B, A, B is matmulBytes, C is bytes. If 0, means the same as bytes /**NC4HW4's Functions*/ int pack; // For pack * bytes > 16 - MNNCopyWithStride(*MNNSelectBlitFunction)(int blitBytes) = nullptr; + MNNCopyWithStride (*MNNSelectBlitFunction)(int blitBytes) = nullptr; - void(*MNNPackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNPackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNPackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNPackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset); - void(*MNNPackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); - void(*MNNPackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); - void(*MNNUnpackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); + void (*MNNPackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); + void (*MNNUnpackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset); // NC4HW4's compute function - void(*MNNConvRunForLineDepthwise)(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, - size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters); - void(*MNNAxByClampBroadcastUnit)(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters); - void(*MNNMatrixAdd)(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride, - size_t bStride, size_t height); - void(*MNNMatrixSub)(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride, - size_t bStride, size_t height); - void(*MNNStrassenMergeCFunction)(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub); - void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber); - void(*MNNGridSampleComputeCord)(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, bool alignCorners); - void(*MNNGridSampleInterp)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode); - void (*MNNGridSampleInterpGrad)(float* outputPtr, float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode); - - void(*MNNGridSampleComputeCord3D)(float* dst, const float* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, bool alignCorners); - void(*MNNGridSampleInterp3D)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inD, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) = nullptr; - void(*MNNRoiPoolingMax)(float* dst, const float* src, int hLen, int wLen, int iw); - void(*MNNRoiAlignMax)(float* dst, const float* src, const std::vector> &vecPos, const std::vector> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth); - void(*MNNRoiAlignAvg)(float* dst, const float* src, const std::vector> &vecPos, const std::vector> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth); + void (*MNNConvRunForLineDepthwise)(float* dst, const float* src, const float* weight, size_t width, + size_t src_w_setup, size_t fw, size_t fh, size_t dilateX_step, + size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters); + void (*MNNAxByClampBroadcastUnit)(float* C, const float* A, const float* B, size_t width, size_t cStride, + size_t aStride, size_t height, const float* parameters); + void (*MNNMatrixAdd)(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height); + void (*MNNMatrixSub)(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height); + void (*MNNStrassenMergeCFunction)(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, + size_t eSub, size_t hSub); + void (*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, + size_t biasNumber); + void (*MNNGridSampleComputeCord)(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, + bool alignCorners); + void (*MNNGridSampleInterp)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, + size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, + bool padMode); + void (*MNNGridSampleInterpGrad)(float* outputPtr, float* inputPtr, const float* cordPtr, size_t inH, size_t inW, + size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, + bool sampleMode, bool padMode); + + void (*MNNGridSampleComputeCord3D)(float* dst, const float* src, size_t inD, size_t inH, size_t inW, size_t outD, + size_t outH, size_t outW, bool alignCorners); + void (*MNNGridSampleInterp3D)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inD, size_t inH, + size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, + bool sampleMode, bool padMode) = nullptr; + void (*MNNRoiPoolingMax)(float* dst, const float* src, int hLen, int wLen, int iw); + void (*MNNRoiAlignMax)(float* dst, const float* src, const std::vector>& vecPos, + const std::vector>& vecArea, int samplingRatioArea, int pooledHeight, + int pooledWidth); + void (*MNNRoiAlignAvg)(float* dst, const float* src, const std::vector>& vecPos, + const std::vector>& vecArea, int samplingRatioArea, int pooledHeight, + int pooledWidth); float penalty; - void(*MNNCopyC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); - void(*MNNAddC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); + void (*MNNCopyC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); + void (*MNNAddC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count); typedef void (*WinoTransPackFunc)(float* srcBlock, float* dstStart, size_t dstStep); - WinoTransPackFunc(*chooseWinoSourceTransformPack)(int k, int w, int ePack, int lPack, int packCUnit); - - typedef void (*WinoUnrollTransFunc)(const float* srcBlock, float* dstStart, size_t srcRowStep, size_t dstRowStep, size_t srcStep, size_t dstStep); - typedef void (*WinoUnrollDestTransFunc)(const float* srcBlock, float* dstStart, const float* bias, const float* postParameters, size_t srcRowStep, size_t dstRowStep, size_t srcStep, size_t dstStep); - WinoUnrollTransFunc(*chooseWinoSourceUnrollTransform)(int k, int w); - void(*chooseWinoDestUnrollTransform)(WinoUnrollDestTransFunc *destFunctions, size_t maxUnit, int k, int h); - - void(*MNNDeconvRunForUnitDepthWise)(const float* dst, float* src, const float* weight, size_t fw, size_t fh, - size_t weight_y_step, size_t dilateX_step, size_t dilateY_step); - void(*MNNDeconvRunForLineDepthwise)(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step); - void(*MNNDepthwiseConvFastKernel)(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, - size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters) = nullptr; - void(*MNNReluWithSlopeChannel)(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad); - void(*MNNPoolingAvg)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput, - int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, - int strideHeight, int padWidth, int padHeight, int padType, int countType); - void(*MNNPoolingMax)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput, - int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, - int strideHeight, int padWidth, int padHeight, int padType, int countType); - - void(*MNNPoolingMaxWithRedice)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput, - int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, - int strideHeight, int padWidth, int padHeight, int padType, int countType, int *RediceOutput); + WinoTransPackFunc (*chooseWinoSourceTransformPack)(int k, int w, int ePack, int lPack, int packCUnit); + + typedef void (*WinoUnrollTransFunc)(const float* srcBlock, float* dstStart, size_t srcRowStep, size_t dstRowStep, + size_t srcStep, size_t dstStep); + typedef void (*WinoUnrollDestTransFunc)(const float* srcBlock, float* dstStart, const float* bias, + const float* postParameters, size_t srcRowStep, size_t dstRowStep, + size_t srcStep, size_t dstStep); + WinoUnrollTransFunc (*chooseWinoSourceUnrollTransform)(int k, int w); + void (*chooseWinoDestUnrollTransform)(WinoUnrollDestTransFunc* destFunctions, size_t maxUnit, int k, int h); + + void (*MNNDeconvRunForUnitDepthWise)(const float* dst, float* src, const float* weight, size_t fw, size_t fh, + size_t weight_y_step, size_t dilateX_step, size_t dilateY_step); + void (*MNNDeconvRunForLineDepthwise)(const float* dst, float* src, const float* weight, size_t width, + size_t src_w_setup, size_t fw, size_t fh, size_t dilateX_step, + size_t dilateY_step); + void (*MNNDepthwiseConvFastKernel)(float* dst, const float* src, const float* weight, size_t width, + size_t src_w_setup, size_t fw, size_t fh, size_t dilateX_step, + size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) = nullptr; + void (*MNNReluWithSlopeChannel)(float* dst, const float* src, const float* slope, size_t sizeQuad, + size_t depthQuad); + void (*MNNPoolingAvg)(const void* channelInput, int inputWidth, int inputHeight, void* channelOutput, + int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, + int strideHeight, int padWidth, int padHeight, int padType, int countType); + void (*MNNPoolingMax)(const void* channelInput, int inputWidth, int inputHeight, void* channelOutput, + int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, + int strideHeight, int padWidth, int padHeight, int padType, int countType); + + void (*MNNPoolingMaxWithRedice)(const void* channelInput, int inputWidth, int inputHeight, void* channelOutput, + int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, + int strideWidth, int strideHeight, int padWidth, int padHeight, int padType, + int countType, int* RediceOutput); // ImageProcess Funtions - void(*MNNRGBAToBGRA)(const unsigned char* source, unsigned char* dest, size_t count); - void(*MNNNV21ToRGBA)(const unsigned char* source, unsigned char* dest, size_t count); - void(*MNNNV21ToRGB)(const unsigned char* source, unsigned char* dest, size_t count); - void(*MNNNV21ToBGRA)(const unsigned char* source, unsigned char* dest, size_t count); - void(*MNNNV21ToBGR)(const unsigned char* source, unsigned char* dest, size_t count); - void(*MNNC1ToFloatC1)(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); - void(*MNNC3ToFloatC3)(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); - void(*MNNC3ToFloatRGBA)(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); - void(*MNNsampleBilinearCommon)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t count, - size_t iw, size_t ih, size_t yStride, size_t bpp); - void(*MNNSamplerC4Nearest)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, - size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); - void(*MNNSamplerC4Bilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, - size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); - void(*MNNSampleC4Bilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, - size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); - void(*MNNSampleBilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t count, - size_t iw, size_t ih, size_t yStride, size_t bpp); - - void(*MNN4BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNN2BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNN1BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNN4BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNN2BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNN1BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); - void(*MNNAccumulateSequenceNumber)(float* dst, const float* src, int size); + void (*MNNRGBAToBGRA)(const unsigned char* source, unsigned char* dest, size_t count); + void (*MNNNV21ToRGBA)(const unsigned char* source, unsigned char* dest, size_t count); + void (*MNNNV21ToRGB)(const unsigned char* source, unsigned char* dest, size_t count); + void (*MNNNV21ToBGRA)(const unsigned char* source, unsigned char* dest, size_t count); + void (*MNNNV21ToBGR)(const unsigned char* source, unsigned char* dest, size_t count); + void (*MNNC1ToFloatC1)(const unsigned char* source, float* dest, const float* mean, const float* normal, + size_t count); + void (*MNNC3ToFloatC3)(const unsigned char* source, float* dest, const float* mean, const float* normal, + size_t count); + void (*MNNC3ToFloatRGBA)(const unsigned char* source, float* dest, const float* mean, const float* normal, + size_t count); + void (*MNNsampleBilinearCommon)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, + size_t count, size_t iw, size_t ih, size_t yStride, size_t bpp); + void (*MNNSamplerC4Nearest)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, + size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); + void (*MNNSamplerC4Bilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, + size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); + void (*MNNSampleC4Bilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t sta, + size_t count, size_t capacity, size_t iw, size_t ih, size_t yStride); + void (*MNNSampleBilinear)(const unsigned char* source, unsigned char* dest, MNN::CV::Point* points, size_t count, + size_t iw, size_t ih, size_t yStride, size_t bpp); + + void (*MNN4BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNN2BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNN1BitcopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNN4BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNN2BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNN1BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); + void (*MNNAccumulateSequenceNumber)(float* dst, const float* src, int size); // Attention - void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim); - void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart); - void(*MNNSoftmax)(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask); - void(*MNNQuantAttentionKey)(int8_t* dst, const float* source, float* sumKey, float* maxKey, int32_t* params); - void(*MNNQuantAttentionValue)(int8_t* dst, const float* source, float* valueQuantInfo, int32_t* params); + void (*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, + const float* scale, const int32_t* units, size_t seqLen, size_t headDim); + void (*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, + int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, + int bytes, int seqStart); + void (*MNNSoftmax)(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, + int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask); + void (*MNNQuantAttentionKey)(int8_t* dst, const float* source, float* sumKey, float* maxKey, int32_t* params); + void (*MNNQuantAttentionValue)(int8_t* dst, const float* source, float* valueQuantInfo, int32_t* params); + void (*MNNRoPECompute)(void* dst, const void* src, const void* cosEven, const void* cosOdd, const void* sinEven, + const void* sinOdd, int numHead, int headDim, int ropeCutHeadDim); MatmulRelatedFunctions int8MatmulRelatedFunctions; MatmulRelatedFunctions arm82MatmulRelatedFunctions; }; void MNNCoreFunctionInit(); CoreFunctions* MNNGetCoreFunctions(); -}; +}; // namespace MNN #endif /* CommonOptFunction_h */ diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index f7e59d86dc..51aab40398 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -9,20 +9,22 @@ #include "ConvolutionTiledExecutor.hpp" #include "core/Macro.h" #include "core/BufferAllocator.hpp" +#include "SharedGather.hpp" #include #include "backend/cpu/CPUBackend.hpp" #include "core/Concurrency.h" #include "core/TensorUtils.hpp" - #define QUANT_INFO_BYTES 4 #define WEIGHT_ONLINE_REORDER 8 namespace MNN { -ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op): CPUConvolution(op->main_as_Convolution2D()->common(), backend) {} +ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op) + : CPUConvolution(op->main_as_Convolution2D()->common(), backend) {} -ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr res): CPUConvolution(op->main_as_Convolution2D()->common(), backend), mResourceInt8(res) { +ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr res) + : CPUConvolution(op->main_as_Convolution2D()->common(), backend), mResourceInt8(res) { if (!res->mDynamicQuant) { mMutableResource.reset(new MutableResourceInt8(res, backend)); mValid = mMutableResource->mValid; @@ -39,14 +41,19 @@ bool ConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) ErrorCode ConvInt8TiledExecutor::onResize(const std::vector& inputs, const std::vector& outputs) { if (nullptr != mMutableResource) { - mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); + mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), + TensorUtils::getQuantInfo(outputs[0])); } CPUConvolution::onResize(inputs, outputs); - ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, static_cast(backend())->functions(), static_cast(backend())->int8Functions()); + ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, + static_cast(backend())->functions(), + static_cast(backend())->int8Functions()); return NO_ERROR; } -void ConvInt8TiledExecutor::initializeConvInt8QuantInfo(std::shared_ptr &resourceInt8, const Convolution2D *conv2D, std::shared_ptr quanCommon) { +void ConvInt8TiledExecutor::initializeConvInt8QuantInfo(std::shared_ptr& resourceInt8, + const Convolution2D* conv2D, + std::shared_ptr quanCommon) { // input/output scale&zeorpoint if (conv2D->symmetricQuan()) { resourceInt8->mWeightBits = conv2D->symmetricQuan()->nbits(); @@ -75,7 +82,8 @@ void ConvInt8TiledExecutor::initializeConvInt8QuantInfo(std::shared_ptrstride(0) - int stride1 = blockL * SRC_UNIT * UNIT; // weight->stride(1) - int stride2 = UNIT * SRC_UNIT; // weight->stride(2) + int blockL = UP_DIV(ic / blockNum, SRC_UNIT) * kernelCount; + int stride0 = blockNum * SRC_UNIT * blockL * UNIT; // weight->stride(0) + int stride1 = blockL * SRC_UNIT * UNIT; // weight->stride(1) + int stride2 = UNIT * SRC_UNIT; // weight->stride(2) int weightlen = stride0 * UP_DIV(oc, UNIT); memset(dst, initval, weightlen); @@ -124,14 +132,16 @@ void ConvInt8TiledExecutor::reorderWeight(uint8_t* dst, const uint8_t* src, int3 int src_blu = src_k2 + blu * SRC_UNIT * kernelCount; for (int inId = 0; inId < UNIT; ++inId) { int i = i_hU * UNIT + inId; - if (i >= oc) continue; + if (i >= oc) + continue; int dst_inId = dst_blu + inId * SRC_UNIT; int src_inId = src_blu + i * ic * kernelCount; - + for (int blp = 0; blp < SRC_UNIT; ++blp) { int j_in_block = blu * SRC_UNIT + blp; - if (j_in_block >= blockic) continue; - + if (j_in_block >= blockic) + continue; + int dstindex = dst_inId + blp; int srcindex = src_inId + blp * kernelCount; dst[dstindex] = src[srcindex]; @@ -143,21 +153,21 @@ void ConvInt8TiledExecutor::reorderWeight(uint8_t* dst, const uint8_t* src, int3 } } // not fast - if (summerFunc != nullptr && kernelsum != nullptr) { summerFunc(kernelsum, (int8_t*)dst, blockNum * hU, blockL, UNIT, SRC_UNIT); } } -void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8_t* weight, const int8_t* quantInfo, int32_t* info, int infoBytes) { - int blockNum = info[0]; - int ocDiv = info[1]; - int blockL = info[2]; - int UNIT = info[3]; - int SRC_UNIT = info[4]; - auto ocUp4 = info[5]; - auto src0 = weight; // int8 weight: [oc/hp, blocknum, ic/lp*(kx*ky)/blocknum, hp, lp] - auto src1 = quantInfo; // dequant scale: [blocknum, ocUp4] +void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8_t* weight, const int8_t* quantInfo, + int32_t* info, int infoBytes) { + int blockNum = info[0]; + int ocDiv = info[1]; + int blockL = info[2]; + int UNIT = info[3]; + int SRC_UNIT = info[4]; + auto ocUp4 = info[5]; + auto src0 = weight; // int8 weight: [oc/hp, blocknum, ic/lp*(kx*ky)/blocknum, hp, lp] + auto src1 = quantInfo; // dequant scale: [blocknum, ocUp4] auto src2 = src1 + infoBytes * ocUp4 * blockNum; // dequant bias: [blocknum, ocUp4] int stride0 = info[0] * info[2] * info[3] * info[4]; int stride1 = info[2] * info[3] * info[4]; @@ -172,7 +182,8 @@ void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8 auto blockPtr = huPtr + bl * (stride1 + 2 * UNIT * infoBytes); memcpy(blockPtr, src0 + bl * stride1 + hU * stride0, stride1); memcpy(blockPtr + stride1, src1 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes); - memcpy(blockPtr + stride1 + UNIT * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes); + memcpy(blockPtr + stride1 + UNIT * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, + scaleCount * infoBytes); } } } @@ -223,9 +234,11 @@ static void _computeReorderQuantInfo(float* weightKernelSum, int32_t* paramsKern alphaPtr[j * ocUp4 + i] = quanInfoPtr[2 * index + 1]; biasPtr[j * ocUp4 + i] = quanInfoPtr[2 * index] + (float)originOffset * quanInfoPtr[2 * index + 1]; if (realInt4OrInt8) { - accum += (ikernelSum[srcSumIndex] * quanInfoPtr[2 * index + 1] + blockSize * biasPtr[j * ocUp4 + i]); + accum += + (ikernelSum[srcSumIndex] * quanInfoPtr[2 * index + 1] + blockSize * biasPtr[j * ocUp4 + i]); } else { - accum += ((ikernelSum[srcSumIndex] - blockSize * 8)* quanInfoPtr[2 * index + 1] + blockSize * quanInfoPtr[2 * index]); + accum += ((ikernelSum[srcSumIndex] - blockSize * 8) * quanInfoPtr[2 * index + 1] + + blockSize * quanInfoPtr[2 * index]); } if (blockQuantInput) { int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; @@ -250,7 +263,7 @@ static void _computeReorderQuantInfo(float* weightKernelSum, int32_t* paramsKern if (realInt4OrInt8) { accum += (ikernelSum[srcSumIndex] * quanInfoPtr[index] + blockSize * biasPtr[j * ocUp4 + i]); } else { - accum += ((ikernelSum[srcSumIndex] - blockSize * 8) * quanInfoPtr[index]); + accum += ((ikernelSum[srcSumIndex] - blockSize * 8) * quanInfoPtr[index]); } if (blockQuantInput) { int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; @@ -265,7 +278,8 @@ static void _computeReorderQuantInfo(float* weightKernelSum, int32_t* paramsKern } } -static inline void calculateSmeNeonWorkDivision(int& ocMain, int& ocBranch, std::vector& divides, int oc, int threads, int pack, int planeSize, int divisionRatio, int smeCores) { +static inline void calculateSmeNeonWorkDivision(int& ocMain, int& ocBranch, std::vector& divides, int oc, + int threads, int pack, int planeSize, int divisionRatio, int smeCores) { // workload auto ocDivPack = UP_DIV(oc, pack); auto workUnit = UP_DIV(ocDivPack, divisionRatio * smeCores + 1 * (threads - smeCores)); @@ -278,17 +292,19 @@ static inline void calculateSmeNeonWorkDivision(int& ocMain, int& ocBranch, std: divides.assign(threads + 1, ocDivPack); divides[0] = 0; - // runtime UNIT for different core and different process(prefill or decode) - auto rtUnit4Sme = planeSize == 1? GEMM_INT8_UNIT_SME2_128 : GEMM_INT8_UNIT_SME2; + // runtime UNIT for different core and different process(prefill or decode) + auto rtUnit4Sme = planeSize == 1 ? GEMM_INT8_UNIT_SME2_128 : GEMM_INT8_UNIT_SME2; // mOcMain - auto ocPerSmeCore = ALIMIN(UP_DIV(UP_DIV(ROUND_UP(ocMain, pack), rtUnit4Sme), smeCores) * (rtUnit4Sme / pack), UP_DIV(ocMain, pack)); + auto ocPerSmeCore = ALIMIN(UP_DIV(UP_DIV(ROUND_UP(ocMain, pack), rtUnit4Sme), smeCores) * (rtUnit4Sme / pack), + UP_DIV(ocMain, pack)); for (int i = 0; i < smeCores; ++i) { divides[i + 1] = ALIMIN(divides[i] + ocPerSmeCore, UP_DIV(ocMain, pack)); } // ocRemain if (ocBranch > 0) { - auto ocPerNeonCore = UP_DIV(UP_DIV(ROUND_UP(ocBranch, pack), GEMM_INT8_UNIT_ARM82), threads - smeCores) * (GEMM_INT8_UNIT_ARM82 / pack); + auto ocPerNeonCore = UP_DIV(UP_DIV(ROUND_UP(ocBranch, pack), GEMM_INT8_UNIT_ARM82), threads - smeCores) * + (GEMM_INT8_UNIT_ARM82 / pack); for (int i = smeCores + 1; i < threads + 1; ++i) { divides[i] = ALIMIN(divides[i - 1] + ocPerNeonCore, ocDivPack); } @@ -322,7 +338,8 @@ static inline void _computeDivides4Sme(std::vector& divides, int threads, i } } -static inline void _updateMixedKernelFlag(bool &mixedKernel, bool &onlineReorderWeightSme, int threads, int eP, bool isDynamciQuant, bool postiveBothProp) { +static inline void _updateMixedKernelFlag(bool& mixedKernel, bool& onlineReorderWeightSme, int threads, int eP, + bool isDynamciQuant, bool postiveBothProp) { mixedKernel = false; if (threads >= 4 && eP == GEMM_INT8_DST_XUNIT_SME2 && isDynamciQuant && postiveBothProp) { mixedKernel = true; @@ -330,7 +347,10 @@ static inline void _updateMixedKernelFlag(bool &mixedKernel, bool &onlineReorder } } -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr quanCommon, bool isDynamicQuant) : ConvInt8TiledExecutor(backend, op) { +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, + std::shared_ptr quanCommon, + bool isDynamicQuant) + : ConvInt8TiledExecutor(backend, op) { // convolution info auto convOp = op->main_as_Convolution2D(); int kernelCount = mCommon->kernelX() * mCommon->kernelY(); @@ -373,25 +393,29 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O mArm82Functions = gcore->arm82MatmulRelatedFunctions; int UNITMain, SRC_UNITMain, DST_XUNITMain; - int UNITBranch = 0; int SRC_UNITBranch = 0, DST_XUNITBranch = 0; + int UNITBranch = 0; + int SRC_UNITBranch = 0, DST_XUNITBranch = 0; mRelatedFunctions.MNNGetGemmUnit(&UNITMain, &SRC_UNITMain, &DST_XUNITMain); if (mArm82Functions.MNNGetGemmUnit != nullptr) { // exclude cpu does not support arm82 mArm82Functions.MNNGetGemmUnit(&UNITBranch, &SRC_UNITBranch, &DST_XUNITBranch); } - // prefer to maximum decode performance & the machine supports 'sme2' & the runtime backend is 'sme2' -> mOnlineReorderWeightSme=true + // prefer to maximum decode performance & the machine supports 'sme2' & the runtime backend is 'sme2' -> + // mOnlineReorderWeightSme=true mOnlineReorderWeightSme = (weightOnlineReorderOption > 0 && DST_XUNITMain == GEMM_INT8_DST_XUNIT_SME2); if (isDynamicQuant == false) { mOnlineReorderWeightSme = false; } - _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNITMain, isDynamicQuant, mRatioDecode&&mRatioPrefill); + _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNITMain, isDynamicQuant, + mRatioDecode && mRatioPrefill); if (mMixedKernel) { // total work: UP_DIV(oc, pack) // (sme's work / neon's work) = divisionRatio auto workUnit = UP_DIV(UP_DIV(oc, pack), mRatioDecode * mSmeCores + 1 * (threads - mSmeCores)); - mOcMain = ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioDecode, GEMM_INT8_UNIT_SME2_128), oc);; + mOcMain = ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioDecode, GEMM_INT8_UNIT_SME2_128), oc); + ; mOcBranch = oc - mOcMain; } if (mOnlineReorderWeightSme) { @@ -416,6 +440,11 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O mResourceInt8->mWeightAsymmetricQuant = asyWeight; mResourceInt8->mWeightBits = 8; mResourceInt8->mBlockNum = blockNum; + mResourceInt8->mHp = UNITMain; + mResourceInt8->mLp = SRC_UNITMain; + if (DST_XUNITMain == GEMM_INT8_DST_XUNIT_SME2) { + mResourceInt8->mPackMode = 1; + } if (quanCommon && quanCommon->canUseInt4) { shapeMain[4] = SRC_UNITMain / 2; shapeBranch[4] = SRC_UNITBranch / 2; @@ -455,7 +484,8 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O mResourceInt8->mReluThreshold[0] = postPtr[2]; mResourceInt8->mReluThreshold[1] = postPtr[3]; if (gcore->bytes == 2) { - gcore->MNNFp32ToLowp(mResourceInt8->mReluThreshold.data(), reinterpret_cast(mResourceInt8->mReluThreshold.data()), 2); + gcore->MNNFp32ToLowp(mResourceInt8->mReluThreshold.data(), + reinterpret_cast(mResourceInt8->mReluThreshold.data()), 2); } // buffer allocate auto quantlenMain = 2 * blockNum * ROUND_UP(mOcMain, UNITMain) * QUANT_INFO_BYTES; @@ -463,9 +493,11 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O auto quantlenBranch = 2 * blockNum * ocUpHpBranch * QUANT_INFO_BYTES; auto weightlenBranch = shapeBranch[0] * shapeBranch[1] * shapeBranch[2] * shapeBranch[3] * shapeBranch[4]; - mResourceInt8->mWeightInt8.reset(Tensor::createDevice({weightlenMain + quantlenMain + weightlenBranch + quantlenBranch})); + mResourceInt8->mWeightInt8.reset( + Tensor::createDevice({weightlenMain + quantlenMain + weightlenBranch + quantlenBranch})); mResourceInt8->mOriginBias.reset(Tensor::createDevice({ocUp4Main + ocUpHpBranch})); // float - mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice({inputBlockNum * QUANT_INFO_BYTES * (ocUpHpMain + ocUpHpBranch)})); + mResourceInt8->mWeightKernelSum.reset( + Tensor::createDevice({inputBlockNum * QUANT_INFO_BYTES * (ocUpHpMain + ocUpHpBranch)})); auto res = backend->onAcquireBuffer(mResourceInt8->mOriginBias.get(), Backend::STATIC); res &= backend->onAcquireBuffer(mResourceInt8->mWeightKernelSum.get(), Backend::STATIC); @@ -484,7 +516,8 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O ::memset(mResourceInt8->mOriginBias->host(), 0, mResourceInt8->mOriginBias->size()); // dynamic quant - bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(mOcMain, UNITMain) == mOcMain && ROUND_UP(ic, SRC_UNITMain) == ic); // TODO:fix this + bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(mOcMain, UNITMain) == mOcMain && + ROUND_UP(ic, SRC_UNITMain) == ic); // TODO:fix this auto ocMain = mOcMain; auto ocBranch = mOcBranch; auto target = mResourceInt8; @@ -493,11 +526,14 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O auto needToReorderWeightOnline4Sme = mOnlineReorderWeightSme; // Save bias if (convOp->bias()) { - ::memcpy(mResourceInt8->mOriginBias->host(), convOp->bias()->data(), convOp->bias()->size() * sizeof(float)); + ::memcpy(mResourceInt8->mOriginBias->host(), convOp->bias()->data(), + convOp->bias()->size() * sizeof(float)); } auto coreFuncs = static_cast(backend)->functions(); - auto reorderFunc = [=](decltype(mRelatedFunctions) funcs, std::vector shape, int UNIT, int SRC_UNIT, int DST_XUNIT, int weightlen, int scaleSize, int oc, int offsetTg, bool fastReadWeight, int8_t** addressPtr, weightSummerFuncion sumFunc) -> int { + auto reorderFunc = [=](decltype(mRelatedFunctions) funcs, std::vector shape, int UNIT, int SRC_UNIT, + int DST_XUNIT, int weightlen, int scaleSize, int oc, int offsetTg, bool fastReadWeight, + int8_t** addressPtr, weightSummerFuncion sumFunc) -> int { auto sh = shape; AutoStorage weightReordered(weightlen); AutoStorage reorderedQuantInfo(2 * scaleSize * QUANT_INFO_BYTES); @@ -552,7 +588,8 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O return -1; } - reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), sumFunc); + reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), + sumFunc); // pack two int4 to int8 int leng = weightlen * 2; @@ -763,36 +800,45 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)}; int8_t* weightInt8 = addressPtr[1]; - ConvInt8TiledExecutor::packWeightAndQuantInfo(weightInt8, (int8_t*)weightReordered.get(), reorderedQuantInfo.get(), params, QUANT_INFO_BYTES); + ConvInt8TiledExecutor::packWeightAndQuantInfo(weightInt8, (int8_t*)weightReordered.get(), + reorderedQuantInfo.get(), params, QUANT_INFO_BYTES); return 0; }; auto function = [=]() -> int { - bool fastReadWeight = (kernelCount == 1 && ROUND_UP(ocMain, UNITMain) == ocMain && ROUND_UP(ic, SRC_UNITMain) == ic); + bool fastReadWeight = + (kernelCount == 1 && ROUND_UP(ocMain, UNITMain) == ocMain && ROUND_UP(ic, SRC_UNITMain) == ic); weightSummerFuncion sumFunc = funcsMain.MNNSumWeightInt8; if (mOnlineReorderWeightSme) { sumFunc = funcsMain.MNNSumWeightInt8SmeHp128; } int8_t* addressPtr[4]; - addressPtr[0] = quanCommon? quanCommon->weight.get() : (int8_t*)convOp->symmetricQuan()->weight()->data(); + addressPtr[0] = quanCommon ? quanCommon->weight.get() : (int8_t*)convOp->symmetricQuan()->weight()->data(); addressPtr[1] = target->mWeightInt8->host(); addressPtr[2] = target->mWeightKernelSum->host(); - addressPtr[3] = quanCommon? (int8_t*) quanCommon->alpha.get() : (int8_t*)convOp->symmetricQuan()->scale()->data(); + addressPtr[3] = + quanCommon ? (int8_t*)quanCommon->alpha.get() : (int8_t*)convOp->symmetricQuan()->scale()->data(); - reorderFunc(funcsMain, shapeMain, UNITMain, SRC_UNITMain, DST_XUNITMain, weightlenMain, scaleSizeMain, ocMain, 0, fastReadWeight, addressPtr, sumFunc); + reorderFunc(funcsMain, shapeMain, UNITMain, SRC_UNITMain, DST_XUNITMain, weightlenMain, scaleSizeMain, ocMain, + 0, fastReadWeight, addressPtr, sumFunc); if (ocBranch > 0) { // update the address of weight source, weight destination, weight kernel sum and weight scale - addressPtr[0] += (target->mWeightBits == 4 ? ocMain * ic * kernelCount / 2 : ocMain * ic * kernelCount); // ocMain%2==0, so divides 2 directly + addressPtr[0] += + (target->mWeightBits == 4 ? ocMain * ic * kernelCount / 2 + : ocMain * ic * kernelCount); // ocMain%2==0, so divides 2 directly addressPtr[1] += (weightlenMain + quantlenMain); addressPtr[2] += ROUND_UP(ocMain, UNITMain) * inputBlockNum * QUANT_INFO_BYTES; - addressPtr[3] += (quanCommon->asymmetric ? 2 * ocMain * blockNum * QUANT_INFO_BYTES : ocMain * blockNum * QUANT_INFO_BYTES); + addressPtr[3] += (quanCommon->asymmetric ? 2 * ocMain * blockNum * QUANT_INFO_BYTES + : ocMain * blockNum * QUANT_INFO_BYTES); sumFunc = funcsBranch.MNNSumWeightInt8; - fastReadWeight = (kernelCount == 1 && ROUND_UP(ocBranch, UNITMain) == ocBranch && ROUND_UP(ic, SRC_UNITMain) == ic); - reorderFunc(funcsBranch, shapeBranch, UNITBranch, SRC_UNITBranch, DST_XUNITBranch, weightlenBranch, scaleSizeBranch, ocBranch, 1, fastReadWeight, addressPtr, sumFunc); + fastReadWeight = + (kernelCount == 1 && ROUND_UP(ocBranch, UNITMain) == ocBranch && ROUND_UP(ic, SRC_UNITMain) == ic); + reorderFunc(funcsBranch, shapeBranch, UNITBranch, SRC_UNITBranch, DST_XUNITBranch, weightlenBranch, + scaleSizeBranch, ocBranch, 1, fastReadWeight, addressPtr, sumFunc); } return 0; }; @@ -802,9 +848,8 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O if (!isDynamicQuant) { mResourceInt8->mDynamicQuant = false; - std::shared_ptr scaleAndBias(new float[ocUpHpMain * 2 * mBlockNum], [](void* ptr) { - delete [] (float*)ptr; - }); + std::shared_ptr scaleAndBias(new float[ocUpHpMain * 2 * mBlockNum], + [](void* ptr) { delete[] (float*)ptr; }); memset(scaleAndBias.get(), 0, ocUpHpMain * 2 * mBlockNum * sizeof(float)); int weightSize; @@ -821,9 +866,9 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O if ((convOp->quanParameter() && convOp->quanParameter()->alpha()) || (quanCommon && quanCommon->alpha.get())) { int quantCount; if (convOp->quanParameter() && convOp->quanParameter()->alpha()) { - quantCount = convOp->quanParameter()->alpha()->size(); + quantCount = convOp->quanParameter()->alpha()->size(); } else { - quantCount = quanCommon->alpha.size(); + quantCount = quanCommon->alpha.size(); } if (false == weightAsy) { // symmetric quant @@ -853,7 +898,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } } #else - if(convOp->symmetricQuan() && convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ + if (convOp->symmetricQuan() && convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE) { mGemmKernel = mRelatedFunctions.Int8GemmKernelFast; } if (mResourceInt8->mWeightBits == 4) { @@ -863,9 +908,9 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } } -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, const DenseConvInt8TiledExecutor& exe) - : ConvInt8TiledExecutor(backend, op, exe.mResourceInt8), mGemmKernel(exe.mGemmKernel) { -} +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, + const DenseConvInt8TiledExecutor& exe) + : ConvInt8TiledExecutor(backend, op, exe.mResourceInt8), mGemmKernel(exe.mGemmKernel) {} DenseConvInt8TiledExecutor::~DenseConvInt8TiledExecutor() { // Do nothing @@ -875,6 +920,10 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** if (nullptr == dst) { return true; } + if (op->type() == OpType_GatherV2) { + *dst = new SharedGather(bn, mResourceInt8); + return true; + } auto exe = new DenseConvInt8TiledExecutor(bn, op, *this); if (!exe->valid()) { return false; @@ -883,8 +932,8 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** return true; } - -ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& inputs, const std::vector& outputs) { +ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& inputs, + const std::vector& outputs) { // Initialize. mUseBatchQuan = false; mIm2ColBasedInt8 = true; @@ -895,7 +944,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input // backend info auto core = static_cast(backend())->int8Functions(); - auto gcore =static_cast(backend())->functions(); + auto gcore = static_cast(backend())->functions(); const int threads = static_cast(backend())->threadNumber(); mRelatedFunctions = *(static_cast(backend())->int8GemmFunctions()); mArm82Functions = gcore->arm82MatmulRelatedFunctions; @@ -906,14 +955,15 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input auto inputBlockQuantOption = option % WEIGHT_ONLINE_REORDER; auto weightOnlineReorderOption = WEIGHT_ONLINE_REORDER & option; - _getProportions(static_cast(backend())->getRuntime()->hint().divisionRatio, mRatioPrefill, mRatioDecode); + _getProportions(static_cast(backend())->getRuntime()->hint().divisionRatio, mRatioPrefill, + mRatioDecode); // feature map info int batch = inputs[0]->batch(); - int inC = inputs[0]->channel(); + int inC = inputs[0]->channel(); auto output = outputs[0]; int kernelCount = mCommon->kernelY() * mCommon->kernelX(); - int inputPlane = batch * inputs[0]->width() * inputs[0]->height(); + int inputPlane = batch * inputs[0]->width() * inputs[0]->height(); auto planeSize = output->width() * output->height() * output->batch(); int UNIT, SRC_UNIT, DST_XUNIT; @@ -924,7 +974,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mOnlineReorderWeightSme = false; } - _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNIT, mResourceInt8->mDynamicQuant, mRatioDecode&&mRatioPrefill); + _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNIT, mResourceInt8->mDynamicQuant, + mRatioDecode && mRatioPrefill); if (mOnlineReorderWeightSme && planeSize == 1) { // Decode, set runtime unit UNIT = GEMM_INT8_UNIT_SME2_128; @@ -934,7 +985,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mGemmUnits[1] = SRC_UNIT; mGemmUnits[2] = DST_XUNIT; - bool fastway = (kernelCount == 1) && (output->width() == inputs[0]->width()) && (output->height() == inputs[0]->height()) && (mCommon->strideX() * mCommon->strideY()) == 1; + bool fastway = (kernelCount == 1) && (output->width() == inputs[0]->width()) && + (output->height() == inputs[0]->height()) && (mCommon->strideX() * mCommon->strideY()) == 1; if (inputPlane > 1) { mUseBatchQuan = true; } @@ -962,7 +1014,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input CPUConvolution::onResize(inputs, outputs); if (mResourceInt8->mDynamicQuant == false) { - mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); + mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), + TensorUtils::getQuantInfo(outputs[0])); if (!mMutableResource->mResource->mUseConvQuan) { // In some previous quantized models, input's scale already fused with weight's scale and output's scale. // So there is no need to read input's scale additionally. @@ -976,7 +1029,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mUseBatchQuan = false; } int matmulUnits[3] = {UNIT, SRC_UNIT, DST_XUNIT}; - ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core, gcore->pack, matmulUnits); + ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, + core, gcore->pack, matmulUnits); // Im2col info int im2colBytes = 1; @@ -989,7 +1043,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } int ic = inputs[0]->channel(); int tileLimit = 0; - int outC = output->channel(); + int outC = output->channel(); int outC4 = UP_DIV(outC, gcore->pack); mOcMain = outC; mOcBranch = 0; @@ -998,25 +1052,29 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mSplitByOc = true; // flop and io - float flop = gcore->bytes * planeSize * (ROUND_UP(output->channel(), gcore->pack) * kernelCountUnit * SRC_UNIT / 1024.0 / 1024.0 / 1024.0); - float ios = (((CPUBackend*)backend())->getTensorSize(outputs[0], true) + ((CPUBackend*)backend())->getTensorSize(inputs[0], true) + ((CPUBackend*)backend())->getTensorSize(mResourceInt8->mWeightInt8.get()) * weightBytes) / (1024.0 * 1024.0 * 1024.0); + float flop = gcore->bytes * planeSize * + (ROUND_UP(output->channel(), gcore->pack) * kernelCountUnit * SRC_UNIT / 1024.0 / 1024.0 / 1024.0); + float ios = (((CPUBackend*)backend())->getTensorSize(outputs[0], true) + + ((CPUBackend*)backend())->getTensorSize(inputs[0], true) + + ((CPUBackend*)backend())->getTensorSize(mResourceInt8->mWeightInt8.get()) * weightBytes) / + (1024.0 * 1024.0 * 1024.0); if ((threads < planeSize || mOnlineReorderWeightSme) && !mMixedKernel) { // Thread split by output nhw. tileLimit = ALIMIN(tileLimitByC, UP_DIV(planeSize, threads)); mIm2ColCount = UP_DIV(tileLimit, DST_XUNIT); auto DynamicDestUnit = DST_XUNIT * mIm2ColCount; - mTileCount = UP_DIV(planeSize, DynamicDestUnit); + mTileCount = UP_DIV(planeSize, DynamicDestUnit); if (mTileCount > threads || (mOnlineReorderWeightSme && planeSize > 1)) { mSplitByOc = false; - } + } } if (mSplitByOc) { tileLimit = ALIMIN(tileLimitByC, planeSize); mIm2ColCount = UP_DIV(tileLimit, DST_XUNIT); auto DynamicDestUnit = DST_XUNIT * mIm2ColCount; - mTileCount = UP_DIV(planeSize, DynamicDestUnit); - mDivides.resize(threads+1); + mTileCount = UP_DIV(planeSize, DynamicDestUnit); + mDivides.resize(threads + 1); mDivides[0] = 0; // output channel divided by threads if (!mMixedKernel) { @@ -1028,7 +1086,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input MNN_ASSERT(UNIT % gcore->pack == 0); int ocDivUnit = UP_DIV(outC4 * gcore->pack, UNIT); ocPerThread = UP_DIV(ocDivUnit, threads); - threadNeed = UP_DIV(ocDivUnit, ocPerThread); + threadNeed = UP_DIV(ocDivUnit, ocPerThread); totalWork = ocDivUnit; part = UNIT / gcore->pack; } @@ -1037,9 +1095,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (threads >= 4 && DST_XUNIT == GEMM_INT8_DST_XUNIT_SME2 && mResourceInt8->mDynamicQuant) { _computeDivides4Sme(mDivides, threads, mSmeCores, totalWork); } else { - mDivides.resize(threads+1); + mDivides.resize(threads + 1); mDivides[0] = 0; - static_cast(backend())->computeDivideSizes(totalWork, mDivides.data() + 1, flop / ios); + static_cast(backend())->computeDivideSizes(totalWork, mDivides.data() + 1, flop / ios); } for (int i = 0; i < mDivides.size(); ++i) { mDivides[i] *= part; @@ -1047,19 +1105,20 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } else { // workload mOcMain = 0; // initialize for mixed kernel, before calculate - calculateSmeNeonWorkDivision(mOcMain, mOcBranch, mDivides, outC, threads, pack, planeSize, mRatioDecode, mSmeCores); + calculateSmeNeonWorkDivision(mOcMain, mOcBranch, mDivides, outC, threads, pack, planeSize, mRatioDecode, + mSmeCores); mThreadNums = threads; } } if (!mSplitByOc) { mThreadNums = ALIMIN(threads, mTileCount); - if (threads >= 4&&DST_XUNIT==GEMM_INT8_DST_XUNIT_SME2&&mResourceInt8->mDynamicQuant&&!mMixedKernel) { + if (threads >= 4 && DST_XUNIT == GEMM_INT8_DST_XUNIT_SME2 && mResourceInt8->mDynamicQuant && !mMixedKernel) { _computeDivides4Sme(mDivides, threads, mSmeCores, mTileCount); } else { - mDivides.resize(threads+1); + mDivides.resize(threads + 1); mDivides[0] = 0; - static_cast(backend())->computeDivideSizes(mTileCount, mDivides.data() + 1, flop / ios); + static_cast(backend())->computeDivideSizes(mTileCount, mDivides.data() + 1, flop / ios); } } mDividesTmp.resize(threads + 1); @@ -1075,11 +1134,12 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } auto bufferAlloc = static_cast(backend())->getBufferAllocator(); - auto blitInfoSize = ConvolutionTiledExecutor::computeBlitInfoSize(workPT, mIm2ColParamter.ow, mIm2ColParamter.kernelX * mIm2ColParamter.kernelY, k); + auto blitInfoSize = ConvolutionTiledExecutor::computeBlitInfoSize( + workPT, mIm2ColParamter.ow, mIm2ColParamter.kernelX * mIm2ColParamter.kernelY, k); mBlitInfoStride = blitInfoSize.second; mBlitInfo = bufferAlloc->alloc(blitInfoSize.first); - const int unitColBufferSize = kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); - const int colBufferSize = unitColBufferSize * mIm2ColCount; + const int unitColBufferSize = kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); + const int colBufferSize = unitColBufferSize * mIm2ColCount; if (!mSplitByOc) { mTempIm2ColBuffer.reset(Tensor::createDevice({threads, colBufferSize * im2colBytes})); @@ -1143,7 +1203,6 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } mQuantFunc = core->DynamicQuanInput_ARM82; mQuantAndReorderFunc = core->DynamicQuanInputAndReorder_ARM82; - } // A axisSum kernel } else { // use sme and neon gemmInt8 @@ -1236,7 +1295,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (mUseBatchQuan) { if (mIm2ColBasedInt8) { size = 2 * mInputBlockNum * inputPlane * QUANT_INFO_BYTES; - } else if (!mSplitByOc){ // only threads buffer needed by this case + } else if (!mSplitByOc) { // only threads buffer needed by this case size = 2 * mInputBlockNum * mIm2ColCount * DST_XUNIT * QUANT_INFO_BYTES; } else { size = 2 * mInputBlockNum * planeSize * QUANT_INFO_BYTES; @@ -1258,8 +1317,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input // Dynamic quant. // set im2col tensor info if (mIm2ColBasedInt8) { - mQuantInput.reset((Tensor::createDevice({batch, mIm2ColParamter.ih, mIm2ColParamter.iw, ROUND_UP(inC, gcore->pack)}))); - } else if (!mSplitByOc){ + mQuantInput.reset(( + Tensor::createDevice({batch, mIm2ColParamter.ih, mIm2ColParamter.iw, ROUND_UP(inC, gcore->pack)}))); + } else if (!mSplitByOc) { mQuantInput.reset((Tensor::createDevice({threads, colBufferSize * 1}))); } else { mQuantInput.reset((Tensor::createDevice({mTileCount, colBufferSize * 1}))); @@ -1290,12 +1350,15 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int ocProcessedByNeon = 0; if (mMixedKernel && mRatioDecode != mRatioPrefill) { auto workUnit = UP_DIV(outC4, mRatioPrefill * mSmeCores + 1 * (threads - mSmeCores)); - ocProcessedBySme = ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioPrefill, GEMM_INT8_UNIT_SME2_128), outC); + ocProcessedBySme = + ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioPrefill, GEMM_INT8_UNIT_SME2_128), outC); ocProcessedBySme = ALIMAX(ocProcessedBySme, mOcMain); ocProcessedByNeon = outC - ocProcessedBySme; } - int weightlenSme = ROUND_UP(ocProcessedBySme, GEMM_INT8_UNIT_SME2_128) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; - int weightlenNeon = ROUND_UP(ocProcessedByNeon, 8) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; + int weightlenSme = ROUND_UP(ocProcessedBySme, GEMM_INT8_UNIT_SME2_128) * mBlockNum * + ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; + int weightlenNeon = + ROUND_UP(ocProcessedByNeon, 8) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; if (mResourceInt8->mWeightBits == 4) { weightlenSme /= 2; weightlenNeon /= 2; @@ -1303,13 +1366,13 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int scalebiasLenSme = 2 * mBlockNum * ROUND_UP(ocProcessedBySme, GEMM_INT8_UNIT_SME2_128) * QUANT_INFO_BYTES; int scalebiasLenNeon = 2 * mBlockNum * ROUND_UP(ocProcessedByNeon, 8) * QUANT_INFO_BYTES; - mWeight4Prefill = bufferAlloc->alloc(weightlenSme + scalebiasLenSme + weightlenNeon + scalebiasLenNeon); if (mWeight4Prefill.invalid()) { return OUT_OF_MEMORY; } if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum - mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, GEMM_INT8_UNIT_SME2_128) * mBlockNum * sizeof(float)); + mWeightKernelSum4Prefill = + bufferAlloc->alloc(ROUND_UP(outC, GEMM_INT8_UNIT_SME2_128) * mBlockNum * sizeof(float)); if (mWeightKernelSum4Prefill.invalid()) { return OUT_OF_MEMORY; } @@ -1318,7 +1381,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mToFuseInputbias2Bias = (!mUseBatchQuan && inputBlockQuantOption != 2) ? true : false; if (mToFuseInputbias2Bias) { // input data has only one bias&scale if (mIm2ColBasedInt8) { - mBiasBufferFusedInputzero = bufferAlloc->alloc(ROUND_UP(outC, UNIT) * QUANT_INFO_BYTES); // should be UP_DIV(oc, UNIT),not UP_DIV(oc, pack) + mBiasBufferFusedInputzero = bufferAlloc->alloc( + ROUND_UP(outC, UNIT) * QUANT_INFO_BYTES); // should be UP_DIV(oc, UNIT),not UP_DIV(oc, pack) } else { mBiasBufferFusedInputzero = bufferAlloc->alloc(threads * ROUND_UP(outC, UNIT) * QUANT_INFO_BYTES); } @@ -1351,7 +1415,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input bufferAlloc->free(mWeightKernelSum4Prefill); } } - if (mBlockNum >1 && kernelCount > 1) { + if (mBlockNum > 1 && kernelCount > 1) { bufferAlloc->free(mReorderBuffer); } if (mToFuseInputbias2Bias) { @@ -1378,8 +1442,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input #endif } - -static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSrc, int hPDst, int hU, int blockNum, int blockLu, int lp, bool int4weight) { +static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSrc, int hPDst, int hU, int blockNum, + int blockLu, int lp, bool int4weight) { // hPSrc = 4 * hPDst int unitsize_ = hPDst * lp; @@ -1398,7 +1462,7 @@ static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSr for (int i = 0; i < hU; ++i) { for (int k = 0; k < blockNum; ++k) { auto weightsrc = (int8_t*)(src + i * srcStride0 + k * srcStride1); - auto weightdst0 = (int8_t*)(dst + (4 * i) * dstStride0 + k * dstStride1); + auto weightdst0 = (int8_t*)(dst + (4 * i) * dstStride0 + k * dstStride1); auto weightdst1 = (int8_t*)(dst + (4 * i + 1) * dstStride0 + k * dstStride1); auto weightdst2 = (int8_t*)(dst + (4 * i + 2) * dstStride0 + k * dstStride1); auto weightdst3 = (int8_t*)(dst + (4 * i + 3) * dstStride0 + k * dstStride1); @@ -1435,14 +1499,14 @@ static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSr } if (lu > 1) { - memcpy(weightdst0, weightsrc, unitsize_); - memcpy(weightdst0 + unitsize_, weightsrc + unitsize4, unitsize_); - memcpy(weightdst1, weightsrc + unitsize_, unitsize_); - memcpy(weightdst1 + unitsize_, weightsrc + unitsize4 + unitsize_, unitsize_); - memcpy(weightdst2, weightsrc + unitsize_ * 2, unitsize_); - memcpy(weightdst2 + unitsize_, weightsrc + unitsize4 + unitsize_ * 2, unitsize_); - memcpy(weightdst3, weightsrc + unitsize_ * 3, unitsize_); - memcpy(weightdst3 + unitsize_, weightsrc + unitsize4 + unitsize_ * 3, unitsize_); + memcpy(weightdst0, weightsrc, unitsize_); + memcpy(weightdst0 + unitsize_, weightsrc + unitsize4, unitsize_); + memcpy(weightdst1, weightsrc + unitsize_, unitsize_); + memcpy(weightdst1 + unitsize_, weightsrc + unitsize4 + unitsize_, unitsize_); + memcpy(weightdst2, weightsrc + unitsize_ * 2, unitsize_); + memcpy(weightdst2 + unitsize_, weightsrc + unitsize4 + unitsize_ * 2, unitsize_); + memcpy(weightdst3, weightsrc + unitsize_ * 3, unitsize_); + memcpy(weightdst3 + unitsize_, weightsrc + unitsize4 + unitsize_ * 3, unitsize_); weightsrc += unitsize4 * 2; weightdst0 += unitsize_ * 2; @@ -1453,8 +1517,8 @@ static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSr } if (lu > 0) { - memcpy(weightdst0, weightsrc, unitsize_); - memcpy(weightdst1, weightsrc + unitsize_, unitsize_); + memcpy(weightdst0, weightsrc, unitsize_); + memcpy(weightdst1, weightsrc + unitsize_, unitsize_); memcpy(weightdst2, weightsrc + unitsize_ * 2, unitsize_); memcpy(weightdst3, weightsrc + unitsize_ * 3, unitsize_); } @@ -1468,22 +1532,23 @@ static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSr // Copy scales (first part of the scale/bias region) int scaleSize = hPDst * sizeof(float); - memcpy(scaleDst0, scaleSrc, scaleSize); - memcpy(scaleDst1, scaleSrc + scaleSize, scaleSize); - memcpy(scaleDst2, scaleSrc + scaleSize * 2, scaleSize); - memcpy(scaleDst3, scaleSrc + scaleSize * 3, scaleSize); + memcpy(scaleDst0, scaleSrc, scaleSize); + memcpy(scaleDst1, scaleSrc + scaleSize, scaleSize); + memcpy(scaleDst2, scaleSrc + scaleSize * 2, scaleSize); + memcpy(scaleDst3, scaleSrc + scaleSize * 3, scaleSize); // Copy biases (second part of the scale/bias region) auto biasSrcOffset = hPSrc * sizeof(float); - memcpy(scaleDst0 + scaleSize, scaleSrc + biasSrcOffset, scaleSize); - memcpy(scaleDst1 + scaleSize, scaleSrc + biasSrcOffset + scaleSize, scaleSize); - memcpy(scaleDst2 + scaleSize, scaleSrc + biasSrcOffset + scaleSize * 2, scaleSize); - memcpy(scaleDst3 + scaleSize, scaleSrc + biasSrcOffset + scaleSize * 3, scaleSize); + memcpy(scaleDst0 + scaleSize, scaleSrc + biasSrcOffset, scaleSize); + memcpy(scaleDst1 + scaleSize, scaleSrc + biasSrcOffset + scaleSize, scaleSize); + memcpy(scaleDst2 + scaleSize, scaleSrc + biasSrcOffset + scaleSize * 2, scaleSize); + memcpy(scaleDst3 + scaleSize, scaleSrc + biasSrcOffset + scaleSize * 3, scaleSize); } } } -static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int blockLu, int lp, bool isInt4Weight, int srcH, int blockNum, int resOcBranch) { +static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int blockLu, int lp, bool isInt4Weight, + int srcH, int blockNum, int resOcBranch) { constexpr int hPSrc = 8; constexpr int hPDst = 32; @@ -1506,7 +1571,7 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int auto weightSrcBase1 = src + (4 * i + 1) * srcStride0 + k * srcStride1; auto weightSrcBase2 = src + (4 * i + 2) * srcStride0 + k * srcStride1; auto weightSrcBase3 = src + (4 * i + 3) * srcStride0 + k * srcStride1; - auto weightDstBase = dst + i * dstStride0 + k * dstStride1; + auto weightDstBase = dst + i * dstStride0 + k * dstStride1; int lu = blockLu; @@ -1517,8 +1582,8 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int for (int s = 0; s < half_size; ++s) { uint8_t p0 = src_b[2 * s]; uint8_t p1 = src_b[2 * s + 1]; - dst_b[s] = (p1 & 0xF0) | (p0 >> 4); - dst_b[s + half_size] = (p1 << 4) | (p0 & 0x0F); + dst_b[s] = (p1 & 0xF0) | (p0 >> 4); + dst_b[s + half_size] = (p1 << 4) | (p0 & 0x0F); } }; while (lu >= 4) { @@ -1539,7 +1604,7 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int weightSrcBase1 += 4 * srcUnitSize; weightSrcBase2 += 4 * srcUnitSize; weightSrcBase3 += 4 * srcUnitSize; - weightDstBase += 4 * dstUnitSize; + weightDstBase += 4 * dstUnitSize; lu -= 4; } @@ -1559,36 +1624,52 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int weightSrcBase1 += srcUnitSize; weightSrcBase2 += srcUnitSize; weightSrcBase3 += srcUnitSize; - weightDstBase += dstUnitSize; + weightDstBase += dstUnitSize; } } else { while (lu >= 4) { // j = 0 - memcpy(weightDstBase + 0 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 0 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 0 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 0 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 0 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 0 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 0 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 0 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 0 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 0 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 0 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 0 * srcUnitSize, + srcUnitSize); // j = 1 - memcpy(weightDstBase + 1 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 1 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 1 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 1 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 1 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 1 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 1 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 1 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 1 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 1 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 1 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 1 * srcUnitSize, + srcUnitSize); // j = 2 - memcpy(weightDstBase + 2 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 2 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 2 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 2 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 2 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 2 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 2 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 2 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 2 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 2 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 2 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 2 * srcUnitSize, + srcUnitSize); // j = 3 - memcpy(weightDstBase + 3 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 3 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 3 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 3 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 3 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 3 * srcUnitSize, srcUnitSize); - memcpy(weightDstBase + 3 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 3 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 3 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 3 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 3 * srcUnitSize, + srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 3 * srcUnitSize, + srcUnitSize); weightSrcBase0 += 4 * srcUnitSize; weightSrcBase1 += 4 * srcUnitSize; weightSrcBase2 += 4 * srcUnitSize; weightSrcBase3 += 4 * srcUnitSize; - weightDstBase += 4 * dstUnitSize; + weightDstBase += 4 * dstUnitSize; lu -= 4; } @@ -1602,7 +1683,7 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int weightSrcBase1 += srcUnitSize; weightSrcBase2 += srcUnitSize; weightSrcBase3 += srcUnitSize; - weightDstBase += dstUnitSize; + weightDstBase += dstUnitSize; } } @@ -1632,11 +1713,11 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int const int i = hUDst; for (int k = 0; k < blockNum; ++k) { const int8_t* srcBases[4] = {nullptr, nullptr, nullptr, nullptr}; - for(int j = 0; j < hTail; ++j) { + for (int j = 0; j < hTail; ++j) { srcBases[j] = src + (4 * i + j) * srcStride0 + k * srcStride1; } - auto weightDstBase = dst + i * dstStride0 + k * dstStride1; + auto weightDstBase = dst + i * dstStride0 + k * dstStride1; int lu = blockLu; @@ -1646,17 +1727,14 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int for (int s = 0; s < half_size; ++s) { uint8_t p0 = src_b[2 * s]; uint8_t p1 = src_b[2 * s + 1]; - dst_b[s] = (p1 & 0xF0) | (p0 >> 4); + dst_b[s] = (p1 & 0xF0) | (p0 >> 4); dst_b[s + half_size] = (p1 << 4) | (p0 & 0x0F); } }; - while (lu --> 0) { + while (lu-- > 0) { for (int j = 0; j < hTail; ++j) { - process_int4_block( - (uint8_t*)(weightDstBase + j * srcUnitSize), - (const uint8_t*)(srcBases[j]), - srcUnitSize - ); + process_int4_block((uint8_t*)(weightDstBase + j * srcUnitSize), (const uint8_t*)(srcBases[j]), + srcUnitSize); } // For the remaining part of the destination block, set 0 @@ -1664,13 +1742,13 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int memset(weightDstBase + hTail * srcUnitSize, 0, (4 - hTail) * srcUnitSize); } - for(int j=0; j 0) { + while (lu-- > 0) { for (int j = 0; j < hTail; ++j) { memcpy(weightDstBase + j * srcUnitSize, srcBases[j], srcUnitSize); } @@ -1679,7 +1757,7 @@ static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int memset(weightDstBase + hTail * srcUnitSize, 0, (4 - hTail) * srcUnitSize); } - for(int j=0; j 0) { size_t resLp = isInt4Weight ? lp / 2 : lp; size_t resChannels = ROUND_UP(resOcBranch, hPSrc); - size_t resDataLen = (size_t)blockNum * ((size_t)blockLu * resChannels * resLp + 2 * resChannels * sizeof(float)); + size_t resDataLen = + (size_t)blockNum * ((size_t)blockLu * resChannels * resLp + 2 * resChannels * sizeof(float)); // The source for residual data starts after ALL processed srcH blocks. - memcpy(dst + (size_t)hUDst * dstStride0 + (hTail > 0 ? dstStride0 : 0), - src + (size_t)srcH * srcStride0, + memcpy(dst + (size_t)hUDst * dstStride0 + (hTail > 0 ? dstStride0 : 0), src + (size_t)srcH * srcStride0, resDataLen); } } @@ -1749,7 +1827,8 @@ static void _onlineReorderWeightKernelSumH128ToH32(float* dst, float* src, int b } } -static void _onlineReorderWeightKernelSumH8ToH32(float* dst, float* src, int blockNum, int hpSrc, int hpDst, int ocNeedReorder, int ocPreserve) { +static void _onlineReorderWeightKernelSumH8ToH32(float* dst, float* src, int blockNum, int hpSrc, int hpDst, + int ocNeedReorder, int ocPreserve) { // hpDst = 4 * hpSrc // src shape: [huSrc, blockNum, hpSrc], where huSrc = huDst * 4 // dst shape: [huDst, blockNum, hpDst] @@ -1776,13 +1855,15 @@ static void _onlineReorderWeightKernelSumH8ToH32(float* dst, float* src, int blo } if (ocPreserve) { - memcpy(dst + huDst * strideDst, src + 4 * huDst * strideSrc, ROUND_UP(ocPreserve, hpSrc) * blockNum * sizeof(float)); + memcpy(dst + huDst * strideDst, src + 4 * huDst * strideSrc, + ROUND_UP(ocPreserve, hpSrc) * blockNum * sizeof(float)); } } -ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inputs, const std::vector& outputs) { +ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inputs, + const std::vector& outputs) { const auto input = inputs[0]; - auto output = outputs[0]; + auto output = outputs[0]; auto core = static_cast(backend())->int8Functions(); auto gcore = static_cast(backend())->functions(); auto dynamicOption = static_cast(backend())->getRuntime()->hint().dynamicQuantOption % 8; @@ -1791,32 +1872,32 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu int SRC_UNIT = mGemmUnits[1]; int DST_XUNIT = mGemmUnits[2]; auto blitProc = mRelatedFunctions.MNNPackC4Int8ForMatMul_A; - const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow; - const int batch = input->batch(); - const int PackUnit = gcore->pack; - const int dstZStep = plane * PackUnit; - const int ocDiv4 = UP_DIV(output->channel(), PackUnit); - const int ocUp4 = ROUND_UP(output->channel(), PackUnit); - const int ocUpHp = ROUND_UP(output->channel(), UNIT); - const auto kernelCountUnit = mIm2ColParamter.kernelCountUnit; - const auto unitColBufferSize = kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); - const auto colBufferSize = unitColBufferSize * mIm2ColCount; - auto dstBytes = static_cast(backend())->getBytes(backend(), output); - const int blockL = kernelCountUnit / mBlockNum; // source depthQuad for each block. - const int kxky = mIm2ColParamter.kernelX * mIm2ColParamter.kernelY; - const int blocklu = blockL / kxky; // UP_DIV(ic,src_unit) per block - const int oc = output->channel(); - const int ic = input->channel(); - float weightBytes = 1.f; - int weightStepY = weightBytes * (UNIT * SRC_UNIT); - int inputPlane = batch * input->width() * input->height(); - - auto im2colPtr = mTempIm2ColBuffer->host(); + const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow; + const int batch = input->batch(); + const int PackUnit = gcore->pack; + const int dstZStep = plane * PackUnit; + const int ocDiv4 = UP_DIV(output->channel(), PackUnit); + const int ocUp4 = ROUND_UP(output->channel(), PackUnit); + const int ocUpHp = ROUND_UP(output->channel(), UNIT); + const auto kernelCountUnit = mIm2ColParamter.kernelCountUnit; + const auto unitColBufferSize = kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); + const auto colBufferSize = unitColBufferSize * mIm2ColCount; + auto dstBytes = static_cast(backend())->getBytes(backend(), output); + const int blockL = kernelCountUnit / mBlockNum; // source depthQuad for each block. + const int kxky = mIm2ColParamter.kernelX * mIm2ColParamter.kernelY; + const int blocklu = blockL / kxky; // UP_DIV(ic,src_unit) per block + const int oc = output->channel(); + const int ic = input->channel(); + float weightBytes = 1.f; + int weightStepY = weightBytes * (UNIT * SRC_UNIT); + int inputPlane = batch * input->width() * input->height(); + + auto im2colPtr = mTempIm2ColBuffer->host(); if (SRC_UNIT > PackUnit) { memset(im2colPtr, 0, mTempIm2ColBuffer->size()); } auto weightDataPtr = mResourceInt8->mWeightInt8->host(); - auto srcKernelSumPtr = (int8_t*)mTempSrcSum.ptr(); + auto srcKernelSumPtr = (int8_t*)mTempSrcSum.ptr(); auto im2colSrc = input->host(); auto outputDataPtr = output->host(); uint8_t* biasPtr = nullptr; @@ -1830,8 +1911,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } if (nullptr != mMutableResource.get()) { - biasPtr = mMutableResource->mBiasFloat->host(); - inputZeroPoint = mMutableResource->mInputZeroPoint; + biasPtr = mMutableResource->mBiasFloat->host(); + inputZeroPoint = mMutableResource->mInputZeroPoint; if (mBatchQuantInfo.get()) { float scalein = TensorUtils::getQuantInfo(inputs[0])[0]; float scaleou = TensorUtils::getQuantInfo(outputs[0])[0]; @@ -1850,7 +1931,9 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu int dropBranch = 0; #ifdef MNN_LOW_MEMORY - auto BatchAsyDynamicQuant = [&](uint8_t* floatPtr, int32_t& inputZero, uint8_t* inputDequantScale, int LDiv4, int eCount, int innerSide, int32_t availableThreads, int8_t* dstInt8, uint8_t* inputDequantBias, int tId) { + auto BatchAsyDynamicQuant = [&](uint8_t* floatPtr, int32_t& inputZero, uint8_t* inputDequantScale, int LDiv4, + int eCount, int innerSide, int32_t availableThreads, int8_t* dstInt8, + uint8_t* inputDequantBias, int tId) { // if mIm2ColBasedInt8=false, input shape: [kernelsize,mBlockNum,blocklu,EP,LP] // if mIm2ColBasedInt8=true, input shape: [ic/pack,EP,pack] auto scalePtr = (float*)inputDequantScale; @@ -1864,9 +1947,18 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto minPtr = mTempMaxMinValueBuffer.ptr() + tId * scaleCount * gcore->bytes; auto maxPtr = mTempMaxMinValueBuffer.ptr() + tId * scaleCount * gcore->bytes + (scaleCount / 2) * gcore->bytes; auto qscale = (float*)(mQScaleZero.ptr() + tId * scaleCount * QUANT_INFO_BYTES); - auto qbias = (float*)(mQScaleZero.ptr() + tId * scaleCount * QUANT_INFO_BYTES + (scaleCount / 2) * QUANT_INFO_BYTES); - - size_t info[9] = {(size_t)mInputBlockNum, (size_t)eCount, (size_t)innerSide, (size_t)DST_XUNIT, (size_t)SRC_UNIT, (size_t)kernelsize, (size_t)blocklu, 0, 0}; + auto qbias = + (float*)(mQScaleZero.ptr() + tId * scaleCount * QUANT_INFO_BYTES + (scaleCount / 2) * QUANT_INFO_BYTES); + + size_t info[9] = {(size_t)mInputBlockNum, + (size_t)eCount, + (size_t)innerSide, + (size_t)DST_XUNIT, + (size_t)SRC_UNIT, + (size_t)kernelsize, + (size_t)blocklu, + 0, + 0}; if (mIm2ColBasedInt8) { info[6] = LDiv4 / mInputBlockNum; } @@ -1877,7 +1969,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu info[8] = 1; } // scale&bias:float32 - gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, qscale, qbias, (float*)minPtr, (float*)maxPtr, (float*)floatPtr, info); + gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, qscale, qbias, (float*)minPtr, (float*)maxPtr, (float*)floatPtr, + info); // quant: float->int8_t if (!mToFuseInputbias2Bias) { @@ -1893,7 +1986,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto matmulBiasPtr = mResourceInt8->mOriginBias->host(); auto weightKernelSum = mResourceInt8->mWeightKernelSum->host(); auto inputZeroF = -qbias[0] * scalePtr[0]; - gcore->MNNDynamicUpdateConvBiasScale(updatedBiasPtr, matmulBiasPtr, weightKernelSum, &inputZeroF, UP_DIV(ocUpHp, 4)); + gcore->MNNDynamicUpdateConvBiasScale(updatedBiasPtr, matmulBiasPtr, weightKernelSum, &inputZeroF, + UP_DIV(ocUpHp, 4)); biasPtr = (uint8_t*)updatedBiasPtr; auto unitsize = mBatchQuantInfo->length(1) / (2 * QUANT_INFO_BYTES); auto inputScale = scalePtr[0]; @@ -1903,19 +1997,20 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } }; - auto BatchSymDynamicQuant = [&](uint8_t* floatPtr, int32_t& inputZero, uint8_t* inputDequantScale, int LU, int EP, int LP, int32_t availableThreads, int8_t* dstInt8, int tId) { + auto BatchSymDynamicQuant = [&](uint8_t* floatPtr, int32_t& inputZero, uint8_t* inputDequantScale, int LU, int EP, + int LP, int32_t availableThreads, int8_t* dstInt8, int tId) { auto quantPtr = mQScaleZero.ptr() + tId * mSizeInputBlockQuant * QUANT_INFO_BYTES; auto maxPtr = mTempMaxMinValueBuffer.ptr() + tId * mSizeInputBlockQuant * gcore->bytes; // compute sum and absmax int divlu = UP_DIV(LU, availableThreads); - MNN_CONCURRENCY_BEGIN (tIdx, ALIMIN(availableThreads, UP_DIV(LU, divlu))) { + MNN_CONCURRENCY_BEGIN(tIdx, ALIMIN(availableThreads, UP_DIV(LU, divlu))) { auto exeLu = ALIMIN(divlu, LU - tIdx * divlu); auto batchMax = reinterpret_cast(maxPtr + tIdx * EP * gcore->bytes); - auto ptr_ = reinterpret_cast(floatPtr + tIdx * divlu * gcore->bytes * EP * LP); + auto ptr_ = reinterpret_cast(floatPtr + tIdx * divlu * gcore->bytes * EP * LP); gcore->MNNAbsMax((float*)ptr_, batchMax, exeLu, EP, LP); - } MNN_CONCURRENCY_END(); - + } + MNN_CONCURRENCY_END(); // Compute quant scale gcore->MNNQuantScale((float*)maxPtr, (float*)quantPtr, (float*)inputDequantScale, availableThreads, EP); @@ -1932,33 +2027,42 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu if (mIm2ColBasedInt8 && mResourceInt8->mDynamicQuant) { int icDiv4 = UP_DIV(input->channel(), PackUnit); if (mUseBatchQuan) { - int availthreads = (icDiv4 > mThreadNums && inputPlane > 255 ) ? mThreadNums : 1; + int availthreads = (icDiv4 > mThreadNums && inputPlane > 255) ? mThreadNums : 1; if (dynamicOption != 2) { - BatchSymDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, inputPlane, PackUnit, availthreads, mQuantInput->host(), 0); + BatchSymDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, + inputPlane, PackUnit, availthreads, mQuantInput->host(), 0); } else { - BatchAsyDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, inputPlane, PackUnit, availthreads, mQuantInput->host(), mBatchQuantInfo->host() + mBatchQuantInfo->stride(0) / 2, 0); + BatchAsyDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, + inputPlane, PackUnit, availthreads, mQuantInput->host(), + mBatchQuantInfo->host() + mBatchQuantInfo->stride(0) / 2, 0); } } else { - BatchAsyDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, inputPlane, PackUnit, 1, mQuantInput->host(), mBatchQuantInfo->host() + mBatchQuantInfo->stride(0) / 2, 0); + BatchAsyDynamicQuant(input->host(), inputZeroPoint, mBatchQuantInfo->host(), icDiv4, + inputPlane, PackUnit, 1, mQuantInput->host(), + mBatchQuantInfo->host() + mBatchQuantInfo->stride(0) / 2, 0); } im2colSrc = mQuantInput->host(); } - if (mOnlineReorderWeightSme && plane > 1) { - _onlineReorderWeightPackH128ToH32((int8_t*)mWeight4Prefill.ptr(), weightDataPtr, GEMM_INT8_UNIT_SME2_128, UNIT, UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128), mBlockNum, blockL, SRC_UNIT, mResourceInt8->mWeightBits == 4); + _onlineReorderWeightPackH128ToH32((int8_t*)mWeight4Prefill.ptr(), weightDataPtr, GEMM_INT8_UNIT_SME2_128, UNIT, + UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128), mBlockNum, blockL, SRC_UNIT, + mResourceInt8->mWeightBits == 4); int kernelSumMainSize = 0; int kernelSumBranchSize = 0; if (dstBytes > 1 && mInputBlockNum > 1) { - _onlineReorderWeightKernelSumH128ToH32((float*)mWeightKernelSum4Prefill.ptr(), mResourceInt8->mWeightKernelSum->host(), mBlockNum, GEMM_INT8_UNIT_SME2_128, UNIT, mOcMain); + _onlineReorderWeightKernelSumH128ToH32((float*)mWeightKernelSum4Prefill.ptr(), + mResourceInt8->mWeightKernelSum->host(), mBlockNum, + GEMM_INT8_UNIT_SME2_128, UNIT, mOcMain); kernelSumMainSize = ROUND_UP(mOcMain, UNIT) * mBlockNum * QUANT_INFO_BYTES; kernelSumBranchSize = ROUND_UP(mOcBranch, 8) * mBlockNum * QUANT_INFO_BYTES; } // If change the workload distribution among SME and NEON cores. if (mMixedKernel && mRatioDecode != mRatioPrefill) { - auto offsetWeight = UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128) * mBlockNum * blockL * SRC_UNIT * GEMM_INT8_UNIT_SME2_128; + auto offsetWeight = + UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128) * mBlockNum * blockL * SRC_UNIT * GEMM_INT8_UNIT_SME2_128; if (mResourceInt8->mWeightBits == 4) { offsetWeight /= 2; } @@ -1967,35 +2071,44 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu // Don't change mOcMain&mOcBranch here. int tmpMain = mOcMain; int tmpBranch = mOcBranch; - calculateSmeNeonWorkDivision(tmpMain, tmpBranch, mDividesTmp, oc, threads, PackUnit, plane, mRatioPrefill, mSmeCores); + calculateSmeNeonWorkDivision(tmpMain, tmpBranch, mDividesTmp, oc, threads, PackUnit, plane, mRatioPrefill, + mSmeCores); auto updatedSmeWork = mDividesTmp[mSmeCores]; - - if (updatedSmeWork - mOriginSmeWork > 0 && ((updatedSmeWork - mOriginSmeWork) * 4 % 8 == 0)) { // To ensure pack=4, dropBranch % 2 == 0 - dropBranch = updatedSmeWork - mOriginSmeWork; // Ensure update "dropBranch" inner the loop. - memcpy(mDivides.data(), mDividesTmp.data(), (threads+1) * sizeof(float)); + if (updatedSmeWork - mOriginSmeWork > 0 && + ((updatedSmeWork - mOriginSmeWork) * 4 % 8 == 0)) { // To ensure pack=4, dropBranch % 2 == 0 + dropBranch = updatedSmeWork - mOriginSmeWork; // Ensure update "dropBranch" inner the loop. + memcpy(mDivides.data(), mDividesTmp.data(), (threads + 1) * sizeof(float)); dropBranch = mDivides[mSmeCores] - mOriginSmeWork; - _onlineReorderWeightPackH8ToH32((int8_t*)(mWeight4Prefill.ptr() + offsetWeight), weightDataPtr + offsetWeight, blockL, SRC_UNIT, mResourceInt8->mWeightBits == 4, (int)(dropBranch * PackUnit / 8), mBlockNum, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); + _onlineReorderWeightPackH8ToH32((int8_t*)(mWeight4Prefill.ptr() + offsetWeight), + weightDataPtr + offsetWeight, blockL, SRC_UNIT, + mResourceInt8->mWeightBits == 4, (int)(dropBranch * PackUnit / 8), + mBlockNum, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); } if (dstBytes > 1 && mInputBlockNum > 1) { if (dropBranch > 0) { // reorder - _onlineReorderWeightKernelSumH8ToH32((float*)(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize), (float*)(mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize), mBlockNum, 8, UNIT, dropBranch * PackUnit, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); + _onlineReorderWeightKernelSumH8ToH32( + (float*)(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize), + (float*)(mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize), mBlockNum, 8, + UNIT, dropBranch * PackUnit, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); } } } - if (dropBranch == 0) { // If dropBranch == 0, it means that the arrangement of the weights processed by the Arm82 architecture remains unchanged. + if (dropBranch == 0) { // If dropBranch == 0, it means that the arrangement of the weights processed by the + // Arm82 architecture remains unchanged. // copy - memcpy(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize, mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize, kernelSumBranchSize); + memcpy(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize, + mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize, kernelSumBranchSize); } weightDataPtr = (int8_t*)mWeight4Prefill.ptr(); } #endif if (mResourceInt8->mWeightBits == 4) { - weightBytes = 0.5; + weightBytes = 0.5; weightStepY /= 2; } else if (mResourceInt8->mWeightBits == 3) { auto packedBytesPerOc = (SRC_UNIT * 3 + 7) / 8; @@ -2043,11 +2156,13 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu ptrInputScale = inputScale; } if (mBlockNum > 1) { - accumbuff = reinterpret_cast(mAccumBuffer->host() + tId * mAccumBuffer->stride(0) * sizeof(int32_t)); + accumbuff = reinterpret_cast(mAccumBuffer->host() + + tId * mAccumBuffer->stride(0) * sizeof(int32_t)); } float* ptrY = nullptr; if (dstBytes != 1) { - ptrY = (mOnlineReorderWeightSme && mInputBlockNum > 1) ? (float*)mWeightKernelSum4Prefill.ptr() : mResourceInt8->mWeightKernelSum->host(); + ptrY = (mOnlineReorderWeightSme && mInputBlockNum > 1) ? (float*)mWeightKernelSum4Prefill.ptr() + : mResourceInt8->mWeightKernelSum->host(); } QuanPostTreatParameters quanParam; quanParam.blockNum = mBlockNum; @@ -2072,24 +2187,30 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto weightPtrTid = weightDataPtr; quanParam.weightKernelSum = ptrY; quanParam.biasFloat = reinterpret_cast(biasPtr); - auto im2colDstThread = im2colPtr + tId * mTempIm2ColBuffer->stride(0); - auto srcPtr = (int8_t const **)(mBlitInfo.ptr() + tId * mBlitInfoStride.first); - auto el = (int32_t *)(srcPtr + mBlitInfoStride.second); - auto xKernelSumPtrTid = reinterpret_cast(srcKernelSumPtr + tId * mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES); + auto im2colDstThread = im2colPtr + tId * mTempIm2ColBuffer->stride(0); + auto srcPtr = (int8_t const**)(mBlitInfo.ptr() + tId * mBlitInfoStride.first); + auto el = (int32_t*)(srcPtr + mBlitInfoStride.second); + auto xKernelSumPtrTid = + reinterpret_cast(srcKernelSumPtr + tId * mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES); int32_t info[5]; info[1] = mIm2ColParamter.iw * mIm2ColParamter.ih * batch; info[2] = static_cast(unitColBufferSize); info[3] = mIm2ColParamter.strideX; for (int tIndex = eStartIndex; tIndex < eEndIndex; tIndex += estep) { - const int xIndexStart = tIndex * DST_XUNIT * mIm2ColCount; + const int xIndexStart = tIndex * DST_XUNIT * mIm2ColCount; auto outputInTilePtr = outputDataPtr + xIndexStart * PackUnit * dstBytes; int realDstCount = ALIMIN(plane - xIndexStart, DST_XUNIT * mIm2ColCount); - ptrInputScale = (mUseBatchQuan && mIm2ColBasedInt8) ? (inputScale + xIndexStart * mInputBlockNum * QUANT_INFO_BYTES) : inputScale; - ptrInputBias = (inputBias != nullptr) ? (inputBias + xIndexStart * mInputBlockNum * QUANT_INFO_BYTES) : inputBias; + ptrInputScale = (mUseBatchQuan && mIm2ColBasedInt8) + ? (inputScale + xIndexStart * mInputBlockNum * QUANT_INFO_BYTES) + : inputScale; + ptrInputBias = + (inputBias != nullptr) ? (inputBias + xIndexStart * mInputBlockNum * QUANT_INFO_BYTES) : inputBias; // im2col auto im2colDst = im2colDstThread; - auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo((const float**)srcPtr, el, xIndexStart, realDstCount, mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); + auto res = + ConvolutionTiledExecutor::turnIm2ColToBlitInfo((const float**)srcPtr, el, xIndexStart, realDstCount, + mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); int number = res.first; bool needZero = res.second; if (needZero && mIm2ColBasedInt8) { @@ -2114,18 +2235,25 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu memset(im2colDst, 0, mTempIm2ColBuffer->stride(0)); } info[2] = realDstCount; - mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp] + mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, + PackUnit); // im2colDst: [lu, realDstCount, lp] } ptrInputScale = mBatchQuantInfo->host() + tId * mBatchQuantInfo->stride(0); if (dynamicOption == 2) { ptrInputBias = ptrInputScale + mBatchQuantInfo->stride(0) / 2; - BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, realDstCount, SRC_UNIT, 1, mQuantInput->host() + tId * mQuantInput->stride(0), ptrInputBias, tId); + BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, + realDstCount, SRC_UNIT, 1, + mQuantInput->host() + tId * mQuantInput->stride(0), ptrInputBias, tId); } else if (mUseBatchQuan) { - BatchSymDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, realDstCount, SRC_UNIT, 1, mQuantInput->host() + tId * mQuantInput->stride(0), tId); + BatchSymDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, + realDstCount, SRC_UNIT, 1, + mQuantInput->host() + tId * mQuantInput->stride(0), tId); } else { auto maxMinPtr = mTempMaxMinValueBuffer.ptr() + tId * 2 * gcore->bytes; ptrInputBias = ptrInputScale + mBatchQuantInfo->stride(0) / 2; - BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, realDstCount, SRC_UNIT, 1, mQuantInput->host() + tId * mQuantInput->stride(0), ptrInputBias, tId); + BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputScale, kernelCountUnit, + realDstCount, SRC_UNIT, 1, + mQuantInput->host() + tId * mQuantInput->stride(0), ptrInputBias, tId); quanParam.biasFloat = (float*)(mBiasBufferFusedInputzero.ptr() + tId * ocUpHp * QUANT_INFO_BYTES); } im2colDst = mQuantInput->host() + tId * mQuantInput->stride(0); @@ -2134,12 +2262,13 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto eU = UP_DIV(realDstCount, DST_XUNIT); // eU <= mIm2ColCount auto reorderBuffer = mReorderBuffer.ptr() + tId * colBufferSize; for (int k = 0; k < eU; ++k) { - int inside = blocklu * SRC_UNIT * ALIMIN(realDstCount - k * DST_XUNIT, DST_XUNIT); + int inside = blocklu * SRC_UNIT * ALIMIN(realDstCount - k * DST_XUNIT, DST_XUNIT); auto dstbuffer = reorderBuffer + k * unitColBufferSize; auto srcbuffer = im2colDst + k * unitColBufferSize; for (int i = 0; i < mBlockNum; ++i) { for (int j = 0; j < kxky; ++j) { - memcpy(dstbuffer + i * kxky * inside + j * inside, srcbuffer + i * inside + j * mBlockNum * inside, inside); + memcpy(dstbuffer + i * kxky * inside + j * inside, + srcbuffer + i * inside + j * mBlockNum * inside, inside); } } } @@ -2148,7 +2277,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #endif if (mResourceInt8->mWeightAsymmetricQuant) { MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host()); - mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, realDstCount, sumParams); + mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, + realDstCount, sumParams); } else { memset(xKernelSumPtrTid, 0, mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES); } @@ -2162,20 +2292,23 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu quanParam.accumBuffer = accumbuff; } quanParam.srcKernelSum = ptrX; - mGemmKernel(outputInTilePtr, im2colDst, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step); + mGemmKernel(outputInTilePtr, im2colDst, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, + &quanParam, step); ptrX += (step * mBlockNum); - realDstCount-=step; + realDstCount -= step; outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; im2colDst += unitColBufferSize; - ptrInputScale = mUseBatchQuan ? (ptrInputScale + step * mInputBlockNum * QUANT_INFO_BYTES) : ptrInputScale; - ptrInputBias = (ptrInputBias != nullptr) ? (ptrInputBias + step * mInputBlockNum * QUANT_INFO_BYTES) : ptrInputBias; - } while(realDstCount > 0); + ptrInputScale = + mUseBatchQuan ? (ptrInputScale + step * mInputBlockNum * QUANT_INFO_BYTES) : ptrInputScale; + ptrInputBias = (ptrInputBias != nullptr) ? (ptrInputBias + step * mInputBlockNum * QUANT_INFO_BYTES) + : ptrInputBias; + } while (realDstCount > 0); } }; auto ocSplitFunction = [&](int threads) { // Thread split by OC - auto im2colDst = mTempIm2ColBuffer->host(); - auto srcPtr = (int8_t const **)(mBlitInfo.ptr()); - auto el = (int32_t *)(srcPtr + mBlitInfoStride.second); + auto im2colDst = mTempIm2ColBuffer->host(); + auto srcPtr = (int8_t const**)(mBlitInfo.ptr()); + auto el = (int32_t*)(srcPtr + mBlitInfoStride.second); auto xKernelSumPtr = reinterpret_cast(mTempSrcSum.ptr()); auto eU = UP_DIV(plane, DST_XUNIT); @@ -2186,7 +2319,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu float* reluPtr = mResourceInt8->mReluThreshold.data(); if (mIm2ColBasedInt8) { // im2col - auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo((const float**)srcPtr, el, 0, plane, mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); + auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo( + (const float**)srcPtr, el, 0, plane, mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); int number = res.first; bool needZero = res.second; if (needZero) { @@ -2214,7 +2348,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu while (realDstCount > 0) { int work = std::min(realDstCount, DST_XUNIT); sizePacked += (work * SRC_UNIT * kernelCountUnit); - auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo((const float**)srcPtr, el, start, work, mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); + auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo( + (const float**)srcPtr, el, start, work, mIm2ColParamter, (uint8_t*)im2colSrc, im2colBytes); int number = res.first; bool needZero = res.second; if (needZero) { @@ -2223,14 +2358,17 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu info[0] = number; info[2] = work; if (number > 0) { // im2col - mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp] + mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, + PackUnit); // im2colDst: [lu, realDstCount, lp] } if (mUseBatchQuan || dynamicOption == 2) { if (dynamicOption == 2) { - BatchAsyDynamicQuant((uint8_t*)im2colDstTmp, inputZeroPoint, ptrInputscale, kernelCountUnit, work, SRC_UNIT, 1, int8Ptr, ptrInputbias, 0); + BatchAsyDynamicQuant((uint8_t*)im2colDstTmp, inputZeroPoint, ptrInputscale, kernelCountUnit, + work, SRC_UNIT, 1, int8Ptr, ptrInputbias, 0); ptrInputbias += (mInputBlockNum * work * sizeof(int32_t)); } else { - BatchSymDynamicQuant((uint8_t*)im2colDstTmp, inputZeroPoint, ptrInputscale, kernelCountUnit, work, SRC_UNIT, 1, int8Ptr, 0); + BatchSymDynamicQuant((uint8_t*)im2colDstTmp, inputZeroPoint, ptrInputscale, kernelCountUnit, + work, SRC_UNIT, 1, int8Ptr, 0); } ptrInputscale += (mInputBlockNum * work * sizeof(int32_t)); int8Ptr += unitColBufferSize; @@ -2240,18 +2378,21 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu im2colDstTmp += (unitColBufferSize * gcore->bytes); } if (!mUseBatchQuan && dynamicOption != 2) { - BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputscale, kernelCountUnit, plane, SRC_UNIT, 1, mQuantInput->host(), ptrInputscale + plane * mInputBlockNum* QUANT_INFO_BYTES, 0); + BatchAsyDynamicQuant((uint8_t*)im2colDst, inputZeroPoint, ptrInputscale, kernelCountUnit, plane, + SRC_UNIT, 1, mQuantInput->host(), + ptrInputscale + plane * mInputBlockNum * QUANT_INFO_BYTES, 0); } im2colDst = mQuantInput->host(); } if (mBlockNum > 1 && kxky > 1) { for (int k = 0; k < eU; ++k) { - int inside = blocklu * SRC_UNIT * ALIMIN(DST_XUNIT, plane - k * DST_XUNIT); + int inside = blocklu * SRC_UNIT * ALIMIN(DST_XUNIT, plane - k * DST_XUNIT); auto dstbuffer = mReorderBuffer.ptr() + k * unitColBufferSize; auto srcbuffer = im2colDst + k * unitColBufferSize; for (int i = 0; i < mBlockNum; ++i) { for (int j = 0; j < kxky; ++j) { - memcpy(dstbuffer + i * kxky * inside + j * inside, srcbuffer + i * inside + j * mBlockNum * inside, inside); + memcpy(dstbuffer + i * kxky * inside + j * inside, + srcbuffer + i * inside + j * mBlockNum * inside, inside); } } } @@ -2260,7 +2401,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #endif if (mResourceInt8->mWeightAsymmetricQuant) { MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host()); - mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host(), plane, sumParams); + mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host(), plane, + sumParams); } else { memset(xKernelSumPtr, 0, mTileCount * mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES); } @@ -2279,7 +2421,9 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto im2colDstThread = im2colDst; float* ptrY = nullptr; if (dstBytes != 1) { - float* wkernelSum = (mOnlineReorderWeightSme && mInputBlockNum > 1 && plane > 1) ? (float*)mWeightKernelSum4Prefill.ptr() : mResourceInt8->mWeightKernelSum->host(); + float* wkernelSum = (mOnlineReorderWeightSme && mInputBlockNum > 1 && plane > 1) + ? (float*)mWeightKernelSum4Prefill.ptr() + : mResourceInt8->mWeightKernelSum->host(); ptrY = wkernelSum + ocIndex * mInputBlockNum; } QuanPostTreatParameters quanParam; @@ -2316,7 +2460,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu inputScale = (uint8_t*)fakeInputScales.data(); } if (mBlockNum > 1) { - accumbuff = reinterpret_cast(mAccumBuffer->host() + tId * mAccumBuffer->stride(0) * sizeof(int32_t)); + accumbuff = reinterpret_cast(mAccumBuffer->host() + + tId * mAccumBuffer->stride(0) * sizeof(int32_t)); } auto outputInTilePtr = outputDataPtr + ocIndex * plane * dstBytes; @@ -2326,7 +2471,9 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu weightSrc = mResourceInt8->mWeightInt8->host(); } - auto weightPtrTid = weightSrc + static_cast(ocIndex * mBlockNum * blockL * SRC_UNIT * weightBytes + ocIndex * 2 * mBlockNum * QUANT_INFO_BYTES); + auto weightPtrTid = + weightSrc + static_cast(ocIndex * mBlockNum * blockL * SRC_UNIT * weightBytes + + ocIndex * 2 * mBlockNum * QUANT_INFO_BYTES); int realDstCount = plane; auto ptrX = xKernelSumPtr; @@ -2339,14 +2486,16 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu memset(accumbuff, 0, UNIT * 4 * DST_XUNIT); quanParam.accumBuffer = accumbuff; } - gemmInt8(outputInTilePtr, im2colDstThread, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step); + gemmInt8(outputInTilePtr, im2colDstThread, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, + &quanParam, step); ptrX += (step * mBlockNum); - realDstCount-=step; + realDstCount -= step; outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; im2colDstThread += unitColBufferSize; inputScale = mUseBatchQuan ? (inputScale + mInputBlockNum * step * QUANT_INFO_BYTES) : inputScale; - inputBias = (inputBias != nullptr) ? (inputBias + mInputBlockNum * step * QUANT_INFO_BYTES) : inputBias; - } while(realDstCount > 0); + inputBias = + (inputBias != nullptr) ? (inputBias + mInputBlockNum * step * QUANT_INFO_BYTES) : inputBias; + } while (realDstCount > 0); } } MNN_CONCURRENCY_END(); @@ -2369,7 +2518,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu outputQuantScale[i] = s; } float zero_ = TensorUtils::getQuantInfo(outputs[0])[1]; - mQuantFunc((float*)mTempOutput.ptr(), output->host(), plane * ocDiv4, outputQuantScale.data(), mResourceInt8->mClampMin, mResourceInt8->mClampMax, &zero_, 0); + mQuantFunc((float*)mTempOutput.ptr(), output->host(), plane * ocDiv4, outputQuantScale.data(), + mResourceInt8->mClampMin, mResourceInt8->mClampMax, &zero_, 0); } return NO_ERROR; } diff --git a/source/backend/cpu/compute/SharedGather.cpp b/source/backend/cpu/compute/SharedGather.cpp new file mode 100644 index 0000000000..4de178b89c --- /dev/null +++ b/source/backend/cpu/compute/SharedGather.cpp @@ -0,0 +1,157 @@ +#include "SharedGather.hpp" +#include "CommonOptFunction.h" +#include "../CPUBackend.hpp" +#include "core/BufferAllocator.hpp" +#include "core/Macro.h" + +namespace MNN { + +SharedGather::SharedGather(Backend* backend, std::shared_ptr res) : Execution(backend) { + mResource = res; +} + +SharedGather::~SharedGather() { + // Do nothing. +} + +ErrorCode SharedGather::onResize(const std::vector& inputs, const std::vector& outputs) { + auto bytes = static_cast(backend())->functions()->bytes; + auto output = outputs[0]; + int ic = output->length(output->dimensions() - 1); + if (bytes != 4) { + mCacheBuffer = static_cast(backend())->getBufferAllocator()->alloc(ic * sizeof(float)); + static_cast(backend())->getBufferAllocator()->free(mCacheBuffer); + } + return NO_ERROR; +} + +ErrorCode SharedGather::onExecute(const std::vector& inputs, const std::vector& outputs) { + auto input = inputs[0]; + auto output = outputs[0]; + int outside = input->elementSize(); + int ic = output->length(output->dimensions() - 1); + MNN_ASSERT(ic % mResource->mBlockNum == 0); + int block = ic / mResource->mBlockNum; + MNN_ASSERT(4 == mResource->mWeightBits || 8 == mResource->mWeightBits); + auto outputPtr = output->host(); + auto indice = input->host(); + auto perHpQuantSize = mResource->mBlockNum * 2 * sizeof(float) * mResource->mHp; + auto perHpWeightSize = UP_DIV(ic, mResource->mLp) * mResource->mHp * mResource->mLp * mResource->mWeightBits / 8; + auto perBlockWeightSize = block * mResource->mHp * mResource->mWeightBits / 8; + auto perBlockQuantSize = 2 * mResource->mHp * sizeof(float); + auto func = static_cast(backend())->functions(); + auto bytes = func->bytes; + + MNN_ASSERT(mResource->mLp % 2 == 0); + int lpStep = mResource->mWeightBits == 4 ? mResource->mLp / 2 : mResource->mLp; + int blockUnit = block / mResource->mLp; + int permuteUnit = mResource->mLp * mResource->mHp; + int halfPermuteStride = static_cast(permuteUnit / 2); + if (8 == mResource->mWeightBits) { + for (int z = 0; z < outside; ++z) { + auto index = indice[z]; + int zO = index / mResource->mHp; + int zI = index % mResource->mHp; + auto srcZ = mResource->mWeightInt8->host() + zO * (perHpQuantSize + perHpWeightSize); + auto dstZInt8 = outputPtr + z * ic * bytes; + float* dstZ = reinterpret_cast(dstZInt8); + if (bytes == 2) { + dstZ = reinterpret_cast(mCacheBuffer.ptr()); + } + for (int i = 0; i < mResource->mBlockNum; ++i) { + auto quantPtr = reinterpret_cast(srcZ + i * (perBlockQuantSize + perBlockWeightSize) + + perBlockWeightSize); + float scale = quantPtr[zI]; + float bias = quantPtr[zI + mResource->mHp]; + auto dstB = dstZ + i * block; + auto srcB = srcZ + i * (perBlockQuantSize + perBlockWeightSize) + zI * lpStep; + for (int j = 0; j < blockUnit; ++j) { + for (int k = 0; k < lpStep; ++k) { + dstB[j * lpStep + k] = srcB[j * lpStep * mResource->mHp + k] * scale + bias; + } + } + } + if (bytes == 2) { + func->MNNFp32ToLowp(dstZ, reinterpret_cast(dstZInt8), ic); + } + } + return NO_ERROR; + } + if (mResource->mPackMode == 0) { + for (int z = 0; z < outside; ++z) { + auto index = indice[z]; + int zO = index / mResource->mHp; + int zI = index % mResource->mHp; + int zI0 = zI / (mResource->mHp / 2); + int zI1 = zI % (mResource->mHp / 2); + int step = (1 - zI0) * 4; + auto srcZ = mResource->mWeightInt8->host() + zO * (perHpQuantSize + perHpWeightSize); + auto dstZInt8 = outputPtr + z * ic * bytes; + float* dstZ = reinterpret_cast(dstZInt8); + if (bytes == 2) { + dstZ = reinterpret_cast(mCacheBuffer.ptr()); + } + for (int i = 0; i < mResource->mBlockNum; ++i) { + auto quantPtr = reinterpret_cast(srcZ + i * (perBlockQuantSize + perBlockWeightSize) + + perBlockWeightSize); + float scale = quantPtr[zI]; + float bias = quantPtr[zI + mResource->mHp]; + auto dstB = dstZ + i * block; + auto srcB = srcZ + i * (perBlockQuantSize + perBlockWeightSize) + zI1 * mResource->mLp; + for (int j = 0; j < blockUnit; ++j) { + for (int k = 0; k < mResource->mLp; ++k) { + uint8_t w = *reinterpret_cast(srcB + j * halfPermuteStride + k); + auto w1 = (w >> step) % 16; + dstB[j * mResource->mLp + k] = w1 * scale + bias; + } + } + } + if (bytes == 2) { + func->MNNFp32ToLowp(dstZ, reinterpret_cast(dstZInt8), ic); + } + } + return NO_ERROR; + } + for (int z = 0; z < outside; ++z) { + auto index = indice[z]; + int zO = index / mResource->mHp; + int zI = index % mResource->mHp; + auto srcZ = mResource->mWeightInt8->host() + zO * (perHpQuantSize + perHpWeightSize); + auto dstZInt8 = outputPtr + z * ic * bytes; + float* dstZ = reinterpret_cast(dstZInt8); + if (bytes == 2) { + dstZ = reinterpret_cast(mCacheBuffer.ptr()); + } + for (int i = 0; i < mResource->mBlockNum; ++i) { + auto quantPtr = reinterpret_cast(srcZ + i * (perBlockQuantSize + perBlockWeightSize) + + perBlockWeightSize); + float scale = quantPtr[zI]; + float bias = quantPtr[zI + mResource->mHp]; + auto dstB = dstZ + i * block; + auto srcB = srcZ + i * (perBlockQuantSize + perBlockWeightSize) + zI * lpStep; + for (int j = 0; j < blockUnit; ++j) { + for (int k = 0; k < lpStep; ++k) { + uint8_t w = *reinterpret_cast(srcB + j * lpStep * mResource->mHp + k); + auto w0 = w % 16; + auto w1 = w / 16; + dstB[2 * (j * lpStep + k) + 0] = w0 * scale + bias; + dstB[2 * (j * lpStep + k) + 1] = w1 * scale + bias; + } + } + } + if (bytes == 2) { + func->MNNFp32ToLowp(dstZ, reinterpret_cast(dstZInt8), ic); + } + } + return NO_ERROR; +} + +bool SharedGather::onClone(Backend* bn, const Op* op, Execution** dst) { + if (nullptr == dst) { + return true; + } + *dst = new SharedGather(bn, mResource); + return true; +} + +} // namespace MNN diff --git a/source/backend/cpu/compute/SharedGather.hpp b/source/backend/cpu/compute/SharedGather.hpp new file mode 100644 index 0000000000..65bbb2aea3 --- /dev/null +++ b/source/backend/cpu/compute/SharedGather.hpp @@ -0,0 +1,22 @@ +#ifndef SharedGather_hpp +#define SharedGather_hpp + +#include "backend/cpu/CPUConvolution.hpp" +#include "core/Execution.hpp" + +namespace MNN { +class SharedGather : public Execution { +public: + SharedGather(Backend* backend, std::shared_ptr res); + virtual ~SharedGather(); + virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; + virtual ErrorCode onExecute(const std::vector& inputs, const std::vector& outputs) override; + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; + +private: + std::shared_ptr mResource; + MemChunk mCacheBuffer; +}; +} // namespace MNN + +#endif diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index 3773105760..6f5c99c0fe 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -12,16 +12,62 @@ #include "avxfma/FunctionSummary.hpp" #include "avx512/FunctionSummary.hpp" #include "sse/FunctionSummary.hpp" +#include namespace MNN { static int geP, glP, ghP; static CoreFunctions* gAVX2CoreFunctions = nullptr; static CoreInt8Functions* gAVX2CoreInt8Functions = nullptr; -static void _MNNGetMatMulPackMode(int* eP, int *lP, int* hP) { +static void _MNNGetMatMulPackMode(int* eP, int* lP, int* hP) { *eP = geP; *lP = glP; *hP = ghP; } +template +static void _MNNNormPacked_Float(float* dest, const float* source, const float* gamma, const float* beta, float epsilon, + size_t batch, size_t channels, bool RMSNorm) { + const size_t channelUnit = UP_DIV(channels, Pack); + for (size_t n = 0; n < batch; ++n) { + float mean = 0.0f; + if (!RMSNorm) { + float sum = 0.0f; + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + sum += source[(cu * batch + n) * Pack + cr]; + } + mean = sum / static_cast(channels); + } + + float squareSum = 0.0f; + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + float v = source[(cu * batch + n) * Pack + cr]; + float d = RMSNorm ? v : (v - mean); + squareSum += d * d; + } + + const float invStd = 1.0f / std::sqrt(squareSum / static_cast(channels) + epsilon); + for (size_t c = 0; c < channels; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + const size_t index = (cu * batch + n) * Pack + cr; + float v = source[index]; + float norm = RMSNorm ? (v * invStd) : ((v - mean) * invStd); + if (gamma && beta) { + norm = norm * gamma[c] + beta[c]; + } + dest[index] = norm; + } + for (size_t c = channels; c < channelUnit * Pack; ++c) { + const size_t cu = c / Pack; + const size_t cr = c - cu * Pack; + dest[(cu * batch + n) * Pack + cr] = 0.0f; + } + } +} + #ifndef MNN_USE_AVX bool AVX2Functions::init(int cpuFlags) { return false; @@ -43,7 +89,7 @@ bool AVX2Functions::init(int cpuFlags) { ghP = 4; _AVX_ReorderInit(coreFunction); - coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul; + coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain; #ifdef MNN_LOW_MEMORY @@ -52,8 +98,8 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNAsyQuantFunc = _AVX_MNNAsyQuantFunc; coreFunction->MNNAsyQuantInfo = _AVX_MNNAsyQuantInfo; #endif - coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; - coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B; + coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; + coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B; coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1; coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1; // Dynamic Quant @@ -62,31 +108,29 @@ bool AVX2Functions::init(int cpuFlags) { // For Packed Functions coreFunction->pack = 8; + coreFunction->MNNNormPacked = _MNNNormPacked_Float<8>; _AVX_ExtraInit(coreFunction); // Winograd _AVX_WinogradInit(coreFunction); if (cpuFlags & libyuv::kCpuHasFMA3) { - coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA; + coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA; coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA; coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1FMA; coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1FMA; _AVX_ExtraInitFMA(coreFunction); } #ifdef MNN_AVX512 - if ((cpuFlags & libyuv::kCpuHasAVX512VNNI) - || (cpuFlags & libyuv::kCpuHasAVX512VL) - || (cpuFlags & libyuv::kCpuHasAVX512BW) - || (cpuFlags & libyuv::kCpuHasAVX512VBMI) - || (cpuFlags & libyuv::kCpuHasAVX512VBITALG) - || (cpuFlags & libyuv::kCpuHasAVX512VPOPCNTDQ) - || (cpuFlags & libyuv::kCpuHasAVX512VBMI2) - ) { + if ((cpuFlags & libyuv::kCpuHasAVX512VNNI) || (cpuFlags & libyuv::kCpuHasAVX512VL) || + (cpuFlags & libyuv::kCpuHasAVX512BW) || (cpuFlags & libyuv::kCpuHasAVX512VBMI) || + (cpuFlags & libyuv::kCpuHasAVX512VBITALG) || (cpuFlags & libyuv::kCpuHasAVX512VPOPCNTDQ) || + (cpuFlags & libyuv::kCpuHasAVX512VBMI2)) { coreFunction->pack = 16; + coreFunction->MNNNormPacked = _MNNNormPacked_Float<16>; _AVX512_ReorderInit(coreFunction); _AVX512_ExtraInit(coreFunction); _AVX512_WinogradInit(coreFunction); - coreFunction->MNNPackForMatMul_B = _AVX512_MNNPackForMatMul_B; - coreFunction->MNNPackC4ForMatMul_A = _AVX512_MNNPackC8ForMatMul_A; + coreFunction->MNNPackForMatMul_B = _AVX512_MNNPackForMatMul_B; + coreFunction->MNNPackC4ForMatMul_A = _AVX512_MNNPackC8ForMatMul_A; coreFunction->MNNPackedMatMul = _AVX512_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _AVX512_MNNPackedMatMulRemain; geP = 48; @@ -94,11 +138,11 @@ bool AVX2Functions::init(int cpuFlags) { glP = 1; _AVX512_MNNInt8FunctionInit(gAVX2CoreInt8Functions, cpuFlags & libyuv::kCpuHasAVX512VNNI); memcpy(coreFunction->MNNPackedMatMulOC16Functions, _AVX512_MNNPackedMatMulOC16Functions, - sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); + sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); memcpy(coreFunction->MNNPackedMatMulOC32Functions, _AVX512_MNNPackedMatMulOC32Functions, - sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); + sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); memcpy(coreFunction->MNNPackedMatMulOC48Functions, _AVX512_MNNPackedMatMulOC48Functions, - sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); + sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX); } #endif { @@ -106,7 +150,8 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->int8MatmulRelatedFunctions.Int8GemmKernelFast = gAVX2CoreInt8Functions->Int8GemmKernelFast; coreFunction->int8MatmulRelatedFunctions.Int8GemmKernel_W4 = gAVX2CoreInt8Functions->Int8GemmKernel_W4; coreFunction->int8MatmulRelatedFunctions.MNNGetGemmUnit = gAVX2CoreInt8Functions->MNNGetGemmUnit; - coreFunction->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A; + coreFunction->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = + gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A; coreFunction->int8MatmulRelatedFunctions.eP = 4; } return true; @@ -119,4 +164,4 @@ CoreFunctions* AVX2Functions::get() { CoreInt8Functions* AVX2Functions::getInt8() { return gAVX2CoreInt8Functions; } -}; +}; // namespace MNN diff --git a/source/backend/metal/ConvSimdGroupShader.hpp b/source/backend/metal/ConvSimdGroupShader.hpp index c10056fb50..dd12821653 100644 --- a/source/backend/metal/ConvSimdGroupShader.hpp +++ b/source/backend/metal/ConvSimdGroupShader.hpp @@ -8,7 +8,7 @@ #if MNN_METAL_ENABLED -const char* gBasicConvPrefix = R"metal( +static const char* gBasicConvPrefix = R"metal( #include #include @@ -205,7 +205,7 @@ typedef half4x4 FLOAT4x4; )metal"; -const char* gConv1x1WqSgMatrix = R"metal( +static const char* gConv1x1WqSgMatrix = R"metal( // W_QUANT_2/3 fall through to W_QUANT_4 macros for unimplemented gemm kernels so that // the source still compiles. Only conv1x1_gemv_g8_wquant_sg has true W_QUANT_2/3 paths. #if (defined(W_QUANT_2) || defined(W_QUANT_3)) && !defined(W_QUANT_4) && !defined(W_QUANT_8) @@ -1256,7 +1256,7 @@ kernel void conv1x1_gemm_32x64_wquant_sg(const device ftype2 *in [[bu } )metal"; -const char* gConv1x1WfpSgMatrix = R"metal( +static const char* gConv1x1WfpSgMatrix = R"metal( #ifdef USE_METAL_TENSOR_OPS #include #include @@ -1305,8 +1305,6 @@ kernel void conv1x1_w_dequant( FLOAT dequant_bias = FLOAT(((const device ftype *)dequantScale)[((idx_n4 * cst.block_size + bi) * 2 + 1) * 4 + idx_nl]) / (FLOAT)cst.scale_coef; #ifdef W_QUANT_3 - // W_QUANT_3 layout: 6 bytes per (4 OC, 4 IC) tile. - // Tile base byte = (idx_n4 * input_slice + idx_k4) * 6. auto wt_base = wi + (idx_n4 * cst.input_slice + idx_k4) * 6; #else auto xy_wi = wi + (idx_n4 * cst.input_slice + idx_k4) * 4 + idx_nl;// [N/4, K/4, N4, K4] @@ -1314,19 +1312,16 @@ kernel void conv1x1_w_dequant( auto xy_wf = wf + ((idx_n4 * ((cst.input_slice+3)/4) + idx_k16) * 4 + idx_nl) * 4;// [N/4, K/4, N4, K4] #ifdef W_QUANT_2 - // W_QUANT_2 layout: 1 byte per (OC, 4 IC); bits [7:6]=IC0..[1:0]=IC3, value=signed+2. - // xy_wi (uchar*) points at byte (idx_n4*K/4+idx_k4)*4 + idx_nl (OC=idx_n4*4+idx_nl). for(int k = 0; k < 4; k++) { #if W_ALIGN_K16_PROTECT if(idx_k4 + k >= cst.input_slice) { xy_wf[k] = ftype4(0); continue; } #endif - uchar b = xy_wi[4*k]; // byte for K4 group (idx_k4+k) + uchar b = xy_wi[4*k]; FLOAT4 w4 = FLOAT4((float)((b >> 6) & 3) - 2, (float)((b >> 4) & 3) - 2, (float)((b >> 2) & 3) - 2, (float)( b & 3) - 2); xy_wf[k] = (ftype4)(w4 * scale + dequant_bias); } #elif defined(W_QUANT_3) - // W_QUANT_3: 6 bytes/tile; OC=idx_nl, IC=k_inner extracted. for(int k = 0; k < 4; k++) { #if W_ALIGN_K16_PROTECT if(idx_k4 + k >= cst.input_slice) { xy_wf[k] = ftype4(0); continue; } @@ -2134,7 +2129,7 @@ kernel void conv1x1_gemm_32x16_sg(const device ftype4 *in [[buffer(0) )metal"; -const char* gConv1x1WfpSgReduce = R"metal( +static const char* gConv1x1WfpSgReduce = R"metal( kernel void conv1x1_z4_sg(const device ftype4 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], constant conv1x1_constants& cst [[buffer(2)]], @@ -2166,7 +2161,7 @@ kernel void conv1x1_z4_sg(const device ftype4 *in [[buffer(0)]], } )metal"; -const char* gConv1x1WqSgReduce = R"metal( +static const char* gConv1x1WqSgReduce = R"metal( // W_QUANT_2/3 fall through to W_QUANT_4 macros for unimplemented kernels. #if (defined(W_QUANT_2) || defined(W_QUANT_3)) && !defined(W_QUANT_4) && !defined(W_QUANT_8) @@ -2303,7 +2298,7 @@ kernel void conv1x1_gemv_g8_wquant_sg(const device ftype4 *in [[buffe int rx = gid.y; #ifdef W_QUANT_3 - auto xy_wt = wt + uz * cst.input_slice * 6; // 6 bytes per tile + auto xy_wt = wt + uz * cst.input_slice * 6; #else auto xy_wt = wt + uz * cst.input_slice; #endif @@ -2329,7 +2324,6 @@ kernel void conv1x1_gemv_g8_wquant_sg(const device ftype4 *in [[buffe FLOAT4 in40 = (FLOAT4)*(xy_in0 + z * area_size); #ifdef W_QUANT_2 - // 4 bytes / tile, byte i = OC i, bits [7:6]=IC0..[1:0]=IC3, signed=unsigned-2. uchar4 w_b = xy_wt[z]; FLOAT4x4 w_dequant; for (int i = 0; i < 4; ++i) { @@ -2339,8 +2333,6 @@ kernel void conv1x1_gemv_g8_wquant_sg(const device ftype4 *in [[buffe w_dequant[i] = w4 * scale[i] + dequant_bias[i]; } #elif defined(W_QUANT_3) - // 6 bytes / tile: bytes 0..3 low 2bit (OC i, 4 IC), bytes 4..5 high 1bit - // (byte 4 OC{0,1}, byte 5 OC{2,3}; upper nibble = OC even, lower = OC odd; bit (3-k)=IC k high). const device uchar* tilePtr = xy_wt + z * 6; uchar lo0 = tilePtr[0], lo1 = tilePtr[1], lo2 = tilePtr[2], lo3 = tilePtr[3]; uchar hi01 = tilePtr[4], hi23 = tilePtr[5]; @@ -2349,9 +2341,7 @@ kernel void conv1x1_gemv_g8_wquant_sg(const device ftype4 *in [[buffe for (int i = 0; i < 4; ++i) { uchar b = lo[i]; uchar h = (i < 2) ? hi01 : hi23; - // upper nibble for OC even (i%2==0), lower for OC odd (i%2==1). uchar hShifted = (i % 2 == 0) ? (h >> 4) : (h & 0xF); - // hShifted bit (3-k) = IC k high bit FLOAT4 w4 = FLOAT4( (float)( ((b >> 6) & 3) | (((hShifted >> 3) & 1) << 2) ) - 4, (float)( ((b >> 4) & 3) | (((hShifted >> 2) & 1) << 2) ) - 4, @@ -2501,4 +2491,3 @@ kernel void conv1x1_gemv_g16_wquant_sg(const device ftype4 *in [[buff #endif - diff --git a/source/backend/metal/LayerNormSimdGroupShader.hpp b/source/backend/metal/LayerNormSimdGroupShader.hpp index 28ad1d852b..ae8b7d7dcf 100644 --- a/source/backend/metal/LayerNormSimdGroupShader.hpp +++ b/source/backend/metal/LayerNormSimdGroupShader.hpp @@ -40,13 +40,13 @@ kernel void layernorm_in_all_sg(const device ftype *in [[buffer(0)]], float mean; float sum = 0.0f; float square_sum = 0.0f; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { sum += in_data[i]; } sum = simd_sum(sum); mean = sum / cst.inside; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { float dis = (in_data[i] - mean); square_sum += dis * dis; @@ -79,17 +79,17 @@ kernel void layernorm_in_all_rms_sg(const device ftype *in [[buffer(0)]], auto out_data = out + gid.y * cst.inside; float square_sum = 0.0f; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { float dis = in_data[i]; square_sum += dis * dis; } - + square_sum = simd_sum(square_sum); float var = 1.0 / sqrt(square_sum / cst.inside + cst.eps); for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { - + float norm = var * ((float)in_data[i]); if(cst.has_gamma_beta) { out_data[i] = (ftype)(norm * gamma[i] + beta[i]); @@ -116,13 +116,13 @@ kernel void layernorm_x1_sg(const device ftype *in [[buffer(0)]], float mean; float sum = 0.0f; float square_sum = 0.0f; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { sum += in_data[i]; } sum = simd_sum(sum); mean = sum / cst.inside; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { float dis = (in_data[i] - mean); square_sum += dis * dis; @@ -131,7 +131,7 @@ kernel void layernorm_x1_sg(const device ftype *in [[buffer(0)]], if(tiisg == 0) { float var = 1.0 / sqrt(square_sum / cst.inside + cst.eps); - + float norm = var * ((float)in_data[gid.x] - mean); if(cst.has_gamma_beta) { out_data[gid.x] = (ftype)(norm * gamma[gid.x] + beta[gid.x]); @@ -158,7 +158,7 @@ kernel void layernorm_x4_sg(const device ftype4 *in [[buffer(0)]], float mean; float sum = 0.0f; float square_sum = 0.0f; - + for(int i = tiisg; i < cst.inside/4; i+=SIMD_GROUP_WIDTH) { sum += in_data[i].x; sum += in_data[i].y; @@ -167,7 +167,7 @@ kernel void layernorm_x4_sg(const device ftype4 *in [[buffer(0)]], } sum = simd_sum(sum); mean = sum / cst.inside; - + for(int i = tiisg; i < cst.inside/4; i+=SIMD_GROUP_WIDTH) { float dis = (in_data[i].x - mean); square_sum += dis * dis; @@ -182,7 +182,7 @@ kernel void layernorm_x4_sg(const device ftype4 *in [[buffer(0)]], if(tiisg == 0) { float var = 1.0 / sqrt(square_sum / cst.inside + cst.eps); - + float4 norm = var * ((float4)in_data[gid.x] - mean); if(cst.has_gamma_beta) { out_data[gid.x] = (ftype4)(norm * gamma[gid.x] + beta[gid.x]); @@ -208,17 +208,17 @@ kernel void layernorm_x1_rms_sg(const device ftype *in [[buffer(0)]], auto out_data = out + gid.y * cst.inside; float square_sum = 0.0f; - + for(int i = tiisg; i < cst.inside; i+=SIMD_GROUP_WIDTH) { float dis = in_data[i]; square_sum += dis * dis; } - + square_sum = simd_sum(square_sum); - + if(tiisg == 0) { float var = 1.0 / sqrt(square_sum / cst.inside + cst.eps); - + float norm = var * ((float)in_data[gid.x]); if(cst.has_gamma_beta) { out_data[gid.x] = (ftype)(norm * gamma[gid.x] + beta[gid.x]); @@ -254,10 +254,10 @@ kernel void layernorm_x4_rms_sg(const device ftype4 *in [[buffer(0)]], } square_sum_all += (square_sum[0] + square_sum[1] + square_sum[2] + square_sum[3]); square_sum_all = simd_sum(square_sum_all); - + if(tiisg == 0) { float var = 1.0 / sqrt(square_sum_all / cst.inside + cst.eps); - + float4 norm = var * ((float4)in_data[in_idx]); if(cst.has_gamma_beta) { out_data[in_idx] = (ftype4)(norm * gamma[in_idx] + beta[in_idx]); @@ -267,6 +267,118 @@ kernel void layernorm_x4_rms_sg(const device ftype4 *in [[buffer(0)]], } } +kernel void binary_layernorm_x4_sg(const device ftype4 *in0 [[buffer(0)]], + const device ftype4 *in1 [[buffer(1)]], + device ftype4 *out0 [[buffer(2)]], + device ftype4 *out1 [[buffer(3)]], + constant layernorm_constants& cst [[buffer(4)]], + const device float4 *gamma [[buffer(5)]], + const device float4 *beta [[buffer(6)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + if ((int)gid.y >= cst.outside) { + return; + } + int channelUnit = cst.inside / 4; + auto in0_data = in0 + gid.y * channelUnit; + auto in1_data = in1 + gid.y * channelUnit; + auto out0_data = out0 + gid.y * channelUnit; + auto out1_data = out1 + gid.y * channelUnit; + + float4 sum4 = 0.0f; + for(int c = sgitg * SIMD_GROUP_WIDTH + tiisg; c < channelUnit; c += 64) { + sum4 += float4(in0_data[c]) + float4(in1_data[c]); + } + sum4 = simd_sum(sum4); + + threadgroup float4 sg_sum[2]; + if(tiisg == 0) { + sg_sum[sgitg] = sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_sum4 = sg_sum[0] + sg_sum[1]; + float mean = (total_sum4.x + total_sum4.y + total_sum4.z + total_sum4.w) / cst.inside; + float4 mean4 = mean; + + float4 square_sum4 = 0.0f; + for(int c = sgitg * SIMD_GROUP_WIDTH + tiisg; c < channelUnit; c += 64) { + float4 data = float4(in0_data[c]) + float4(in1_data[c]); + float4 diff = data - mean4; + square_sum4 += diff * diff; + } + square_sum4 = simd_sum(square_sum4); + + threadgroup float4 sg_square_sum[2]; + if(tiisg == 0) { + sg_square_sum[sgitg] = square_sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_square_sum4 = sg_square_sum[0] + sg_square_sum[1]; + float square_sum = total_square_sum4.x + total_square_sum4.y + total_square_sum4.z + total_square_sum4.w; + float var = 1.0f / sqrt(square_sum / cst.inside + cst.eps); + float4 var4 = var; + + for(int c = sgitg * SIMD_GROUP_WIDTH + tiisg; c < channelUnit; c += 64) { + float4 data = float4(in0_data[c]) + float4(in1_data[c]); + out0_data[c] = (ftype4)data; + float4 norm = var4 * (data - mean4); + if(cst.has_gamma_beta) { + out1_data[c] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out1_data[c] = (ftype4)norm; + } + } +} + +kernel void binary_layernorm_x4_rms_sg(const device ftype4 *in0 [[buffer(0)]], + const device ftype4 *in1 [[buffer(1)]], + device ftype4 *out0 [[buffer(2)]], + device ftype4 *out1 [[buffer(3)]], + constant layernorm_constants& cst [[buffer(4)]], + const device float4 *gamma [[buffer(5)]], + const device float4 *beta [[buffer(6)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + if ((int)gid.y >= cst.outside) { + return; + } + int channelUnit = cst.inside / 4; + auto in0_data = in0 + gid.y * channelUnit; + auto in1_data = in1 + gid.y * channelUnit; + auto out0_data = out0 + gid.y * channelUnit; + auto out1_data = out1 + gid.y * channelUnit; + + float4 square_sum4 = 0.0f; + for(int c = sgitg * SIMD_GROUP_WIDTH + tiisg; c < channelUnit; c += 64) { + float4 data = float4(in0_data[c]) + float4(in1_data[c]); + square_sum4 += data * data; + } + square_sum4 = simd_sum(square_sum4); + + threadgroup float4 sg_square_sum[2]; + if(tiisg == 0) { + sg_square_sum[sgitg] = square_sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_square_sum4 = sg_square_sum[0] + sg_square_sum[1]; + float square_sum = total_square_sum4.x + total_square_sum4.y + total_square_sum4.z + total_square_sum4.w; + float var = 1.0f / sqrt(square_sum / cst.inside + cst.eps); + float4 var4 = var; + + for(int c = sgitg * SIMD_GROUP_WIDTH + tiisg; c < channelUnit; c += 64) { + float4 data = float4(in0_data[c]) + float4(in1_data[c]); + out0_data[c] = (ftype4)data; + float4 norm = var4 * data; + if(cst.has_gamma_beta) { + out1_data[c] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out1_data[c] = (ftype4)norm; + } + } +} + kernel void layernorm_x16_rms_sg(const device ftype4 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], constant layernorm_constants& cst [[buffer(2)]], @@ -325,7 +437,221 @@ kernel void layernorm_x16_rms_sg(const device ftype4 *in [[buffer(0)]], } } +kernel void layernorm_c4_sg(const device ftype4 *in [[buffer(0)]], + device ftype4 *out [[buffer(1)]], + constant layernorm_constants& cst [[buffer(2)]], + const device float4 *gamma [[buffer(3)]], + const device float4 *beta [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + int batch = cst.outside; + int channelUnit = cst.inside / 4; + + if ((int)gid.y >= batch) { + return; + } + + float mean1 = 0.0f; + float4 sum4 = 0.0f; + + for(int c = tiisg; c < channelUnit; c += SIMD_GROUP_WIDTH) { + int idx = c * batch + gid.y; + sum4 += float4(in[idx]); + } + + sum4 = simd_sum(sum4); + float sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; + mean1 = sum / (channelUnit * 4); + float4 mean4 = mean1; + + float4 square_sum4 = 0.0f; + for(int c = tiisg; c < channelUnit; c += SIMD_GROUP_WIDTH) { + int idx = c * batch + gid.y; + float4 diff = float4(in[idx]) - mean4; + square_sum4 += diff * diff; + } + + square_sum4 = simd_sum(square_sum4); + float square_sum = square_sum4[0] + square_sum4[1] + square_sum4[2] + square_sum4[3]; + float var = 1.0f / sqrt(square_sum / (channelUnit * 4) + cst.eps); + float4 var4 = var; + + for(int c = tiisg; c < channelUnit; c += SIMD_GROUP_WIDTH) { + int idx = c * batch + gid.y; + float4 norm = var4 * (float4(in[idx]) - mean4); + if(cst.has_gamma_beta) { + out[idx] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out[idx] = (ftype4)(norm); + } + } +} + +kernel void binary_layernorm_c4_sg(const device ftype4 *in0 [[buffer(0)]], + const device ftype4 *in1 [[buffer(1)]], + device ftype4 *out0 [[buffer(2)]], + device ftype4 *out1 [[buffer(3)]], + constant layernorm_constants& cst [[buffer(4)]], + const device float4 *gamma [[buffer(5)]], + const device float4 *beta [[buffer(6)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + int batch = cst.outside; + int channelUnit = cst.inside / 4; + + if ((int)gid.y >= batch) { + return; + } + + float mean1 = 0.0f; + float4 sum4 = 0.0f; + + for(int c = sgitg * 32 + tiisg; c < channelUnit; c += 64) { + int idx = c * batch + gid.y; + float4 data = float4(in0[idx]) + float4(in1[idx]); + sum4 += data; + } + + sum4 = simd_sum(sum4); + + // cross simd group communication for threadgroup size 64 + threadgroup float4 sg_sum[2]; + if(tiisg == 0) { + sg_sum[sgitg] = sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_sum4 = sg_sum[0] + sg_sum[1]; + + float sum = total_sum4[0] + total_sum4[1] + total_sum4[2] + total_sum4[3]; + mean1 = sum / (channelUnit * 4); + float4 mean4 = mean1; + + float4 square_sum4 = 0.0f; + for(int c = sgitg * 32 + tiisg; c < channelUnit; c += 64) { + int idx = c * batch + gid.y; + float4 data = float4(in0[idx]) + float4(in1[idx]); + float4 diff = data - mean4; + square_sum4 += diff * diff; + } + + square_sum4 = simd_sum(square_sum4); + + threadgroup float4 sg_square_sum[2]; + if(tiisg == 0) { + sg_square_sum[sgitg] = square_sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_square_sum4 = sg_square_sum[0] + sg_square_sum[1]; + + float square_sum = total_square_sum4[0] + total_square_sum4[1] + total_square_sum4[2] + total_square_sum4[3]; + float var = 1.0f / sqrt(square_sum / (channelUnit * 4) + cst.eps); + float4 var4 = var; + + for(int c = sgitg * 32 + tiisg; c < channelUnit; c += 64) { + int idx = c * batch + gid.y; + float4 my_data = float4(in0[idx]) + float4(in1[idx]); + out0[idx] = (ftype4)my_data; + float4 norm = var4 * (my_data - mean4); + if(cst.has_gamma_beta) { + out1[idx] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out1[idx] = (ftype4)(norm); + } + } +} + +kernel void layernorm_c4_rms_sg(const device ftype4 *in [[buffer(0)]], + device ftype4 *out [[buffer(1)]], + constant layernorm_constants& cst [[buffer(2)]], + const device float4 *gamma [[buffer(3)]], + const device float4 *beta [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + int batch = cst.outside; + int channelUnit = cst.inside / 4; + + if ((int)gid.y >= batch) { + return; + } + + float4 square_sum4 = 0.0f; + + for(int c = tiisg; c < channelUnit; c += SIMD_GROUP_WIDTH) { + int idx = c * batch + gid.y; + float4 data = float4(in[idx]); + square_sum4 += data * data; + } + + square_sum4 = simd_sum(square_sum4); + float square_sum = square_sum4[0] + square_sum4[1] + square_sum4[2] + square_sum4[3]; + float var = 1.0f / sqrt(square_sum / (channelUnit * 4) + cst.eps); + float4 var4 = var; + + for(int c = tiisg; c < channelUnit; c += SIMD_GROUP_WIDTH) { + int idx = c * batch + gid.y; + float4 norm = var4 * float4(in[idx]); + if(cst.has_gamma_beta) { + out[idx] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out[idx] = (ftype4)(norm); + } + } +} + +kernel void binary_layernorm_c4_rms_sg(const device ftype4 *in0 [[buffer(0)]], + const device ftype4 *in1 [[buffer(1)]], + device ftype4 *out0 [[buffer(2)]], + device ftype4 *out1 [[buffer(3)]], + constant layernorm_constants& cst [[buffer(4)]], + const device float4 *gamma [[buffer(5)]], + const device float4 *beta [[buffer(6)]], + uint3 gid [[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + int batch = cst.outside; + int channelUnit = cst.inside / 4; + + if ((int)gid.y >= batch) { + return; + } + + float4 square_sum4 = 0.0f; + + for(int c = sgitg * 32 + tiisg; c < channelUnit; c += 64) { + int idx = c * batch + gid.y; + float4 data = float4(in0[idx]) + float4(in1[idx]); + square_sum4 += data * data; + } + + square_sum4 = simd_sum(square_sum4); + + threadgroup float4 sg_square_sum[2]; + if(tiisg == 0) { + sg_square_sum[sgitg] = square_sum4; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float4 total_square_sum4 = sg_square_sum[0] + sg_square_sum[1]; + + float square_sum = total_square_sum4[0] + total_square_sum4[1] + total_square_sum4[2] + total_square_sum4[3]; + float var = 1.0f / sqrt(square_sum / (channelUnit * 4) + cst.eps); + float4 var4 = var; + + for(int c = sgitg * 32 + tiisg; c < channelUnit; c += 64) { + int idx = c * batch + gid.y; + float4 my_data = float4(in0[idx]) + float4(in1[idx]); + out0[idx] = (ftype4)my_data; + float4 norm = var4 * my_data; + if(cst.has_gamma_beta) { + out1[idx] = (ftype4)(norm * gamma[c] + beta[c]); + } else { + out1[idx] = (ftype4)(norm); + } + } +} + )metal"; #endif - diff --git a/source/backend/metal/MetalAttention.hpp b/source/backend/metal/MetalAttention.hpp index 01a2f05324..8c42f5abb0 100644 --- a/source/backend/metal/MetalAttention.hpp +++ b/source/backend/metal/MetalAttention.hpp @@ -21,7 +21,7 @@ namespace MNN { class AttentionBufExecution : public MetalExecution { public: - AttentionBufExecution(Backend *backend, bool kv_cache); + AttentionBufExecution(Backend *backend, bool outputC4, float attnScale, std::shared_ptr kvQuantParam); virtual ~AttentionBufExecution() = default; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; @@ -30,8 +30,8 @@ class AttentionBufExecution : public MetalExecution { if (nullptr == dst) { return true; } - auto exe = new AttentionBufExecution(bn, mKVCache); - if (bn->getMetaPtr() == backend()->getMetaPtr()) { + auto exe = new AttentionBufExecution(bn, mOutputC4, mAttnScale, mKVQuantParameter); + if (bn->getMetaPtr() == mMeta && mMeta != nullptr) { exe->mKVCacheManager = mKVCacheManager; } *dst = exe; @@ -42,14 +42,12 @@ class AttentionBufExecution : public MetalExecution { void _init(); void compilerShader(const std::vector &inputs); void handleKVAllocMemory(); - bool mKVCache; std::shared_ptr mKVCacheManager = nullptr; + float mAttnScale = 0.0f; float mScale; + bool mOutputC4 = false; bool mShortSeq = false; - bool mUseSimpleAttention = false; - bool mUseFlashAttention = false; - bool mUseFlashAttentionFused = false; - std::shared_ptr mTempQK, mTempSoftMax, mTempOutput; + std::shared_ptr mTempQK, mTempSoftMax; int mNumHead = 0, mHeadDim = 0, mValueH = 0, mKvNumHead = 0; int mSeqLen; // for simd/tensor maxtrix load alignment @@ -59,16 +57,12 @@ class AttentionBufExecution : public MetalExecution { id mKernel_qk = nil; id mKernel_qkv = nil; id mKernel_copy = nil; + id mKernel_qk_softmax = nil; id mKernelPrefill_qk = nil; id mKernelPrefill_qkv = nil; id mParamQKV; id mParamSoftmax; id mParamCopy; - - id mKernel_flash_softmax = nil; - id mKernel_flash_matmul_qkv = nil; - id mKernel_flash_scale = nil; - id mKernel_flash_fused = nil; private: KVMeta* mMeta; @@ -78,21 +72,23 @@ class AttentionBufExecution : public MetalExecution { bool mSftmSimdReduce = false; bool mQkvSimdReduce = false; bool mQkvSimdMatrix = false; + bool mDecodeQkSoftmax = false; + bool mCopySimdReduce = false; private: bool mHasMask = false; bool mIsAddMask = false; - bool mCausalMaskScalar = false; // scalar mask input means causal mask - int mBatch, mKvSeqLen, mKvMaxLen; + int mBatch, mKvSeqLen, mKvMaxLen, mCurrentKvLen = 0; int mQseqSplitNum = 1; std::shared_ptr mTempK, mTempV; - std::shared_ptr mRunningStats; // [Batch, Head, SeqLen, 2] - std::shared_ptr mCorrectionScale; // [Batch, Head, SeqLen] bool mKvInDisk; - + + // KV static quantization (only V is quantized on Metal) + std::shared_ptr mKVQuantParameter = nullptr; + bool mQuantValue = false; // whether V cache is stored as int8 + bool mQuantKey = false; // whether K cache is stored as int8 }; } // namespace MNN #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ #endif/* MNN_METAL_ENABLED */ #endif/* MetalAttention_hpp */ - diff --git a/source/backend/metal/MetalAttention.mm b/source/backend/metal/MetalAttention.mm index 0a7eff6962..c78be1dd8e 100644 --- a/source/backend/metal/MetalAttention.mm +++ b/source/backend/metal/MetalAttention.mm @@ -7,15 +7,13 @@ // #import "MetalCast.hpp" -#import "MetalAttention.hpp" #import "MNNMetalContext.h" #import "MetalAttentionShader.hpp" -#include "backend/cpu/compute/CommonOptFunction.h" -#include "core/TensorUtils.hpp" +#import "MetalSoftmaxShader.hpp" +#import "MetalAttention.hpp" #if MNN_METAL_ENABLED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE - namespace MNN { struct Param { @@ -29,24 +27,40 @@ int max_kv_len; int batch; int kv_align_len; + int mask_batch; + int mask_head_num; + int mask_q_len; + int mask_k_len; + float v_scale; + float k_scale; +}; + +struct CopyParam { + int head_count; + int kv_seq_len; + int max_kv_len; + int dst_k_offset; + int dst_v_offset; + int batch; + float v_scale; + float k_scale; }; -AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce) - : MetalExecution(backend) , mKVCache(kv_cahce) { + +AttentionBufExecution::AttentionBufExecution(Backend* backend, bool outputC4, float attnScale, + std::shared_ptr kvQuantParam) + : MetalExecution(backend), mOutputC4(outputC4), mAttnScale(attnScale), mKVQuantParameter(kvQuantParam) { _init(); } void AttentionBufExecution::_init() { - auto mtbn = static_cast(backend()); - auto context = (__bridge MNNMetalContext *)mtbn->context(); + auto mtbn = static_cast(backend()); + auto context = (__bridge MNNMetalContext*)mtbn->context(); mMeta = (KVMeta*)(mtbn->getMetaPtr()); mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly]; mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; - mParamCopy = [context newDeviceBuffer:6 * sizeof(int) access:CPUWriteOnly]; + mParamCopy = [context newDeviceBuffer:sizeof(CopyParam) access:CPUWriteOnly]; mTempQK.reset(Tensor::createDevice({0, 0})); mTempSoftMax.reset(Tensor::createDevice({0, 0})); - mRunningStats.reset(Tensor::createDevice({0, 0, 0, 0})); - mCorrectionScale.reset(Tensor::createDevice({0, 0, 0})); - mTempOutput.reset(Tensor::createDevice({0, 0})); MNN::MetalKVCacheManager::KVCacheConfig kvconfig; kvconfig.mKVCacheDir = mtbn->getRuntime()->hint().kvcacheDirPath; @@ -56,17 +70,18 @@ mKVCacheManager.reset(new MetalKVCacheManager(backend(), kvconfig)); mKvInDisk = !kvconfig.mKVCacheDir.empty(); + mKVCacheManager->setKVQuantParameter(mKVQuantParameter); } -void AttentionBufExecution::compilerShader(const std::vector &inputs) { - auto mtbn = static_cast(backend()); +void AttentionBufExecution::compilerShader(const std::vector& inputs) { + auto mtbn = static_cast(backend()); auto rt = (MetalRuntime*)mtbn->runtime(); - auto context = (__bridge MNNMetalContext *)mtbn->context(); - + auto context = (__bridge MNNMetalContext*)mtbn->context(); + auto seq_len = inputs[0]->length(1); int group_size = inputs[0]->length(2) / inputs[1]->length(2); std::string group_str = std::to_string(group_size); - + // Init Kernel std::string ftype = "float"; std::string ftype4 = "float4"; @@ -74,100 +89,108 @@ ftype = "half"; ftype4 = "half4"; } - std::vector qkKeys = { - {"matmul_qk_div_mask", ftype, group_str} - }; - if(mHeadDim % 4 != 0) { - qkKeys.emplace_back("HEAD_DIM_UNALIGNED_4"); - } - - std::vector qkvKeys = { - {"matmul_qkv", ftype, group_str} - }; - if(mQkvSimdReduce) { + const bool staticQuantK = mQuantKey && mKVQuantParameter != nullptr && mKVQuantParameter->kScale != 0.0f; + const bool staticQuantV = mQuantValue && mKVQuantParameter != nullptr && mKVQuantParameter->vScale != 0.0f; + const bool dynamicQuantK = mQuantKey && !staticQuantK; + const bool dynamicQuantV = mQuantValue && !staticQuantV; + std::vector qkKeys = {{"matmul_qk_div_mask", ftype, group_str}}; + + std::vector qkvKeys = {{"matmul_qkv", ftype, group_str}}; + if (mQkvSimdReduce) { qkvKeys.emplace_back("SIMD_GROUP_REDUCE"); } - std::vector qkPrefillKeys = { - {"matmul_qk_div_mask", ftype, group_str, "FOR_PREFILL"} - }; - if (mCausalMaskScalar) { - qkPrefillKeys.emplace_back("CAUSAL_MASK"); - if (seq_len > 1) { - qkKeys.emplace_back("CAUSAL_MASK"); - } - } - if(mHasMask) { + std::vector qkPrefillKeys = {{"matmul_qk_div_mask", ftype, group_str, "FOR_PREFILL"}}; + if (mHasMask) { if (mIsAddMask) { qkPrefillKeys.emplace_back("ADD_MASK"); - if(seq_len > 1) { + if (seq_len > 1) { qkKeys.emplace_back("ADD_MASK"); } } else { qkPrefillKeys.emplace_back("SET_MASK"); - if(seq_len > 1) { + if (seq_len > 1) { qkKeys.emplace_back("SET_MASK"); } } + } else { + qkPrefillKeys.emplace_back("DEFAULT_MASK"); + if (seq_len > 1) { + qkKeys.emplace_back("DEFAULT_MASK"); + } } - if(mQkSimdMatrix) { + if (mQkSimdMatrix) { qkPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } - std::vector qkvPrefillKeys = { - {"matmul_qkv", ftype, group_str, "FOR_PREFILL"} - }; - if(mQkvSimdMatrix) { + std::vector qkvPrefillKeys = {{"matmul_qkv", ftype, group_str, "FOR_PREFILL"}}; + if (mQkvSimdMatrix) { qkvPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } if (mtbn->useFp16InsteadFp32()) { qkPrefillKeys.emplace_back("MNN_METAL_FLOAT16_STORAGE"); qkvPrefillKeys.emplace_back("MNN_METAL_FLOAT16_STORAGE"); } - std::vector copyPastKeys = { - {"pastkv_copy", ftype, group_str} - }; - std::vector shaders = { - "decode_qk", - "decode_qkv", - "prefill_qk", - "prefill_qkv", - "copy" - }; - if(mQkTensorMatrix) { + if (mQuantKey) { + qkKeys.emplace_back("QUANT_K"); + qkPrefillKeys.emplace_back("QUANT_K"); + if (dynamicQuantK) { + qkKeys.emplace_back("DYNAMIC_QUANT_K"); + qkPrefillKeys.emplace_back("DYNAMIC_QUANT_K"); + } + } + if (mQuantValue) { + qkvKeys.emplace_back("QUANT_V"); + qkvPrefillKeys.emplace_back("QUANT_V"); + if (dynamicQuantV) { + qkvKeys.emplace_back("DYNAMIC_QUANT_V"); + qkvPrefillKeys.emplace_back("DYNAMIC_QUANT_V"); + } + } + std::vector copyPastKeys = {{"pastkv_copy", ftype, group_str}}; + if (mQuantValue) { + copyPastKeys.emplace_back("KV_QUANT_V"); + } + if (mQuantKey) { + copyPastKeys.emplace_back("KV_QUANT_K"); + } + if (dynamicQuantK || dynamicQuantV) { + copyPastKeys.emplace_back("DYNAMIC_QUANT"); + if (mCopySimdReduce) { + copyPastKeys.emplace_back("SIMD_GROUP_REDUCE"); + } + } + std::vector shaders = {"decode_qk", "decode_qkv", "prefill_qk", "prefill_qkv", "copy"}; + if (mQkTensorMatrix) { shaders[2] = "prefill_qk_tensor"; shaders[3] = "prefill_qkv_tensor"; qkPrefillKeys.emplace_back("USE_METAL_TENSOR_OPS"); qkvPrefillKeys.emplace_back("USE_METAL_TENSOR_OPS"); } - std::vector> keys = { - qkKeys, - qkvKeys, - qkPrefillKeys, - qkvPrefillKeys, - copyPastKeys - }; - std::vector sources = { - gMatMulDivMask, - gMatMulQKV, - gMatMulDivMask, - gMatMulQKV, - gCopyPastKV - }; + if (mOutputC4) { + qkvKeys.emplace_back("ATTENTION_C4"); + qkvPrefillKeys.emplace_back("ATTENTION_C4"); + if (mQkvSimdReduce) { + qkvKeys.emplace_back("ATTENTION_C4_VEC2"); + shaders[1] = "decode_qkv_c2"; + } + } + std::vector> keys = {qkKeys, qkvKeys, qkPrefillKeys, qkvPrefillKeys, copyPastKeys}; + std::vector sources = {gMatMulDivMask, gMatMulQKV, gMatMulDivMask, gMatMulQKV, gCopyPastKV}; std::vector> pipelines(keys.size()); - for (int i=0; ifindPipeline(keys[i]); if (nil == pipeline) { // Rebuild Pipeline - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + MTLCompileOptions* option = [[MTLCompileOptions alloc] init]; auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; [dic setValue:@(keys[i][1].c_str()) forKey:@"ftype"]; [dic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; [dic setValue:@(keys[i][2].c_str()) forKey:@"GROUP_SIZE"]; - for (int j=3; jmakeComputePipelineWithSourceOption(sources[i], shaders[i].c_str(), option); rt->insertPipeline(keys[i], pipeline); } @@ -184,14 +207,13 @@ MNN_ASSERT(nil != mKernelPrefill_qkv); MNN_ASSERT(nil != mKernel_copy); - if(mSftmSimdReduce) { - - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - option.preprocessorMacros = @{ - @"ftype" : @(ftype.c_str()), - @"ftype4" : @(ftype4.c_str()), - }; + MTLCompileOptions* option = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + option.preprocessorMacros = @{ + @"ftype" : @(ftype.c_str()), + @"ftype4" : @(ftype4.c_str()), + }; + if (mSftmSimdReduce) { std::vector keys = {"softmax_sg_reduce", ftype}; keys.emplace_back("softmax_plane_sg"); auto pipeline = rt->findPipeline(keys); @@ -201,318 +223,149 @@ } mKernel_softmax = pipeline; } else { - mKernel_softmax = [context pipelineWithName:@"softmax_plane" fp16:mtbn->useFp16InsteadFp32()]; - } - - if(mUseFlashAttention) - { - std::vector> flashKeys = { - {"flash_softmax", ftype}, - {"flash_matmul_qkv", ftype}, - {"flash_scale", ftype} - }; - - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto basicDic = [NSMutableDictionary dictionaryWithCapacity:0]; - [basicDic setValue:@(ftype.c_str()) forKey:@"ftype"]; - - { - NSMutableDictionary *dic = [basicDic mutableCopy]; - if(mSftmSimdReduce) { - [dic setValue:@"1" forKey:@"SIMD_GROUP_REDUCE"]; - flashKeys[0].emplace_back("SIMD_GROUP_REDUCE"); - } - option.preprocessorMacros = dic; - - - auto pipeline = rt->findPipeline(flashKeys[0]); - if (nil == pipeline) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gFlashSoftmax, "flash_softmax", option); - rt->insertPipeline(flashKeys[0], pipeline); - } - - mKernel_flash_softmax = pipeline;//mtbn->makeComputePipelineWithSourceOption(gFlashSoftmax, "flash_softmax", option); - } - { - NSMutableDictionary *dic = [basicDic mutableCopy]; - if(mQkvSimdMatrix) { - [dic setValue:@"1" forKey:@"SIMD_GROUP_MATRIX"]; - flashKeys[1].emplace_back("SIMD_GROUP_MATRIX"); - } - if(mQkvSimdReduce) { - [dic setValue:@"1" forKey:@"SIMD_GROUP_REDUCE"]; - flashKeys[1].emplace_back("SIMD_GROUP_REDUCE"); - } - if (mtbn->useFp16InsteadFp32()) { - [dic setValue:@"1" forKey:@"MNN_METAL_FLOAT16_STORAGE"]; - flashKeys[1].emplace_back("MNN_METAL_FLOAT16_STORAGE"); - } - - option.preprocessorMacros = dic; - - auto pipeline = rt->findPipeline(flashKeys[1]); - if (nil == pipeline) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gFlashMatMulQKV, "flash_matmul_qkv", option); - rt->insertPipeline(flashKeys[1], pipeline); - } - - mKernel_flash_matmul_qkv = pipeline; - // mKernel_flash_matmul_qkv = mtbn->makeComputePipelineWithSourceOption(gFlashMatMulQKV, "flash_matmul_qkv", option); - } - { - NSMutableDictionary *dic = [basicDic mutableCopy]; - option.preprocessorMacros = dic; - - auto pipeline = rt->findPipeline(flashKeys[2]); - if (nil == pipeline) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gFlashScale, "flash_scale", option); - rt->insertPipeline(flashKeys[2], pipeline); - } - - mKernel_flash_scale = pipeline; - mKernel_flash_scale = pipeline; - // mKernel_flash_scale = mtbn->makeComputePipelineWithSourceOption(gFlashScale, "flash_scale", option); + std::vector keys = {"softmax_sg_reduce", ftype}; + keys.emplace_back("softmax_plane"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = mtbn->makeComputePipelineWithSourceOption(gSoftmaxSgReduce, keys.back().c_str(), option); + rt->insertPipeline(keys, pipeline); } + mKernel_softmax = pipeline; } - if(mUseFlashAttentionFused) { - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto basicDic = [NSMutableDictionary dictionaryWithCapacity:0]; - [basicDic setValue:@(ftype.c_str()) forKey:@"ftype"]; - [basicDic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; - - std::vector keys = {"flash_attention_fused", ftype}; - { - // Fused Attention (Naive/Simd) - NSMutableDictionary *dic = [basicDic mutableCopy]; -// if(mSftmSimdReduce) { -// [dic setValue:@"1" forKey:@"SIMD_GROUP_REDUCE"]; -// keys.emplace_back("SIMD_GROUP_REDUCE"); -// } - if(mQkvSimdMatrix) { - [dic setValue:@"1" forKey:@"SIMD_GROUP_MATRIX"]; - keys.emplace_back("SIMD_GROUP_MATRIX"); + if (mDecodeQkSoftmax) { + std::string head_dim_str = std::to_string(mHeadDim); + std::vector keys = {"decode_qk_softmax", ftype, group_str, "HEAD_DIM_" + head_dim_str}; + if (mQuantKey) { + keys.emplace_back("QUANT_K"); + if (dynamicQuantK) { + keys.emplace_back("DYNAMIC_QUANT_K"); } - if (mtbn->useFp16InsteadFp32()) { - [dic setValue:@"1" forKey:@"MNN_METAL_FLOAT16_STORAGE"]; - keys.emplace_back("MNN_METAL_FLOAT16_STORAGE"); - } - if(mHasMask) { - if(mIsAddMask) { - [dic setValue:@"1" forKey:@"ADD_MASK"]; - } else { - [dic setValue:@"1" forKey:@"SET_MASK"]; - } - } else if(mCausalMaskScalar) { - // Use causal mask when scalar mask is provided - [dic setValue:@"1" forKey:@"CAUSAL_MASK"]; - keys.emplace_back("CAUSAL_MASK"); + } + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + MTLCompileOptions* option = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + [dic setValue:@(ftype.c_str()) forKey:@"ftype"]; + [dic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; + [dic setValue:@(group_str.c_str()) forKey:@"GROUP_SIZE"]; + [dic setValue:@(head_dim_str.c_str()) forKey:@"HEAD_DIM"]; + for (int j = 4; j < keys.size(); ++j) { + [dic setValue:@"1" forKey:@(keys[j].c_str())]; } option.preprocessorMacros = dic; - - auto pipeline = rt->findPipeline(keys); - if (nil == pipeline) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gFlashAttentionFused, "flash_attention_fused", option); - rt->insertPipeline(keys, pipeline); - } - mKernel_flash_fused = pipeline; + pipeline = mtbn->makeComputePipelineWithSourceOption(gDecodeQkSoftmax, "decode_qk_softmax", option); + rt->insertPipeline(keys, pipeline); } + mKernel_qk_softmax = pipeline; + MNN_ASSERT(nil != mKernel_qk_softmax); } } void AttentionBufExecution::handleKVAllocMemory() { - if(mKVCache) { - mKVCacheManager->setPastLength(mMeta != nullptr ? mMeta->previous : 0); + if (nullptr == mMeta || mMeta->previous == mMeta->remove) { + mKVCacheManager->onClear(); + mKVCacheManager->onAlloc(mMeta, mCurrentKvLen); + } else { + mKVCacheManager->onRealloc(mMeta); + } - if (nullptr == mMeta || mMeta->previous == mMeta->remove) { - mKVCacheManager->onClear(); - mKVCacheManager->onAlloc(mMeta, mSeqLen); - } else { - MNN_ASSERT(mMeta->previous == mKVCacheManager->kvLength()); - mKVCacheManager->onRealloc(mMeta); - } - - mKvSeqLen = mKVCacheManager->kvLength() + mSeqLen; - mKvMaxLen = mKVCacheManager->maxLength(); - - float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead; - // elementSize larger than 32M - mQseqSplitNum = 1; - - if(mUseFlashAttentionFused) { - // no need temp memory - return; - } + mKvSeqLen = mKVCacheManager->kvLength() + mCurrentKvLen; + mKvMaxLen = mKVCacheManager->maxLength(); + float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead; + // elementSize larger than 32M + mQseqSplitNum = 1; - int qSeqLenPiece = UP_DIV(mSeqLen, mQseqSplitNum); - // temp tensor alloc memory - bool needMalloc = mTempQK->length(0) != mBatch * mNumHead; - if (mTempQK->length(1) != qSeqLenPiece * mKvMaxLen) { - needMalloc = true; - } + int qSeqLenPiece = UP_DIV(mSeqLen, mQseqSplitNum); + // temp tensor alloc memory + bool needMalloc = mTempQK->length(0) != mBatch * mNumHead; + if (mTempQK->length(1) != qSeqLenPiece * mKvMaxLen) { + needMalloc = true; + } - if (needMalloc) { - mTempQK->setLength(0, mBatch * mNumHead); - mTempQK->setLength(1, qSeqLenPiece * mKvMaxLen); - mTempSoftMax->setLength(0, mBatch * mNumHead); - mTempSoftMax->setLength(1, qSeqLenPiece * mKvMaxLen); - - if (mUseFlashAttention) { - // Flash Attention - int blockSize = MNN_FLASH_ATTENTION_BLOCK_SIZE; - mTempQK->setLength(1, mSeqLen * blockSize); - mTempSoftMax->setLength(1, mSeqLen * blockSize); - - mRunningStats->setLength(0, mBatch); - mRunningStats->setLength(1, mNumHead); - mRunningStats->setLength(2, mSeqLen); - mRunningStats->setLength(3, 2 * 4/*sizeof(float)*/); - - mCorrectionScale->setLength(0, mBatch); - mCorrectionScale->setLength(1, mNumHead); - mCorrectionScale->setLength(2, mSeqLen * 4/*sizeof(float)*/); - - mTempOutput->setLength(0, mSeqLen * mBatch); - mTempOutput->setLength(1, mNumHead * mHeadDim * 4/*sizeof(float)*/); - } - - auto res = backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC) && backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); - if (mUseFlashAttention) { - res = res && backend()->onAcquireBuffer(mRunningStats.get(), Backend::STATIC); - res = res && backend()->onAcquireBuffer(mCorrectionScale.get(), Backend::STATIC); - res = res && backend()->onAcquireBuffer(mTempOutput.get(), Backend::STATIC); - } - if (!res) { - MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal %d\n", res); - return; - } - } + if (needMalloc) { + mTempQK->setLength(0, mBatch * mNumHead); + mTempQK->setLength(1, qSeqLenPiece * mKvMaxLen); + mTempSoftMax->setLength(0, mBatch * mNumHead); + mTempSoftMax->setLength(1, qSeqLenPiece * mKvMaxLen); } + + constexpr auto allocType = Backend::DYNAMIC_IN_EXECUTION; + auto res = backend()->onAcquireBuffer(mTempQK.get(), allocType) && + backend()->onAcquireBuffer(mTempSoftMax.get(), allocType); + if (!res) { + MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal %d\n", res); + return; + } + backend()->onReleaseBuffer(mTempQK.get(), allocType); + backend()->onReleaseBuffer(mTempSoftMax.get(), allocType); } -ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, const std::vector &outputs) { - mHasMask = inputs.size() > 3; - mCausalMaskScalar = false; + +ErrorCode AttentionBufExecution::onResize(const std::vector& inputs, const std::vector& outputs) { + mHasMask = inputs.size() > 3 && inputs[3]->dimensions() >= 2; if (mHasMask) { - // In transformer graphs / unit tests, a scalar mask input is a placeholder that means: - // apply causal(lower-triangular) mask inside Attention, instead of providing an explicit mask matrix. - // See `AttentionTest::generateMask()` for the expected rule: - // keep if (j - i) <= (kv_len - q_len), else mask to -inf. - if (inputs[3]->elementSize() <= 1) { - mCausalMaskScalar = true; - mHasMask = false; // don't bind mask buffer; shader will generate causal mask - } else { - mIsAddMask = (inputs[3]->getType() == halide_type_of()); - } + mIsAddMask = (inputs[3]->getType() == halide_type_of()); } auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; - auto mtbn = static_cast(backend()); - auto context = (__bridge MNNMetalContext *)mtbn->context(); + auto mtbn = static_cast(backend()); + auto context = (__bridge MNNMetalContext*)mtbn->context(); auto shape = query->shape(); mBatch = shape[0]; mSeqLen = shape[1]; mNumHead = shape[2]; mHeadDim = shape[3]; - mScale = 1.0 / sqrt(mHeadDim); + mScale = (mAttnScale == 0.0f) ? (1.0f / sqrt(mHeadDim)) : mAttnScale; // TODO : define short_seq more accurately mShortSeq = mSeqLen < 16; - - int attentionOption = static_cast(backend())->getRuntime()->hint().attentionOption; // hardware resource limit - mUseFlashAttentionFused = !mShortSeq && (attentionOption / 8 == 2) && mHeadDim <= 128; - mUseFlashAttention = !mShortSeq && (attentionOption / 8 >= 1) && !mUseFlashAttentionFused; - - mUseSimpleAttention = !mUseFlashAttentionFused && !mUseFlashAttention; // Check Env mKvNumHead = key->shape()[2]; - mKvSeqLen = key->shape()[1]; + mCurrentKvLen = key->shape()[1]; + mKvSeqLen = mCurrentKvLen; // Align to mKvAlignNum, for simd/tensor matrix load mKvMaxLen = ROUND_UP(mKvSeqLen, mKvAlignNum); - - if(mKVCache) { - mKVCacheManager->onResize(mKvNumHead, mHeadDim); - return NO_ERROR; - } - - float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead; - // elementSize larger than 32M - mQseqSplitNum = 1; - - // no kv_cache memory, should create temp q/k memory - mTempK.reset(Tensor::createDevice({mKvMaxLen * mHeadDim * mBatch * mKvNumHead})); - mTempV.reset(Tensor::createDevice({mKvMaxLen * mHeadDim * mBatch * mKvNumHead})); - - backend()->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC); - - if (mUseSimpleAttention) { - mTempQK.reset(Tensor::createDevice({mKvMaxLen * UP_DIV(mSeqLen, mQseqSplitNum) * mBatch * mNumHead})); - mTempSoftMax.reset(Tensor::createDevice({mKvMaxLen * UP_DIV(mSeqLen, mQseqSplitNum) * mBatch * mNumHead})); - - backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); - } else if(mUseFlashAttention){ - int blockSize = MNN_FLASH_ATTENTION_BLOCK_SIZE; - mTempQK.reset(Tensor::createDevice({blockSize * mSeqLen * mBatch * mNumHead})); - mTempSoftMax.reset(Tensor::createDevice({blockSize * mSeqLen * mBatch * mNumHead})); - mRunningStats.reset(Tensor::createDevice({(int)mBatch * mNumHead * mSeqLen * 2 * 4/*sizeof(float)*/})); - mCorrectionScale.reset(Tensor::createDevice({mBatch * mNumHead * mSeqLen * 4/*sizeof(float)*/})); - mTempOutput.reset(Tensor::createDevice({mBatch * mNumHead * mSeqLen * mHeadDim * 4/*sizeof(float)*/})); - - backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mRunningStats.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mCorrectionScale.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - - } - - // release buffer - backend()->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC); - if (mUseSimpleAttention || mUseFlashAttention) { - backend()->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); - } - if (mUseFlashAttention) { - backend()->onReleaseBuffer(mRunningStats.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mCorrectionScale.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); - } + // Enable static KV quantization only when kv-cache is in memory and mhq_quant provides valid scale + int attentionOption = static_cast(backend())->getRuntime()->hint().attentionOption; + bool dynamicQuantK = (attentionOption % 8 >= 1); + bool dynamicQuantV = (attentionOption % 8 > 1); + + mQuantValue = !mKvInDisk && ((mKVQuantParameter != nullptr && mKVQuantParameter->vScale != 0.0f) || dynamicQuantV); + mQuantKey = !mKvInDisk && ((mKVQuantParameter != nullptr && mKVQuantParameter->kScale != 0.0f) || dynamicQuantK); + mKVCacheManager->setKVQuantParameter(mKVQuantParameter); + mKVCacheManager->setAttenQuantKeyValue(mQuantKey, mQuantValue); + mKVCacheManager->onResize(mKvNumHead, mHeadDim); return NO_ERROR; } -void AttentionBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { +void AttentionBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs, + id encoder) { auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; - auto mtbn = static_cast(backend()); - auto context = (__bridge MNNMetalContext *)mtbn->context(); + auto mtbn = static_cast(backend()); + auto context = (__bridge MNNMetalContext*)mtbn->context(); auto rt = (MetalRuntime*)mtbn->runtime(); int group_size = mNumHead / mKvNumHead; - + // temp memory alloc, handle variable set Tensor* tempTensorK; Tensor* tempTensorV; handleKVAllocMemory(); - id tempBufferK; id tempBufferV; - if(mKvInDisk) { + if (mKvInDisk) { tempBufferK = mKVCacheManager->getKeyBuffer(); tempBufferV = mKVCacheManager->getValueBuffer(); - } else if(mKVCache) { + } else { tempTensorK = mKVCacheManager->getKeyTensor(); tempTensorV = mKVCacheManager->getValueTensor(); - } else { - tempTensorK = mTempK.get(); - tempTensorV = mTempV.get(); } - + // whether use simdgroup bool supportSimdReduce = rt->supportSimdGroupReduce(); bool supportSimdMatrix = rt->supportSimdGroupMatrix(); - bool supportTensorMatrix = mtbn->isSupportTensorApi();// rt->supportTensorOps(); + bool supportTensorMatrix = mtbn->isSupportTensorApi(); // rt->supportTensorOps(); // decode and thread number not too large mQkSimdReduce = supportSimdReduce && mShortSeq; @@ -524,31 +377,46 @@ mSftmSimdReduce = supportSimdReduce; mQkvSimdReduce = supportSimdReduce && mShortSeq && mHeadDim * mNumHead < mKvSeqLen * 32; mQkvSimdMatrix = supportSimdMatrix && mSeqLen >= 16; - + mCopySimdReduce = supportSimdReduce && mKVCacheManager->useDynamicScaleBuffer(); + mDecodeQkSoftmax = mShortSeq && mSeqLen <= 8 && + !mHasMask && !mKvInDisk && + group_size == 2 && mHeadDim % 8 == 0 && mKvSeqLen <= 2048; + // start to compile attention shaders compilerShader(inputs); - + // Run Copy and Format-Convert Kernel { - auto copyp = (int*)mParamCopy.contents; + auto copyp = (CopyParam*)mParamCopy.contents; /* Key -> K-Cache : [mBatch, mKvSeqLen, mKvNumHead, mHeadDim] -> [mKvMaxLen, mBatch, mKvNumHead, mHeadDim] - Value -> V-Cache : [mBatch, mKvSeqLen, mKvNumHead, mHeadDim] -> [mBatch, mKvNumHead, mHeadDim, mKvMaxLen (fill when decode)] + Value -> V-Cache : [mBatch, mKvSeqLen, mKvNumHead, mHeadDim] -> [mBatch, mKvNumHead, mHeadDim, mKvMaxLen (fill + when decode)] */ - copyp[0] = mKvNumHead * mHeadDim; + copyp->head_count = mKvNumHead * mHeadDim; // current new kv_len - copyp[1] = key->shape()[1]; - copyp[2] = mKvMaxLen; - copyp[3] = mKVCacheManager->kvLength() * copyp[0]; - copyp[4] = mKVCacheManager->kvLength(); - copyp[5] = mBatch; + copyp->kv_seq_len = key->shape()[1]; + copyp->max_kv_len = mKvMaxLen; + copyp->dst_k_offset = mKVCacheManager->kvLength() * copyp->head_count; + copyp->dst_v_offset = mKVCacheManager->kvLength(); + copyp->batch = mBatch; + if (mQuantValue && mKVQuantParameter != nullptr) { + copyp->v_scale = mKVQuantParameter->vScale; + } else { + copyp->v_scale = 0.0f; + } + if (mQuantKey && mKVQuantParameter != nullptr) { + copyp->k_scale = mKVQuantParameter->kScale; + } else { + copyp->k_scale = 0.0f; + } int copy_line = key->shape()[1]; id pipeline = mKernel_copy; [encoder setComputePipelineState:pipeline]; MetalBackend::setTensor(key, encoder, 0); MetalBackend::setTensor(value, encoder, 1); - if(mKvInDisk) { + if (mKvInDisk) { MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); MetalBackend::setBuffer(tempBufferV, 0, encoder, 3); } else { @@ -556,14 +424,24 @@ MetalBackend::setTensor(tempTensorV, encoder, 3); } [encoder setBuffer:mParamCopy offset:0 atIndex:4]; - + if (mKVCacheManager->getKScaleBuffer() != nil) { + [encoder setBuffer:mKVCacheManager->getKScaleBuffer() offset:0 atIndex:8]; + [encoder setBuffer:mKVCacheManager->getVScaleBuffer() offset:0 atIndex:9]; + } + std::pair gl; - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mKvNumHead * mHeadDim, copy_line, mBatch)]; + if (mKVCacheManager->getKScaleBuffer() != nil) { + int localSize = mCopySimdReduce ? 32 : 128; + gl = std::make_pair(MTLSizeMake(1, copy_line, mBatch), MTLSizeMake(localSize, 1, 1)); + } else if (mDecodeQkSoftmax) { + gl = std::make_pair(MTLSizeMake(UP_DIV(mKvNumHead * mHeadDim, 128), copy_line, mBatch), MTLSizeMake(128, 1, 1)); + } else { + gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mKvNumHead * mHeadDim, copy_line, mBatch)]; + } [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - } - + // Update Parameters int seqLenPiece = UP_DIV(mSeqLen, mQseqSplitNum); { @@ -578,288 +456,191 @@ param->max_kv_len = mKvMaxLen; param->batch = mBatch; param->kv_align_len = mKvAlignNum; + param->mask_batch = mHasMask ? inputs[3]->length(0) : 1; + param->mask_head_num = (mHasMask && inputs[3]->dimensions() > 3) ? inputs[3]->length(1) : 1; + param->mask_q_len = (mHasMask && inputs[3]->dimensions() > 3) ? inputs[3]->length(2) : 1; + param->mask_k_len = (mHasMask && inputs[3]->dimensions() > 0) ? inputs[3]->length(inputs[3]->dimensions() - 1) : 1; + if (mQuantValue && mKVQuantParameter != nullptr) { + param->v_scale = mKVQuantParameter->vScale; + } else { + param->v_scale = 0.0f; + } + if (mQuantKey && mKVQuantParameter != nullptr) { + param->k_scale = mKVQuantParameter->kScale; + } else { + param->k_scale = 0.0f; + } } - - if (mUseSimpleAttention) { - for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + + for (int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + if (mDecodeQkSoftmax) { + [encoder setComputePipelineState:mKernel_qk_softmax]; + MetalBackend::setTensor(query, encoder, 0); + MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1); + MetalBackend::setTensor(tempTensorK, encoder, 2); + [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; + [encoder setBuffer:mParamQKV offset:0 atIndex:4]; + if (mQuantKey && mKVCacheManager->getKScaleBuffer() != nil) { + [encoder setBuffer:mKVCacheManager->getKScaleBuffer() offset:0 atIndex:8]; + } + int qkGroups = mBatch * (mNumHead / group_size) * seqLenPiece; + int maxLocalSize = ALIMAX(32, ((int)mKernel_qk_softmax.maxTotalThreadsPerThreadgroup / 32) * 32); + int localSize = qkGroups <= 8 ? ALIMIN(1024, maxLocalSize) : + ALIMIN(maxLocalSize, ALIMAX(64, ROUND_UP(UP_DIV(mKvSeqLen, 6), 32))); + auto gl = std::make_pair(MTLSizeMake(mBatch * (mNumHead / group_size), seqLenPiece, 1), MTLSizeMake(localSize, 1, 1)); + [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + } else { // Run QK Kernel - { - id pipeline; - if (mShortSeq) { - pipeline = mKernel_qk; - } else { - pipeline = mKernelPrefill_qk; - } - //pipeline = mKernel_qk; - [encoder setComputePipelineState:pipeline]; - // [mBatch, mSeqLen, mNumHead, mHeadDim] - MetalBackend::setTensor(query, encoder, 0); - // [mBatch, mNumHead, mSeqLen, mKvSeqLen] - MetalBackend::setTensor(mTempQK.get(), encoder, 1); - // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] - if(mKvInDisk) { - MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); - } else { - MetalBackend::setTensor(tempTensorK, encoder, 2); - } - [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; - [encoder setBuffer:mParamQKV offset:0 atIndex:4]; - int kv_start = 0, current_block_len = mKvSeqLen; - [encoder setBytes:&kv_start length:sizeof(kv_start) atIndex:5]; - [encoder setBytes:¤t_block_len length:sizeof(int) atIndex:6]; - if(mHasMask && !mCausalMaskScalar) { - MetalBackend::setTensor(inputs[3], encoder, 7); - } - - int decode_grid_y = mBatch * mNumHead; - std::pair gl; - if(mShortSeq) { - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, decode_grid_y / group_size, mKvSeqLen)]; - } else if(mQkTensorMatrix) { - gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mKvSeqLen, 32), decode_grid_y), MTLSizeMake(128, 1, 1)); - } else if(mQkSimdMatrix) { - gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mKvSeqLen, 16), decode_grid_y), MTLSizeMake(32, 1, 1)); + id pipeline; + if (mShortSeq) { + pipeline = mKernel_qk; } else { - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, decode_grid_y, mKvSeqLen)]; + pipeline = mKernelPrefill_qk; } - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - + // pipeline = mKernel_qk; + [encoder setComputePipelineState:pipeline]; + // [mBatch, mSeqLen, mNumHead, mHeadDim] + MetalBackend::setTensor(query, encoder, 0); + // [mBatch, mNumHead, mSeqLen, mKvSeqLen] + MetalBackend::setTensor(mTempQK.get(), encoder, 1); + // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] + if (mKvInDisk) { + MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); + } else { + MetalBackend::setTensor(tempTensorK, encoder, 2); + } + [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; + [encoder setBuffer:mParamQKV offset:0 atIndex:4]; + if (mKVCacheManager->getKScaleBuffer() != nil) { + [encoder setBuffer:mKVCacheManager->getKScaleBuffer() offset:0 atIndex:8]; + [encoder setBuffer:mKVCacheManager->getVScaleBuffer() offset:0 atIndex:9]; + } + int kv_start = 0, current_block_len = mKvSeqLen; + [encoder setBytes:&kv_start length:sizeof(kv_start) atIndex:5]; + [encoder setBytes:¤t_block_len length:sizeof(int) atIndex:6]; + if (mHasMask) { + MetalBackend::setTensor(inputs[3], encoder, 7); } + + int decode_grid_y = mBatch * mNumHead; + std::pair gl; + if (mShortSeq) { + gl = [context computeBestGroupAndLocal:pipeline + threads:MTLSizeMake(seqLenPiece, decode_grid_y / group_size, mKvSeqLen)]; + } else if (mQkTensorMatrix) { + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mKvSeqLen, 32), decode_grid_y), + MTLSizeMake(128, 1, 1)); + } else if (mQkSimdMatrix) { + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mKvSeqLen, 16), decode_grid_y), + MTLSizeMake(32, 1, 1)); + } else { + gl = [context computeBestGroupAndLocal:pipeline + threads:MTLSizeMake(seqLenPiece, decode_grid_y, mKvSeqLen)]; + } + [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; // Run Softmax Kernel + // For softmax parameter + // [mBatch, mNumHead, mSeqLen, mKvSeqLen] + int inside = 1; + int outside = mBatch * mNumHead * seqLenPiece; + int axis = mKvSeqLen; + int axis_align = ROUND_UP(axis, mKvAlignNum); { - // For softmax parameter - // [mBatch, mNumHead, mSeqLen, mKvSeqLen] - int inside = 1; - int outside = mBatch * mNumHead * seqLenPiece; - int axis = mKvSeqLen; - int axis_align = ROUND_UP(axis, mKvAlignNum); - { - auto softmax = (int*)mParamSoftmax.contents; - // Inside, axis, outside, plane(invalid) - softmax[0] = inside; - softmax[1] = axis; - softmax[2] = outside; - softmax[3] = axis_align; - } - [encoder setComputePipelineState:mKernel_softmax]; - // [mBatch, mNumHead, mSeqLen, mKvSeqLen] - MetalBackend::setTensor(mTempQK.get(), encoder, 0); - // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] - MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1); - [encoder setBuffer:mParamSoftmax offset:0 atIndex:2]; - - int thread_group_size = 32; - std::pair gl; - if(mSftmSimdReduce) { - gl = std::make_pair(MTLSizeMake(inside, outside, 1), MTLSizeMake(thread_group_size, 1, 1)); - } else { - gl = [context computeBestGroupAndLocal: mKernel_softmax threads:MTLSizeMake(inside, outside, 1)]; - } - - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - + auto softmax = (int*)mParamSoftmax.contents; + // Inside, axis, outside, plane(invalid) + softmax[0] = inside; + softmax[1] = axis; + softmax[2] = outside; + softmax[3] = axis_align; } - // Run QKV Kernel - { - - id pipeline; - if (mShortSeq) { - pipeline = mKernel_qkv; - } else { - pipeline = mKernelPrefill_qkv; - } - [encoder setComputePipelineState:pipeline]; - // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] - MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0); - // [mBatch, mSeqLen, mNumHead, mHeadDim] - MetalBackend::setTensor(outputs[0], encoder, 1); - // [mBatch, mKvNumHead, mHeadDim, mMaxSeqLen] - if(mKvInDisk) { - MetalBackend::setBuffer(tempBufferV, 0, encoder, 2); - } else { - MetalBackend::setTensor(tempTensorV, encoder, 2); - } - [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; - [encoder setBuffer:mParamQKV offset:0 atIndex:4]; - std::pair gl; - if(mQkvSimdReduce) { - gl = std::make_pair(MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim), MTLSizeMake(32, 1, 1)); - } else if(mQkTensorMatrix){ - gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mHeadDim, 32), mBatch * mNumHead), MTLSizeMake(128, 1, 1)); - } else if(mQkvSimdMatrix){ - gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mHeadDim, 16), mBatch * mNumHead), MTLSizeMake(32, 1, 1)); - } else { - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim)]; - } - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + [encoder setComputePipelineState:mKernel_softmax]; + // [mBatch, mNumHead, mSeqLen, mKvSeqLen] + MetalBackend::setTensor(mTempQK.get(), encoder, 0); + // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] + MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1); + [encoder setBuffer:mParamSoftmax offset:0 atIndex:2]; + + int thread_group_size = 32; + std::pair softmaxGl; + if (mSftmSimdReduce) { + softmaxGl = std::make_pair(MTLSizeMake(inside, outside, 1), MTLSizeMake(thread_group_size, 1, 1)); + } else { + softmaxGl = [context computeBestGroupAndLocal:mKernel_softmax threads:MTLSizeMake(inside, outside, 1)]; } + + [encoder dispatchThreadgroups:softmaxGl.first threadsPerThreadgroup:softmaxGl.second]; } - } else { - // Flash Attention - if (mUseFlashAttentionFused) { - id pipeline = mKernel_flash_fused; + // Run QKV Kernel + { + id pipeline; + if (mShortSeq) { + pipeline = mKernel_qkv; + } else { + pipeline = mKernelPrefill_qkv; + } [encoder setComputePipelineState:pipeline]; - MetalBackend::setTensor(query, encoder, 0); - if(mKvInDisk) { - MetalBackend::setBuffer(tempBufferK, 0, encoder, 1); + // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] + MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0); + // [mBatch, mSeqLen, mNumHead, mHeadDim] + MetalBackend::setTensor(outputs[0], encoder, 1); + // [mBatch, mKvNumHead, mHeadDim, mMaxSeqLen] + if (mKvInDisk) { MetalBackend::setBuffer(tempBufferV, 0, encoder, 2); } else { - MetalBackend::setTensor(tempTensorK, encoder, 1); MetalBackend::setTensor(tempTensorV, encoder, 2); } - if(mHasMask && !mCausalMaskScalar) { - MetalBackend::setTensor(inputs[3], encoder, 3); + [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; + [encoder setBuffer:mParamQKV offset:0 atIndex:4]; + if (mKVCacheManager->getKScaleBuffer() != nil) { + [encoder setBuffer:mKVCacheManager->getKScaleBuffer() offset:0 atIndex:8]; + [encoder setBuffer:mKVCacheManager->getVScaleBuffer() offset:0 atIndex:9]; } - MetalBackend::setTensor(outputs[0], encoder, 4); - [encoder setBuffer:mParamQKV offset:0 atIndex:5]; - - // TEMPORARY: Revert to stable configuration for debugging - // Grid: [q_seqlen/16, batch*headNum, 1], Threadgroup: 32 threads - if(mQkvSimdMatrix) { - [encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(mSeqLen, 8), mBatch * mNumHead, 1) - threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + std::pair gl; + if (mQkvSimdReduce) { + int grid_z = mOutputC4 ? UP_DIV(mHeadDim, 2) : mHeadDim; + gl = std::make_pair(MTLSizeMake(seqLenPiece, mBatch * mNumHead, grid_z), MTLSizeMake(32, 1, 1)); + } else if (mQkTensorMatrix) { + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mHeadDim, 32), mBatch * mNumHead), + MTLSizeMake(128, 1, 1)); + } else if (mQkvSimdMatrix) { + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mHeadDim, 16), mBatch * mNumHead), + MTLSizeMake(32, 1, 1)); } else { - [encoder dispatchThreadgroups:MTLSizeMake(mSeqLen, mBatch * mNumHead, 1) - threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } - } else { - int blockSize = MNN_FLASH_ATTENTION_BLOCK_SIZE; - int kv_blocks = UP_DIV(mKvSeqLen, blockSize); - - { - auto param = (Param*)mParamQKV.contents; - // Original logic updates Param per piece, but here we run full seq. - // Adjust param for KV block if needed? - // The prefill_qk uses param.key_seq_len for loop bound. - // We need to update this per block or just reuse shader carefully? - // Reuse prefill_qk: keys logic relies on param. - } - - int seq_idx = 0; // prefill usually 1 piece - - for (int i = 0; i < kv_blocks; ++i) { - int kv_start = i * blockSize; - int current_block_len = std::min(blockSize, mKvSeqLen - kv_start); - - // 1. MatMul QK -> TempQK - { - id pipeline = mKernelPrefill_qk; - [encoder setComputePipelineState:pipeline]; - MetalBackend::setTensor(query, encoder, 0); - MetalBackend::setTensor(mTempQK.get(), encoder, 1); - if(mKvInDisk) { - MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); - } else { - MetalBackend::setTensor(tempTensorK, encoder, 2); - } - - [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; - [encoder setBuffer:mParamQKV offset:0 atIndex:4]; - [encoder setBytes:&kv_start length:sizeof(kv_start) atIndex:5]; - [encoder setBytes:¤t_block_len length:sizeof(int) atIndex:6]; - if(mHasMask && !mCausalMaskScalar) { - MetalBackend::setTensor(inputs[3], encoder, 7); - } - - int decode_grid_y = mBatch * mNumHead; - std::pair gl; - - // Block len logic mirroring original - if(mQkTensorMatrix) { - gl = std::make_pair(MTLSizeMake(UP_DIV(mSeqLen, 32), UP_DIV(current_block_len, 32), decode_grid_y), MTLSizeMake(128, 1, 1)); - } else if(mQkSimdMatrix) { - gl = std::make_pair(MTLSizeMake(UP_DIV(mSeqLen, 16), UP_DIV(current_block_len, 16), decode_grid_y), MTLSizeMake(32, 1, 1)); - } else { - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mSeqLen, decode_grid_y, current_block_len)]; - } - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - } - - // 2. Flash Softmax - { - [encoder setComputePipelineState:mKernel_flash_softmax]; - MetalBackend::setTensor(mTempQK.get(), encoder, 0); - MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1); - MetalBackend::setTensor(mRunningStats.get(), encoder, 2); - MetalBackend::setTensor(mCorrectionScale.get(), encoder, 3); - [encoder setBytes:¤t_block_len length:sizeof(int) atIndex:4]; - [encoder setBuffer:mParamQKV offset:0 atIndex:5]; - [encoder setBytes:&kv_start length:sizeof(int) atIndex:6]; - - // Grid: [SeqLen, Batch*Head, 1] - std::pair gl; - if (mSftmSimdReduce) { - gl = std::make_pair(MTLSizeMake(mSeqLen, mBatch * mNumHead, 1), MTLSizeMake(32, 1, 1)); - } else { - gl = [context computeBestGroupAndLocal:mKernel_flash_softmax threads:MTLSizeMake(mSeqLen, mBatch * mNumHead, 1)]; - } - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - } - - // 3. Flash MatMul QKV - { - [encoder setComputePipelineState:mKernel_flash_matmul_qkv]; - MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0); // P_block - MetalBackend::setTensor(mTempOutput.get(), encoder, 1); // tempOutput - - // V_block: needs to be just V tensor - if(mKvInDisk) { - MetalBackend::setBuffer(tempBufferV, 0, encoder, 2); - } else { - MetalBackend::setTensor(tempTensorV, encoder, 2); - } - - MetalBackend::setTensor(mCorrectionScale.get(), encoder, 3); - [encoder setBytes:&kv_start length:sizeof(int) atIndex:4]; - [encoder setBytes:¤t_block_len length:sizeof(int) atIndex:5]; - [encoder setBuffer:mParamQKV offset:0 atIndex:6]; - - // Grid: [HeadDim/4, SeqLen, Batch*Head] - // We use float4 for HeadDim - std::pair gl; - if(mQkvSimdReduce) { - gl = std::make_pair(MTLSizeMake(UP_DIV(mHeadDim, 4), mSeqLen, mBatch * mNumHead), MTLSizeMake(32, 1, 1)); - } else if(mQkvSimdMatrix){ - gl = std::make_pair(MTLSizeMake(UP_DIV(mSeqLen, 16), UP_DIV(mHeadDim, 16), mBatch * mNumHead), MTLSizeMake(32, 1, 1)); - } else { - gl = [context computeBestGroupAndLocal:mKernel_flash_matmul_qkv threads:MTLSizeMake(UP_DIV(mHeadDim, 4), mSeqLen, mBatch * mNumHead)]; - } - - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; - } - } - - // 4. Flash Scale - { - [encoder setComputePipelineState:mKernel_flash_scale]; - MetalBackend::setTensor(mTempOutput.get(), encoder, 0); // tempOutput - MetalBackend::setTensor(outputs[0], encoder, 1); - MetalBackend::setTensor(mRunningStats.get(), encoder, 2); - [encoder setBuffer:mParamQKV offset:0 atIndex:3]; - - auto gl = [context computeBestGroupAndLocal:mKernel_flash_scale threads:MTLSizeMake(UP_DIV(mHeadDim, 4), mSeqLen, mBatch * mNumHead)]; - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + gl = [context computeBestGroupAndLocal:pipeline + threads:MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim)]; } + [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; } - } + // Update status - if(mKVCache) { - mKVCacheManager->setPastLength(mKVCacheManager->kvLength() + mSeqLen); - } + mKVCacheManager->setPastLength(mKVCacheManager->kvLength() + mCurrentKvLen); return; } class AttentionBufCreator : public MetalBackend::Creator { public: - virtual Execution *onCreate(const std::vector &inputs, const MNN::Op *op, Backend *backend, const std::vector &outputs) const override { + virtual Execution* onCreate(const std::vector& inputs, const MNN::Op* op, Backend* backend, + const std::vector& outputs) const override { auto param = op->main_as_AttentionParam(); - return new AttentionBufExecution(backend, param->kv_cache()); + std::shared_ptr quantParam; + if (nullptr != param->mhq_quant() && param->mhq_quant()->size() > 0) { + MNN_ASSERT(param->mhq_quant()->size() == 4); + std::vector mhqscale(param->mhq_quant()->size()); + for (int i = 0; i < mhqscale.size(); ++i) { + mhqscale[i] = param->mhq_quant()->GetAs(i)->scale(); + } + quantParam.reset(new KVQuantParameter); + quantParam->qScale = mhqscale[0]; + quantParam->kScale = mhqscale[1]; + quantParam->qkScale = mhqscale[2]; + quantParam->vScale = mhqscale[3]; + } + return new AttentionBufExecution(backend, param->output_c4(), param->attnScale(), quantParam); } }; REGISTER_METAL_OP_TRANSFORMER_CREATOR(AttentionBufCreator, OpType_Attention); } // namespace MNN -#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ #endif - diff --git a/source/backend/metal/MetalAttentionShader.hpp b/source/backend/metal/MetalAttentionShader.hpp index 85eb00e9d8..0f3807860a 100644 --- a/source/backend/metal/MetalAttentionShader.hpp +++ b/source/backend/metal/MetalAttentionShader.hpp @@ -28,8 +28,32 @@ struct Param { int max_kv_len; int batch; int kv_align_len; + int mask_batch; + int mask_head_num; + int mask_q_len; + int mask_k_len; + float v_scale; + float k_scale; }; +static inline bool attention_mask_hit(constant Param& param, int k) { + if (param.mask_k_len <= 1) { + return true; + } + int mask_k_start = max(param.key_seq_len - param.mask_k_len, 0); + int local_k = k - mask_k_start; + return local_k >= 0 && local_k < param.mask_k_len; +} + +static inline int attention_mask_offset(constant Param& param, int b, int hn, int q, int k) { + int mask_b = param.mask_batch <= 1 ? 0 : b; + int mask_h = param.mask_head_num <= 1 ? 0 : hn; + int mask_q = param.mask_q_len <= 1 ? 0 : min(q, param.mask_q_len - 1); + int mask_k_start = max(param.key_seq_len - param.mask_k_len, 0); + int local_k = param.mask_k_len <= 1 ? 0 : clamp(k - mask_k_start, 0, param.mask_k_len - 1); + return ((mask_b * param.mask_head_num + mask_h) * param.mask_q_len + mask_q) * param.mask_k_len + local_k; +} + #if MNN_METAL_FLOAT16_STORAGE typedef simdgroup_half8x8 simdgroup_T8x8; #else @@ -37,7 +61,18 @@ typedef simdgroup_float8x8 simdgroup_T8x8; #endif #define SIMD_GROUP_WIDTH 32 - +#ifdef QUANT_K +#ifdef DYNAMIC_QUANT_K +#define GETK(v, token_idx) ftype((float(v) * k_scales[(token_idx) * 2] + k_scales[(token_idx) * 2 + 1])) +#define GETK4(v, token_idx) (float4(v) * k_scales[(token_idx) * 2] + k_scales[(token_idx) * 2 + 1]) +#else +#define GETK(v, token_idx) ftype((float(v) * param.k_scale)) +#define GETK4(v, token_idx) (float4(v) * param.k_scale) +#endif +#else +#define GETK(v, token_idx) v +#define GETK4(v, token_idx) v +#endif #ifdef USE_METAL_TENSOR_OPS kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], device ftype* output [[buffer(1)]], @@ -51,6 +86,8 @@ kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], #elif defined(SET_MASK) const device int* mask [[buffer(7)]], #endif + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], uint3 gid[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], @@ -65,7 +102,7 @@ kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], */ threadgroup ftype sdata[2048] = {0.f}; - const int K = 32, M = 32, N = 32; + const int K = 32, M = 32, N = 32; const int tb_offset = M * K; auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + tb_offset, dextents(K, N));//[N, K] @@ -133,14 +170,18 @@ kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], auto A_offset = input0 + ((b * q_seq_len + idx_slq) * head_num + hn) * head_dim / 4 + (0 * 4 + kl) * 2 + 0; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] +#ifdef QUANT_K + auto B_offset = (const device char4*)past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim / 4 + (0 * 4 + kl) * 2 + 0; +#else auto B_offset = past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim / 4 + (0 * 4 + kl) * 2 + 0; +#endif for(int i = 0; i < head_dim/4; i += 8){ ((threadgroup ftype4*)sdata)[(ml * 4 + kl) * 2 + 0] = A_offset[i + 0]; ((threadgroup ftype4*)sdata)[(ml * 4 + kl) * 2 + 1] = A_offset[i + 1]; - ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 0] = B_offset[i + 0]; - ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 1] = B_offset[i + 1]; + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 0] = (ftype4)GETK4(B_offset[i + 0], idx_slk * param.batch + b); + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 1] = (ftype4)GETK4(B_offset[i + 1], idx_slk * param.batch + b); threadgroup_barrier(mem_flags::mem_threadgroup); auto sA = tA.slice(0, 0); @@ -161,6 +202,10 @@ kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], float Vscale = (float)param.scale; +#if defined(DEFAULT_MASK) + int kv_valid_offset = max(k_seq_len - q_seq_len, 0); +#endif + int base_k_idx = (slk * 4 + ncl) * 8 + 0; auto xy_out = output + (z * q_seq_piece_len + slq * 32 + mcl) * output_k_len + base_k_idx + 0; if(slq * 32 + mcl < q_seq_piece_len && seq_idx * q_seq_piece_len + slq * 32 + mcl < q_seq_len) { @@ -168,97 +213,152 @@ kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], if(base_k_idx + 0 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 0] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 0) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 0) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 0)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 0)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 0))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - // keep if j <= i + (k_len - q_len), else -inf - out0 = (kv_start + base_k_idx + 0) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 0)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 0)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 0; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[0] = out0; } if(base_k_idx + 1 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 1] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 1) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 1) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 1)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 1)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 1))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 1) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 1)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 1)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 1; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[1] = out0; } if(base_k_idx + 2 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 2] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 2) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 2) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 2)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 2)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 2))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 2) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 2)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 2)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 2; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[2] = out0; } if(base_k_idx + 3 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 3] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 3) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 3) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 3)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 3)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 3))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 3) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 3)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 3)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 3; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[3] = out0; } if(base_k_idx + 4 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 4] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 4) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 4) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 4)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 4)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 4))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 4) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 4)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 4)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 4; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[4] = out0; } if(base_k_idx + 5 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 5] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 5) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 5) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 5)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 5)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 5))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 5) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 5)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 5)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 5; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[5] = out0; } if(base_k_idx + 6 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 6] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 6) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 6) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 6)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 6)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 6))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 6) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 6)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 6)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 6; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[6] = out0; } if(base_k_idx + 7 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 7] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + base_k_idx + 7) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + base_k_idx + 7) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 7)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 7)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + base_k_idx + 7))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - out0 = (kv_start + base_k_idx + 7) > (ori_q_idx + (k_seq_len - q_seq_len)) ? -FLT_MAX : out0; + if (attention_mask_hit(param, kv_start + base_k_idx + 7)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + base_k_idx + 7)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + base_k_idx + 7; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[7] = out0; } @@ -281,6 +381,8 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], #elif defined(SET_MASK) const device int* mask [[buffer(7)]], #endif + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], #ifdef SIMD_GROUP_MATRIX uint3 gid[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -304,7 +406,7 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], #ifdef USE_METAL_TENSOR_OPS - const int K = 8, M = 16, N = 16; + const int K = 8, M = 16, N = 16; auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + 128, dextents(N, K));//[K, N] @@ -369,18 +471,20 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], auto A_offset = input0 + ((b * q_seq_len + idx_slq) * head_num + hn) * head_dim + (0 * 2 + kl) * 4 + 0; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] +#ifdef QUANT_K + auto B_offset = (const device char*)past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0; +#else auto B_offset = past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0; +#endif for(int i = 0; i < head_dim; i += 8){ - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - - ((threadgroup ftype*)sdata)[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0]; - ((threadgroup ftype*)sdata)[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1]; - ((threadgroup ftype*)sdata)[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2]; - ((threadgroup ftype*)sdata)[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3]; + // 向量化写入 Q(4 元素一组) + *((threadgroup ftype4*)(&((threadgroup ftype*)sdata)[rcl * 8 + kl * 4])) = *((const device ftype4*)(&A_offset[i])); + + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 0) * 16 + rcl] = GETK(B_offset[i + 0], idx_slk * param.batch + b); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 1) * 16 + rcl] = GETK(B_offset[i + 1], idx_slk * param.batch + b); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 2) * 16 + rcl] = GETK(B_offset[i + 2], idx_slk * param.batch + b); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 3) * 16 + rcl] = GETK(B_offset[i + 3], idx_slk * param.batch + b); threadgroup_barrier(mem_flags::mem_threadgroup); #ifdef USE_METAL_TENSOR_OPS @@ -391,10 +495,10 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], #else simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); - + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); - + simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); @@ -426,111 +530,162 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], float Vscale = (float)param.scale; +#if defined(DEFAULT_MASK) + int kv_valid_offset = k_seq_len - q_seq_len; +#endif + auto xy_out = output + (z * q_seq_piece_len + slq * 16 + rcl) * output_k_len + slk * 16 + kl * 8 + 0; if(slq * 16 + rcl < q_seq_piece_len && seq_idx * q_seq_piece_len + slq * 16 + rcl < q_seq_len) { int ori_q_idx = seq_idx * q_seq_piece_len + slq * 16 + rcl; if(slk * 16 + kl * 8 + 0 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 0] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 0) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 0) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 0)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 0)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 0))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - // Causal mask: keep if key_pos <= query_pos + (kv_len - q_len), else -inf - int key_pos = kv_start + slk * 16 + kl * 8 + 0; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 0)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 0)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 0; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[0] = out0; } if(slk * 16 + kl * 8 + 1 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 1] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 1) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 1) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 1)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 1)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 1))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 1; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 1)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 1)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 1; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[1] = out0; } if(slk * 16 + kl * 8 + 2 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 2] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 2) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 2) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 2)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 2)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 2))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 2; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 2)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 2)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 2; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[2] = out0; } if(slk * 16 + kl * 8 + 3 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 3] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 3) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 3) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 3)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 3)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 3))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 3; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 3)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 3)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 3; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[3] = out0; } if(slk * 16 + kl * 8 + 4 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 4] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 4) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 4) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 4)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 4)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 4))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 4; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 4)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 4)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 4; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[4] = out0; } if(slk * 16 + kl * 8 + 5 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 5] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 5) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 5) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 5)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 5)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 5))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 5; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 5)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 5)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 5; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[5] = out0; } if(slk * 16 + kl * 8 + 6 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 6] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 6) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 6) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 6)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 6)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 6))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 6; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 6)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 6)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 6; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[6] = out0; } if(slk * 16 + kl * 8 + 7 < output_k_len) { auto out0 = ((threadgroup float*)sdata)[sindex_base + 7] * Vscale; #ifdef ADD_MASK - auto mask_val = (kv_start + slk * 16 + kl * 8 + 7) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (kv_start + slk * 16 + kl * 8 + 7) - k_seq_len + q_seq_len)] : 0.0; - out0 = mask_val + out0; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 7)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 7)]; + out0 = mask_val + out0; + } #elif defined(SET_MASK) - out0 = mask[(ori_q_idx * k_seq_len + (kv_start + slk * 16 + kl * 8 + 7))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - int key_pos = kv_start + slk * 16 + kl * 8 + 7; - if (key_pos > (ori_q_idx + (k_seq_len - q_seq_len))) out0 = -FLT_MAX; + if (attention_mask_hit(param, kv_start + slk * 16 + kl * 8 + 7)) { + out0 = mask[attention_mask_offset(param, b, hn, ori_q_idx, kv_start + slk * 16 + kl * 8 + 7)] == 0 ? -FLT_MAX : out0; + } + #elif defined(DEFAULT_MASK) + int k_global = kv_start + slk * 16 + kl * 8 + 7; + if (k_global > kv_valid_offset + ori_q_idx) { + out0 = -FLT_MAX; + } #endif xy_out[7] = out0; } @@ -540,7 +695,7 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], const int x = gid.x; // query_seq_len const int y = gid.y; // head_num * batch const int z = gid.z; // key_seq_len - + int q_idx = seq_idx * param.q_seq_piece_len + x; int z_global = kv_start + z; if (x >= param.q_seq_piece_len || q_idx >= param.query_seq_len || y >= param.head_num * param.batch || z_global >= param.key_seq_len) { @@ -553,7 +708,7 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], int head_dim = param.head_dim; int b = y / head_num; int hn = y % head_num; - + const int offset = head_num * head_dim; const int offset_head = y * head_dim; const int offset_head_kv = (hn / group) * head_dim; @@ -562,27 +717,55 @@ kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], float Vscale = (float)param.scale; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] +#ifdef QUANT_K + const device char* B_offset = (const device char*)past_key + ((z_global * param.batch + b) * offset / group + offset_head_kv); +#else device const ftype* B_offset = past_key + (z_global * param.batch + b) * offset / group + offset_head_kv; +#endif const int output_offset = y * param.q_seq_piece_len * output_k_len; float out0 = 0.0; - - for(int i = 0; i < head_dim; ++i){ - float A = (float)(A_offset[i]); - float B = (float)(B_offset[i]); - out0 += B * A; + + // 两路流水:每次处理 8 个标量(两个 float4),减少循环开销 + int itN = head_dim / 8; // head_dim 保证 16 对齐,因此 /8 为整数 + const device ftype4* A4p = (const device ftype4*)A_offset; +#ifdef QUANT_K + const device char4* B4p_c = (const device char4*)B_offset; +#else + const device ftype4* B4p = (const device ftype4*)B_offset; +#endif + for (int i = 0; i < itN; ++i) { +#ifdef QUANT_K + float4 B0 = GETK4(B4p_c[i * 2 + 0], z_global * param.batch + b); + float4 B1 = GETK4(B4p_c[i * 2 + 1], z_global * param.batch + b); +#else + float4 B0 = float4(B4p[i * 2 + 0]); + float4 B1 = float4(B4p[i * 2 + 1]); +#endif + float4 A0 = float4(A4p[i * 2 + 0]); + float4 A1 = float4(A4p[i * 2 + 1]); + out0 += dot(A0, B0) + dot(A1, B1); } - + out0 *= Vscale; - - #ifdef ADD_MASK - auto mask_val = z_global >= key_seq_len - query_seq_len ? mask[((q_idx + 0) * query_seq_len + (z_global - key_seq_len + query_seq_len))] : 0.0; - out0 = mask_val + out0; - #elif defined(SET_MASK) - out0 = mask[((q_idx + 0) * key_seq_len + (z_global + 0))] == 0 ? -FLT_MAX : out0; - #elif defined(CAUSAL_MASK) - // keep if j <= i + (k_len - q_len), else -inf - out0 = z_global > (q_idx + (key_seq_len - query_seq_len)) ? -FLT_MAX : out0; - #endif + +#ifdef ADD_MASK + if (attention_mask_hit(param, z_global)) { + auto mask_val = mask[attention_mask_offset(param, b, hn, q_idx, z_global)]; + out0 = mask_val + out0; + } +#elif defined(SET_MASK) + if (attention_mask_hit(param, z_global)) { + out0 = mask[attention_mask_offset(param, b, hn, q_idx, z_global)] == 0 ? -FLT_MAX : out0; + } +#elif defined(DEFAULT_MASK) + { + int kv_valid_offset = max(key_seq_len - query_seq_len, 0); + int k_global = z_global; + if (k_global > kv_valid_offset + q_idx) { + out0 = -FLT_MAX; + } + } +#endif output[output_offset + x * output_k_len + z] = (ftype)out0; #endif } @@ -600,11 +783,25 @@ kernel void decode_qk(const device ftype* input0 [[buffer(0)]], #elif defined(SET_MASK) const device int* mask [[buffer(7)]], #endif + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], +#ifdef SIMD_GROUP_REDUCE + uint3 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +#else uint3 gid[[thread_position_in_grid]] +#endif ) { +#ifdef SIMD_GROUP_REDUCE + int x = gid.x; // query_seq_len + int y = gid.y; // head_num * batch + int z = gid.z; // key_seq_len +#else int x = gid.x; // query_seq_len int y = gid.y; // head_num * batch int z = gid.z; // key_seq_len +#endif int group = param.group; int kv_head_num = param.head_num / group; if (x >= param.query_seq_len || y >= kv_head_num * param.batch || z >= param.key_seq_len) { @@ -614,7 +811,7 @@ kernel void decode_qk(const device ftype* input0 [[buffer(0)]], int key_seq_len = param.key_seq_len; int head_num = param.head_num; int head_dim = param.head_dim; - + int b = y / kv_head_num; int kv_hn = y % kv_head_num; const int offset = head_num * head_dim; @@ -624,49 +821,107 @@ kernel void decode_qk(const device ftype* input0 [[buffer(0)]], // [mBatch, mSeqLen, mNumHead, mHeadDim] const device ftype* A_offset = input0 + (b * param.query_seq_len + x) * offset + offset_head; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] +#ifdef QUANT_K + const device char* Pastkey_offset = (const device char*)past_key + ((z * param.batch + b) * offset / group + offset_head_kv); +#else device ftype* Pastkey_offset = past_key + (z * param.batch + b) * offset / group + offset_head_kv; +#endif float Vscale = (float)param.scale; + // 保持与原 Mask 分支一致的计算路径,避免提前返回带来的数值波动 float out[GROUP_SIZE] = {0.0}; - #ifdef HEAD_DIM_UNALIGNED_4 +#if defined(QUANT_K) && defined(DYNAMIC_QUANT_K) + int k_token_idx = z * param.batch + b; + float k_scale = k_scales[k_token_idx * 2]; + float k_bias = k_scales[k_token_idx * 2 + 1]; +#endif + +#ifdef SIMD_GROUP_REDUCE { - for(int i = 0; i < head_dim; i++){ - float B = (float)Pastkey_offset[i]; - for(int j = 0; j < group; j++) { - float A = A_offset[i + head_dim * j]; - out[j] += A * B; + int itN = head_dim / 8; + for (int i = tiisg; i < itN; i+=SIMD_GROUP_WIDTH) { +#ifdef QUANT_K +#ifdef DYNAMIC_QUANT_K + float4 B0 = float4(((const device char4*)Pastkey_offset)[i * 2 + 0]) * k_scale + k_bias; + float4 B1 = float4(((const device char4*)Pastkey_offset)[i * 2 + 1]) * k_scale + k_bias; +#else + float4 B0 = GETK4(((const device char4*)Pastkey_offset)[i * 2 + 0], z * param.batch + b); + float4 B1 = GETK4(((const device char4*)Pastkey_offset)[i * 2 + 1], z * param.batch + b); +#endif +#else + float4 B0 = float4(((const device ftype4*)Pastkey_offset)[i * 2 + 0]); + float4 B1 = float4(((const device ftype4*)Pastkey_offset)[i * 2 + 1]); +#endif + for (int j = 0; j < group; j++) { + const device ftype4* Ajp = (const device ftype4*)(A_offset + head_dim * j); + float4 A0 = float4(Ajp[i * 2 + 0]); + float4 A1 = float4(Ajp[i * 2 + 1]); + out[j] += dot(A0, B0) + dot(A1, B1); } } } - #else + for(int j = 0; j < group; j++) { + out[j] = simd_sum(out[j]); + } +#else { - for(int i = 0; i < head_dim/4; i++){ - float4 B = float4(((const device ftype4*)Pastkey_offset)[i]); - for(int j = 0; j < group; j++) { - float4 A = float4(((const device ftype4*)(A_offset + head_dim * j))[i]); - out[j] += dot(A, B); + // 统一使用 float4 向量化点积(QUANT_K 走 GETK4) + int itN = head_dim / 8; + for (int i = 0; i < itN; ++i) { +#ifdef QUANT_K +#ifdef DYNAMIC_QUANT_K + float4 B0 = float4(((const device char4*)Pastkey_offset)[i * 2 + 0]) * k_scale + k_bias; + float4 B1 = float4(((const device char4*)Pastkey_offset)[i * 2 + 1]) * k_scale + k_bias; +#else + float4 B0 = GETK4(((const device char4*)Pastkey_offset)[i * 2 + 0], z * param.batch + b); + float4 B1 = GETK4(((const device char4*)Pastkey_offset)[i * 2 + 1], z * param.batch + b); +#endif +#else + float4 B0 = float4(((const device ftype4*)Pastkey_offset)[i * 2 + 0]); + float4 B1 = float4(((const device ftype4*)Pastkey_offset)[i * 2 + 1]); +#endif + for (int j = 0; j < group; j++) { + const device ftype4* Ajp = (const device ftype4*)(A_offset + head_dim * j); + float4 A0 = float4(Ajp[i * 2 + 0]); + float4 A1 = float4(Ajp[i * 2 + 1]); + out[j] += dot(A0, B0) + dot(A1, B1); } } } - #endif - #ifdef ADD_MASK - float mask_val = z >= key_seq_len - param.query_seq_len ? mask[((x + 0) * param.query_seq_len + (z - key_seq_len + param.query_seq_len))] : 0.0; - #elif defined(SET_MASK) - int mask_val = mask[((x + 0) * key_seq_len + (z + 0))]; - #endif +#endif + +#ifdef SIMD_GROUP_REDUCE + if (tiisg == 0) { +#endif + for(int j = 0; j < group; j++) { out[j] *= Vscale; #ifdef ADD_MASK - out[j] += mask_val; + if (attention_mask_hit(param, z)) { + float mask_val = mask[attention_mask_offset(param, b, kv_hn * group + j, x, z)]; + out[j] += mask_val; + } #elif defined(SET_MASK) - out[j] = mask_val == 0 ? -FLT_MAX : out[j]; - #elif defined(CAUSAL_MASK) - out[j] = z > (x + (key_seq_len - param.query_seq_len)) ? -FLT_MAX : out[j]; + if (attention_mask_hit(param, z)) { + int mask_val = mask[attention_mask_offset(param, b, kv_hn * group + j, x, z)]; + out[j] = mask_val == 0 ? -FLT_MAX : out[j]; + } + #elif defined(DEFAULT_MASK) + { + int kv_valid_offset = max(key_seq_len - param.query_seq_len, 0); + int k_global = z; + if (k_global > kv_valid_offset + x) { + out[j] = -FLT_MAX; + } + } #endif output[((y * group + j) * param.query_seq_len + x) * key_seq_len + z] = (ftype)out[j]; } +#ifdef SIMD_GROUP_REDUCE + } +#endif } )metal"; @@ -681,16 +936,264 @@ struct Param { int dst_k_offset; int dst_v_offset; int batch; + float v_scale; + float k_scale; }; // Key: [batch, kv_seq_len, head_num / group * head_dim] -> [max_kv_len, batch, head_num / group * head_dim] // Value: [batch, kv_seq_len, head_num / group * head_dim] -> [batch, head_num / group * head_dim, max_kv_len] + +#ifdef KV_QUANT_K +#define KOUT_TYPE char +#else +#define KOUT_TYPE ftype +#endif + +#ifdef KV_QUANT_V +#define VOUT_TYPE char +#else +#define VOUT_TYPE ftype +#endif + + kernel void copy(const device ftype* input0 [[buffer(0)]], const device ftype* input1 [[buffer(1)]], - device ftype* output0 [[buffer(2)]], - device ftype* output1 [[buffer(3)]], + device KOUT_TYPE* output0 [[buffer(2)]], + device VOUT_TYPE* output1 [[buffer(3)]], constant Param& param [[buffer(4)]], + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], +#ifdef DYNAMIC_QUANT + uint3 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint titg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint3 tptg_3d[[threads_per_threadgroup]] +#else uint3 gid[[thread_position_in_grid]] +#endif ) { +#ifdef DYNAMIC_QUANT + const int y = gid.y; // kv_seq_len + const int b = gid.z; // batch + const uint tptg = tptg_3d.x * tptg_3d.y * tptg_3d.z; + if (y >= param.kv_seq_len || b >= param.batch) { + return; + } + +#if defined(KV_QUANT_K) || defined(KV_QUANT_V) + float k_scale = param.k_scale; + float k_bias = 0.0f; + float v_scale = param.v_scale; + float v_bias = 0.0f; + +#ifdef DYNAMIC_QUANT + // Dynamic quantization scale calculation + { +#ifdef KV_QUANT_K + float min_k = 1000000.0f; + float max_k = -1000000.0f; +#endif +#ifdef KV_QUANT_V + float min_v = 1000000.0f; + float max_v = -1000000.0f; +#endif + + int vector_end = (param.head_count / 4) * 4; + for (int x = int(titg) * 4; x < vector_end; x += int(tptg) * 4) { + const int in_idx = (b * param.kv_seq_len + y) * param.head_count + x; +#ifdef KV_QUANT_K + float4 k4 = float4(((const device ftype4*)(input0 + in_idx))[0]); + float k_min = metal::min(metal::min(k4.x, k4.y), metal::min(k4.z, k4.w)); + float k_max = metal::max(metal::max(k4.x, k4.y), metal::max(k4.z, k4.w)); + min_k = metal::min(min_k, k_min); + max_k = metal::max(max_k, k_max); +#endif +#ifdef KV_QUANT_V + float4 v4 = float4(((const device ftype4*)(input1 + in_idx))[0]); + float v_min = metal::min(metal::min(v4.x, v4.y), metal::min(v4.z, v4.w)); + float v_max = metal::max(metal::max(v4.x, v4.y), metal::max(v4.z, v4.w)); + min_v = metal::min(min_v, v_min); + max_v = metal::max(max_v, v_max); +#endif + } + for (int x = vector_end + int(titg); x < param.head_count; x += int(tptg)) { + const int in_idx = (b * param.kv_seq_len + y) * param.head_count + x; +#ifdef KV_QUANT_K + float k = (float)input0[in_idx]; + min_k = metal::min(min_k, k); + max_k = metal::max(max_k, k); +#endif +#ifdef KV_QUANT_V + float v = (float)input1[in_idx]; + min_v = metal::min(min_v, v); + max_v = metal::max(max_v, v); +#endif + } + +#ifdef SIMD_GROUP_REDUCE +#ifdef KV_QUANT_K + min_k = simd_min(min_k); + max_k = simd_max(max_k); +#endif +#ifdef KV_QUANT_V + min_v = simd_min(min_v); + max_v = simd_max(max_v); +#endif +#else +#ifdef KV_QUANT_K + threadgroup float tg_min_k[256]; + threadgroup float tg_max_k[256]; +#endif +#ifdef KV_QUANT_V + threadgroup float tg_min_v[256]; + threadgroup float tg_max_v[256]; +#endif + +#ifdef KV_QUANT_K + tg_min_k[titg] = min_k; + tg_max_k[titg] = max_k; +#endif +#ifdef KV_QUANT_V + tg_min_v[titg] = min_v; + tg_max_v[titg] = max_v; +#endif + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (titg == 0) { + for (uint i = 1; i < tptg; i++) { +#ifdef KV_QUANT_K + min_k = metal::min(min_k, tg_min_k[i]); + max_k = metal::max(max_k, tg_max_k[i]); +#endif +#ifdef KV_QUANT_V + min_v = metal::min(min_v, tg_min_v[i]); + max_v = metal::max(max_v, tg_max_v[i]); +#endif + } +#ifdef KV_QUANT_K + tg_min_k[0] = min_k; + tg_max_k[0] = max_k; +#endif +#ifdef KV_QUANT_V + tg_min_v[0] = min_v; + tg_max_v[0] = max_v; +#endif + } + threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef KV_QUANT_K + min_k = tg_min_k[0]; + max_k = tg_max_k[0]; +#endif +#ifdef KV_QUANT_V + min_v = tg_min_v[0]; + max_v = tg_max_v[0]; +#endif +#endif +#ifdef KV_QUANT_K + k_scale = (max_k - min_k) / 255.0f; + if (k_scale < 1e-6f) k_scale = 1e-6f; + k_bias = min_k + 128.0f * k_scale; +#endif +#ifdef KV_QUANT_V + v_scale = (max_v - min_v) / 255.0f; + if (v_scale < 1e-6f) v_scale = 1e-6f; + v_bias = min_v + 128.0f * v_scale; +#endif + + if (titg == 0) { +#ifdef KV_QUANT_K + int k_tok_idx = param.dst_k_offset / param.head_count + (y * param.batch + b); + k_scales[k_tok_idx * 2 + 0] = k_scale; + k_scales[k_tok_idx * 2 + 1] = k_bias; +#endif +#ifdef KV_QUANT_V + int v_tok_idx = b * param.max_kv_len + (param.dst_k_offset / param.head_count + y); + v_scales[v_tok_idx * 2 + 0] = v_scale; + v_scales[v_tok_idx * 2 + 1] = v_bias; +#endif + } + } +#endif // DYNAMIC_QUANT +#endif // KV_QUANT_K || KV_QUANT_V + + int vector_end = (param.head_count / 4) * 4; + for (int x = int(titg) * 4; x < vector_end; x += int(tptg) * 4) { + const int in_idx = (b * param.kv_seq_len + y) * param.head_count + x; + + // Write K + int out_idx_k = param.dst_k_offset + (y * param.batch + b) * param.head_count + x; +#ifdef KV_QUANT_K + float4 k = float4(((const device ftype4*)(input0 + in_idx))[0]); + if (k_scale == 0.0f) { + ((device char4*)(output0 + out_idx_k))[0] = char4(0); + } else { + int4 qi = int4(rint((k - k_bias) / k_scale)); + qi = clamp(qi, int4(-128), int4(127)); + ((device char4*)(output0 + out_idx_k))[0] = char4(qi); + } +#else + ((device ftype4*)(output0 + out_idx_k))[0] = ((const device ftype4*)(input0 + in_idx))[0]; +#endif + + // Write V + int out_idx_v = param.dst_v_offset + (b * param.head_count + x) * param.max_kv_len + y; +#ifdef KV_QUANT_V + float4 v = float4(((const device ftype4*)(input1 + in_idx))[0]); + if (v_scale == 0.0f) { + output1[out_idx_v] = (char)0; + output1[out_idx_v + param.max_kv_len] = (char)0; + output1[out_idx_v + param.max_kv_len * 2] = (char)0; + output1[out_idx_v + param.max_kv_len * 3] = (char)0; + } else { + int4 qi = int4(rint((v - v_bias) / v_scale)); + qi = clamp(qi, int4(-128), int4(127)); + output1[out_idx_v] = (char)qi.x; + output1[out_idx_v + param.max_kv_len] = (char)qi.y; + output1[out_idx_v + param.max_kv_len * 2] = (char)qi.z; + output1[out_idx_v + param.max_kv_len * 3] = (char)qi.w; + } +#else + output1[out_idx_v] = input1[in_idx]; + output1[out_idx_v + param.max_kv_len] = input1[in_idx + 1]; + output1[out_idx_v + param.max_kv_len * 2] = input1[in_idx + 2]; + output1[out_idx_v + param.max_kv_len * 3] = input1[in_idx + 3]; +#endif + } + for (int x = vector_end + int(titg); x < param.head_count; x += int(tptg)) { + const int in_idx = (b * param.kv_seq_len + y) * param.head_count + x; + + int out_idx_k = param.dst_k_offset + (y * param.batch + b) * param.head_count + x; +#ifdef KV_QUANT_K + float k = (float)input0[in_idx]; + if (k_scale == 0.0f) { + output0[out_idx_k] = (char)0; + } else { + float q = (k - k_bias) / k_scale; + int qi = (int)rint(q); + qi = clamp(qi, -128, 127); + output0[out_idx_k] = (char)qi; + } +#else + output0[out_idx_k] = input0[in_idx]; +#endif + + int out_idx_v = param.dst_v_offset + (b * param.head_count + x) * param.max_kv_len + y; +#ifdef KV_QUANT_V + float v = (float)input1[in_idx]; + if (v_scale == 0.0f) { + output1[out_idx_v] = (char)0; + } else { + float q = (v - v_bias) / v_scale; + int qi = (int)rint(q); + qi = clamp(qi, -128, 127); + output1[out_idx_v] = (char)qi; + } +#else + output1[out_idx_v] = input1[in_idx]; +#endif + } +#else const int x = gid.x; // head_num / group * head_dim const int y = gid.y; // kv_seq_len const int b = gid.z; // batch @@ -698,12 +1201,41 @@ kernel void copy(const device ftype* input0 [[buffer(0)]], return; } const int in_idx = (b * param.kv_seq_len + y) * param.head_count + x; - int out_idx = param.dst_k_offset + (y * param.batch + b) * param.head_count + x; - output0[out_idx] = input0[in_idx]; - out_idx = param.dst_v_offset + (b * param.head_count + x) * param.max_kv_len + y; - output1[out_idx] = input1[in_idx]; + int out_idx_k = param.dst_k_offset + (y * param.batch + b) * param.head_count + x; +#ifdef KV_QUANT_K + float k = (float)input0[in_idx]; + if (param.k_scale == 0.0f) { + output0[out_idx_k] = (char)0; + } else { + float q = k / param.k_scale; + int qi = (int)rint(q); + qi = clamp(qi, -128, 127); + output0[out_idx_k] = (char)qi; + } +#else + output0[out_idx_k] = input0[in_idx]; +#endif + + int out_idx_v = param.dst_v_offset + (b * param.head_count + x) * param.max_kv_len + y; +#ifdef KV_QUANT_V + float v = (float)input1[in_idx]; + if (param.v_scale == 0.0f) { + output1[out_idx_v] = (char)0; + } else { + float q = v / param.v_scale; + int qi = (int)rint(q); + qi = clamp(qi, -128, 127); + output1[out_idx_v] = (char)qi; + } +#else + output1[out_idx_v] = input1[in_idx]; +#endif +#endif } + +#undef KOUT_TYPE +#undef VOUT_TYPE )metal"; const char* gMatMulQKV = R"metal( @@ -725,12 +1257,31 @@ struct Param { int max_kv_len; int batch; int kv_align_len; + int mask_batch; + int mask_head_num; + int mask_q_len; + int mask_k_len; + float v_scale; + float k_scale; }; #if MNN_METAL_FLOAT16_STORAGE typedef simdgroup_half8x8 simdgroup_T8x8; #else typedef simdgroup_float8x8 simdgroup_T8x8; #endif +#ifdef QUANT_V +#ifdef DYNAMIC_QUANT_V +#define GETV(v, tok_idx) ftype((float(v) * v_scales[(tok_idx) * 2] + v_scales[(tok_idx) * 2 + 1])) +#define GETV4(v, tok_idx) (float4(v) * float4(v_scales[(tok_idx) * 2], v_scales[((tok_idx) + 1) * 2], v_scales[((tok_idx) + 2) * 2], v_scales[((tok_idx) + 3) * 2]) + \ + float4(v_scales[(tok_idx) * 2 + 1], v_scales[((tok_idx) + 1) * 2 + 1], v_scales[((tok_idx) + 2) * 2 + 1], v_scales[((tok_idx) + 3) * 2 + 1])) +#else +#define GETV(v, tok_idx) ftype((float(v) * param.v_scale)) +#define GETV4(v, tok_idx) (float4(v) * param.v_scale) +#endif +#else +#define GETV(v, tok_idx) v +#define GETV4(v, tok_idx) v +#endif #ifdef USE_METAL_TENSOR_OPS kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], @@ -738,6 +1289,8 @@ kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], device ftype4* past_value [[buffer(2)]], constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], uint3 gid[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], @@ -753,7 +1306,7 @@ kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], threadgroup ftype sdata[2048] = {0.f}; - const int K = 32, M = 32, N = 32; + const int K = 32, M = 32, N = 32; const int tb_offset = M * K; auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + tb_offset, dextents(K, N));//[N, K] @@ -764,7 +1317,7 @@ kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], auto cT = mmOps.get_destination_cooperative_tensor(); - // QK:[32, 4] + // QK:[32, 4] int ml = tiitg / 4;// 0~31 int kl = tiitg % 4;// 0~3 @@ -820,21 +1373,20 @@ kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], int idx_qk_sl = sl * 32 + ml < q_seq_piece_len ? (sl * 32 + ml) : q_seq_piece_len - 1; auto A_offset = input0 + (z * q_seq_piece_len + idx_qk_sl) * align_value_len + (0 * 4 + kl) * 8 + 0; +#ifdef QUANT_V + auto B_offset = (const device char4*)past_value + (zin * head_dim + hm * 32 + nl) * param.max_kv_len / 4 + (0 * 4 + kvl) * 2 + 0; +#else auto B_offset = past_value + (zin * head_dim + hm * 32 + nl) * param.max_kv_len / 4 + (0 * 4 + kvl) * 2 + 0; - +#endif + for(int i = 0; i < (value_seq_len+3)/4; i += 8){ - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 0] = A_offset[4*i + 0]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 1] = A_offset[4*i + 1]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 2] = A_offset[4*i + 2]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 3] = A_offset[4*i + 3]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 4] = A_offset[4*i + 4]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 5] = A_offset[4*i + 5]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 6] = A_offset[4*i + 6]; - ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 7] = A_offset[4*i + 7]; - - ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 0] = B_offset[i + 0]; - ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 1] = B_offset[i + 1]; + // 向量化写入 P(两次 ftype4,覆盖 8 个标量) + *((threadgroup ftype4*)(&((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 0])) = *((const device ftype4*)(&A_offset[4*i + 0])); + *((threadgroup ftype4*)(&((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 4])) = *((const device ftype4*)(&A_offset[4*i + 4])); + + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 0] = (ftype4)GETV4(B_offset[i + 0], b * param.max_kv_len + i * 4 + 0); + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 1] = (ftype4)GETV4(B_offset[i + 1], b * param.max_kv_len + i * 4 + 4); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -849,7 +1401,7 @@ kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], auto tC = tensor, tensor_inline>((threadgroup float*)sdata, dextents(N, M)); // [M , N] cT.store(tC); threadgroup_barrier(mem_flags::mem_threadgroup); - + // [M32, N4, N2, n4] auto sindex_base = (mcl * 4 + ncl) * 2 + 0; @@ -874,6 +1426,8 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], device ftype* past_value [[buffer(2)]], constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], #ifdef SIMD_GROUP_MATRIX uint3 gid[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -896,7 +1450,7 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], #ifdef USE_METAL_TENSOR_OPS - const int K = 8, M = 16, N = 16; + const int K = 8, M = 16, N = 16; auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + 128, dextents(N, K));//[K, N] @@ -960,19 +1514,19 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], int idx_qk_sl = sl * 16 + rcl < q_seq_piece_len ? (sl * 16 + rcl) : q_seq_piece_len - 1; auto A_offset = input0 + (z * q_seq_piece_len + idx_qk_sl) * align_value_len + (0 * 2 + kl) * 4 + 0; +#ifdef QUANT_V + auto B_offset = (const device char*)past_value + (zin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (0 * 8 + kcl); +#else auto B_offset = past_value + (zin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (0 * 8 + kcl); - - - for(int i = 0; i < value_seq_len; i += 8){ - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 0] = B_offset[i + 0 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 1] = B_offset[i + 1 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 2] = B_offset[i + 2 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 3] = B_offset[i + 3 * param.max_kv_len]; +#endif + + for(int i = 0; i < align_value_len; i += 8){ + *((threadgroup ftype4*)(&((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0])) = *((const device ftype4*)(&A_offset[i + 0])); + + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 0] = GETV(B_offset[i + 0 * param.max_kv_len], b * param.max_kv_len + i); + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 1] = GETV(B_offset[i + 1 * param.max_kv_len], b * param.max_kv_len + i); + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 2] = GETV(B_offset[i + 2 * param.max_kv_len], b * param.max_kv_len + i); + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 3] = GETV(B_offset[i + 3 * param.max_kv_len], b * param.max_kv_len + i); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -984,10 +1538,10 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], #else simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); - + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); - + simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); @@ -1018,6 +1572,36 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], #endif // [N2, M2, M8, N8] +#ifdef ATTENTION_C4 + // [mNumHead * (mHeadDim / 4), mBatch * mSeqLen, 4] + auto xy_out = output + (b * q_seq_len + seq_idx * q_seq_piece_len + sl * 16 + rcl) * 4 + (hn * head_dim / 4 + hm * 4 + kl * 2) * 4 * param.batch * q_seq_len + 0; + if(sl * 16 + rcl < q_seq_piece_len && seq_idx * q_seq_piece_len + sl * 16 + rcl < q_seq_len) { + if(hm * 16 + kl * 8 + 0 < head_dim) { + xy_out[0] = ((threadgroup float*)sdata)[sindex_base + 0]; + } + if(hm * 16 + kl * 8 + 1 < head_dim) { + xy_out[1] = ((threadgroup float*)sdata)[sindex_base + 1]; + } + if(hm * 16 + kl * 8 + 2 < head_dim) { + xy_out[2] = ((threadgroup float*)sdata)[sindex_base + 2]; + } + if(hm * 16 + kl * 8 + 3 < head_dim) { + xy_out[3] = ((threadgroup float*)sdata)[sindex_base + 3]; + } + if(hm * 16 + kl * 8 + 4 < head_dim) { + xy_out[q_seq_len * 4 + 0] = ((threadgroup float*)sdata)[sindex_base + 4]; + } + if(hm * 16 + kl * 8 + 5 < head_dim) { + xy_out[q_seq_len * 4 + 1] = ((threadgroup float*)sdata)[sindex_base + 5]; + } + if(hm * 16 + kl * 8 + 6 < head_dim) { + xy_out[q_seq_len * 4 + 2] = ((threadgroup float*)sdata)[sindex_base + 6]; + } + if(hm * 16 + kl * 8 + 7 < head_dim) { + xy_out[q_seq_len * 4 + 3] = ((threadgroup float*)sdata)[sindex_base + 7]; + } + } +#else // [mBatch, mSeqLen, mNumHead, mHeadDim] auto xy_out = output + ((b * q_seq_len + seq_idx * q_seq_piece_len + sl * 16 + rcl) * head_num + hn) * head_dim + hm * 16 + kl * 8 + 0; if(sl * 16 + rcl < q_seq_piece_len && seq_idx * q_seq_piece_len + sl * 16 + rcl < q_seq_len) { @@ -1046,6 +1630,7 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], xy_out[7] = ((threadgroup float*)sdata)[sindex_base + 7]; } } +#endif #else const int x = gid.x; // q_seq_len @@ -1073,17 +1658,36 @@ kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], // [mBatch, mNumHead, mSeqLen, mKvSeqLen] device const ftype *A_offset = input0 + (y * q_seq_piece_len + x) * align_value_len; +#ifdef QUANT_V + const device char *B_offset = ((const device char*)past_value) + offset_head * param.max_kv_len; +#else device const ftype *B_offset = past_value + offset_head * param.max_kv_len; - float out = 0.0; - - for(int i = 0; i < value_seq_len; ++i){ - float A0 = (float)A_offset[i]; - float B = (float)B_offset[i]; - out += A0 * B; - } +#endif + float4 out4 = 0.0; + + for(int i = 0; i < align_value_len; i += 4){ + float4 A = float4(((const device ftype4*)(A_offset + i))[0]); +#ifdef QUANT_V + float4 B = GETV4(((const device char4*)(B_offset + i))[0], b * param.max_kv_len + i); +#else + float4 B = float4(((const device ftype4*)(B_offset + i))[0]); +#endif + out4 += A * B; + } + float out = out4.x + out4.y + out4.z + out4.w; +#ifdef ATTENTION_C4 + // [mNumHead * (mHeadDim / 4), mBatch * mSeqLen, 4] + { + int c = hn * head_dim + z; + int co = c / 4; + int ci = c % 4; + output[(b * q_seq_len + x) * 4 + ci + co * param.batch * q_seq_len * 4] = (ftype)out; + } +#else // [mBatch, mSeqLen, mNumHead, mHeadDim] output[(b * q_seq_len + q_idx) * stride * group + (hn * head_dim + z)] = out; #endif +#endif } kernel void decode_qkv(const device ftype* input0 [[buffer(0)]], @@ -1092,6 +1696,8 @@ kernel void decode_qkv(const device ftype* input0 [[buffer(0)]], // docode actually not compute in block constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], #ifdef SIMD_GROUP_REDUCE uint3 gid[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1120,186 +1726,145 @@ kernel void decode_qkv(const device ftype* input0 [[buffer(0)]], const int offset_head = (yin * head_dim + z) * param.max_kv_len; device const ftype *A_offset = input0 + (y * q_seq_len + x) * align_value_len; +#ifdef QUANT_V + const device char *Pastvalue_offset8 = ((const device char*)past_value) + offset_head; +#else device ftype *Pastvalue_offset = past_value + offset_head; +#endif float out = 0; - + #ifdef SIMD_GROUP_REDUCE - for(int i = tiisg; i < value_seq_len; i+=SIMD_GROUP_WIDTH){ - float A = (float)A_offset[i]; - float B = (float)Pastvalue_offset[i]; - - out += A * B; + float4 out4 = 0; + for(int i = tiisg * 4; i < align_value_len; i+=SIMD_GROUP_WIDTH * 4){ + float4 A = float4(((const device ftype4*)(A_offset + i))[0]); +#ifdef QUANT_V + float4 B = GETV4(((const device char4*)(Pastvalue_offset8 + i))[0], b * param.max_kv_len + i); +#else + float4 B = float4(((const device ftype4*)(Pastvalue_offset + i))[0]); +#endif + out4 += A * B; } + out = out4.x + out4.y + out4.z + out4.w; out = simd_sum(out); if(tiisg == 0) { +#ifdef ATTENTION_C4 + // [mNumHead * (mHeadDim / 4), mBatch * mSeqLen, 4] + { + int c = hn * head_dim + z; + int co = c / 4; + int ci = c % 4; + output[(b * q_seq_len + x) * 4 + ci + co * param.batch * q_seq_len * 4] = (ftype)out; + } +#else // [mBatch, mSeqLen, mNumHead, mHeadDim] output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (ftype)out; +#endif } #else - for(int i = 0; i < value_seq_len; i++){ - float A = (float)A_offset[i]; - float B = (float)Pastvalue_offset[i]; - - out += A * B; + float4 out4 = 0; + for(int i = 0; i < align_value_len; i += 4){ + float4 A = float4(((const device ftype4*)(A_offset + i))[0]); +#ifdef QUANT_V + float4 B = GETV4(((const device char4*)(Pastvalue_offset8 + i))[0], b * param.max_kv_len + i); +#else + float4 B = float4(((const device ftype4*)(Pastvalue_offset + i))[0]); +#endif + out4 += A * B; } + out = out4.x + out4.y + out4.z + out4.w; +#ifdef ATTENTION_C4 + // [mNumHead * (mHeadDim / 4), mBatch * mSeqLen, 4] + { + int c = hn * head_dim + z; + int co = c / 4; + int ci = c % 4; + output[(b * q_seq_len + x) * 4 + ci + co * param.batch * q_seq_len * 4] = (ftype)out; + } +#else output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (ftype)out; #endif +#endif } -)metal"; - -const char* gSoftmaxSgReduce = R"metal( -#include -using namespace metal; -struct softmax_shape { - int inside_size; - int axis_length; - int outside_size; - int axis_align_length; -}; -#define SIMD_GROUP_WIDTH 32 - -kernel void softmax_plane_sg(const device ftype *in [[buffer(0)]], - device ftype *out [[buffer(1)]], - constant softmax_shape& s [[buffer(2)]], - uint2 gid[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] - ) { - // threadgroup contain one simdgroup - // simdgroup compute axis data - if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; - - auto in_offset = gid.y * s.axis_length * s.inside_size + gid.x; - auto out_offset = gid.y * s.axis_align_length * s.inside_size + gid.x; - auto axis_in = in + in_offset; - auto axis_out = out + out_offset; - - // get max - float max1 = -FLT_MAX; - for (int i = tiisg; i < s.axis_length; i+=SIMD_GROUP_WIDTH) { - max1 = max(max1, float(axis_in[i * s.inside_size])); - } - max1 = simd_max(max1); - // get sum - float sum1 = 0; - for (int i = tiisg; i < s.axis_length; i+=SIMD_GROUP_WIDTH) { - sum1 += exp(float(axis_in[i * s.inside_size]) - float(max1)); - } - sum1 = simd_sum(sum1); - - // output - for (int i = tiisg; i < s.axis_align_length; i+=SIMD_GROUP_WIDTH) { - axis_out[i * s.inside_size] = i >= s.axis_length ? ftype(0.0) : ftype(exp(float(axis_in[i * s.inside_size]) - float(max1)) / sum1); +kernel void decode_qkv_c2(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_value [[buffer(2)]], + constant int &seq_idx [[buffer(3)]], + constant Param& param [[buffer(4)]], + device ftype* k_scales [[buffer(8)]], + device ftype* v_scales [[buffer(9)]], + uint3 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]] +) { + const int x = gid.x; + const int y = gid.y; + const int z = gid.z * 2; + if (x >= param.query_seq_len || y >= param.head_num * param.batch || z >= param.head_dim) { + return; } -} - - -)metal"; - -const char* gFlashSoftmax = R"metal( -#include -using namespace metal; + int head_dim = param.head_dim; + int head_num = param.head_num; + int q_seq_len = param.query_seq_len; + int group = param.group; + int b = y / head_num; + int hn = y % head_num; -struct Param { - int query_seq_len; - int q_seq_piece_len; - int key_seq_len; - int head_num; - int group; - int head_dim; - float scale; - int max_kv_len; - int batch; - int kv_align_len; -}; + int yin = b * (head_num / group) + hn / group; + int value_seq_len = param.key_seq_len; + int align_value_len = ((value_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; -kernel void flash_softmax( - const device ftype* input [[buffer(0)]], - device ftype* output [[buffer(1)]], - device float* runningStats [[buffer(2)]], - device float* correctionScale [[buffer(3)]], - constant int& block_len [[buffer(4)]], - constant Param& param [[buffer(5)]], - constant int& kv_start [[buffer(6)]], -#ifdef SIMD_GROUP_REDUCE - uint2 gid [[threadgroup_position_in_grid]], - uint tiisg [[thread_index_in_simdgroup]] + device const ftype *A_offset = input0 + (y * q_seq_len + x) * align_value_len; +#ifdef QUANT_V + const device char *B0 = ((const device char*)past_value) + (yin * head_dim + z + 0) * param.max_kv_len; + const device char *B1 = ((const device char*)past_value) + (yin * head_dim + z + 1) * param.max_kv_len; #else - uint3 gid [[thread_position_in_grid]] + device const ftype *B0 = past_value + (yin * head_dim + z + 0) * param.max_kv_len; + device const ftype *B1 = past_value + (yin * head_dim + z + 1) * param.max_kv_len; #endif -) { -#ifdef SIMD_GROUP_REDUCE - int s = gid.x; - int bh = gid.y; + + float4 out0 = 0; + float4 out1 = 0; + for(int i = tiisg * 4; i < align_value_len; i += SIMD_GROUP_WIDTH * 4){ + float4 A = float4(((const device ftype4*)(A_offset + i))[0]); +#ifdef QUANT_V +#ifdef DYNAMIC_QUANT_V + int tok_idx = b * param.max_kv_len + i; + float4 scale4 = float4(v_scales[tok_idx * 2], v_scales[(tok_idx + 1) * 2], + v_scales[(tok_idx + 2) * 2], v_scales[(tok_idx + 3) * 2]); + float4 bias4 = float4(v_scales[tok_idx * 2 + 1], v_scales[(tok_idx + 1) * 2 + 1], + v_scales[(tok_idx + 2) * 2 + 1], v_scales[(tok_idx + 3) * 2 + 1]); + out0 += A * (float4(((const device char4*)(B0 + i))[0]) * scale4 + bias4); + out1 += A * (float4(((const device char4*)(B1 + i))[0]) * scale4 + bias4); #else - int s = gid.x; - int bh = gid.y; + out0 += A * GETV4(((const device char4*)(B0 + i))[0], b * param.max_kv_len + i); + out1 += A * GETV4(((const device char4*)(B1 + i))[0], b * param.max_kv_len + i); #endif - - if (s >= param.query_seq_len || bh >= param.batch * param.head_num) { - return; - } - - int seq_len = param.query_seq_len; - int stat_idx = (bh * seq_len + s) * 2; - int block_offset = (bh * seq_len + s) * block_len; - - float prev_max = (float)runningStats[stat_idx]; - float prev_sum = (float)runningStats[stat_idx + 1]; - - float safe_min = -10000.0; - if (kv_start == 0) { - prev_max = safe_min; - prev_sum = 0; - } - - float block_max = safe_min; -#ifdef SIMD_GROUP_REDUCE - for (int i = tiisg; i < block_len; i += 32) { - block_max = max(block_max, float(input[block_offset + i])); - } - block_max = simd_max(block_max); #else - for (int i = 0; i < block_len; ++i) { - block_max = max(block_max, float(input[block_offset + i])); - } + out0 += A * float4(((const device ftype4*)(B0 + i))[0]); + out1 += A * float4(((const device ftype4*)(B1 + i))[0]); #endif - - float new_max = max(prev_max, block_max); - float scale = exp(prev_max - new_max); - - float block_sum = 0; -#ifdef SIMD_GROUP_REDUCE - for (int i = tiisg; i < block_len; i += 32) { - float val = exp(float(input[block_offset + i]) - new_max); - output[block_offset + i] = (ftype)val; - block_sum += val; } - block_sum = simd_sum(block_sum); -#else - for (int i = 0; i < block_len; ++i) { - float val = exp(float(input[block_offset + i]) - new_max); - output[block_offset + i] = (ftype)val; - block_sum += val; - } -#endif - - float new_sum = prev_sum * scale + block_sum; - -#ifdef SIMD_GROUP_REDUCE - if (tiisg == 0) { -#endif - runningStats[stat_idx] = (float)new_max; - runningStats[stat_idx + 1] = (float)new_sum; - correctionScale[bh * seq_len + s] = (float)scale; -#ifdef SIMD_GROUP_REDUCE + float r0 = out0.x + out0.y + out0.z + out0.w; + float r1 = out1.x + out1.y + out1.z + out1.w; + r0 = simd_sum(r0); + r1 = simd_sum(r1); + if(tiisg == 0) { + int c0 = hn * head_dim + z; + int co0 = c0 / 4; + int ci0 = c0 % 4; + output[(b * q_seq_len + x) * 4 + ci0 + co0 * param.batch * q_seq_len * 4] = (ftype)r0; + if (z + 1 < head_dim) { + int c1 = c0 + 1; + int co1 = c1 / 4; + int ci1 = c1 % 4; + output[(b * q_seq_len + x) * 4 + ci1 + co1 * param.batch * q_seq_len * 4] = (ftype)r1; + } } -#endif } + )metal"; -const char* gFlashMatMulQKV = R"metal( +const char* gDecodeQkSoftmax = R"metal( #include #include using namespace metal; @@ -1314,812 +1879,179 @@ struct Param { int max_kv_len; int batch; int kv_align_len; + int mask_batch; + int mask_head_num; + int mask_q_len; + int mask_k_len; + float v_scale; + float k_scale; }; +#define SIMD_GROUP_WIDTH 32 -#if MNN_METAL_FLOAT16_STORAGE -typedef simdgroup_half8x8 simdgroup_T8x8; -#else -typedef simdgroup_float8x8 simdgroup_T8x8; +kernel void decode_qk_softmax(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_key [[buffer(2)]], + constant int &seq_idx [[buffer(3)]], + constant Param& param [[buffer(4)]], +#if defined(QUANT_K) && defined(DYNAMIC_QUANT_K) + device ftype* k_scales [[buffer(8)]], #endif - - -kernel void flash_matmul_qkv( - const device ftype* P_block [[buffer(0)]], - device float* Output [[buffer(1)]], - const device ftype* V_block [[buffer(2)]], - const device float* correctionScale [[buffer(3)]], - constant int& kv_start [[buffer(4)]], - constant int& block_len [[buffer(5)]], - constant Param& param [[buffer(6)]], -#if defined(SIMD_GROUP_MATRIX) uint3 gid[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], + uint tid[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] -#elif defined(SIMD_GROUP_REDUCE) - uint3 gid[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]] -#else - uint3 gid [[thread_position_in_grid]] -#endif + uint sgitg[[simdgroup_index_in_threadgroup]], + uint3 tptg_3d[[threads_per_threadgroup]] ) { -#if defined(SIMD_GROUP_MATRIX) - threadgroup float sdata[256 + 128] = {0.f}; // 128 for A, 128 for B, 128 for C scaling - simdgroup_float8x8 sgd[4]; - for (int i = 0; i < 4; i++){ - sgd[i] = make_filled_simdgroup_matrix(0.f); - } - - const int sl = gid.x; // s / 16 - const int hm = gid.y; // d / 16 - const int bh = gid.z; - - int b = bh / param.head_num; - int h = bh % param.head_num; - int kv_h = h / param.group; - int yin = b * (param.head_num / param.group) + kv_h; - - int rcl = tiitg / 2; // 0~15 - int kl = tiitg % 2; // 0~1 - int nl = tiitg / 8; // 0~3 - int kcl = tiitg % 8; // 0~7 - - int head_dim = param.head_dim; - int q_seq_len = param.query_seq_len; - - if (sl * 16 >= q_seq_len || hm * 16 >= head_dim) return; - - // 0. Load old Output and scale - if (kv_start > 0) { - for (int i = 0; i < 8; ++i) { - int idx = tiitg * 8 + i; - int local_s = idx / 16; - int local_d = idx % 16; - // Map threads to 16x16 block of Output - int cur_s = sl * 16 + local_s; - int cur_d = hm * 16 + local_d; - if (cur_s < q_seq_len && cur_d < head_dim) { - float scale = correctionScale[bh * q_seq_len + cur_s]; - int out_idx = ((b * q_seq_len + cur_s) * param.head_num + h) * head_dim + cur_d; - float val = Output[out_idx] * scale; - // Store to sdata for loading into sgd - ((threadgroup float*)sdata)[local_s * 16 + local_d] = val; - } else { - ((threadgroup float*)sdata)[local_s * 16 + local_d] = 0.f; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - simdgroup_load(sgd[0], (threadgroup float*)sdata, 16); - simdgroup_load(sgd[1], (threadgroup float*)sdata + 128, 16); - simdgroup_load(sgd[2], (threadgroup float*)sdata + 8, 16); - simdgroup_load(sgd[3], (threadgroup float*)sdata + 136, 16); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - auto A_offset = P_block + (bh * q_seq_len + sl * 16 + rcl) * block_len + (0 * 2 + kl) * 4 + 0; - auto B_offset = V_block + (yin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (kv_start + (0 * 8 + kcl)); - - for (int i = 0; i < block_len; i += 8) { - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; - ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 0] = B_offset[i + 0 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 1] = B_offset[i + 1 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 2] = B_offset[i + 2 * param.max_kv_len]; - ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 3] = B_offset[i + 3 * param.max_kv_len]; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - simdgroup_T8x8 sga[2], sgb[2]; - simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); - simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); - - simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); - simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); - simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); - simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - simdgroup_store(sgd[0], (threadgroup float*)sdata, 16); - simdgroup_store(sgd[1], (threadgroup float*)sdata + 128, 16); - simdgroup_store(sgd[2], (threadgroup float*)sdata + 8, 16); - simdgroup_store(sgd[3], (threadgroup float*)sdata + 136, 16); - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int i = 0; i < 8; ++i) { - int idx = tiitg * 8 + i; - int local_s = idx / 16; - int local_d = idx % 16; - int cur_s = sl * 16 + local_s; - int cur_d = hm * 16 + local_d; - if (cur_s < q_seq_len && cur_d < head_dim) { - int out_idx = ((b * q_seq_len + cur_s) * param.head_num + h) * head_dim + cur_d; - Output[out_idx] = ((threadgroup float*)sdata)[local_s * 16 + local_d]; - } - } - -#elif defined(SIMD_GROUP_REDUCE) - int d_vec = gid.x; - int s = gid.y; - int bh = gid.z; - - int head_dim = param.head_dim; - if (d_vec * 4 >= head_dim || s >= param.query_seq_len || bh >= param.batch * param.head_num) return; - - int b = bh / param.head_num; - int h = bh % param.head_num; - int kv_h = h / param.group; - int yin = b * (param.head_num / param.group) + kv_h; - int v_base_offset = yin * head_dim * param.max_kv_len; - - int p_offset = (bh * param.query_seq_len + s) * block_len; - int out_idx = ((b * param.query_seq_len + s) * param.head_num + h) * head_dim + d_vec * 4; - - float4 acc = 0; - if (kv_start > 0 && tiisg == 0) { - acc = float4(Output[out_idx], Output[out_idx+1], Output[out_idx+2], Output[out_idx+3]); - acc *= correctionScale[bh * param.query_seq_len + s]; - } - - float4 p_v_acc = 0; - for (int k = tiisg; k < block_len; k += 32) { - ftype p_val = P_block[p_offset + k]; - int seq_idx = kv_start + k; - float4 v; - v.x = (float)V_block[v_base_offset + (d_vec * 4 + 0) * param.max_kv_len + seq_idx]; - v.y = (float)V_block[v_base_offset + (d_vec * 4 + 1) * param.max_kv_len + seq_idx]; - v.z = (float)V_block[v_base_offset + (d_vec * 4 + 2) * param.max_kv_len + seq_idx]; - v.w = (float)V_block[v_base_offset + (d_vec * 4 + 3) * param.max_kv_len + seq_idx]; - p_v_acc += (float)p_val * v; - } - p_v_acc.x = simd_sum(p_v_acc.x); - p_v_acc.y = simd_sum(p_v_acc.y); - p_v_acc.z = simd_sum(p_v_acc.z); - p_v_acc.w = simd_sum(p_v_acc.w); - - if (tiisg == 0) { - acc += p_v_acc; - Output[out_idx] = acc.x; - Output[out_idx+1] = acc.y; - Output[out_idx+2] = acc.z; - Output[out_idx+3] = acc.w; - } + threadgroup float scores0[2048]; + threadgroup float scores1[2048]; + threadgroup float reduce0[32]; + threadgroup float reduce1[32]; + + const int tptg = int(tptg_3d.x * tptg_3d.y * tptg_3d.z); + const int sg_count = tptg / SIMD_GROUP_WIDTH; + const int kv_head_num = param.head_num / GROUP_SIZE; + const int b = int(gid.x) / kv_head_num; + const int kv_hn = int(gid.x) - b * kv_head_num; +#ifdef HEAD_DIM + const int head_dim = HEAD_DIM; #else - int d_vec = gid.x; - int s = gid.y; - int bh = gid.z; - - int head_dim = param.head_dim; - if (d_vec * 4 >= head_dim || s >= param.query_seq_len || bh >= param.batch * param.head_num) { - return; - } - - int b = bh / param.head_num; - int h = bh % param.head_num; - int kv_h = h / param.group; - - int p_offset = (bh * param.query_seq_len + s) * block_len; - // V layout: [batch, kv_num_head * head_dim, max_kv_len] - // Same as decode_qkv: offset_head = (yin * head_dim + z) * max_kv_len, where yin = b * kv_num_head + kv_h - // So for (batch, kv_head, head_dim_idx), base offset = (b * kv_num_head + kv_h) * head_dim * max_kv_len - int yin = b * (param.head_num / param.group) + kv_h; - int v_base_offset = yin * head_dim * param.max_kv_len; - - int out_idx = ((b * param.query_seq_len + s) * param.head_num + h) * head_dim + d_vec * 4; - float scale = (float)correctionScale[bh * param.query_seq_len + s]; - - float4 acc = 0; - if (kv_start > 0) { - acc = float4(Output[out_idx], Output[out_idx+1], Output[out_idx+2], Output[out_idx+3]); - acc *= (float)scale; - } - - for (int k = 0; k < block_len; ++k) { - ftype p_val = P_block[p_offset + k]; - int seq_idx = kv_start + k; - - // V layout: [batch, kv_num_head * head_dim, max_kv_len] - // For (batch, kv_head, head_dim_idx, seq_idx): offset = ((b * kv_num_head + kv_h) * head_dim + head_dim_idx) * max_kv_len + seq_idx - float v0 = 0, v1 = 0, v2 = 0, v3 = 0; - int d0 = d_vec * 4 + 0; - int d1 = d_vec * 4 + 1; - int d2 = d_vec * 4 + 2; - int d3 = d_vec * 4 + 3; - - if (d0 < head_dim) { - int v_idx0 = v_base_offset + d0 * param.max_kv_len + seq_idx; - v0 = (float)V_block[v_idx0]; - } - if (d1 < head_dim) { - int v_idx1 = v_base_offset + d1 * param.max_kv_len + seq_idx; - v1 = (float)V_block[v_idx1]; - } - if (d2 < head_dim) { - int v_idx2 = v_base_offset + d2 * param.max_kv_len + seq_idx; - v2 = (float)V_block[v_idx2]; - } - if (d3 < head_dim) { - int v_idx3 = v_base_offset + d3 * param.max_kv_len + seq_idx; - v3 = (float)V_block[v_idx3]; - } - - acc += (float)p_val * float4(v0, v1, v2, v3); - } - - Output[out_idx] = (float)acc.x; - Output[out_idx+1] = (float)acc.y; - Output[out_idx+2] = (float)acc.z; - Output[out_idx+3] = (float)acc.w; + const int head_dim = param.head_dim; #endif -} -)metal"; - -const char* gFlashScale = R"metal( -#include -using namespace metal; + const int key_seq_len = param.key_seq_len; + const int align_key_len = ((key_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; + const int x = int(gid.y); + const int q_idx = seq_idx * param.q_seq_piece_len + x; -struct Param { - int query_seq_len; - int q_seq_piece_len; - int key_seq_len; - int head_num; - int group; - int head_dim; - float scale; - int max_kv_len; - int batch; - int kv_align_len; -}; - -kernel void flash_scale( - const device float* Input [[buffer(0)]], - device ftype* Output [[buffer(1)]], - const device float* runningStats [[buffer(2)]], - constant Param& param [[buffer(3)]], - uint3 gid [[thread_position_in_grid]] -) { - int d_vec = gid.x; - int s = gid.y; - int bh = gid.z; - - if (d_vec * 4 >= param.head_dim || s >= param.query_seq_len || bh >= param.batch * param.head_num) { + if (b >= param.batch || kv_hn >= kv_head_num || x >= param.q_seq_piece_len || q_idx >= param.query_seq_len) { return; } - - int stat_idx = (bh * param.query_seq_len + s) * 2; - float sum = (float)runningStats[stat_idx + 1]; - float inv_sum = 1.0 / sum; - - int b = bh / param.head_num; - int h = bh % param.head_num; - - int out_idx = ((b * param.query_seq_len + s) * param.head_num + h) * param.head_dim + d_vec * 4; - - Output[out_idx] = (ftype)(inv_sum * (float)Input[out_idx] ); - Output[out_idx+1] = (ftype)(inv_sum * (float)Input[out_idx+1]); - Output[out_idx+2] = (ftype)(inv_sum * (float)Input[out_idx+2]); - Output[out_idx+3] = (ftype)(inv_sum * (float)Input[out_idx+3]); -} -)metal"; - - -const char* gFlashAttentionFused = R"metal( -#include -#include -using namespace metal; - -struct Param { - int query_seq_len; - int q_seq_piece_len; - int key_seq_len; - int head_num; - int group; - int head_dim; - float scale; - int max_kv_len; - int batch; - int kv_align_len; -}; -#if MNN_METAL_FLOAT16_STORAGE -typedef simdgroup_half8x8 simdgroup_T8x8; + const int head0 = kv_hn * GROUP_SIZE; + const int head1 = head0 + 1; + const int query_offset = (b * param.query_seq_len + q_idx) * param.head_num * head_dim; + const device ftype* query0 = input0 + query_offset + head0 * head_dim; + const device ftype* query1 = input0 + query_offset + head1 * head_dim; + const int key_head_offset = kv_hn * head_dim; + const int key_stride = kv_head_num * head_dim; + + float local_max0 = -FLT_MAX; + float local_max1 = -FLT_MAX; + const int kv_valid_limit = max(key_seq_len - param.query_seq_len, 0) + q_idx; + for (int k = int(tid); k < key_seq_len; k += tptg) { +#ifdef QUANT_K + const device char* key = (const device char*)past_key + (k * param.batch + b) * key_stride + key_head_offset; #else -typedef simdgroup_float8x8 simdgroup_T8x8; + const device ftype* key = past_key + (k * param.batch + b) * key_stride + key_head_offset; +#endif + float s0 = 0.0f; + float s1 = 0.0f; + const device ftype4* q04 = (const device ftype4*)query0; + const device ftype4* q14 = (const device ftype4*)query1; +#ifdef QUANT_K + const device char4* k4 = (const device char4*)key; +#ifdef DYNAMIC_QUANT_K + const int k_token_idx = k * param.batch + b; + const float k_scale = float(k_scales[k_token_idx * 2]); + const float k_bias = float(k_scales[k_token_idx * 2 + 1]); #endif - -// 定义支持缓存的最大 Head Dim -#define MAX_HEAD_DIM 128 -// Padding 步长,避免 Shared Memory Bank Conflict -// 128 + 8 = 136,错开 Bank 索引 -#define Q_SMEM_STRIDE (MAX_HEAD_DIM + 8) - -// 优化配置 -#define Q_BLOCK 8 -#define K_BLOCK_16 16 -#define TG_SIZE 128 -#define SIMD_GROUPS 4 -#define K_BLOCK 64 - - - -#define HEAD_DIM 128 -typedef uint4 vec_128b; - - -// 调整为 5120 (20KB),M4 缓存充足 -#define SMEM_SIZE 4096 - -kernel void flash_attention_fused( - const device ftype* query [[buffer(0)]], - const device ftype* key [[buffer(1)]], - const device ftype* value [[buffer(2)]], - const device ftype* mask [[buffer(3)]], - device ftype* output [[buffer(4)]], - constant Param& param [[buffer(5)]], - uint ltid [[thread_index_in_threadgroup]], // 0..127 global inside group -#if defined(SIMD_GROUP_REDUCE) - uint3 gid [[thread_position_in_grid]], - uint3 tgid [[threadgroup_position_in_grid]], - uint tiisg [[thread_index_in_simdgroup]], - uint sgitg [[simdgroup_index_in_threadgroup]] #else - uint3 gid [[thread_position_in_grid]], - uint3 tgid [[threadgroup_position_in_grid]], - uint tiisg [[thread_index_in_simdgroup]], - uint sgitg [[simdgroup_index_in_threadgroup]] + const device ftype4* k4 = (const device ftype4*)key; #endif -) { -#ifdef SIMD_GROUP_MATRIX - - // 8个query threadgroup=128 K_BLOCK=16 - // Shared Memory 布局 - threadgroup ftype sdata[4096]; - threadgroup ftype* sdata_q = sdata; - threadgroup float* sdata_work = (threadgroup float*)(sdata + Q_BLOCK * Q_SMEM_STRIDE); - - threadgroup float* sdata_partials = sdata_work + 128; - threadgroup float* sdata_scale = sdata_work + 128; - threadgroup float* sdata_final_sum = sdata_work + 136; - threadgroup float* sdata_scratch = sdata_work + 512; - - int sl_blk = tgid.x; - int bh = tgid.y; - int head_dim = param.head_dim; - int q_seq_len = param.query_seq_len; - - if (sl_blk * Q_BLOCK >= q_seq_len) return; - - int b = bh / param.head_num; - int h = bh % param.head_num; - int kv_h = h / param.group; - int kv_len = param.key_seq_len; - int max_kv_len = param.max_kv_len; - - // 1. 协作加载 Query - { - int q_base_offset = ((b * q_seq_len + sl_blk * Q_BLOCK) * param.head_num + h) * head_dim; - const device ftype* q_ptr_base = query + q_base_offset; - int q_global_stride = param.head_num * head_dim; - - for (int i = ltid; i < Q_BLOCK * head_dim; i += TG_SIZE) { - int r = i / head_dim; - int c = i % head_dim; - - // [Fix] 检查 Query 是否越界,越界部分填充 0 避免计算错误 - int global_r = sl_blk * Q_BLOCK + r; - if (global_r < q_seq_len) { - sdata_q[r * Q_SMEM_STRIDE + c] = q_ptr_base[r * q_global_stride + c]; - } else { - sdata_q[r * Q_SMEM_STRIDE + c] = 0.0f; - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - - // 2. 状态初始化 - float row_max = -FLT_MAX; - float row_sum = 0.0f; - - simdgroup_float8x8 acc_reg[2][2]; - for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) acc_reg[i][j] = make_filled_simdgroup_matrix(0.f); - - int k_seq_stride = param.batch * (param.head_num / param.group) * head_dim; - int k_base = (b * (param.head_num / param.group) + kv_h) * head_dim; - int v_base = (b * (param.head_num / param.group) + kv_h) * head_dim * max_kv_len; - - - const int global_r = sl_blk * Q_BLOCK + Q_BLOCK - 1 + kv_len - param.query_seq_len; - - // 3. 主循环 K Block - // ========================================== - for (int t_blk = 0; t_blk < kv_len; t_blk += K_BLOCK_16) { - - bool skip_block = false; - - // Block mask all -inf - if(t_blk > global_r) { - skip_block = true; + for (int d = 0; d < head_dim / 8; ++d) { +#ifdef QUANT_K +#ifdef DYNAMIC_QUANT_K + float4 k0 = float4(k4[d * 2 + 0]) * k_scale + k_bias; + float4 k1 = float4(k4[d * 2 + 1]) * k_scale + k_bias; +#else + float4 k0 = float4(k4[d * 2 + 0]) * param.k_scale; + float4 k1 = float4(k4[d * 2 + 1]) * param.k_scale; +#endif +#else + float4 k0 = float4(k4[d * 2 + 0]); + float4 k1 = float4(k4[d * 2 + 1]); +#endif + s0 += dot(float4(q04[d * 2 + 0]), k0) + dot(float4(q04[d * 2 + 1]), k1); + s1 += dot(float4(q14[d * 2 + 0]), k0) + dot(float4(q14[d * 2 + 1]), k1); } - -#if 0 - // ============================================================ - // [Optimization] Fast & Safe Mask Check (Hybrid Version) - // ============================================================ - { - // 1. 内存安全:将 scratch 移至 +1024,避开 partials (0~640) - threadgroup float* sdata_flag = sdata_work + 1024; - - // 2. 每个线程计算局部最大值 - float my_max = -FLT_MAX; - - // 循环 stride 覆盖 (兼容 TG_SIZE < 128) - for (int i = ltid; i < 128; i += TG_SIZE) { - int r_local = i / 16; - int c_local = i % 16; - - int global_r = sl_blk * Q_BLOCK + r_local; - int global_c = t_blk + c_local; - - if (global_c < kv_len) { - #ifdef ADD_MASK - int mask_offset = global_c - kv_len + param.query_seq_len; - if (global_r < q_seq_len) { - // 如果 mask_offset 有效,读取 Mask - if (mask_offset >= 0 && mask_offset < param.query_seq_len) { - float m_val = (float)mask[global_r * param.query_seq_len + mask_offset]; - my_max = max(my_max, m_val); - } else { - // 越界 (通常是左侧非Mask区) = 有效 (0.0f) -> 不能跳过 - my_max = max(my_max, 0.0f); - } - } - #elif defined(SET_MASK) - if (global_r < q_seq_len) { - float m_val = (float)mask[global_r * kv_len + global_c]; - // SET_MASK: 0 表示 Mask, 非0 表示 Keep - if (m_val != 0.0f) my_max = max(my_max, 0.0f); - } - #else - my_max = 0.0f; // 无 Mask 宏,始终 Active - #endif - } - } - - // 3. SIMD Group 内快速归约 - float sg_max = simd_max(my_max); - - // 4. 写入 Shared Memory (仅每个SG的第一个线程写) - if (tiisg == 0) { - sdata_flag[sgitg] = sg_max; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // 5. Thread 0 汇总 (仅需检查 4 个值) - if (ltid == 0) { - float block_max = -FLT_MAX; - int num_sg = (TG_SIZE + 31) / 32; - - for (int i = 0; i < num_sg; ++i) { - block_max = max(block_max, sdata_flag[i]); - } - - // 阈值判定:只有全为极小值时才跳过 - // 写入 1.0f 表示跳过 - sdata_flag[0] = (block_max <= -10000.0f) ? 1.0f : 0.0f; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // 6. 读取跳过标志 - if (sdata_flag[0] > 0.5f) { - skip_block = true; - } + s0 *= param.scale; + s1 *= param.scale; + if (k > kv_valid_limit) { + s0 = -FLT_MAX; + s1 = -FLT_MAX; } -#endif - // ============================================================ - - // 使用 if 包裹,而非 continue,确保控制流绝对安全 - if (!skip_block) { - - // --- Step A: Q * K^T --- - simdgroup_float8x8 sg_score[2]; - sg_score[0] = make_filled_simdgroup_matrix(0.f); - sg_score[1] = make_filled_simdgroup_matrix(0.f); - - int d_start = sgitg * 32; - int d_end = min(d_start + 32, head_dim); - - for (int d = d_start; d < d_end; d += 8) { - simdgroup_T8x8 sgq; - simdgroup_T8x8 sgk[2]; - - simdgroup_load(sgq, sdata_q + d, Q_SMEM_STRIDE, ulong2(0), false); - - const device ftype* k_curr = key + k_base + t_blk * k_seq_stride + d; - ulong k_stride = ulong(k_seq_stride); - - simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(sgk[0], k_curr, k_stride, ulong2(0), true); - simdgroup_load(sgk[1], k_curr + 8 * k_seq_stride, k_stride, ulong2(0), true); - simdgroup_barrier(mem_flags::mem_none); - - simdgroup_multiply_accumulate(sg_score[0], sgq, sgk[0], sg_score[0]); - simdgroup_multiply_accumulate(sg_score[1], sgq, sgk[1], sg_score[1]); - } - - int smem_score_offset = sgitg * 128; - simdgroup_store(sg_score[0], sdata_partials + smem_score_offset, 16); - simdgroup_store(sg_score[1], sdata_partials + smem_score_offset + 8, 16); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // --- Step B: Reduction --- - if (ltid < 128) { - float sum = 0.0f; - #pragma unroll - for (int i = 0; i < SIMD_GROUPS; ++i) { - sum += sdata_partials[i * 128 + ltid]; - } - sdata_work[ltid] = sum; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // --- Step C: Softmax & P (SG0 Only) --- - if (ltid < 8) { - float m_prev = row_max; - float s_prev = row_sum; - float m_curr = -FLT_MAX; - - for (int j=0; j<16; ++j) { - float val = sdata_work[ltid * 16 + j] * param.scale; - int ti = t_blk + j; - - #ifdef ADD_MASK - int mask_offset = ti - kv_len + param.query_seq_len; - if (ti < kv_len && mask_offset >= 0 && mask_offset < param.query_seq_len) - val += (float)mask[(sl_blk * Q_BLOCK + ltid) * param.query_seq_len + mask_offset]; - else if (ti >= kv_len) val = -FLT_MAX; - #elif defined(SET_MASK) - if (ti >= kv_len || mask[(sl_blk * Q_BLOCK + ltid) * kv_len + ti] == 0) val = -FLT_MAX; - #elif defined(CAUSAL_MASK) - // Causal mask: keep if ti <= query_pos + (kv_len - query_seq_len), else -inf - int query_pos = sl_blk * Q_BLOCK + ltid; - if (ti > query_pos + (kv_len - param.query_seq_len)) val = -FLT_MAX; - #endif - - sdata_work[ltid * 16 + j] = val; - m_curr = max(m_curr, val); - } - - float m_new = max(m_prev, m_curr); - float exp_diff = exp(m_prev - m_new); - float s_curr = 0.0f; - - threadgroup ftype* sdata_p_out = (threadgroup ftype*)sdata_work; - for (int j=0; j<16; ++j) { - float p = exp(sdata_work[ltid * 16 + j] - m_new); - sdata_p_out[ltid * 16 + j] = (ftype)p; - s_curr += p; - } - - row_max = m_new; - row_sum = s_prev * exp_diff + s_curr; - sdata_scale[ltid] = exp_diff; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // --- Step D: P * V --- - for (int iter = 0; iter < 2; ++iter) { - int d_tile = sgitg * 2 + iter; - if (d_tile * 16 >= head_dim) continue; - - // 注意:这里继续使用 sdata_scratch (offset 512) 没问题, - // 因为它只在 Block 计算内部使用,不会跨 Block 影响 Skip 逻辑。 - // 且 Skip Flag 已经用完了。 - threadgroup float* my_scratch = sdata_work + 512 + sgitg * 128; - - simdgroup_store(acc_reg[iter][0], my_scratch, 16); - simdgroup_store(acc_reg[iter][1], my_scratch + 8, 16); - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg < 8) { - float sc = sdata_scale[tiisg]; - #pragma unroll - for (int j=0; j<16; ++j) my_scratch[tiisg * 16 + j] *= sc; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - simdgroup_load(acc_reg[iter][0], my_scratch, 16); - simdgroup_load(acc_reg[iter][1], my_scratch + 8, 16); - - threadgroup ftype* sdata_p = (threadgroup ftype*)sdata_work; - simdgroup_T8x8 sgp[2]; - simdgroup_load(sgp[0], sdata_p, 16, ulong2(0), false); - simdgroup_load(sgp[1], sdata_p + 8, 16, ulong2(0), false); - - int d_start = d_tile * 16; - const device ftype* v_curr = value + v_base + d_start * max_kv_len + t_blk; - - simdgroup_T8x8 sgv[4]; - simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(sgv[0], v_curr, max_kv_len, ulong2(0), true); - simdgroup_load(sgv[1], v_curr + 8 * max_kv_len, max_kv_len, ulong2(0), true); - simdgroup_load(sgv[2], v_curr + 8, max_kv_len, ulong2(0), true); - simdgroup_load(sgv[3], v_curr + 8 * max_kv_len + 8, max_kv_len, ulong2(0), true); - simdgroup_barrier(mem_flags::mem_none); - - simdgroup_multiply_accumulate(acc_reg[iter][0], sgp[0], sgv[0], acc_reg[iter][0]); - simdgroup_multiply_accumulate(acc_reg[iter][0], sgp[1], sgv[2], acc_reg[iter][0]); - simdgroup_multiply_accumulate(acc_reg[iter][1], sgp[0], sgv[1], acc_reg[iter][1]); - simdgroup_multiply_accumulate(acc_reg[iter][1], sgp[1], sgv[3], acc_reg[iter][1]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } // End of if (!skip_block) - } // End K Loop + scores0[k] = s0; + scores1[k] = s1; + local_max0 = max(local_max0, s0); + local_max1 = max(local_max1, s1); + } - // 4. Output Finalization (保持不变) - if (ltid < 8) { - sdata_final_sum[ltid] = row_sum; + local_max0 = simd_max(local_max0); + local_max1 = simd_max(local_max1); + if (tiisg == 0) { + reduce0[sgitg] = local_max0; + reduce1[sgitg] = local_max1; } threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float* my_out_buf = sdata_scratch + sgitg * 128; - - for (int iter = 0; iter < 2; ++iter) { - int d_tile = sgitg * 2 + iter; - if (d_tile * 16 >= head_dim) continue; - - simdgroup_store(acc_reg[iter][0], my_out_buf, 16); - simdgroup_store(acc_reg[iter][1], my_out_buf + 8, 16); - simdgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg < 8) { - float inv_sum = 1.0f / sdata_final_sum[tiisg]; - int qi = sl_blk * Q_BLOCK + tiisg; - if (qi < q_seq_len) { - device ftype* out_ptr = output + ((b * q_seq_len + qi) * param.head_num + h) * head_dim + d_tile * 16; - #pragma unroll - for (int j=0; j<16; ++j) { - if (d_tile * 16 + j < head_dim) { - out_ptr[j] = (ftype)(my_out_buf[tiisg * 16 + j] * inv_sum); - } - } - } + if (sgitg == 0 && tiisg == 0) { + float max0 = -FLT_MAX; + float max1 = -FLT_MAX; + for (int i = 0; i < sg_count; ++i) { + max0 = max(max0, reduce0[i]); + max1 = max(max1, reduce1[i]); } - simdgroup_barrier(mem_flags::mem_threadgroup); + reduce0[0] = max0; + reduce1[0] = max1; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float max0 = reduce0[0]; + const float max1 = reduce1[0]; + + float local_sum0 = 0.0f; + float local_sum1 = 0.0f; + for (int k = int(tid); k < key_seq_len; k += tptg) { + float v0 = exp(scores0[k] - max0); + float v1 = exp(scores1[k] - max1); + scores0[k] = v0; + scores1[k] = v1; + local_sum0 += v0; + local_sum1 += v1; } -#else - // ===== Optimized Basic Version: Threadgroup parallel without simd_sum ===== - // Grid: [SeqLen, Batch*Head, 1], Threadgroup: [THREADS_PER_GROUP, 1, 1] - // Each threadgroup processes one Q token, threads cooperate on dimension reduction - - threadgroup float shared_reduce[256]; // For manual reduction, max 256 threads - - int s = tgid.x; // query sequence position - int bh = tgid.y; // batch * head_num - uint tid = tiisg; // thread index in simdgroup (0-31) - uint threads_per_group = sgitg * 32 + tiisg; // global thread index in threadgroup - - if (s >= param.query_seq_len || bh >= param.batch * param.head_num) return; - - int b = bh / param.head_num; - int h = bh % param.head_num; - int kv_h = h / param.group; - - int head_dim = param.head_dim; - int kv_len = param.key_seq_len; - int max_kv_len = param.max_kv_len; - int group = param.group; - - int q_offset = ((b * param.query_seq_len + s) * param.head_num + h) * head_dim; - - // Each thread processes multiple dimensions - int d_per_thread = (head_dim + 31) / 32; // Assume 32 threads per group - - float acc[8] = {0.0f}; - - // Load Q values for this thread's dimensions - float q_local[8] = {0.0f}; - for (int i = 0; i < d_per_thread; ++i) { - int d = tid + i * 32; - if (d < head_dim) { - q_local[i] = (float)query[q_offset + d]; - } + local_sum0 = simd_sum(local_sum0); + local_sum1 = simd_sum(local_sum1); + if (tiisg == 0) { + reduce0[sgitg] = local_sum0; + reduce1[sgitg] = local_sum1; } - - float cur_max = -FLT_MAX; - float cur_sum = 0.0f; - - // K/V offsets - int kv_head_num = param.head_num / group; - int k_seq_stride = param.batch * kv_head_num * head_dim; - int k_base_offset = (b * kv_head_num + kv_h) * head_dim; - auto k_ptr_start = key + k_base_offset; - - int v_base_offset = (b * kv_head_num + kv_h) * head_dim * max_kv_len; - - for (int t = 0; t < kv_len; ++t) { - auto k_ptr_t = k_ptr_start + t * k_seq_stride; - - // Each thread computes partial dot product for its dimensions - float partial_dot = 0.0f; - for (int i = 0; i < d_per_thread; ++i) { - int d = tid + i * 32; - if (d < head_dim) { - partial_dot += q_local[i] * (float)k_ptr_t[d]; - } - } - - // Manual reduction across threads (no simd_sum) - shared_reduce[tid] = partial_dot; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Tree reduction in threadgroup memory - float score = 0.0f; - if (tid == 0) { - for (int i = 0; i < 32; ++i) { - score += shared_reduce[i]; - } - score *= param.scale; - - #ifdef ADD_MASK - int mask_offset = t - kv_len + param.query_seq_len; - if (mask_offset >= 0 && mask_offset < param.query_seq_len) { - float m = (float)mask[s * param.query_seq_len + mask_offset]; - score += m; - } - #elif defined(SET_MASK) - int mask_val = mask[s * kv_len + t]; - if (mask_val == 0) score = -FLT_MAX; - #elif defined(CAUSAL_MASK) - // Causal mask: keep if t <= s + (kv_len - query_seq_len), else -inf - if (t > s + (kv_len - param.query_seq_len)) score = -FLT_MAX; - #endif - - shared_reduce[0] = score; // Store score for all threads - } - threadgroup_barrier(mem_flags::mem_threadgroup); - score = shared_reduce[0]; // All threads read the score - - // Online softmax update - float new_max = max(cur_max, score); - float exp_score = exp(score - new_max); - float running_scale = exp(cur_max - new_max); - - cur_sum = cur_sum * running_scale + exp_score; - cur_max = new_max; - - // Update accumulator for each dimension - for (int i = 0; i < d_per_thread; ++i) { - int d = tid + i * 32; - if (d < head_dim) { - float v_val = (float)value[v_base_offset + d * max_kv_len + t]; - acc[i] = acc[i] * running_scale + v_val * exp_score; - } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + float sum0 = 0.0f; + float sum1 = 0.0f; + for (int i = 0; i < sg_count; ++i) { + sum0 += reduce0[i]; + sum1 += reduce1[i]; } + reduce0[0] = sum0; + reduce1[0] = sum1; } - - float inv_sum = 1.0f / cur_sum; - - // Write output - auto out_ptr = output + ((b * param.query_seq_len + s) * param.head_num + h) * head_dim; - for (int i = 0; i < d_per_thread; ++i) { - int d = tid + i * 32; - if (d < head_dim) { - out_ptr[d] = (ftype)(acc[i] * inv_sum); - } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float inv_sum0 = 1.0f / reduce0[0]; + const float inv_sum1 = 1.0f / reduce1[0]; + + const int base0 = ((b * param.head_num + head0) * param.query_seq_len + q_idx) * align_key_len; + const int base1 = ((b * param.head_num + head1) * param.query_seq_len + q_idx) * align_key_len; + for (int k = int(tid); k < key_seq_len; k += tptg) { + output[base0 + k] = (ftype)(scores0[k] * inv_sum0); + output[base1 + k] = (ftype)(scores1[k] * inv_sum1); + } + for (int k = int(tid) + key_seq_len; k < align_key_len; k += tptg) { + output[base0 + k] = (ftype)0.0f; + output[base1 + k] = (ftype)0.0f; } -#endif } )metal"; -#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif +// softmax sg reduce source moved to MetalSoftmaxShader.cpp +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ +#endif diff --git a/source/backend/metal/MetalBackend.hpp b/source/backend/metal/MetalBackend.hpp index 5990fc6f32..31abac9ee5 100644 --- a/source/backend/metal/MetalBackend.hpp +++ b/source/backend/metal/MetalBackend.hpp @@ -89,6 +89,10 @@ class MetalRuntime : public Runtime { BufferAllocator* createDynamicAllocator(int index, bool secondResize) const; mutable id _waiting = nil; + size_t maxThreadSize() const { + return mMaxThreadSize; + } + private: MetalRuntime(void* context); void* mContext = nullptr; @@ -114,6 +118,7 @@ class MetalRuntime : public Runtime { bool mSimdGroupReduce; bool mSimdGroupMatrix; bool mTensorOps; + size_t mMaxThreadSize; }; @@ -284,6 +289,7 @@ class MetalBackend : public Backend { bool mUseFloatAsFp16; bool mIsIphone = false; BufferAllocator* mCurrentAllocator = nullptr; + std::shared_ptr mExecutionBufferPool; }; diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 1184cd30cf..4edbeaadeb 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -86,6 +86,7 @@ static void _MetalApplyTensor(uint8_t* host, size_t offset, Tensor* t) { mRuntime = runtime; auto ctx = (__bridge MNNMetalContext *)runtime->context(); mBufferPool.reset(runtime->createDynamicAllocator(0, false)); + mExecutionBufferPool.reset(new EagerBufferAllocator(runtime->buffer(0)->root, 1024)); mCurrentAllocator = mBufferPool.get(); mUseFloatAsFp16 = usefp16AsFp32; mMemoryMode = mode; @@ -269,6 +270,10 @@ MemChunk chunk() override { buffer = mCurrentAllocator->alloc(size, true); allocator = mCurrentAllocator; } break; + case Backend::DYNAMIC_IN_EXECUTION: { + buffer = mExecutionBufferPool->alloc(size, false); + allocator = mExecutionBufferPool.get(); + } break; default:{ break; } @@ -291,6 +296,9 @@ MemChunk chunk() override { bool MetalBackend::onClearBuffer() { mCurrentAllocator->release(true); + if (mExecutionBufferPool.get() != nullptr) { + mExecutionBufferPool->release(true); + } if (nullptr != mRuntime->mStaticAllocatorRaw.get()) { mRuntime->mStaticAllocator->sync(); mRuntime->mStaticAllocator = mRuntime->mStaticAllocatorRaw; @@ -1076,6 +1084,7 @@ static void _execute(id encoder, const MetalBackend::C mSimdGroupReduce = false; mSimdGroupMatrix = false; } + mMaxThreadSize = [[ctx device] maxThreadsPerThreadgroup].width; // Metal4 Support M1/A14 and later chips #ifdef MNN_METAL_TENSOR if (@available(iOS 13.0, macOS 10.15, *)) { diff --git a/source/backend/metal/MetalBinary.mm b/source/backend/metal/MetalBinary.mm index 997e2de298..c7f811d66a 100755 --- a/source/backend/metal/MetalBinary.mm +++ b/source/backend/metal/MetalBinary.mm @@ -66,6 +66,12 @@ */ return @"select(V0%V1,(V0%V1)+V1,(V0%V1<0&&V1>0)||(V0%V1>0&&V1<0))"; } + if (BinaryOpOperation_MUL_SILU == originOp) { + if (!inputFloat) { + return nil; + } + return @"V0*(V1/(1.0f+exp(-V1)))"; + } CHECK(BinaryOpOperation_ADD, @"V0+V1"); CHECK(BinaryOpOperation_ATAN2, @"atan2(V0,V1)"); CHECK(BinaryOpOperation_SUB, @"V0-V1"); diff --git a/source/backend/metal/MetalConvolution1x1.mm b/source/backend/metal/MetalConvolution1x1.mm index 8bf5237efc..0196842ddd 100644 --- a/source/backend/metal/MetalConvolution1x1.mm +++ b/source/backend/metal/MetalConvolution1x1.mm @@ -9,6 +9,7 @@ #import "backend/metal/MetalConvolution1x1.hpp" #import "core/Macro.h" #import "backend/metal/MetalBackend.hpp" +#import "backend/metal/MetalSharedGather.hpp" #import "ConvSimdGroupShader.hpp" #if MNN_METAL_ENABLED @@ -53,6 +54,17 @@ if (nullptr == dst) { return true; } + if (op->type() == OpType_GatherV2) { + // SharedGather path: reuse quantized weight and dequant resources + if (!mDequantScaleBias.get() || (mDequantBits != 4 && mDequantBits != 8)) { + // Quantized weight is required for SharedGather + return false; + } + auto conv2D = mOp->main_as_Convolution2D(); + int oc = conv2D->common()->outputCount(); + *dst = new MetalSharedGather(bn, oc, mWeight, mDequantScaleBias, mDequantBits, mScaleCoef); + return true; + } *dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits, mScaleCoef); return true; } @@ -130,7 +142,7 @@ MetalRuntime* rt = (MetalRuntime *)backend->runtime(); std::string basicShaderPrefix = gBasicConvPrefix; - + // if M is small, dequant weight in shader // if device not support simdgroup matrix, only support dequant in shader bool dequantInShader = (area < 64) || !(rt->supportSimdGroupMatrix()); @@ -143,11 +155,11 @@ dequantInShader = false; } mPreDequantWeight = false; - + #ifdef MNN_LOW_MEMORY if (mDequantScaleBias.get() && dequantInShader) { //printf("inner dequant MNK: %d %d %d %d\n", area, oc, ic, blockSize); - + std::string sgmWqShader = gConv1x1WqSgMatrix; std::string sgrWqShader = gConv1x1WqSgReduce; @@ -184,7 +196,7 @@ } if(rt->supportSimdGroupReduce() && area <= short_seq) { baseKeys.emplace_back("conv1x1_wquant_sg_reduce"); - + std::string sgrWqStr = basicShaderPrefix + sgrWqShader; if(area > 1) { auto keys = baseKeys; @@ -212,7 +224,7 @@ mPipeline = pipeline; mThreads = std::make_pair(MTLSizeMake(UP_DIV(oc, 4), piece, 1), MTLSizeMake(32, 1, 1)); } else if(mDequantBits != 2 && mDequantBits != 3 && oc > 16384 && oc_4 % 2 == 0) { - // g16 path not extended for W_QUANT_2/3 — fall back to g8. + // g16 path not extended for W_QUANT_2/3, fall back to g8. auto keys = baseKeys; keys.emplace_back("conv1x1_gemv_g16_wquant_sg"); auto pipeline = rt->findPipeline(keys); @@ -263,7 +275,7 @@ } mPipeline = pipeline; mThreads = std::make_pair(MTLSizeMake(UP_DIV(area, 32), UP_DIV(oc, 64), 1), MTLSizeMake(128, 1, 1)); - + } else if(area >= 32 && area * oc > 128 * 2048) { auto keys = baseKeys; keys.emplace_back("conv1x1_gemm_32x16_wquant_sg"); @@ -364,16 +376,16 @@ std::string sgmWfpShader = gConv1x1WfpSgMatrix; std::string sgrWfpShader = gConv1x1WfpSgReduce; - + // Dequant using single shader if (mDequantScaleBias.get()) { baseKeys.emplace_back("conv1x1_dequant_weight_outter"); std::string sgmWfpStr = basicShaderPrefix + sgmWfpShader; - + mPreDequantWeight = true; { NSMutableDictionary *dic = [baseDic mutableCopy]; - + auto keys = baseKeys; keys.emplace_back("conv1x1_w_dequant"); if(mDequantBits == 2) { @@ -394,27 +406,27 @@ keys.emplace_back("W_ALIGN_K16_PROTECT"); } option.preprocessorMacros = dic; - + int bytes = backend->useFp16InsteadFp32() ? 2 : 4; // accquire space - mTempWeight.reset(Tensor::createDevice(std::vector{ROUND_UP(oc, 4) * ROUND_UP(ic, 16) * bytes})); + mTempWeight.reset(Tensor::createDevice(std::vector{ROUND_UP(oc, 4) * ROUND_UP(ic, 32) * bytes})); backend->onAcquireBuffer(mTempWeight.get(), Backend::DYNAMIC); backend->onReleaseBuffer(mTempWeight.get(), Backend::DYNAMIC); - + auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { pipeline = backend->makeComputePipelineWithSourceOption(sgmWfpStr.c_str(), "conv1x1_w_dequant", option); rt->insertPipeline(keys, pipeline); } mDequantPipeline = pipeline; - + mDequantThreads = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(UP_DIV(oc, 1), UP_DIV(ic, 16), 1)]; } - + { auto keys = baseKeys; keys.emplace_back("conv1x1_gemm_32x64_split_k_sg"); - + NSMutableDictionary *dic = [baseDic mutableCopy]; if (ic_4 % 8 != 0) { [dic setValue:@"1" forKey:@"MNN_METAL_SRC_PROTECT"]; @@ -439,10 +451,10 @@ mThreads = std::make_pair(MTLSizeMake(UP_DIV(area, 32), UP_DIV(oc, 64), 1), MTLSizeMake(128, 1, 1)); //printf("out dequant MNK: %d %d %d %d\n", area, oc, ic, blockSize); } - + return NO_ERROR; } - + option.preprocessorMacros = baseDic; if(rt->supportSimdGroupMatrix()) { @@ -617,7 +629,7 @@ static_cast(backend())->flushEncoder(); static_cast(backend())->commit_net(); static_cast(backend())->wait(); - + auto buffer = static_cast(backend())->getBuffer(input); auto ptr = (float*)((int8_t*)buffer.first.contents + buffer.second); for(int i=0; i<64; i++) { @@ -641,7 +653,7 @@ } printf("\n\n"); } - + { auto buffer = static_cast(backend())->getBuffer(output); auto ptr = (float*)((int8_t*)buffer.first.contents + buffer.second); diff --git a/source/backend/metal/MetalConvolutionCommon.mm b/source/backend/metal/MetalConvolutionCommon.mm index f2bba2a3d9..26a06d2788 100644 --- a/source/backend/metal/MetalConvolutionCommon.mm +++ b/source/backend/metal/MetalConvolutionCommon.mm @@ -415,21 +415,19 @@ kernel void weight_transform_common(const device IType* src [[buffer(0)]], #ifdef MNN_LOW_MEMORY if (subBits == 3) { // 3-bit packed: 6 bytes / (4 OC, 4 IC) tile. - // Bytes 0..3: low 2 bits. Byte ro holds OC=ro, bits [7:6]=IC0..[1:0]=IC3, value=signed+4 lower 2 bits. - // Bytes 4..5: high 1 bit. Byte 4 holds OC{0,1}, byte 5 holds OC{2,3}. - // Within high-byte: upper nibble [7..4] = OC even (i%2==0), bit (3-k) = IC k high bit. - // lower nibble [3..0] = OC odd (i%2==1), bit (3-k) = IC k high bit. size_t weight_bytes = (size_t)group * goc_4 * gic_4 * kh * kw * 6; std::shared_ptr weightLow(MNN::Tensor::createDevice({(int)weight_bytes})); if (!backend->onAcquireBuffer(weightLow.get(), Backend::STATIC)) { MNN_ERROR("Memory alloc error!\n"); return nullptr; } - if (nil == src) return weightLow; + if (nil == src) { + return weightLow; + } auto buf = MetalBackend::getBuffer(weightLow.get()); auto dstPtr = (uint8_t*)[buf.first contents] + buf.second; ::memset(dstPtr, 0, weight_bytes); - auto srcPtr = (const int8_t*)src; // signed [-4,3] per weight + auto srcPtr = (const int8_t*)src; for (int g = 0; g < group; g++) { for (int o = 0; o < goc; o++) { int zo = o / 4, ro = o % 4; @@ -438,11 +436,9 @@ kernel void weight_transform_common(const device IType* src [[buffer(0)]], for (int h = 0; h < kh; h++) { for (int w = 0; w < kw; w++) { int srcIdx = ((g * goc + o) * gic + i) * kh * kw + h * kw + w; - int sv = (int)srcPtr[srcIdx] + 4; // unsigned [0,7] + int sv = (int)srcPtr[srcIdx] + 4; int tileBase = (((g * goc_4 + zo) * gic_4 + zi) * kh + h) * kw * 6 + w * 6; - // low 2 bits dstPtr[tileBase + ro] |= (uint8_t)((sv & 3) << (6 - ri * 2)); - // high 1 bit int hiByte = tileBase + 4 + (ro / 2); int hiShift = (ro % 2 == 0 ? 4 : 0) + (3 - ri); dstPtr[hiByte] |= (uint8_t)(((sv >> 2) & 1) << hiShift); @@ -454,20 +450,20 @@ kernel void weight_transform_common(const device IType* src [[buffer(0)]], return weightLow; } if (subBits == 2) { - // 2-bit packed: 4 bytes / (4 OC, 4 IC) tile. Byte ro holds 1 OC's 4 IC values. - // Bits [7:6]=IC0, [5:4]=IC1, [3:2]=IC2, [1:0]=IC3, value = signed_weight + 2 in [0,3]. - // Mirrors W_QUANT_8 tile order: byte index ro within a tile = OC inner. + // 2-bit packed: 4 bytes / (4 OC, 4 IC) tile. size_t weight_bytes = (size_t)group * goc_4 * gic_4 * kh * kw * 4; std::shared_ptr weightLow(MNN::Tensor::createDevice({(int)weight_bytes})); if (!backend->onAcquireBuffer(weightLow.get(), Backend::STATIC)) { MNN_ERROR("Memory alloc error!\n"); return nullptr; } - if (nil == src) return weightLow; + if (nil == src) { + return weightLow; + } auto buf = MetalBackend::getBuffer(weightLow.get()); auto dstPtr = (uint8_t*)[buf.first contents] + buf.second; ::memset(dstPtr, 0, weight_bytes); - auto srcPtr = (const int8_t*)src; // signed [-2,1] per weight + auto srcPtr = (const int8_t*)src; for (int g = 0; g < group; g++) { for (int o = 0; o < goc; o++) { int zo = o / 4, ro = o % 4; @@ -478,8 +474,7 @@ kernel void weight_transform_common(const device IType* src [[buffer(0)]], int srcIdx = ((g * goc + o) * gic + i) * kh * kw + h * kw + w; int sv = (int)srcPtr[srcIdx] + 2; int tileBase = (((g * goc_4 + zo) * gic_4 + zi) * kh + h) * kw * 4 + w * 4; - int byteIdx = tileBase + ro; - dstPtr[byteIdx] |= (uint8_t)((sv & 3) << (6 - ri * 2)); + dstPtr[tileBase + ro] |= (uint8_t)((sv & 3) << (6 - ri * 2)); } } } diff --git a/source/backend/metal/MetalKVCacheManager.hpp b/source/backend/metal/MetalKVCacheManager.hpp index 7da4c83114..96660044b5 100644 --- a/source/backend/metal/MetalKVCacheManager.hpp +++ b/source/backend/metal/MetalKVCacheManager.hpp @@ -1,4 +1,3 @@ - // // MetalKVCacheManager.hpp // MNN @@ -21,10 +20,17 @@ namespace MNN { class MetalKVCacheManager : public KVCacheManager{ private: - id mKeyBuffer; - id mValueBuffer; - size_t mCurrentTotalSize; - + id mKeyBuffer = nil; + id mValueBuffer = nil; + id mKScaleBuffer = nil; + id mVScaleBuffer = nil; + // Only used when KV cache is stored on disk. For in-memory path V may use int8. + size_t mCurrentTotalSize = 0; + + bool mQuantValue = false; // whether V is stored as int8 in cache + bool mQuantKey = false; // whether K is stored as int8 in cache + std::shared_ptr mKVQuantParameter = nullptr; + private: void expandKVCacheInDisk(size_t oldSize, size_t curSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy, file_t specKeyFile = INVALID_FILE, file_t specValueFile = INVALID_FILE); void expandKVCacheInMem(size_t oldSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy); @@ -44,14 +50,34 @@ class MetalKVCacheManager : public KVCacheManager{ id getKeyBuffer() { return mKeyBuffer; } + id getKScaleBuffer() { + return mKScaleBuffer; + } + id getVScaleBuffer() { + return mVScaleBuffer; + } id getValueBuffer() { return mValueBuffer; } - + void setPastLength(int length) { mPastLength = length; } + void setKVQuantParameter(std::shared_ptr p) { + mKVQuantParameter = p; + } + void setAttenQuantKeyValue(bool quantKey, bool quantValue) { + mQuantKey = quantKey; + mQuantValue = quantValue; + } + bool useDynamicScaleBuffer() const { + return (mQuantKey || mQuantValue) && mKVQuantParameter == nullptr; + } + bool quantValue() const { + return mQuantValue; + } + virtual void onResize(int kv_num_head, int head_dim); virtual void onClear(); virtual void onAlloc(KVMeta* meta, int seq_len); diff --git a/source/backend/metal/MetalKVCacheManager.mm b/source/backend/metal/MetalKVCacheManager.mm index 5f4be7ccf2..68b82ea791 100644 --- a/source/backend/metal/MetalKVCacheManager.mm +++ b/source/backend/metal/MetalKVCacheManager.mm @@ -13,10 +13,13 @@ #import "MetalKVCacheManager.hpp" namespace MNN { - + void MetalKVCacheManager::onResize(int kv_num_head, int head_dim) { mKvNumHead = kv_num_head; mHeadDim = head_dim; + auto mtbn = static_cast(mBackend); + // Record bytes for K cache element. When mQuantKey is enabled, key is stored as int8. + mBytes = mQuantKey ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); } void MetalKVCacheManager::onAlloc(KVMeta* meta, int seq_len) { @@ -25,10 +28,8 @@ auto context = (__bridge MNNMetalContext *)mtbn->context(); auto kv_seq_len = mMeta != nullptr ? mMeta->add : seq_len; - int byte = 4; - if(mtbn->useFp16InsteadFp32()) { - byte = 2; - } + int keyByte = mQuantKey ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); + int valueByte = mQuantValue ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); // load disk prefix kvcache if(mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingRead) { // create new files @@ -51,8 +52,8 @@ if(oldKeySize != oldValueSize) { MNN_ERROR("[Error]: Kvcache in disk size of key and value should equal with metal backend\n"); } - size_t oldKeyMaxLength = oldKeySize / (mKvNumHead * mHeadDim * byte); - size_t oldValueMaxLength = oldValueSize / (mKvNumHead * mHeadDim * byte); + size_t oldKeyMaxLength = oldKeySize / (mKvNumHead * mHeadDim * (mtbn->useFp16InsteadFp32() ? 2 : 4)); + size_t oldValueMaxLength = oldValueSize / (mKvNumHead * mHeadDim * (mtbn->useFp16InsteadFp32() ? 2 : 4)); size_t oldMaxLength = ALIMIN(oldKeyMaxLength, oldValueMaxLength); if(oldMaxLength < meta->seqlen_in_disk) { MNN_ERROR("[Error]: Kvcache in disk size smaller than saved lengthInDiskToload:%d\n", (int)meta->seqlen_in_disk); @@ -60,13 +61,13 @@ int kv_seq_len = ROUND_UP(meta->add + meta->seqlen_in_disk, mConfig.mKvAlignNum); mMaxLength = kv_seq_len > oldMaxLength ? ROUND_UP(meta->add + meta->seqlen_in_disk + mConfig.mExpandChunk, mConfig.mKvAlignNum) : oldMaxLength; - size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * byte; + size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * (mtbn->useFp16InsteadFp32() ? 2 : 4); mCurrentTotalSize = totalSize; - size_t old_piece_size = meta->seqlen_in_disk * byte; - size_t old_piece_stride = oldMaxLength * byte; - size_t new_piece_stride = mMaxLength * byte; - + size_t old_piece_size = meta->seqlen_in_disk * (mtbn->useFp16InsteadFp32() ? 2 : 4); + size_t old_piece_stride = oldMaxLength * (mtbn->useFp16InsteadFp32() ? 2 : 4); + size_t new_piece_stride = mMaxLength * (mtbn->useFp16InsteadFp32() ? 2 : 4); + mCurrentTotalSize = ALIMAX(mCurrentTotalSize, oldKeySize); mCurrentTotalSize = ALIMAX(mCurrentTotalSize, oldValueSize); @@ -79,10 +80,10 @@ return; } - + // align max kv_seq_len to mKvAlignNum, for simd/tensor matrix load alignment mMaxLength = ROUND_UP(kv_seq_len + mConfig.mExpandChunk, mConfig.mKvAlignNum); - size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * byte; + size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * keyByte; mCurrentTotalSize = totalSize; bool storeKvInDisk = !mConfig.mKVCacheDir.empty(); bool sharePrefixKv = mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingWrite; @@ -93,11 +94,11 @@ MNN_PRINT("Failed to create prefix cache file dir: %s\n", mConfig.mPrefixCacheDir.c_str()); } } - + if(storeKvInDisk || sharePrefixKv) { std::string keyStoredDst = ""; std::string valueStoredDst = ""; - + if(mMeta != nullptr) { mBasePrefixFileName = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index); keyStoredDst = sharePrefixKv ? mBasePrefixFileName + ".k" : ""; @@ -109,23 +110,25 @@ resetKVCacheFileSize(totalSize, totalSize); mmapKVCache(totalSize, totalSize); mKVCacheInDisk = true; - mKeyBuffer = [[context device] newBufferWithBytesNoCopy:mMapKeyAddr length:totalSize options:MTLResourceStorageModeShared deallocator:nil]; mValueBuffer = [[context device] newBufferWithBytesNoCopy:mMapValueAddr length:totalSize options:MTLResourceStorageModeShared deallocator:nil]; - + if (useDynamicScaleBuffer()) { + int scaleByte = mtbn->useFp16InsteadFp32() ? 2 : 4; + mKScaleBuffer = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + mVScaleBuffer = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + } else { + mKScaleBuffer = nil; + mVScaleBuffer = nil; + } auto new_key_ptr = (uint8_t*)[mKeyBuffer contents]; - ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * keyByte); auto new_value_ptr = (uint8_t*)[mValueBuffer contents]; - ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * (mtbn->useFp16InsteadFp32() ? 2 : 4)); } else { // past_key: [maxlen, kvNumhead, headdim] - auto new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim}); + Tensor* new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim * keyByte}); // past_value: [kvNumhead, headdim, maxlen] - auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength}); - - + Tensor* new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength * valueByte}); auto res = mBackend->onAcquireBuffer(new_key, Backend::STATIC); res = res && mBackend->onAcquireBuffer(new_value, Backend::STATIC); if(!res) { @@ -134,63 +137,67 @@ // memset for qkv matmul mad, in case dirty data auto newKeyBuf = MetalBackend::getBuffer(new_key); auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second; - ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * keyByte); auto newValueBuf = MetalBackend::getBuffer(new_value); auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; - ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * valueByte); mPastKey.reset(new_key); mPastValue.reset(new_value); + if (useDynamicScaleBuffer()) { + int scaleByte = mtbn->useFp16InsteadFp32() ? 2 : 4; + mKScaleBuffer = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + mVScaleBuffer = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + } else { + mKScaleBuffer = nil; + mVScaleBuffer = nil; + } + } - } void MetalKVCacheManager::onRealloc(KVMeta* meta) { mMeta = meta; auto kv_seq_len = mMeta->previous + mMeta->add - mMeta->remove + mMeta->computeReverseSize(); auto mtbn = static_cast(mBackend); - - int byte = 4; - if(mtbn->useFp16InsteadFp32()) { - byte = 2; - } - - auto start = mPastLength - mMeta->remove; + + int keyByte = mQuantKey ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); + int valueByte = mQuantValue ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); + + auto start = mMeta->previous - mMeta->remove; // latest length larger than maxLen if (kv_seq_len > mMaxLength) { // copy mPastLength including all remove/reverse to new buffer first auto copy_len = mPastLength; bool needCopy = mPastLength > 0; - - size_t old_size = mKvNumHead * copy_len * mHeadDim * byte; - size_t old_piece_size = copy_len * byte; - size_t old_piece_stride = mMaxLength * byte; + + size_t old_size = (size_t)mKvNumHead * copy_len * mHeadDim * keyByte; + size_t old_piece_size = (size_t)copy_len * valueByte; + size_t old_piece_stride = (size_t)mMaxLength * valueByte; // align max kv_seq_len to mKvAlignNum, for simd/tensor matrix load alignment mMaxLength = ROUND_UP(kv_seq_len + mConfig.mExpandChunk, mConfig.mKvAlignNum); - + auto oldTotalSize = mCurrentTotalSize; - size_t size = mKvNumHead * mMaxLength * mHeadDim * byte; + size_t size = (size_t)mKvNumHead * mMaxLength * mHeadDim * keyByte; mCurrentTotalSize = size; - size_t new_piece_stride = mMaxLength * byte; - + size_t new_piece_stride = (size_t)mMaxLength * valueByte; + mPastLength = (int)start; if(mKVCacheInDisk) { expandKVCacheInDisk(oldTotalSize, mCurrentTotalSize, old_piece_stride, old_piece_size, new_piece_stride, needCopy); } else { - expandKVCacheInMem(oldTotalSize, old_piece_stride, old_piece_size, new_piece_stride, needCopy); + expandKVCacheInMem(old_size, old_piece_stride, old_piece_size, new_piece_stride, needCopy); } } - + // Remove { if (0 == mMeta->n_reserve) { mPastLength = start; return; } - + int8_t *key_ptr = nullptr; int8_t *value_ptr = nullptr; if(mKVCacheInDisk) { @@ -213,11 +220,33 @@ auto copy_src_index = src_start + begin; auto copy_dst_index = start; for(int i = 0; i < length; i++) { - ::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * byte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * byte, mKvNumHead * mHeadDim * byte); + ::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * keyByte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * keyByte, mKvNumHead * mHeadDim * keyByte); } for(int j = 0; j < mKvNumHead * mHeadDim; j++) { for(int i = 0; i < length; i++) { - ::memcpy(value_ptr + (j * mMaxLength + copy_dst_index + i) * byte, value_ptr + (j * mMaxLength + copy_src_index + i) * byte, byte); + ::memcpy(value_ptr + (j * mMaxLength + copy_dst_index + i) * valueByte, value_ptr + (j * mMaxLength + copy_src_index + i) * valueByte, valueByte); + } + } + if (mKScaleBuffer != nil) { + int scaleByte = mtbn->useFp16InsteadFp32() ? 2 : 4; + if (scaleByte == 2) { + int16_t* k_scale_ptr = (int16_t*)[mKScaleBuffer contents]; + int16_t* v_scale_ptr = (int16_t*)[mVScaleBuffer contents]; + for(int i = 0; i < length; i++) { + k_scale_ptr[(copy_dst_index + i) * 2 + 0] = k_scale_ptr[(copy_src_index + i) * 2 + 0]; + k_scale_ptr[(copy_dst_index + i) * 2 + 1] = k_scale_ptr[(copy_src_index + i) * 2 + 1]; + v_scale_ptr[(copy_dst_index + i) * 2 + 0] = v_scale_ptr[(copy_src_index + i) * 2 + 0]; + v_scale_ptr[(copy_dst_index + i) * 2 + 1] = v_scale_ptr[(copy_src_index + i) * 2 + 1]; + } + } else { + float* k_scale_ptr = (float*)[mKScaleBuffer contents]; + float* v_scale_ptr = (float*)[mVScaleBuffer contents]; + for(int i = 0; i < length; i++) { + k_scale_ptr[(copy_dst_index + i) * 2 + 0] = k_scale_ptr[(copy_src_index + i) * 2 + 0]; + k_scale_ptr[(copy_dst_index + i) * 2 + 1] = k_scale_ptr[(copy_src_index + i) * 2 + 1]; + v_scale_ptr[(copy_dst_index + i) * 2 + 0] = v_scale_ptr[(copy_src_index + i) * 2 + 0]; + v_scale_ptr[(copy_dst_index + i) * 2 + 1] = v_scale_ptr[(copy_src_index + i) * 2 + 1]; + } } } start += length; @@ -225,73 +254,103 @@ mPastLength = (int)start; } } - + void MetalKVCacheManager::expandKVCacheInMem(size_t oldSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy) { auto mtbn = static_cast(mBackend); - int byte = 4; - if(mtbn->useFp16InsteadFp32()) { - byte = 2; - } + int keyByte = mQuantKey ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); + int valueByte = mQuantValue ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); // past_key: [maxlen, kvNumhead, headdim] - auto new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim}); + Tensor* new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim * keyByte}); // past_value: [kvNumhead, headdim, maxlen] - auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength}); - + Tensor* new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength * valueByte}); + auto res = mBackend->onAcquireBuffer(new_key, Backend::STATIC); res = res && mBackend->onAcquireBuffer(new_value, Backend::STATIC); if(!res) { MNN_ERROR("attition kv cache realloc memory error:%d\n", res); } - + // memset for qkv matmul mad, in case dirty data auto newKeyBuf = MetalBackend::getBuffer(new_key); auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second; - ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_key_ptr, 0, (size_t)mMaxLength * mKvNumHead * mHeadDim * keyByte); + auto newValueBuf = MetalBackend::getBuffer(new_value); auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; - ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); - + ::memset(new_value_ptr, 0, (size_t)mMaxLength * mKvNumHead * mHeadDim * valueByte); + if (need_copy) { auto keyBuf = MetalBackend::getBuffer(mPastKey.get()); auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;; ::memcpy(new_key_ptr, key_ptr, oldSize); - + auto valueBuf = MetalBackend::getBuffer(mPastValue.get()); auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second; for(int i = 0; i < mKvNumHead * mHeadDim; i++) { ::memcpy(new_value_ptr + i * new_piece_stride, value_ptr + i * old_piece_stride, old_piece_size); } } - + mPastKey.reset(new_key); mPastValue.reset(new_value); + + auto context = (__bridge MNNMetalContext *)mtbn->context(); + if (useDynamicScaleBuffer()) { + int scaleByte = mtbn->useFp16InsteadFp32() ? 2 : 4; + id newKScale = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + id newVScale = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + if (need_copy && mKScaleBuffer != nil) { + ::memcpy([newKScale contents], [mKScaleBuffer contents], (old_piece_size / valueByte) * scaleByte * 2); + ::memcpy([newVScale contents], [mVScaleBuffer contents], (old_piece_size / valueByte) * scaleByte * 2); + } + mKScaleBuffer = newKScale; + mVScaleBuffer = newVScale; + } else { + mKScaleBuffer = nil; + mVScaleBuffer = nil; + } } - + void MetalKVCacheManager::expandKVCacheInDisk(size_t oldSize, size_t curSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy, file_t specKeyFile, file_t specValueFile) { auto mtbn = static_cast(mBackend); auto context = (__bridge MNNMetalContext *)mtbn->context(); - + mmapKVCache(oldSize, oldSize, specKeyFile, specValueFile); std::vector prevKey, prevValue; prevKey.resize(oldSize); prevValue.resize(oldSize); memcpy(prevKey.data(), mMapKeyAddr, oldSize); memcpy(prevValue.data(), mMapValueAddr, oldSize); - + unmapKVCache(oldSize, oldSize); resetKVCacheFileSize(curSize, curSize); mmapKVCache(curSize, curSize); - + // reset id mKeyBuffer = [[context device] newBufferWithBytesNoCopy:mMapKeyAddr length:curSize options:MTLResourceStorageModeShared deallocator:nil]; mValueBuffer = [[context device] newBufferWithBytesNoCopy:mMapValueAddr length:curSize options:MTLResourceStorageModeShared deallocator:nil]; - - + + int valueByte = mQuantValue ? 1 : (mtbn->useFp16InsteadFp32() ? 2 : 4); + if (useDynamicScaleBuffer()) { + int scaleByte = mtbn->useFp16InsteadFp32() ? 2 : 4; + id newKScale = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + id newVScale = [[context device] newBufferWithLength:mMaxLength * scaleByte * 2 options:MTLResourceStorageModeShared]; + if (need_copy && mKScaleBuffer != nil) { + ::memcpy([newKScale contents], [mKScaleBuffer contents], (old_piece_size / valueByte) * scaleByte * 2); + ::memcpy([newVScale contents], [mVScaleBuffer contents], (old_piece_size / valueByte) * scaleByte * 2); + } + mKScaleBuffer = newKScale; + mVScaleBuffer = newVScale; + } else { + mKScaleBuffer = nil; + mVScaleBuffer = nil; + } + + // Step 3: Move the kvcache from temporary buffers in memory to disk memset(mMapKeyAddr, 0, curSize); memset(mMapValueAddr, 0, curSize); - + if (need_copy) { ::memcpy(mMapKeyAddr, prevKey.data(), oldSize); for(int i = 0; i < mKvNumHead * mHeadDim; i++) { @@ -299,15 +358,27 @@ } } } - + void MetalKVCacheManager::onClear() { if (mKVCacheInDisk) { mKeyBuffer = nil; mValueBuffer = nil; - + // mSaveShareKvPrefix also need unmap file unmapKVCache(mCurrentTotalSize, mCurrentTotalSize); - if(!mSaveShareKvPrefix) { + if(mSaveShareKvPrefix) { + // set prefix cachefile validation + auto k_file = mBasePrefixFileName + ".k"; + if(MNNFileExist(k_file.c_str())) { + auto k_sync_file = mBasePrefixFileName + "_sync.k"; + MNNCreateFile(k_sync_file.c_str()); + } + auto v_file = mBasePrefixFileName + ".v"; + if(MNNFileExist(v_file.c_str())) { + auto v_sync_file = mBasePrefixFileName + "_sync.v"; + MNNCreateFile(v_sync_file.c_str()); + } + } else { // delete temp kvcache file removeKVCacheFile(); } @@ -315,10 +386,11 @@ } mPastKey.reset(); mPastValue.reset(); + mKScaleBuffer = nil; + mVScaleBuffer = nil; mMaxLength = 0; mPastLength = 0; } } // namespace MNN #endif // MNN_SUPPORT_TRANSFORMER_FUSE - diff --git a/source/backend/metal/MetalLayerNorm.hpp b/source/backend/metal/MetalLayerNorm.hpp index 535bcaaed0..3c2c1cd219 100644 --- a/source/backend/metal/MetalLayerNorm.hpp +++ b/source/backend/metal/MetalLayerNorm.hpp @@ -24,6 +24,7 @@ class MetalLayerNorm : public MetalExecution { bool mHasGammaBeta = false; bool mRMSNorm = false; + int mGammaSize = 0; std::shared_ptr mGammaBuffer; std::shared_ptr mBetaBuffer; }; @@ -37,6 +38,9 @@ class MetalLayerNorm : public MetalExecution { private: int mOutside; int mInside; + bool mIsNC4HW4 = false; + bool mIsBinaryNCHW = false; + int mChannelUnit; std::shared_ptr mResource; id mShapeBuffer; id mPipeline; diff --git a/source/backend/metal/MetalLayerNorm.mm b/source/backend/metal/MetalLayerNorm.mm index 72688d1086..6728f6bda6 100755 --- a/source/backend/metal/MetalLayerNorm.mm +++ b/source/backend/metal/MetalLayerNorm.mm @@ -10,6 +10,7 @@ #import "backend/metal/MNNMetalContext.h" #import "backend/metal/MetalBackend.hpp" #import "LayerNormSimdGroupShader.hpp" +#import "core/TensorUtils.hpp" #if MNN_METAL_ENABLED namespace MNN { @@ -49,6 +50,7 @@ auto externalSize = layernorm->external()->size(); gamma_size = static_cast(externalInfo[1]) / sizeof(float); } + res->mGammaSize = gamma_size; if (gamma_size > 0) { res->mHasGammaBeta = true; res->mGammaBuffer.reset(Tensor::createDevice({(int)(gamma_size * sizeof(float))})); @@ -65,7 +67,7 @@ const float* gamma_data = layernorm->gamma()->data(); auto gammaPtr = MetalBackend::getBuffer(res->mGammaBuffer.get()); memcpy((uint8_t*)gammaPtr.first.contents + gammaPtr.second, (const void *)gamma_data, gamma_size * sizeof(float)); - + if (layernorm->beta()->size() != gamma_size) { MNN_ERROR("Size of gamma and beta are not match in MetalLayerNorm.\n"); } @@ -82,11 +84,20 @@ auto context = (__bridge MNNMetalContext *)backend->context(); auto input = inputs[0], output = outputs[0]; - + auto c4 = TensorUtils::getDescribe(input)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4; + mOutside = 1; mInside = 1; int rank = input->dimensions(); - if (mResource->mGroup > 1) { + const bool channelNCHW = !c4 && rank == 4 && input->length(2) == 1 && input->length(3) == 1 && + mResource->mAxisSize == 1 && mResource->mGammaSize == input->length(1); + if (c4) { + mOutside = input->length(0); + mInside = input->length(1); + } else if (channelNCHW) { + mOutside = input->length(0); + mInside = input->length(1); + } else if (mResource->mGroup > 1) { mOutside = input->length(0) * mResource->mGroup; for (int i = 1; i < rank; i++) { mInside *= input->length(i); @@ -109,107 +120,226 @@ bool parallel = (mInside > 32) && ((mInside & 3) == 0); auto inside = parallel ? mInside/4 : mInside; auto rt = (MetalRuntime *)backend->runtime(); - if(rt->supportSimdGroupReduce()) { - // basic marco info - std::string ftype = "float"; - std::string ftype4 = "float4"; - if (backend->useFp16InsteadFp32()) { - ftype = "half"; - ftype4 = "half4"; - } - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - option.preprocessorMacros = @{ - @"ftype" : @(ftype.c_str()), - @"ftype4" : @(ftype4.c_str()), - }; - std::vector baseKeys = {"layernorm_sg_reduce", ftype}; - if(mResource->mRMSNorm) { - // pretty much threads compute all inside dims in a threadgroup - if(mOutside / 512.0 * mInside / 512.0 > 1.0) { - auto keys = baseKeys; - keys.emplace_back("layernorm_in_all_rms_sg"); - auto pipeline = rt->findPipeline(keys); - if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_in_all_rms_sg", option); - rt->insertPipeline(keys, pipeline); + if (c4) { + mIsNC4HW4 = true; + mChannelUnit = UP_DIV(mInside, 4); + if (inputs.size() == 2 && outputs.size() == 2) { + if(rt->supportSimdGroupReduce()) { + std::string ftype = "float"; + std::string ftype4 = "float4"; + if (backend->useFp16InsteadFp32()) { + ftype = "half"; + ftype4 = "half4"; } - mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(32, 1, 1)); - } else if(parallel) { - if(inside >= 16 && inside * mOutside >= 2048) { + + MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + option.preprocessorMacros = @{ + @"ftype" : @(ftype.c_str()), + @"ftype4" : @(ftype4.c_str()), + }; + std::vector baseKeys = {"layernorm_sg_reduce", ftype}; + if(mResource->mRMSNorm) { auto keys = baseKeys; - keys.emplace_back("layernorm_x16_rms_sg"); + keys.emplace_back("binary_layernorm_c4_rms_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x16_rms_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "binary_layernorm_c4_rms_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(UP_DIV(inside, 4), mOutside, 1), MTLSizeMake(32, 1, 1)); + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(64, 1, 1)); } else { auto keys = baseKeys; - keys.emplace_back("layernorm_x4_rms_sg"); + keys.emplace_back("binary_layernorm_c4_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x4_rms_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "binary_layernorm_c4_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(64, 1, 1)); } - } else { - auto keys = baseKeys; - keys.emplace_back("layernorm_x1_rms_sg"); - auto pipeline = rt->findPipeline(keys); - if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x1_rms_sg", option); - rt->insertPipeline(keys, pipeline); + } else { + if(mResource->mRMSNorm) { + mPipeline = [context pipelineWithName:@"binary_layernorm_c4_rms" fp16:backend->useFp16InsteadFp32()]; + } else { + mPipeline = [context pipelineWithName:@"binary_layernorm_c4" fp16:backend->useFp16InsteadFp32()]; } - mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)mChannelUnit, (NSUInteger)mOutside, 1)]; } - } else { - if(mOutside / 512.0 * mInside / 512.0 > 1.0) { + } else if(rt->supportSimdGroupReduce()) { + std::string ftype = "float"; + std::string ftype4 = "float4"; + if (backend->useFp16InsteadFp32()) { + ftype = "half"; + ftype4 = "half4"; + } + + MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + option.preprocessorMacros = @{ + @"ftype" : @(ftype.c_str()), + @"ftype4" : @(ftype4.c_str()), + }; + std::vector baseKeys = {"layernorm_sg_reduce", ftype}; + if(mResource->mRMSNorm) { auto keys = baseKeys; - keys.emplace_back("layernorm_in_all_sg"); + keys.emplace_back("layernorm_c4_rms_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_in_all_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_c4_rms_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(32, 1, 1)); - } else if(parallel) { - auto keys = baseKeys; - keys.emplace_back("layernorm_x4_sg"); - auto pipeline = rt->findPipeline(keys); - if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x4_sg", option); - rt->insertPipeline(keys, pipeline); - } - mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); } else { auto keys = baseKeys; - keys.emplace_back("layernorm_x1_sg"); + keys.emplace_back("layernorm_c4_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x1_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_c4_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; - mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(32, 1, 1)); + } + } else { + if(mResource->mRMSNorm) { + mPipeline = [context pipelineWithName:@"layernorm_c4_rms" fp16:backend->useFp16InsteadFp32()]; + } else { + mPipeline = [context pipelineWithName:@"layernorm_c4" fp16:backend->useFp16InsteadFp32()]; } + mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)mChannelUnit, (NSUInteger)mOutside, 1)]; } } else { - if(mResource->mRMSNorm){ - mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4_rms" : @"layernorm_x1_rms" fp16:backend->useFp16InsteadFp32()]; - }else{ - mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4" : @"layernorm_x1" fp16:backend->useFp16InsteadFp32()]; + mIsNC4HW4 = false; + mIsBinaryNCHW = inputs.size() == 2 && outputs.size() == 2 && channelNCHW; + if(rt->supportSimdGroupReduce()) { + // basic marco info + std::string ftype = "float"; + std::string ftype4 = "float4"; + if (backend->useFp16InsteadFp32()) { + ftype = "half"; + ftype4 = "half4"; + } + + MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + option.preprocessorMacros = @{ + @"ftype" : @(ftype.c_str()), + @"ftype4" : @(ftype4.c_str()), + }; + std::vector baseKeys = {"layernorm_sg_reduce", ftype}; + if (mIsBinaryNCHW) { + auto keys = baseKeys; + if (mResource->mRMSNorm) { + keys.emplace_back("binary_layernorm_x4_rms_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "binary_layernorm_x4_rms_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + } else { + keys.emplace_back("binary_layernorm_x4_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "binary_layernorm_x4_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + } + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(64, 1, 1)); + } else if(mResource->mRMSNorm) { + // pretty much threads compute all inside dims in a threadgroup + if(mOutside / 512.0 * mInside / 512.0 > 1.0) { + auto keys = baseKeys; + keys.emplace_back("layernorm_in_all_rms_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_in_all_rms_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(32, 1, 1)); + } else if(parallel) { + if(inside >= 16 && inside * mOutside >= 2048) { + auto keys = baseKeys; + keys.emplace_back("layernorm_x16_rms_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x16_rms_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(UP_DIV(inside, 4), mOutside, 1), MTLSizeMake(32, 1, 1)); + } else { + auto keys = baseKeys; + keys.emplace_back("layernorm_x4_rms_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x4_rms_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + } + } else { + auto keys = baseKeys; + keys.emplace_back("layernorm_x1_rms_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x1_rms_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + } + } else { + if(mOutside / 512.0 * mInside / 512.0 > 1.0) { + auto keys = baseKeys; + keys.emplace_back("layernorm_in_all_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_in_all_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(1, mOutside, 1), MTLSizeMake(32, 1, 1)); + } else if(parallel) { + auto keys = baseKeys; + keys.emplace_back("layernorm_x4_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x4_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + } else { + auto keys = baseKeys; + keys.emplace_back("layernorm_x1_sg"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(gLayerNormSgReduce, "layernorm_x1_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(inside, mOutside, 1), MTLSizeMake(32, 1, 1)); + } + } + } else { + if (mIsBinaryNCHW) { + return NOT_SUPPORT; + } + if(mResource->mRMSNorm){ + mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4_rms" : @"layernorm_x1_rms" fp16:backend->useFp16InsteadFp32()]; + }else{ + mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4" : @"layernorm_x1" fp16:backend->useFp16InsteadFp32()]; + } + mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)inside, (NSUInteger)mOutside, 1)]; } - mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)inside, (NSUInteger)mOutside, 1)]; } return NO_ERROR; } @@ -219,19 +349,39 @@ auto backend = static_cast(this->backend()); auto context = (__bridge MNNMetalContext *)backend->context(); auto input = inputs[0], output = outputs[0]; + + if (mPipeline == nil) { + MNN_ERROR("MetalLayerNorm: mPipeline is nil!\n"); + return; + } + [encoder setComputePipelineState:mPipeline]; - MetalBackend::setTensor(input, encoder, 0); - MetalBackend::setTensor(output, encoder, 1); - [encoder setBuffer:mShapeBuffer offset:0 atIndex:2]; - if (!mResource->mHasGammaBeta) { - // Set fake buffer to avoid validate - MetalBackend::setTensor(input, encoder, 3); - MetalBackend::setTensor(input, encoder, 4); + if (inputs.size() == 2 && outputs.size() == 2 && (mIsNC4HW4 || mIsBinaryNCHW)) { + MetalBackend::setTensor(inputs[0], encoder, 0); + MetalBackend::setTensor(inputs[1], encoder, 1); + MetalBackend::setTensor(outputs[0], encoder, 2); + MetalBackend::setTensor(outputs[1], encoder, 3); + [encoder setBuffer:mShapeBuffer offset:0 atIndex:4]; + if (!mResource->mHasGammaBeta) { + MetalBackend::setTensor(inputs[0], encoder, 5); + MetalBackend::setTensor(inputs[0], encoder, 6); + } else { + MetalBackend::setTensor(mResource->mGammaBuffer.get(), encoder, 5); + MetalBackend::setTensor(mResource->mBetaBuffer.get(), encoder, 6); + } } else { - MetalBackend::setTensor(mResource->mGammaBuffer.get(), encoder, 3); - MetalBackend::setTensor(mResource->mBetaBuffer.get(), encoder, 4); + MetalBackend::setTensor(input, encoder, 0); + MetalBackend::setTensor(output, encoder, 1); + [encoder setBuffer:mShapeBuffer offset:0 atIndex:2]; + if (!mResource->mHasGammaBeta) { + // Set fake buffer to avoid validate + MetalBackend::setTensor(input, encoder, 3); + MetalBackend::setTensor(input, encoder, 4); + } else { + MetalBackend::setTensor(mResource->mGammaBuffer.get(), encoder, 3); + MetalBackend::setTensor(mResource->mBetaBuffer.get(), encoder, 4); + } } - [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; MNN_PRINT_ENCODER(context, encoder); } diff --git a/source/backend/metal/MetalOPRegister.mm b/source/backend/metal/MetalOPRegister.mm index 01b04181d0..f54d7003db 100644 --- a/source/backend/metal/MetalOPRegister.mm +++ b/source/backend/metal/MetalOPRegister.mm @@ -15,6 +15,7 @@ extern void ___MetalLayerNormCreator__OpType_LayerNorm__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE extern void ___AttentionBufCreator__OpType_Attention__(); + extern void ___MetalRoPECreator__OpType_RoPE__(); #endif extern void ___MetalMatMulCreator__OpType_MatMul__(); extern void ___MetalBinaryCreator__OpType_BinaryOp__(); @@ -72,9 +73,10 @@ void registerMetalOps() { ___MetalReLU6Creator__OpType_ReLU6__(); ___MetalReLU6Creator__OpType_ReLU__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE + ___MetalRoPECreator__OpType_RoPE__(); ___AttentionBufCreator__OpType_Attention__(); ___MetalLinearAttentionCreator__OpType_LinearAttention__(); #endif } #endif -} \ No newline at end of file +} diff --git a/source/backend/metal/MetalRope.mm b/source/backend/metal/MetalRope.mm new file mode 100644 index 0000000000..3b146f0444 --- /dev/null +++ b/source/backend/metal/MetalRope.mm @@ -0,0 +1,397 @@ +// +// MetalRope.mm +// MNN +// +// Fused RoPE (Rotary Positional Embedding) kernel for Metal backend via Extra op path. +// +// Inputs: x, cosEven, cosOdd, sinEven, sinOdd +// Output: same shape as x +// +// For last dimension D (must be even), let halfD = D/2 and split x as +// even = x[..., 0:halfD] +// odd = x[..., halfD:] +// Then compute +// q0 = even * cosEven - odd * sinEven +// q1 = odd * cosOdd + even * sinOdd +// and concatenate [q0, q1] along the last dimension. +// + +#define MNN_UNUSED(x) +#import "MNNMetalContext.h" +#import "backend/metal/MetalBackend.hpp" +#import "MetalExecution.hpp" +#import "MetalLayerNorm.hpp" +#import "core/TensorUtils.hpp" +#import "core/Macro.h" +#include "MNN_generated.h" +#include + +#if MNN_METAL_ENABLED +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +namespace MNN { + +struct RopeParam { + int outerSize; + int halfD; + int ropeHalfD; + int D; + int numHead; + int kvnumHead; + int fullHead; + float qEps; + float kEps; +}; + +// Metal kernel source. ftype is float / half selected by MNN_METAL_FLOAT16_STORAGE. +static const char* gMetalRopeKernelSource = R"metal( +#include +#include +using namespace metal; +#ifdef MNN_METAL_FLOAT16_STORAGE +typedef half ftype; +#else +typedef float ftype; +#endif + +struct RopeParam { + int outerSize; + int halfD; + int ropeHalfD; + int D; + int numHead; + int kvnumHead; + int fullHead; + float qEps; + float kEps; +}; + +#if defined(Q_NORM) || defined(K_NORM) +kernel void rope_kernel( + const device ftype* q [[ buffer(0) ]], + const device ftype* k [[ buffer(1) ]], + const device ftype* cosEven [[ buffer(2) ]], + const device ftype* cosOdd [[ buffer(3) ]], + const device ftype* sinEven [[ buffer(4) ]], + const device ftype* sinOdd [[ buffer(5) ]], + device ftype* qo [[ buffer(6) ]], + device ftype* ko [[ buffer(7) ]], + constant RopeParam& p [[ buffer(8) ]], +#ifdef Q_NORM + const device float* qGamma [[ buffer(9) ]], +#endif +#ifdef K_NORM + const device float* kGamma [[ buffer(10) ]], +#endif +#ifdef USE_SG + uint3 gid [[ threadgroup_position_in_grid]], + uint tiisg [[ thread_index_in_simdgroup]], + uint sgitg [[ simdgroup_index_in_threadgroup ]] +#else + uint3 gid [[ thread_position_in_grid]] +#endif +) { +#ifdef USE_SG + uint actual_z = gid.z * 2 + sgitg; + if (gid.y >= (uint)p.outerSize || actual_z >= p.fullHead) { + return; + } + int step = 32; + int start = tiisg; +#else + uint actual_z = gid.z; + if (gid.x >= 1 || gid.y >= (uint)p.outerSize || actual_z >= p.fullHead) { + return; + } + int step = 1; + int start = 0; +#endif + + const device ftype* x = q + actual_z * p.D + gid.y * p.D * p.numHead; + device ftype* y = qo + actual_z * p.D + gid.y * p.D * p.numHead; + bool isQ = true; + if (actual_z >= p.numHead) { + x = k + (actual_z-p.numHead) * p.D + gid.y * p.D * p.kvnumHead; + y = ko + (actual_z-p.numHead) * p.D + gid.y * p.D * p.kvnumHead; + isQ = false; + } + + float square_sum = 0.0f; +#ifdef Q_NORM + if (isQ) { + for (int i = start; i < p.D; i += step) { + float val = x[i]; + square_sum += val * val; + } +#ifdef USE_SG + square_sum = simd_sum(square_sum); +#endif + } +#endif +#ifdef K_NORM + if (!isQ) { + for (int i = start; i < p.D; i += step) { + float val = x[i]; + square_sum += val * val; + } +#ifdef USE_SG + square_sum = simd_sum(square_sum); +#endif + } +#endif + + float var = 0; +#ifdef Q_NORM + if (isQ) { + var = 1.0 / sqrt(square_sum / p.D + p.qEps); + } +#endif +#ifdef K_NORM + if (!isQ) { + var = 1.0 / sqrt(square_sum / p.D + p.kEps); + } +#endif + + for (int i = start; i < p.halfD; i += step) { + ftype evenVal = x[i]; + ftype oddVal = x[i + p.halfD]; +#ifdef Q_NORM + if (isQ) { + evenVal = evenVal * var * qGamma[i]; + oddVal = oddVal * var * qGamma[i + p.halfD]; + } +#endif +#ifdef K_NORM + if (!isQ) { + evenVal = evenVal * var * kGamma[i]; + oddVal = oddVal * var * kGamma[i + p.halfD]; + } +#endif + + if (i < p.ropeHalfD) { + int cosIndex = gid.y * p.halfD + i; + ftype cEven = cosEven[cosIndex]; + ftype cOdd = cosOdd[cosIndex]; + ftype sEven = sinEven[cosIndex]; + ftype sOdd = sinOdd[cosIndex]; + + y[i] = evenVal * cEven - oddVal * sEven; + y[i + p.halfD] = oddVal * cOdd + evenVal * sOdd; + } else { + y[i] = evenVal; + y[i + p.halfD] = oddVal; + } + } +} +#else +kernel void rope_kernel( + const device ftype* q [[ buffer(0) ]], + const device ftype* k [[ buffer(1) ]], + const device ftype* cosEven [[ buffer(2) ]], + const device ftype* cosOdd [[ buffer(3) ]], + const device ftype* sinEven [[ buffer(4) ]], + const device ftype* sinOdd [[ buffer(5) ]], + device ftype* qo [[ buffer(6) ]], + device ftype* ko [[ buffer(7) ]], + constant RopeParam& p [[ buffer(8) ]], + uint3 gid [[ thread_position_in_grid]]) { + if (gid.x >= (uint)p.halfD || gid.y >= (uint)p.outerSize || gid.z >= p.fullHead) { + return; + } + const device ftype* x = q + gid.z * p.D + gid.y * p.D * p.numHead; + device ftype* y = qo + gid.z * p.D + gid.y * p.D * p.numHead; + if (gid.z >= p.numHead) { + x = k + (gid.z-p.numHead) * p.D + gid.y * p.D * p.kvnumHead; + y = ko + (gid.z-p.numHead) * p.D + gid.y * p.D * p.kvnumHead; + } + ftype evenVal = x[gid.x]; + ftype oddVal = x[gid.x + p.halfD]; + + if (gid.x < (uint)p.ropeHalfD) { + int cosIndex = gid.y * p.halfD + gid.x; + + ftype cEven = cosEven[cosIndex]; + ftype cOdd = cosOdd[cosIndex]; + ftype sEven = sinEven[cosIndex]; + ftype sOdd = sinOdd[cosIndex]; + + ftype q0 = evenVal * cEven - oddVal * sEven; + ftype q1 = oddVal * cOdd + evenVal * sOdd; + + y[gid.x] = q0; + y[gid.x + p.halfD] = q1; + } else { + y[gid.x] = evenVal; + y[gid.x + p.halfD] = oddVal; + } +} +#endif +)metal"; + +class MetalRopeExecution : public MetalExecution { +public: + explicit MetalRopeExecution(Backend *backend, int ropeCutHeadDim, std::shared_ptr qNorm, std::shared_ptr kNorm) + : MetalExecution(backend), mRopeCutHeadDim(ropeCutHeadDim), mQNorm(qNorm), mKNorm(kNorm) { + auto mtbn = static_cast(backend); + auto context = (__bridge MNNMetalContext *)mtbn->context(); + mParam = [context newDeviceBuffer:sizeof(RopeParam) access:CPUWriteOnly]; + auto rt = static_cast(mtbn->getRuntime()); + std::vector keys = {"rope_kernel"}; + MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + NSMutableDictionary *macros = [NSMutableDictionary dictionary]; + if (mtbn->useFp16InsteadFp32()) { + macros[@"MNN_METAL_FLOAT16_STORAGE"] = @"1"; + keys.emplace_back("fp16"); + } + if (mQNorm) { + macros[@"Q_NORM"] = @"1"; + keys.emplace_back("q_norm"); + } + if (mKNorm) { + macros[@"K_NORM"] = @"1"; + keys.emplace_back("k_norm"); + } + if ((mQNorm || mKNorm) && rt->supportSimdGroupReduce()) { + macros[@"USE_SG"] = @"1"; + keys.emplace_back("sg"); + mUseSG = true; + } else { + mUseSG = false; + } + option.preprocessorMacros = macros; + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = mtbn->makeComputePipelineWithSourceOption(gMetalRopeKernelSource, "rope_kernel", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + if (nil == mPipeline) { + MNN_ERROR("MetalRope: failed to compile rope_kernel.\n"); + } + } + + virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override { + MNN_ASSERT(6 == inputs.size()); + MNN_ASSERT(2 == outputs.size()); + auto q = inputs[0]; + auto k = inputs[1]; + int headDim = q->length(3); + int batch = q->length(0); + int seqLen = q->length(1); + int numHead = q->length(2); + int kvnumHead = k->length(2); + + RopeParam* p = (RopeParam*)(mParam.contents); + p->outerSize = static_cast(batch * seqLen); + p->halfD = headDim / 2; + int ropeDim = mRopeCutHeadDim; + if (ropeDim <= 0 || ropeDim > headDim) { + ropeDim = headDim; + } + ropeDim = (ropeDim / 2) * 2; + p->ropeHalfD = ropeDim / 2; + p->D = headDim; + p->numHead = numHead; + p->kvnumHead = kvnumHead; + p->fullHead = kvnumHead + numHead; + p->qEps = mQNorm ? mQNorm->mEps : 0.0f; + p->kEps = mKNorm ? mKNorm->mEps : 0.0f; + auto mtbn = static_cast(backend()); + auto context = (__bridge MNNMetalContext *)mtbn->context(); + if (mQNorm || mKNorm) { + if (mUseSG) { + mThreads = std::make_pair(MTLSizeMake(1, p->outerSize, (NSUInteger)(numHead + kvnumHead + 1) / 2), MTLSizeMake(64, 1, 1)); + } else { + mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(1, p->outerSize, (NSUInteger)(numHead + kvnumHead))]; + } + } else { + mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)p->halfD, p->outerSize, (NSUInteger)(numHead + kvnumHead))]; + } + return NO_ERROR; + } + + virtual void onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) override { + if (nil == mPipeline) { + return; + } + + auto backend = static_cast(this->backend()); + + [encoder setComputePipelineState:mPipeline]; + MetalBackend::setTensor(inputs[0], encoder, 0); + MetalBackend::setTensor(inputs[1], encoder, 1); + MetalBackend::setTensor(inputs[2], encoder, 2); + MetalBackend::setTensor(inputs[3], encoder, 3); + MetalBackend::setTensor(inputs[4], encoder, 4); + MetalBackend::setTensor(inputs[5], encoder, 5); + MetalBackend::setTensor(outputs[0], encoder, 6); + MetalBackend::setTensor(outputs[1], encoder, 7); + [encoder setBuffer:mParam offset:0 atIndex:8]; + if (mQNorm && mQNorm->mGammaBuffer) { + MetalBackend::setTensor(mQNorm->mGammaBuffer.get(), encoder, 9); + } + if (mKNorm && mKNorm->mGammaBuffer) { + MetalBackend::setTensor(mKNorm->mGammaBuffer.get(), encoder, 10); + } + [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; + } + + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { + if (nullptr == dst) { + return true; + } + auto rope = new MetalRopeExecution(bn, mRopeCutHeadDim, mQNorm, mKNorm); + *dst = rope; + return true; + } + +private: + int mRopeCutHeadDim = 0; + bool mUseSG = false; + std::shared_ptr mQNorm; + std::shared_ptr mKNorm; + id mParam = nil; + id mPipeline = nil; + std::pair mThreads; +}; + +class MetalRoPECreator : public MetalBackend::Creator { +public: + virtual Execution *onCreate(const std::vector &inputs, const MNN::Op *op, Backend *backend, const std::vector& outputs) const { + int ropeCutHeadDim = 0; + std::shared_ptr qNorm; + std::shared_ptr kNorm; + if (nullptr != op && OpParameter_Extra == op->main_type()) { + auto extra = op->main_as_Extra(); + if (nullptr != extra && nullptr != extra->attr()) { + for (int i = 0; i < extra->attr()->size(); ++i) { + auto attr = extra->attr()->GetAs(i); + if (nullptr == attr || nullptr == attr->key()) { + continue; + } + if (attr->key()->str() == "rope_cut_head_dim") { + ropeCutHeadDim = attr->i(); + continue; + } + if (attr->key()->str() == "q_norm") { + auto qLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + qNorm = MetalLayerNorm::makeResource(backend, qLayernorm->main_as_LayerNorm()); + continue; + } + if (attr->key()->str() == "k_norm") { + auto kLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + kNorm = MetalLayerNorm::makeResource(backend, kLayernorm->main_as_LayerNorm()); + continue; + } + } + } + } + return new MetalRopeExecution(backend, ropeCutHeadDim, qNorm, kNorm); + } +}; +REGISTER_METAL_OP_CREATOR(MetalRoPECreator, OpType_RoPE); + +} // namespace MNN + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE +#endif // MNN_METAL_ENABLED diff --git a/source/backend/metal/MetalSharedGather.hpp b/source/backend/metal/MetalSharedGather.hpp new file mode 100755 index 0000000000..743583afcc --- /dev/null +++ b/source/backend/metal/MetalSharedGather.hpp @@ -0,0 +1,70 @@ +// +// MetalSharedGather.hpp +// MNN + +#ifndef MetalSharedGather_hpp +#define MetalSharedGather_hpp + +#import "MetalExecution.hpp" +#import "MNN_generated.h" + +#if MNN_METAL_ENABLED +namespace MNN { + +// SharedGather implementation on Metal backend. +// It reuses quantized 1x1 convolution weights and dequantization parameters +// to gather selected output-channel rows into a dense floating-point matrix. +class MetalSharedGather : public MetalExecution { +public: + MetalSharedGather(Backend *backend, + int oc, + std::shared_ptr weight, + std::shared_ptr dequantScaleBias, + int dequantBits, + float scaleCoef); + virtual ~MetalSharedGather() = default; + + virtual ErrorCode onResize(const std::vector &inputs, + const std::vector &outputs) override; + + virtual void onEncode(const std::vector &inputs, + const std::vector &outputs, + id encoder) override; + + virtual bool onClone(Backend *bn, const Op *op, Execution **dst) override; + +private: + // conv1x1_constants host-side mirror, used by conv1x1_w_dequant and shared_gather kernels + struct Conv1x1Constants { + int input_size; // repurposed as ic for SharedGather + int input_slice; // ic_4 + int output_width; // selectSize + int output_height; // unused for SharedGather + int output_size; // total elements (selectSize * ic) + int output_slice; // oc_4 + int output_channel; // oc + int batch; // unused for SharedGather + int block_size; // quant block size along K axis + int activation; // not used (no activation) + float scale_coef; // scale normalization factor + }; + +private: + int mOc = 0; // number of output channels (rows in weight) + std::shared_ptr mWeight; // quantized int8/4 weight + std::shared_ptr mDequantScaleBias; // packed scale + bias + int mDequantBits = 0; // 4 or 8 + float mScaleCoef = 1.0f; + + // direct int4/int8 quant gather pipeline (preferred path) + id mQuantPipeline = nil; + std::pair mQuantThreads; + + // constant buffer shared by dequant and gather kernels + id mConstBuffer = nil; +}; + +} // namespace MNN +#endif /* MNN_METAL_ENABLED */ + +#endif /* MetalSharedGather_hpp */ diff --git a/source/backend/metal/MetalSharedGather.mm b/source/backend/metal/MetalSharedGather.mm new file mode 100755 index 0000000000..ab5302a104 --- /dev/null +++ b/source/backend/metal/MetalSharedGather.mm @@ -0,0 +1,248 @@ +// +// MetalSharedGather.mm +// MNN + +#import "backend/metal/MetalSharedGather.hpp" +#import "backend/metal/MetalBackend.hpp" +#import "backend/metal/MNNMetalContext.h" +#import "core/Macro.h" +#import "backend/metal/ConvSimdGroupShader.hpp" + +#if MNN_METAL_ENABLED + +namespace MNN { + +// gSharedGatherQuant: directly decode int4/int8 weights and gather on-the-fly. +// Layout and dequant parameters follow conv1x1 low-memory path. +// Weight layout: [N/4, K/4, N4, K4] (packed), linear index for a pack: +// offset = ((idx_n4 * cst.input_slice + idx_k4) * 4 + idx_nl) +// - idx_n4 = n / 4, idx_nl = n % 4 +// - idx_k4 = k / 4, comp = k % 4 +// W_QUANT_4: +// uchar2 pack = wt[offset]; +// w0 = (pack.x >> 4) - 8; w1 = (pack.x & 15) - 8; +// w2 = (pack.y >> 4) - 8; w3 = (pack.y & 15) - 8; +// choose w = w{comp}. +// W_QUANT_8: +// char4 pack = wt[offset]; choose pack.{x,y,z,w} by comp. +// Dequant scale/bias: +// blockK4PerBi = (cst.input_slice + cst.block_size - 1) / cst.block_size; +// bi = clamp(idx_k4 / blockK4PerBi, 0, cst.block_size-1); +// sbIndex = idx_n4 * cst.block_size + bi; +// scaleVec = dequantScale[2*sbIndex+0] / cst.scale_coef; // ftype4 +// biasVec = dequantScale[2*sbIndex+1] / cst.scale_coef; +// out = w * scaleVec[idx_nl] + biasVec[idx_nl]. +// Thread grid: 1D over all elements (selectSize * ic). +static const char* gSharedGatherQuant = R"metal( +kernel void shared_gather_quant( + device ftype4 *wf [[buffer(0)]], +#ifdef W_QUANT_4 + const device uchar2 *wi [[buffer(1)]], +#elif defined(W_QUANT_8) + const device char4 *wi [[buffer(1)]], +#else + const device ftype4 *wi [[buffer(1)]],// [N/4, K/4, N4, K4] +#endif + const device int *indices [[buffer(2)]], + constant conv1x1_constants& cst [[buffer(3)]], + const device ftype4 *dequantScale [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + int ic = cst.input_size; + int selectSize = cst.output_width; + int idx_k16 = gid.y; // K/16 + + int idx_k4 = idx_k16 * 4; + + if(idx_k4 >= cst.input_slice || gid.x >= selectSize) { + return; + } + + int idx_n = indices[gid.x]; // N + + int idx_n4 = idx_n/4; + int idx_nl = idx_n%4; + + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + + + int bi = idx_k4 / block; + // [N/4, cst.block_size, 2/*scale_bias*/, N4] + FLOAT scale = FLOAT(((const device ftype *)dequantScale)[((idx_n4 * cst.block_size + bi) * 2 + 0) * 4 + idx_nl]) / (FLOAT)cst.scale_coef; + FLOAT dequant_bias = FLOAT(((const device ftype *)dequantScale)[((idx_n4 * cst.block_size + bi) * 2 + 1) * 4 + idx_nl]) / (FLOAT)cst.scale_coef; + + auto xy_wi = wi + (idx_n4 * cst.input_slice + idx_k4) * 4 + idx_nl;// [N/4, K/4, N4, K4] + auto xy_wf = wf + (ic * gid.x + idx_k16 * 16) / 4; + + #ifdef W_QUANT_4 + for(int k = 0; k < 4; k++) { + uchar2 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)(w_int4[0] >> 4) - 8, (float)(w_int4[0] & 15) - 8, (float)(w_int4[1] >> 4) - 8, (float)(w_int4[1] & 15) - 8); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + #elif defined(W_QUANT_8) + for(int k = 0; k < 4; k++) { + char4 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)w_int4[0], (float)w_int4[1], (float)w_int4[2], (float)w_int4[3]); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + #endif +} +)metal"; + +MetalSharedGather::MetalSharedGather(Backend *backend, + int oc, + std::shared_ptr weight, + std::shared_ptr dequantScaleBias, + int dequantBits, + float scaleCoef) + : MetalExecution(backend) { + mOc = oc; + mWeight = std::move(weight); + mDequantScaleBias = std::move(dequantScaleBias); + mDequantBits = dequantBits; + mScaleCoef = scaleCoef; +} + +ErrorCode MetalSharedGather::onResize(const std::vector &inputs, + const std::vector &outputs) { + auto backend = static_cast(this->backend()); + auto context = (__bridge MNNMetalContext *)backend->context(); + + auto input = inputs[0]; // indices tensor + auto output = outputs[0]; // gathered weight rows + + if (nullptr == mWeight.get() || nullptr == mDequantScaleBias.get()) { + // Only support quantized weights for SharedGather + return NOT_SUPPORT; + } + + // Logical sizes + int selectSize = input->elementSize(); + int ic = output->length(output->dimensions() - 1); + int oc = mOc; + int oc_4 = UP_DIV(oc, 4); + int ic_4 = UP_DIV(ic, 4); + + int bytes = backend->useFp16InsteadFp32() ? 2 : 4; + int blockSize = 1; + if (mDequantScaleBias.get()) { + // Layout in MetalConvolutionCommon::getDequantScale: [alignOutputCount, blockSize, 2, 4] + blockSize = (int)(mDequantScaleBias->usize() / bytes / oc_4 / 2 / 4); + if (blockSize <= 0) { + blockSize = 1; + } + } + if (ic % 16 != 0) { + MNN_PRINT("Currnetly metal shared gather don's support ic not align to 16: %d\n", ic); + return NOT_SUPPORT; + } + + // Prepare constant buffer shared by quant/dequant and gather kernels + mConstBuffer = backend->getConstBuffer(sizeof(Conv1x1Constants)); + auto param = (Conv1x1Constants *)mConstBuffer.contents; + ::memset(param, 0, sizeof(Conv1x1Constants)); + param->input_size = ic; // reinterpret as ic + param->input_slice = ic_4; // ic_4 + param->output_width = selectSize; + param->output_height = 1; + param->output_size = selectSize * ic; + param->output_slice = oc_4; + param->output_channel = oc; + param->batch = 1; + param->block_size = blockSize; + param->activation = 0; + param->scale_coef = mScaleCoef; + + // basic macro info for fp16/fp32 + std::string ftype = "float"; + std::string ftype2 = "float2"; + std::string ftype4 = "float4"; + std::string ftype2x4 = "float2x4"; + std::string ftype4x4 = "float4x4"; + if (backend->useFp16InsteadFp32()) { + ftype = "half"; + ftype2 = "half2"; + ftype4 = "half4"; + ftype2x4 = "half2x4"; + ftype4x4 = "half4x4"; + } + + auto baseDic = [NSMutableDictionary dictionaryWithCapacity:0]; + [baseDic setValue:@(ftype.c_str()) forKey:@"ftype"]; + [baseDic setValue:@(ftype2.c_str()) forKey:@"ftype2"]; + [baseDic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; + [baseDic setValue:@(ftype2x4.c_str()) forKey:@"ftype2x4"]; + [baseDic setValue:@(ftype4x4.c_str()) forKey:@"ftype4x4"]; + [baseDic setValue:@"1" forKey:@"MNN_METAL_FLOAT32_COMPUTER"]; + if (backend->useFp16InsteadFp32()) { + [baseDic setValue:@"1" forKey:@"MNN_METAL_FLOAT16_STORAGE"]; + } + + MetalRuntime *rt = (MetalRuntime *)backend->runtime(); + std::string basicShaderPrefix = gBasicConvPrefix; + + // Preferred path: direct int4/int8 quant gather in shader + mQuantPipeline = nil; + + MTLCompileOptions *optionQuant = [[MTLCompileOptions alloc] init]; + NSMutableDictionary *dic = [baseDic mutableCopy]; + std::vector keys = {ftype4, "MNN_METAL_FLOAT32_COMPUTER", "shared_gather_quant"}; + if (mDequantBits == 4) { + [dic setValue:@"1" forKey:@"W_QUANT_4"]; + keys.emplace_back("W_QUANT_4"); + } else { + [dic setValue:@"1" forKey:@"W_QUANT_8"]; + keys.emplace_back("W_QUANT_8"); + } + optionQuant.preprocessorMacros = dic; + + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + std::string shader = basicShaderPrefix + gSharedGatherQuant; + pipeline = backend->makeComputePipelineWithSourceOption(shader.c_str(), "shared_gather_quant", optionQuant); + rt->insertPipeline(keys, pipeline); + } + mQuantPipeline = pipeline; + + auto threads = MTLSizeMake((NSUInteger)selectSize, UP_DIV(ic, 16), 1); + mQuantThreads = [context computeBestGroupAndLocal:pipeline threads:threads]; + + // In int4/int8 path we do not build global dequant + blit by default. + return NO_ERROR; +} + +void MetalSharedGather::onEncode(const std::vector &inputs, + const std::vector &outputs, + id encoder) { + auto backend = static_cast(this->backend()); + + auto input = inputs[0]; + auto output = outputs[0]; + + // Preferred path: direct quant gather + [encoder setComputePipelineState:mQuantPipeline]; + MetalBackend::setTensor(output, encoder, 0); // out + MetalBackend::setTensor(mWeight.get(), encoder, 1); // quant weight + MetalBackend::setTensor(input, encoder, 2); // indices + [encoder setBuffer:mConstBuffer offset:0 atIndex:3]; + if (nullptr != mDequantScaleBias.get()) { + MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 4); // dequantScaleBias + } + [encoder dispatchThreadgroups:mQuantThreads.first threadsPerThreadgroup:mQuantThreads.second]; + MNN_PRINT_ENCODER((__bridge MNNMetalContext *)backend->context(), encoder); + return; +} + +bool MetalSharedGather::onClone(Backend *bn, const Op *op, Execution **dst) { + if (nullptr == dst) { + return true; + } + *dst = new MetalSharedGather(bn, mOc, mWeight, mDequantScaleBias, mDequantBits, mScaleCoef); + return true; +} + +} // namespace MNN + +#endif /* MNN_METAL_ENABLED */ diff --git a/source/backend/metal/MetalSoftmaxShader.cpp b/source/backend/metal/MetalSoftmaxShader.cpp new file mode 100644 index 0000000000..90d2c71db3 --- /dev/null +++ b/source/backend/metal/MetalSoftmaxShader.cpp @@ -0,0 +1,212 @@ +// Copyright @ MNN +#include "MetalSoftmaxShader.hpp" + +namespace MNN { + +// Plane Softmax (scalar) +const char* gSoftmaxPlaneSrc = R"metal( +#include +using namespace metal; +struct softmax_shape { + int inside_size; + int axis_length; + int outside_size; + int flat_length; +}; +kernel void softmax_plane(const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant softmax_shape& s [[buffer(2)]], + uint2 gid [[thread_position_in_grid]]) { + if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; + const int axis_off = int(gid.y) * s.axis_length * s.inside_size + int(gid.x); + const device T* axis_in = in + axis_off; + device T* axis_out = out + axis_off; + float maxv = -FLT_MAX; + for (int i = 0; i < s.axis_length; ++i) { + maxv = max(maxv, float(axis_in[i * s.inside_size])); + } + float sumv = 0.0f; + for (int i = 0; i < s.axis_length; ++i) { + sumv += exp(float(axis_in[i * s.inside_size]) - maxv); + } + for (int i = 0; i < s.axis_length; ++i) { + axis_out[i * s.inside_size] = (T)(exp(float(axis_in[i * s.inside_size]) - maxv) / sumv); + } +} +)metal"; + +// Plane Softmax with simd group reduce (scalar) +const char* gSoftmaxPlaneSgSrc = R"metal( +#include +#include +using namespace metal; +struct softmax_shape { + int inside_size; + int axis_length; + int outside_size; + int flat_length; +}; +#define SIMD_GROUP_WIDTH 32 +kernel void softmax_plane_sg(const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant softmax_shape& s [[buffer(2)]], + uint2 gid [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]]) { + if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; + const int axis_off = int(gid.y) * s.axis_length * s.inside_size + int(gid.x); + const device T* axis_in = in + axis_off; + device T* axis_out = out + axis_off; + float lmax = -FLT_MAX; + for (int i = tiisg; i < s.axis_length; i += SIMD_GROUP_WIDTH) { + lmax = max(lmax, float(axis_in[i * s.inside_size])); + } + float maxv = simd_max(lmax); + float lsum = 0.0f; + for (int i = tiisg; i < s.axis_length; i += SIMD_GROUP_WIDTH) { + lsum += exp(float(axis_in[i * s.inside_size]) - maxv); + } + float sumv = simd_sum(lsum); + for (int i = tiisg; i < s.axis_length; i += SIMD_GROUP_WIDTH) { + axis_out[i * s.inside_size] = (T)(exp(float(axis_in[i * s.inside_size]) - maxv) / sumv); + } +} +)metal"; + +// Plane Softmax with multi-simdgroup threadgroup reduction +const char* gSoftmaxPlaneSgTG = R"metal( +#include +#include +using namespace metal; +struct softmax_shape { + int inside_size; + int axis_length; + int outside_size; + int flat_length; +}; +#define SIMD_GROUP_WIDTH 32 +#ifndef TG_SIZE +#define TG_SIZE 128 +#endif +#define SG_PER_TG (TG_SIZE / SIMD_GROUP_WIDTH) + +kernel void softmax_plane_sg_tg(const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant softmax_shape& s [[buffer(2)]], + uint2 gtp [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]]) { + if ((int)gtp.x >= s.inside_size || (int)gtp.y >= s.outside_size) return; + const int axis_off = int(gtp.y) * s.axis_length * s.inside_size + int(gtp.x); + const device T* axis_in = in + axis_off; + device T* axis_out = out + axis_off; + + const int stride = SIMD_GROUP_WIDTH * SG_PER_TG; + int start = int(tiisg) + int(sgitg) * SIMD_GROUP_WIDTH; + + // 1) Max reduction + float lmax = -FLT_MAX; + for (int i = start; i < s.axis_length; i += stride) { + lmax = max(lmax, float(axis_in[i * s.inside_size])); + } + float sgMax = simd_max(lmax); + threadgroup float tgMax[SG_PER_TG]; + if (tiisg == 0) tgMax[sgitg] = sgMax; + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float finalMaxStore[1]; + if (sgitg == 0 && tiisg == 0) { + float fm = -FLT_MAX; + for (int k = 0; k < SG_PER_TG; ++k) fm = max(fm, tgMax[k]); + finalMaxStore[0] = fm; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float maxv = finalMaxStore[0]; + + // 2) Sum reduction + float lsum = 0.0f; + for (int i = start; i < s.axis_length; i += stride) { + lsum += exp(float(axis_in[i * s.inside_size]) - maxv); + } + float sgSum = simd_sum(lsum); + threadgroup float tgSum[SG_PER_TG]; + if (tiisg == 0) tgSum[sgitg] = sgSum; + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float finalSumStore[1]; + if (sgitg == 0 && tiisg == 0) { + float fs = 0.0f; + for (int k = 0; k < SG_PER_TG; ++k) fs += tgSum[k]; + finalSumStore[0] = fs; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumv = finalSumStore[0]; + + // 3) Write back + for (int i = start; i < s.axis_length; i += stride) { + axis_out[i * s.inside_size] = (T)(exp(float(axis_in[i * s.inside_size]) - maxv) / sumv); + } +} +)metal"; + +// Attention variant (uses ftype and axis_align_length) +const char* gSoftmaxSgReduce = R"metal( +#include +using namespace metal; +struct softmax_shape { + int inside_size; + int axis_length; + int outside_size; + int axis_align_length; +}; +#define SIMD_GROUP_WIDTH 32 + +kernel void softmax_plane(const device ftype *in [[buffer(0)]], + device ftype *out [[buffer(1)]], + constant softmax_shape& s [[buffer(2)]], + uint2 gid [[thread_position_in_grid]]) { + if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; + auto in_offset = gid.y * s.axis_length * s.inside_size + gid.x; + auto out_offset = gid.y * s.axis_align_length * s.inside_size + gid.x; + auto axis_in = in + in_offset; + auto axis_out = out + out_offset; + float max1 = -FLT_MAX; + for (int i = 0; i < s.axis_length; i++) { + max1 = max(max1, float(axis_in[i * s.inside_size])); + } + float sum1 = 0; + for (int i = 0; i < s.axis_length; i++) { + sum1 += exp(float(axis_in[i * s.inside_size]) - float(max1)); + } + for (int i = 0; i < s.axis_align_length; i++) { + axis_out[i * s.inside_size] = i >= s.axis_length ? ftype(0.0) : ftype(exp(float(axis_in[i * s.inside_size]) - float(max1)) / sum1); + } +} + +kernel void softmax_plane_sg(const device ftype *in [[buffer(0)]], + device ftype *out [[buffer(1)]], + constant softmax_shape& s [[buffer(2)]], + uint2 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] + ) { + if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; + auto in_offset = gid.y * s.axis_length * s.inside_size + gid.x; + auto out_offset = gid.y * s.axis_align_length * s.inside_size + gid.x; + auto axis_in = in + in_offset; + auto axis_out = out + out_offset; + float max1 = -FLT_MAX; + for (int i = tiisg; i < s.axis_length; i+=SIMD_GROUP_WIDTH) { + max1 = max(max1, float(axis_in[i * s.inside_size])); + } + max1 = simd_max(max1); + float sum1 = 0; + for (int i = tiisg; i < s.axis_length; i+=SIMD_GROUP_WIDTH) { + sum1 += exp(float(axis_in[i * s.inside_size]) - float(max1)); + } + sum1 = simd_sum(sum1); + for (int i = tiisg; i < s.axis_align_length; i+=SIMD_GROUP_WIDTH) { + axis_out[i * s.inside_size] = i >= s.axis_length ? ftype(0.0) : ftype(exp(float(axis_in[i * s.inside_size]) - float(max1)) / sum1); + } +} + +)metal"; + +} diff --git a/source/backend/metal/MetalSoftmaxShader.hpp b/source/backend/metal/MetalSoftmaxShader.hpp new file mode 100644 index 0000000000..4b842b8f5e --- /dev/null +++ b/source/backend/metal/MetalSoftmaxShader.hpp @@ -0,0 +1,16 @@ +// Copyright @ MNN +#pragma once + +namespace MNN { + +// Plane softmax (scalar T macro) +extern const char* gSoftmaxPlaneSrc; +// Plane softmax with simd_max/simd_sum (scalar T macro) +extern const char* gSoftmaxPlaneSgSrc; +// Plane softmax with enlarged local size (multi-simdgroup reduce) +extern const char* gSoftmaxPlaneSgTG; + +// Plane softmax with simd reduce used by Attention (uses ftype and axis_align_length) +extern const char* gSoftmaxSgReduce; + +} diff --git a/source/backend/metal/MetalTopKV2.mm b/source/backend/metal/MetalTopKV2.mm index c37f0ff20e..ef56539214 100644 --- a/source/backend/metal/MetalTopKV2.mm +++ b/source/backend/metal/MetalTopKV2.mm @@ -15,40 +15,190 @@ #if MNN_METAL_ENABLED namespace MNN { -static const int kTopKThreadNumber = 128; -static const int kTopKLocalK = 8; -static const int kTopKCandidateNumber = kTopKThreadNumber * kTopKLocalK; +static const int kTopKLocalK = 16; -static const char* gTopKV2Template = R"metal( +static const char* gTopKV2K1Template = R"metal( #include #include using namespace metal; -#define THREAD_NUMBER 128 -#define LOCAL_K 8 -#define CANDIDATE_NUMBER (THREAD_NUMBER * LOCAL_K) +#define SIMD_GROUP_WIDTH 32 struct TopKParam { int4 size; // rowSize, k, numRows, pad }; -inline bool afterAsc(T aValue, int aIndex, T bValue, int bIndex) { +inline bool better(T aValue, int aIndex, T bValue, int bIndex) { + if (bIndex < 0) { + return true; + } + if (aIndex < 0) { + return false; + } +#ifdef SORT_DESC if (aValue > bValue) { return true; } if (aValue < bValue) { return false; } - // tie-break: larger index comes after - return aIndex > bIndex; +#else + if (aValue < bValue) { + return true; + } + if (aValue > bValue) { + return false; + } +#endif + return aIndex < bIndex; +} + +kernel void topkv2(device T* outValue [[buffer(0)]], + device int* outIndex [[buffer(1)]], + const device T* inValue [[buffer(2)]], + constant TopKParam& p [[buffer(3)]], +#ifdef SIMD_GROUP_REDUCE + uint3 tgp [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]] +#else + uint tid [[thread_index_in_threadgroup]], + uint3 tgp [[threadgroup_position_in_grid]] +#endif + ) { + const uint row = tgp.x; + const int rowSize = p.size.x; + const int numRows = p.size.z; + if ((int)row >= numRows) { + return; + } + +#ifdef IS_INT + const T initWorst = (T)(2147483647); + const T initBestWorst = (T)(-2147483648); +#else +#ifdef USE_FP16 + const T initWorst = (T)(65504.0h); + const T initBestWorst = (T)(-65504.0h); +#else + const T initWorst = (T)(FLT_MAX); + const T initBestWorst = (T)(-FLT_MAX); +#endif +#endif + + const device T* rowIn = inValue + row * (uint)rowSize; + + T bestVal; + int bestIdx = -1; +#ifdef SORT_DESC + bestVal = initBestWorst; +#else + bestVal = initWorst; +#endif + +#ifdef SIMD_GROUP_REDUCE + const uint tid = tiisg + sgitg * SIMD_GROUP_WIDTH; +#endif + + for (int i = (int)tid; i < rowSize; i += (int)THREAD_NUMBER) { + const T val = rowIn[i]; + if (better(val, i, bestVal, bestIdx)) { + bestVal = val; + bestIdx = i; + } + } + +#ifdef SIMD_GROUP_REDUCE + // SIMD group 内归约 + 跨 SG 合并(需要 threadgroup_barrier) +#ifdef SORT_DESC + T sgBestVal = simd_max(bestVal); +#else + T sgBestVal = simd_min(bestVal); +#endif + const int INF_IDX = 2147483647; + int candidate = (bestIdx >= 0 && bestVal == sgBestVal) ? bestIdx : INF_IDX; + int sgBestIdx = simd_min(candidate); + + // 写入每个 simdgroup 的结果 + threadgroup T sharedBestVal[THREAD_NUMBER / SIMD_GROUP_WIDTH]; + threadgroup int sharedBestIdx[THREAD_NUMBER / SIMD_GROUP_WIDTH]; + if (tiisg == 0) { + sharedBestVal[sgitg] = sgBestVal; + sharedBestIdx[sgitg] = sgBestIdx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // 跨 simdgroup 合并 + if (tiisg == 0 && sgitg == 0) { + const uint NUM_SG = THREAD_NUMBER / SIMD_GROUP_WIDTH; + T finalVal; + int finalIdx = -1; +#ifdef SORT_DESC + finalVal = initBestWorst; +#else + finalVal = initWorst; +#endif + for (uint i = 0; i < NUM_SG; ++i) { + T v = sharedBestVal[i]; + int idx = sharedBestIdx[i]; + if (better(v, idx, finalVal, finalIdx)) { + finalVal = v; + finalIdx = idx; + } + } + device T* rowOut = outValue + row * (uint)1; + device int* rowIdx = outIndex + row * (uint)1; + rowOut[0] = finalVal; + rowIdx[0] = finalIdx; + } +#else + // Fallback: threadgroup tree reduction + threadgroup T sharedBestVal[THREAD_NUMBER]; + threadgroup int sharedBestIdx[THREAD_NUMBER]; + sharedBestVal[tid] = bestVal; + sharedBestIdx[tid] = bestIdx; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = THREAD_NUMBER / 2; s > 0; s >>= 1) { + if (tid < s) { + T val1 = sharedBestVal[tid]; + int idx1 = sharedBestIdx[tid]; + T val2 = sharedBestVal[tid + s]; + int idx2 = sharedBestIdx[tid + s]; + if (!better(val1, idx1, val2, idx2)) { + sharedBestVal[tid] = val2; + sharedBestIdx[tid] = idx2; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tid == 0) { + device T* rowOut = outValue + row * (uint)1; + device int* rowIdx = outIndex + row * (uint)1; + rowOut[0] = sharedBestVal[0]; + rowIdx[0] = sharedBestIdx[0]; + } +#endif } +)metal"; + + +static const char* gTopKV2K32Template = R"metal( +#include +#include +using namespace metal; + +#define LOCAL_K 16 + +struct TopKParam { + int4 size; // rowSize, k, numRows, pad +}; inline bool better(T aValue, int aIndex, T bValue, int bIndex) { - // b is invalid => always better if (bIndex < 0) { return true; } - // a is invalid => never better if (aIndex < 0) { return false; } @@ -67,7 +217,6 @@ inline bool better(T aValue, int aIndex, T bValue, int bIndex) { return false; } #endif - // tie-break: smaller index wins return aIndex < bIndex; } @@ -88,11 +237,18 @@ kernel void topkv2(device T* outValue [[buffer(0)]], #ifdef IS_INT const T initWorst = (T)(2147483647); const T initBestWorst = (T)(-2147483648); +#else +#ifdef USE_FP16 + const T initWorst = (T)(65504.0h); + const T initBestWorst = (T)(-65504.0h); #else const T initWorst = (T)(FLT_MAX); const T initBestWorst = (T)(-FLT_MAX); +#endif #endif + const device T* rowIn = inValue + row * (uint)rowSize; + thread T localValue[LOCAL_K]; thread int localIndex[LOCAL_K]; #ifdef SORT_DESC @@ -107,25 +263,24 @@ kernel void topkv2(device T* outValue [[buffer(0)]], } #endif - const device T* rowIn = inValue + row * (uint)rowSize; - for (int i = (int)tid; i < rowSize; i += (int)THREAD_NUMBER) { const T value = rowIn[i]; - if (!better(value, i, localValue[LOCAL_K - 1], localIndex[LOCAL_K - 1])) { + if (!better(value, i, localValue[k - 1], localIndex[k - 1])) { continue; } - uint insertPos = LOCAL_K; - for (uint j = 0; j < LOCAL_K; ++j) { + uint insertPos = k; + for (uint j = 0; j < k; ++j) { if (better(value, i, localValue[j], localIndex[j])) { insertPos = j; break; } } - if (insertPos >= LOCAL_K) { + if (insertPos >= k) { continue; } - for (uint j = LOCAL_K - 1; j > insertPos; --j) { + for (uint j = k - 1; j > 0; --j) { + if (j == insertPos) break; localValue[j] = localValue[j - 1]; localIndex[j] = localIndex[j - 1]; } @@ -133,39 +288,41 @@ kernel void topkv2(device T* outValue [[buffer(0)]], localIndex[insertPos] = i; } - threadgroup T sharedValue[CANDIDATE_NUMBER]; - threadgroup int sharedIndex[CANDIDATE_NUMBER]; - const uint base = tid * LOCAL_K; - for (uint i = 0; i < LOCAL_K; ++i) { - sharedValue[base + i] = localValue[i]; - sharedIndex[base + i] = localIndex[i]; + threadgroup T sharedValue[THREAD_NUMBER][LOCAL_K]; + threadgroup int sharedIndex[THREAD_NUMBER][LOCAL_K]; + for (uint i = 0; i < k; ++i) { + sharedValue[tid][i] = localValue[i]; + sharedIndex[tid][i] = localIndex[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint size = 2; size <= CANDIDATE_NUMBER; size <<= 1) { - for (uint stride = size >> 1; stride > 0; stride >>= 1) { - for (uint idx = tid; idx < CANDIDATE_NUMBER; idx += THREAD_NUMBER) { - const uint ixj = idx ^ stride; - if (ixj <= idx) { - continue; - } - bool up = ((idx & size) == 0); -#ifdef SORT_DESC - up = !up; -#endif + for (uint s = THREAD_NUMBER / 2; s > 0; s >>= 1) { + if (tid < s) { + T aVals[LOCAL_K]; + int aIdxs[LOCAL_K]; + T bVals[LOCAL_K]; + int bIdxs[LOCAL_K]; + for (uint i = 0; i < k; ++i) { + aVals[i] = sharedValue[tid][i]; + aIdxs[i] = sharedIndex[tid][i]; + bVals[i] = sharedValue[tid + s][i]; + bIdxs[i] = sharedIndex[tid + s][i]; + } - const bool after = afterAsc(sharedValue[idx], sharedIndex[idx], sharedValue[ixj], sharedIndex[ixj]); - if (up == after) { - const T tValue = sharedValue[idx]; - sharedValue[idx] = sharedValue[ixj]; - sharedValue[ixj] = tValue; - const int tIndex = sharedIndex[idx]; - sharedIndex[idx] = sharedIndex[ixj]; - sharedIndex[ixj] = tIndex; + uint ai = 0, bi = 0; + for (uint oi = 0; oi < k; ++oi) { + if (better(aVals[ai], aIdxs[ai], bVals[bi], bIdxs[bi])) { + sharedValue[tid][oi] = aVals[ai]; + sharedIndex[tid][oi] = aIdxs[ai]; + ai++; + } else { + sharedValue[tid][oi] = bVals[bi]; + sharedIndex[tid][oi] = bIdxs[bi]; + bi++; } } - threadgroup_barrier(mem_flags::mem_threadgroup); } + threadgroup_barrier(mem_flags::mem_threadgroup); } if (tid == 0) { @@ -173,16 +330,24 @@ kernel void topkv2(device T* outValue [[buffer(0)]], device int* rowIdx = outIndex + row * (uint)k; const int realK = min(k, rowSize); for (int i = 0; i < realK; ++i) { - rowOut[i] = sharedValue[i]; - rowIdx[i] = sharedIndex[i]; + rowOut[i] = sharedValue[0][i]; + rowIdx[i] = sharedIndex[0][i]; } } } )metal"; + class MetalTopKV2 : public MetalExecution { +private: + id mParam = nil; + id mPipeline = nil; + int mGroupNumber = 0; + int mTopK = 0; + bool mLargest; + int mLocalThreadNumber = 0; public: - MetalTopKV2(id pipeline, Backend* backend) : MetalExecution(backend), mPipeline(pipeline) { + MetalTopKV2(Backend* backend, bool largest) : MetalExecution(backend), mLargest(largest) { auto mtbn = static_cast(backend); auto context = (__bridge MNNMetalContext*)mtbn->context(); mParam = mtbn->getConstBuffer(sizeof(int) * 4); @@ -203,9 +368,74 @@ ErrorCode onResize(const std::vector& inputs, const std::vector(backend()); + int kTopKThreadNumber = static_cast(mtbn->getRuntime())->maxThreadSize(); const int numRows = input->elementSize() / rowSize; const int k = output->length(output->dimensions() - 1); + if (k > 1) { + kTopKThreadNumber = 32768 / kTopKLocalK / (2 * sizeof(float)); + } + + const int kTopKCandidateNumber = kTopKThreadNumber * kTopKLocalK; + + if (k <= 0 || k > rowSize) { + return NOT_SUPPORT; + } + if (k > kTopKCandidateNumber) { + MNN_ERROR("Metal TopK don't support k=%dlarger than %d\n", k, kTopKCandidateNumber); + return NOT_SUPPORT; + } + + const bool useFp16 = mtbn->useFp16InsteadFp32(); + bool largest = mLargest; + NSString* T = MetalCast::getScalarType(inputs[0]->getType(), useFp16); + + std::vector keys = { + "topkv2", + std::string([T UTF8String]), + largest ? "largest" : "smallest", + }; + mLocalThreadNumber = kTopKThreadNumber; + + const char* sourceTemplate = nullptr; + if (k == 1) { + keys.push_back("k1"); + sourceTemplate = gTopKV2K1Template; + } else if (k <= kTopKLocalK) { + keys.push_back("smallk"); + sourceTemplate = gTopKV2K32Template; + } + auto pipeline = mtbn->runtime()->findPipeline(keys); + if (nil == pipeline) { + MTLCompileOptions* compileOptions = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + [dic setValue:T forKey:@"T"]; + if (largest) { + [dic setValue:@"1" forKey:@"SORT_DESC"]; + } + if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) { + [dic setValue:@"1" forKey:@"IS_INT"]; + } + if (useFp16 && inputs[0]->getType().code != halide_type_int) { + [dic setValue:@"1" forKey:@"USE_FP16"]; + } + // Keep THREAD_NUMBER in sync with host-side kTopKThreadNumber + [dic setValue:[NSString stringWithFormat:@"%d", kTopKThreadNumber] forKey:@"THREAD_NUMBER"]; + // Enable SIMD group reduction for K=1 when supported + if (k == 1 && ((MetalRuntime*)mtbn->runtime())->supportSimdGroupReduce()) { + [dic setValue:@"1" forKey:@"SIMD_GROUP_REDUCE"]; + } + compileOptions.preprocessorMacros = dic; + + pipeline = mtbn->makeComputePipelineWithSourceOption(sourceTemplate, "topkv2", compileOptions); + mtbn->runtime()->insertPipeline(keys, pipeline); + } + if (nil == pipeline) { + MNN_ERROR("Create TopKV2 pipeline error\n"); + return NOT_SUPPORT; + } + mPipeline = pipeline; auto p = (int*)mParam.contents; p[0] = rowSize; p[1] = k; @@ -213,6 +443,7 @@ ErrorCode onResize(const std::vector& inputs, const std::vector& inputs, const std::vector& ou if (mGroupNumber <= 0) { return; } + auto mtbn = static_cast(backend()); [encoder setComputePipelineState:mPipeline]; MetalBackend::setTensor(outputs[0], encoder, 0); MetalBackend::setTensor(outputs[1], encoder, 1); MetalBackend::setTensor(inputs[0], encoder, 2); [encoder setBuffer:mParam offset:0 atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(mGroupNumber, 1, 1) - threadsPerThreadgroup:MTLSizeMake(kTopKThreadNumber, 1, 1)]; + threadsPerThreadgroup:MTLSizeMake(mLocalThreadNumber, 1, 1)]; } -private: - id mParam = nil; - id mPipeline = nil; - int mGroupNumber = 0; }; class MetalTopKV2Creator : public MetalBackend::Creator { @@ -244,54 +472,17 @@ void onEncode(const std::vector& inputs, const std::vector& ou if (TensorUtils::getDescribe(inputs[0])->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { return nullptr; } - - const int rowSize = inputs[0]->length(inputs[0]->dimensions() - 1); - const int k = outputs[0]->length(outputs[0]->dimensions() - 1); - if (k <= 0 || k > rowSize) { - return nullptr; - } - // Limit by threadgroup candidate capacity: THREAD_NUMBER * LOCAL_K - if (k > kTopKCandidateNumber) { - return nullptr; - } - bool largest = true; auto param = op->main_as_TopKV2(); if (nullptr != param) { largest = param->largest(); } - - auto mtbn = static_cast(backend); - const bool useFp16 = mtbn->useFp16InsteadFp32(); - NSString* T = MetalCast::getScalarType(inputs[0]->getType(), useFp16); - - std::vector keys = { - "topkv2", - std::string([T UTF8String]), - largest ? "largest" : "smallest", - }; - - auto pipeline = mtbn->runtime()->findPipeline(keys); - if (nil == pipeline) { - MTLCompileOptions* compileOptions = [[MTLCompileOptions alloc] init]; - auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - [dic setValue:T forKey:@"T"]; - if (largest) { - [dic setValue:@"1" forKey:@"SORT_DESC"]; - } - if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) { - [dic setValue:@"1" forKey:@"IS_INT"]; - } - compileOptions.preprocessorMacros = dic; - - pipeline = mtbn->makeComputePipelineWithSourceOption(gTopKV2Template, "topkv2", compileOptions); - mtbn->runtime()->insertPipeline(keys, pipeline); - } - if (nil == pipeline) { - MNN_ERROR("Create TopKV2 pipeline error\n"); + auto output = outputs[0]; + const int k = output->length(output->dimensions() - 1); + if (k > kTopKLocalK) { return nullptr; } - return new MetalTopKV2(pipeline, backend); + return new MetalTopKV2(backend, largest); } }; diff --git a/source/backend/opencl/core/OpenCLOPRegister.cpp b/source/backend/opencl/core/OpenCLOPRegister.cpp index c5892cb0f4..922876c12b 100644 --- a/source/backend/opencl/core/OpenCLOPRegister.cpp +++ b/source/backend/opencl/core/OpenCLOPRegister.cpp @@ -70,78 +70,80 @@ extern void ___OpenCLSplitGeluBufCreator__OpType_SplitGeLU__BUFFER__(); extern void ___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); extern void ___OpenCLLinearAttentionBufCreator__OpType_LinearAttention__BUFFER__(); extern void ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); +extern void ___OpenCLRopeBufCreator__OpType_RoPE__BUFFER__(); #endif void registerOpenCLOps() { #ifndef MNN_OPENCL_BUFFER_CLOSED -___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__(); -___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__(); -___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__(); -___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__(); -___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__(); -___OpenCLRasterBufCreator__OpType_Raster__BUFFER__(); -___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__(); -___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__(); -___OpenCLInterpBufCreator__OpType_Interp__BUFFER__(); -___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__(); -___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__(); -___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__(); -___OpenCLSelectBufCreator__OpType_Select__BUFFER__(); -___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__(); -___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__(); -___OpenCLCastBufCreator__OpType_Cast__BUFFER__(); -___OpenCLReluBufCreator__OpType_ReLU__BUFFER__(); -___OpenCLReluBufCreator__OpType_PReLU__BUFFER__(); -___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__(); -___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__(); -___OpenCLLoopBufCreator__OpType_While__BUFFER__(); -___OpenCLRangeBufCreator__OpType_Range__BUFFER__(); -___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__(); -___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__(); -___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__(); -___OpenCLFuseBufCreator__OpType_Extra__BUFFER__(); -___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__(); -___OpenCLScaleBufCreator__OpType_Scale__BUFFER__(); -___OpenCLTopKV2BufCreator__OpType_TopKV2__BUFFER__(); + ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__(); + ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__(); + ___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__(); + ___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__(); + ___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__(); + ___OpenCLRasterBufCreator__OpType_Raster__BUFFER__(); + ___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__(); + ___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__(); + ___OpenCLInterpBufCreator__OpType_Interp__BUFFER__(); + ___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__(); + ___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__(); + ___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__(); + ___OpenCLSelectBufCreator__OpType_Select__BUFFER__(); + ___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__(); + ___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__(); + ___OpenCLCastBufCreator__OpType_Cast__BUFFER__(); + ___OpenCLReluBufCreator__OpType_ReLU__BUFFER__(); + ___OpenCLReluBufCreator__OpType_PReLU__BUFFER__(); + ___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__(); + ___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__(); + ___OpenCLLoopBufCreator__OpType_While__BUFFER__(); + ___OpenCLRangeBufCreator__OpType_Range__BUFFER__(); + ___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__(); + ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__(); + ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__(); + ___OpenCLFuseBufCreator__OpType_Extra__BUFFER__(); + ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__(); + ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__(); + ___OpenCLTopKV2BufCreator__OpType_TopKV2__BUFFER__(); #endif -___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__(); -___OpenCLMatMulCreator__OpType_MatMul__IMAGE__(); -___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__(); -___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__(); -___OpenCLUnaryCreator__OpType_TanH__IMAGE__(); -___OpenCLScaleCreator__OpType_Scale__IMAGE__(); -___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__(); -___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__(); -___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__(); -___OpenCLRangeCreator__OpType_Range__IMAGE__(); -___OpenCLRasterCreator__OpType_Raster__IMAGE__(); -___OpenCLFuseCreator__OpType_Extra__IMAGE__(); -___OpenCLLoopCreator__OpType_While__IMAGE__(); -___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__(); -___OpenCLReluCreator__OpType_ReLU__IMAGE__(); -___OpenCLReluCreator__OpType_PReLU__IMAGE__(); -___OpenCLReluCreator__OpType_ReLU6__IMAGE__(); -___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__(); -___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__(); -___OpenCLReductionCreator__OpType_Reduction__IMAGE__(); -___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__(); -___OpenCLPoolCreator__OpType_Pooling__IMAGE__(); -___OpenCLSelectCreator__OpType_Select__IMAGE__(); -___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__(); -___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__(); -___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__(); -___OpenCLCastCreator__OpType_Cast__IMAGE__(); -___OpenCLInterpCreator__OpType_Interp__IMAGE__(); -___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__(); -___OpenCLTopKV2Creator__OpType_TopKV2__IMAGE__(); + ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__(); + ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__(); + ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__(); + ___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__(); + ___OpenCLUnaryCreator__OpType_TanH__IMAGE__(); + ___OpenCLScaleCreator__OpType_Scale__IMAGE__(); + ___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__(); + ___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__(); + ___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__(); + ___OpenCLRangeCreator__OpType_Range__IMAGE__(); + ___OpenCLRasterCreator__OpType_Raster__IMAGE__(); + ___OpenCLFuseCreator__OpType_Extra__IMAGE__(); + ___OpenCLLoopCreator__OpType_While__IMAGE__(); + ___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__(); + ___OpenCLReluCreator__OpType_ReLU__IMAGE__(); + ___OpenCLReluCreator__OpType_PReLU__IMAGE__(); + ___OpenCLReluCreator__OpType_ReLU6__IMAGE__(); + ___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__(); + ___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__(); + ___OpenCLReductionCreator__OpType_Reduction__IMAGE__(); + ___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__(); + ___OpenCLPoolCreator__OpType_Pooling__IMAGE__(); + ___OpenCLSelectCreator__OpType_Select__IMAGE__(); + ___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__(); + ___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__(); + ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__(); + ___OpenCLCastCreator__OpType_Cast__IMAGE__(); + ___OpenCLInterpCreator__OpType_Interp__IMAGE__(); + ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__(); + ___OpenCLTopKV2Creator__OpType_TopKV2__IMAGE__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -___OpenCLSelfAttentionBufCreator__OpType_FmhaV2__BUFFER__(); -___OpenCLSplitGeluBufCreator__OpType_SplitGeLU__BUFFER__(); -___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); -___OpenCLLinearAttentionBufCreator__OpType_LinearAttention__BUFFER__(); -___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); + ___OpenCLSelfAttentionBufCreator__OpType_FmhaV2__BUFFER__(); + ___OpenCLSplitGeluBufCreator__OpType_SplitGeLU__BUFFER__(); + ___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); + ___OpenCLLinearAttentionBufCreator__OpType_LinearAttention__BUFFER__(); + ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); + ___OpenCLRopeBufCreator__OpType_RoPE__BUFFER__(); #endif } -} -} +} // namespace OpenCL +} // namespace MNN #endif diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp index 749618dc02..26fa8d20b1 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp @@ -13,8 +13,8 @@ namespace MNN { namespace OpenCL { -KVCacheCLManager::KVCacheCLManager(Backend *backend, bool kv_cahce) : mKVCache(kv_cahce){ - mOpenCLBackend = static_cast(backend); +KVCacheCLManager::KVCacheCLManager(Backend* backend, bool kv_cahce) : mKVCache(kv_cahce) { + mOpenCLBackend = static_cast(backend); } void KVCacheCLManager::allocKVCache(const KVMeta* meta, int seqlen) { @@ -22,7 +22,7 @@ void KVCacheCLManager::allocKVCache(const KVMeta* meta, int seqlen) { return; } mPastLength = meta != nullptr ? meta->previous : 0; - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { mByte = 2; } // Fully execute reallocKVCache (including Remove and mPastLength update) @@ -55,22 +55,26 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, int seqlen, bool isExe size_t newMaxlen = ROUND_UP(mMaxLength, 4); size_t bufferSize = UP_DIV(mMaxLength, 4) * mKvNumHead * mHeadDim * 4 * mByte; // past_key: [1, numhead, headdim, maxlen] - auto newKey = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, bufferSize); + auto newKey = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, bufferSize); // past_value: [1, numhead, maxlen, headdim] - auto newValue = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, bufferSize); + auto newValue = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, bufferSize); - if(needCopy){ + if (needCopy) { // copy key { size_t oldMaxlenSize = oldMaxlen * mByte; size_t newMaxlenSize = newMaxlen * mByte; - char *newKeyPtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*newKey, true, CL_MAP_WRITE, 0, bufferSize, nullptr, nullptr, &res); - char *keyPtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastKey.get(), true, CL_MAP_READ, 0, oldSize, nullptr, nullptr, &res); - if(newKeyPtr != nullptr && keyPtr != nullptr && res == CL_SUCCESS){ - for(int i = 0; i < mKvNumHead * mHeadDim; ++i){ + char* newKeyPtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *newKey, true, CL_MAP_WRITE, 0, bufferSize, nullptr, nullptr, &res); + char* keyPtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *mPastKey.get(), true, CL_MAP_READ, 0, oldSize, nullptr, nullptr, &res); + if (newKeyPtr != nullptr && keyPtr != nullptr && res == CL_SUCCESS) { + for (int i = 0; i < mKvNumHead * mHeadDim; ++i) { ::memcpy(newKeyPtr + i * newMaxlenSize, keyPtr + i * oldMaxlenSize, oldMaxlenSize); } - }else{ + } else { MNN_ERROR("Map error key_ptr == nullptr \n"); MNN_ASSERT(false); } @@ -80,15 +84,18 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, int seqlen, bool isExe // copy value { - char *newValuePtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*newValue, true, CL_MAP_WRITE, 0, bufferSize, nullptr, nullptr, &res); - char *valuePtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastValue.get(), true, CL_MAP_READ, 0, oldSize, nullptr, nullptr, &res); - if(newValuePtr != nullptr && valuePtr != nullptr && res == CL_SUCCESS){ - for(int i = 0; i < mKvNumHead; ++i){ - for(int j = 0; j < copylen; ++j){ - ::memcpy(newValuePtr + (i * newMaxlen + j) * mHeadDim * mByte, valuePtr + (i * oldMaxlen + j) * mHeadDim * mByte, mHeadDim * mByte); + char* newValuePtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *newValue, true, CL_MAP_WRITE, 0, bufferSize, nullptr, nullptr, &res); + char* valuePtr = (char*)mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *mPastValue.get(), true, CL_MAP_READ, 0, oldSize, nullptr, nullptr, &res); + if (newValuePtr != nullptr && valuePtr != nullptr && res == CL_SUCCESS) { + for (int i = 0; i < mKvNumHead; ++i) { + for (int j = 0; j < copylen; ++j) { + ::memcpy(newValuePtr + (i * newMaxlen + j) * mHeadDim * mByte, + valuePtr + (i * oldMaxlen + j) * mHeadDim * mByte, mHeadDim * mByte); } } - }else{ + } else { MNN_ERROR("Map error value_ptr == nullptr \n"); MNN_ASSERT(false); } @@ -99,14 +106,14 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, int seqlen, bool isExe mPastKey.reset(newKey); mPastValue.reset(newValue); // resize phase don't update mPastLength value, excute phase will update it - if(isExecute){ + if (isExecute) { mPastLength = start; } } // Remove // resize phase don't remove kvcache, excute phase will do it - if(isExecute){ + if (isExecute) { if (0 == meta->n_reserve) { mPastLength = start; return true; @@ -127,12 +134,14 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, int seqlen, bool isExe auto copySrcIndex = start + begin; auto copyDstIndex = start; - for(int i = 0; i < mKvNumHead * mHeadDim; i++) { - ::memcpy(keyPtr + (i * mMaxLength + copyDstIndex) * mByte, keyPtr + (i * mMaxLength + copySrcIndex) * mByte, length * mByte); + for (int i = 0; i < mKvNumHead * mHeadDim; i++) { + ::memcpy(keyPtr + (i * mMaxLength + copyDstIndex) * mByte, + keyPtr + (i * mMaxLength + copySrcIndex) * mByte, length * mByte); } - for(int i = 0; i < mKvNumHead; i++) { - for(int j = 0; j < length; j++) { - ::memcpy(valuePtr + (i * mMaxLength + copyDstIndex + j) * mHeadDim * mByte, valuePtr + (i * mMaxLength + copySrcIndex + j) * mHeadDim * mByte, mHeadDim * mByte); + for (int i = 0; i < mKvNumHead; i++) { + for (int j = 0; j < length; j++) { + ::memcpy(valuePtr + (i * mMaxLength + copyDstIndex + j) * mHeadDim * mByte, + valuePtr + (i * mMaxLength + copySrcIndex + j) * mHeadDim * mByte, mHeadDim * mByte); } } start += length; @@ -144,8 +153,8 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, int seqlen, bool isExe return true; } -void AttentionBufExecution::handleKVCache(const std::vector &inputs, const std::vector &outputs) { - if(mHasMask) { +void AttentionBufExecution::handleKVCache(const std::vector& inputs, const std::vector& outputs) { + if (mHasMask) { auto mask = inputs[3]; mIsAddMask = (mask->getType() == halide_type_of()); } @@ -156,30 +165,31 @@ void AttentionBufExecution::handleKVCache(const std::vector &inputs, c int batch = shape[0]; int seqlen = shape[1]; + int kvInputLen = key->shape()[1]; int numHead = shape[2]; int kvNumHead = key->shape()[2]; int headDim = shape[3]; - if(nullptr == mMeta) { + if (nullptr == mMeta) { mPastKvSeqlen = 0; - mKvSeqlen = seqlen; - mKeyValueMaxlen = ROUND_UP(seqlen, 4); - mDecodeTmpMaxlen = ROUND_UP(seqlen, 4); + mKvSeqlen = kvInputLen; + mKeyValueMaxlen = ROUND_UP(kvInputLen, 4); + mDecodeTmpMaxlen = ROUND_UP(kvInputLen, 4); return; } mKVCacheCLManager->setArgs(numHead, kvNumHead, headDim); - mKVCacheCLManager->allocKVCache(mMeta, seqlen); + mKVCacheCLManager->allocKVCache(mMeta, kvInputLen); mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4); mDecodeTmpMaxlen = mKeyValueMaxlen; mPastKvSeqlen = mKVCacheCLManager->pastKvLength(); - mKvSeqlen = mPastKvSeqlen + seqlen; + mKvSeqlen = mPastKvSeqlen + kvInputLen; } ErrorCode AttentionBufExecution::init() { - if(nullptr == mMeta) { + if (nullptr == mMeta) { return NO_ERROR; } - //clear update arg vector, if prefill and decode use the same one + // clear update arg vector, if prefill and decode use the same one mOpRecordUpdateInfo.clear(); mRgQUpdateInfo.update_kernel_args.clear(); mRgQUpdateInfo.update_global_size.clear(); @@ -206,85 +216,89 @@ ErrorCode AttentionBufExecution::init() { return NO_ERROR; } -ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, const std::vector &outputs){ - if(nullptr == mMeta) { +ErrorCode AttentionBufExecution::UpdateArgs(const std::vector& inputs, const std::vector& outputs) { + if (nullptr == mMeta) { return NO_ERROR; } auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; - auto mask = inputs[3]; auto shape = query->shape(); int batch = shape[0]; int seqlen = shape[1]; + int kvInputLen = key->shape()[1]; int numHead = shape[2]; int kvNumHead = key->shape()[2]; int headDim = shape[3]; int group_size = numHead / kvNumHead; - float scale = 1.0 / sqrt(headDim); - auto mask_shape = mask->shape(); - int dim = mask->dimensions(); - MNN_ASSERT(dim >= 2); - int mask_seqlen = mask_shape[dim - 2]; - int maskKvlen = mask_shape[dim - 1]; + float scale = (mAttnScale == 0.0f) ? (1.0f / sqrt(headDim)) : mAttnScale; mPastKvSeqlen = mKVCacheCLManager->pastKvLength(); - mKvSeqlen = mKVCacheCLManager->pastKvLength() + seqlen; - mKVCacheCLManager->addKvLength(seqlen); + mKvSeqlen = mKVCacheCLManager->pastKvLength() + kvInputLen; + mKVCacheCLManager->addKvLength(kvInputLen); // prefill - if(mIsDecode == false){ + if (mIsDecode == false) { int maskKvlen = mKvSeqlen; int maskQlen = seqlen; - if(mHasMask) { + if (mHasMask) { auto mask = inputs[3]; auto mask_shape = mask->shape(); int dim = mask->dimensions(); MNN_ASSERT(dim >= 2); maskQlen = mask_shape[dim - 2]; - maskKvlen = mask_shape[dim - 1]; + maskKvlen = mask_shape[dim - 1]; } // key value static memory has been changed, need reset args - if(mKeyValueMaxlen != ROUND_UP(mKVCacheCLManager->maxLength(), 4)){ + if (mKeyValueMaxlen != ROUND_UP(mKVCacheCLManager->maxLength(), 4)) { mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4); } - if(false == mLongPrefill){ + if (false == mLongPrefill) { mGlobalWorkSizeQk0 = UP_DIV(mKvSeqlen, 4); mQkPrefillGlobal_size[1] = ROUND_UP(mGlobalWorkSizeQk0, std::max((uint32_t)1, mLocalWorkSizeQk[1])); mGlobalWorkSizeQk[1] = mQkPrefillGlobal_size[1]; mTempQ.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * batch * numHead})); mTempQK.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); mTempSoftMax.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); - if(mIsAddMask) { - mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); - } else { - mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + if (mHasMask) { + if (mIsAddMask) { + mTempMask.reset( + Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } else { + mTempMask.reset( + Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } } mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); - mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + if (mHasMask) { + mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + } mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); - mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + if (mHasMask) { + mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + } mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); } - #ifndef ENABLE_OPENCL_TIME_PROFILER - if(mOpenCLBackend->isUseRecordQueue()){ - if(mLongPrefill){ +#ifndef ENABLE_OPENCL_TIME_PROFILER + if (mOpenCLBackend->isUseRecordQueue()) { + if (mLongPrefill) { mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); mRgUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))(); - }else{ + } else { mRgQUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQ.get())(); mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); - mRgMUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempMask.get())(); mQkUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempQ.get())(); mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mKVCacheCLManager->key()))(); - if(mHasMask){ + if (mHasMask) { + mRgMUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempMask.get())(); mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempMask.get())(); mQkUpdateInfo.update_kernel_args[4].arg_value = &openCLDeferBuffer(mTempQK.get())(); - }else{ + } else { mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempQK.get())(); + mQkUpdateInfo.update_kernel_args[4].arg_value = &openCLDeferBuffer(mTempQK.get())(); } mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQK.get())(); mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempSoftMax.get())(); @@ -293,15 +307,15 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))(); } } else { - #endif - if(mLongPrefill){ +#endif + if (mLongPrefill) { // rearrange key value cl_int ret = CL_SUCCESS; ret |= mKernel_rearrange_vec[0]->get().setArg(9, *mKVCacheCLManager->key()); ret |= mKernel_rearrange_vec[0]->get().setArg(10, *mKVCacheCLManager->value()); ret |= mKernel_rearrange_vec[0]->get().setArg(14, mKeyValueMaxlen); MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k"); - }else{ + } else { { // rearrange query cl_int ret = CL_SUCCESS; @@ -316,7 +330,7 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, ret |= mKernel_rearrange->get().setArg(6, mKeyValueMaxlen); MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k"); } - if(mHasMask){ + if (mHasMask) { // rearrange mask cl_int ret = CL_SUCCESS; ret |= mKernel_rearrangeMask->get().setArg(4, openCLDeferBuffer(mTempMask.get())); @@ -324,13 +338,17 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, } { // matmul qk - mGlobalWorkSizeQk = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(mKvSeqlen, 4)), static_cast(numHead*batch)}; + mGlobalWorkSizeQk = {static_cast(UP_DIV(seqlen, 4)), + static_cast(UP_DIV(mKvSeqlen, 4)), + static_cast(numHead * batch)}; cl_int ret = CL_SUCCESS; ret |= mKernel_qk->get().setArg(1, mGlobalWorkSizeQk0); ret |= mKernel_qk->get().setArg(3, openCLDeferBuffer(mTempQ.get())); ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key()); - if(mHasMask) { + if (mHasMask) { ret |= mKernel_qk->get().setArg(5, openCLDeferBuffer(mTempMask.get())); + } else { + ret |= mKernel_qk->get().setArg(5, openCLDeferBuffer(mTempQK.get())); } ret |= mKernel_qk->get().setArg(6, openCLDeferBuffer(mTempQK.get())); ret |= mKernel_qk->get().setArg(10, mKvSeqlen); @@ -366,15 +384,15 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qkv_decode"); } } - #ifndef ENABLE_OPENCL_TIME_PROFILER +#ifndef ENABLE_OPENCL_TIME_PROFILER } - #endif +#endif return NO_ERROR; } // Decode mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4); - if(mKvSeqlen > mDecodeTmpMaxlen){ + if (mKvSeqlen > mDecodeTmpMaxlen) { mDecodeTmpMaxlen = mKeyValueMaxlen; mTempQK.reset(Tensor::createDevice({mDecodeTmpMaxlen * numHead})); mTempSoftMax.reset(Tensor::createDevice({mDecodeTmpMaxlen * numHead})); @@ -389,7 +407,7 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, #ifndef ENABLE_OPENCL_TIME_PROFILER // use record, only update args - if(mOpenCLBackend->isUseRecordQueue()){ + if (mOpenCLBackend->isUseRecordQueue()) { mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); mQkUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->key()))(); mQkUpdateInfo.update_kernel_args[2].arg_value = &openCLDeferBuffer(mTempQK.get())(); @@ -463,16 +481,16 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, return NO_ERROR; } -int AttentionBufExecution::getLocalSize(int size, int maxGroupSize){ +int AttentionBufExecution::getLocalSize(int size, int maxGroupSize) { int local_size = 1; - while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){ + while (local_size * 2 <= maxGroupSize && local_size * 2 <= size) { local_size *= 2; } return local_size; } -ErrorCode AttentionBufExecution::longPrefillResize(const std::vector &inputs, const std::vector &outputs){ - +ErrorCode AttentionBufExecution::longPrefillResize(const std::vector& inputs, + const std::vector& outputs) { auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; @@ -481,56 +499,89 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int batch = shape[0]; int seqlen = shape[1]; + int kvInputLen = key->shape()[1]; int numHead = shape[2]; int kvNumHead = key->shape()[2]; int headDim = shape[3]; int group_size = numHead / kvNumHead; - float scale = 1.0 / sqrt(headDim); + float scale = (mAttnScale == 0.0f) ? (1.0f / sqrt(headDim)) : mAttnScale; + int maskQlen = seqlen; + int maskKvlen = mKvSeqlen; + if (mHasMask) { + auto mask = inputs[3]; + auto maskShape = mask->shape(); + int dim = mask->dimensions(); + MNN_ASSERT(dim >= 2); + maskQlen = maskShape[dim - 2]; + maskKvlen = maskShape[dim - 1]; + } mAlignQ = 32; mAlignKV = 32; mAlignHDK = 4; mAlignHDN = 32; - float useMemorySize = 1.0 * ROUND_UP(seqlen, mAlignQ) / 1024.0 * ROUND_UP(seqlen, mAlignKV) / 1024.0 * batch * numHead; + float useMemorySize = + 1.0 * ROUND_UP(seqlen, mAlignQ) / 1024.0 * ROUND_UP(mKvSeqlen, mAlignKV) / 1024.0 * batch * numHead; // elementSize larger than 32M - if(useMemorySize > 32.0) { + if (useMemorySize > 32.0) { mQseqSplitNum = useMemorySize >= 256.0 ? 8 : ((useMemorySize < 128.0) ? 2 : 4); } // splitPiecesSize need aligned to 32, make sure XgemmBatched globalsize be divisible by localsize int splitPiecesSize = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum; - while((splitPiecesSize % 32) != 0){ + while ((splitPiecesSize % 32) != 0) { mAlignQ *= 2; splitPiecesSize = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum; } - mKernel_rearrange_vec.resize(1); mGwsRearrgVec.resize(1); mLwsRearrgVec.resize(1); - mKernel_mask_vec.resize(1); mGwsMaskVec.resize(1); mLwsMaskVec.resize(1); - mKernel_qk_vec.resize(mQseqSplitNum); mGwsQkVec.resize(mQseqSplitNum); mLwsQkVec.resize(mQseqSplitNum); - mKernel_softmax_vec.resize(mQseqSplitNum); mGwsSoftMaxVec.resize(mQseqSplitNum); mLwsSoftMaxVec.resize(mQseqSplitNum); - mKernel_trans_vec.resize(mQseqSplitNum); mGwsTransVec.resize(mQseqSplitNum); mLwsTransVec.resize(mQseqSplitNum); - mKernel_qkv_vec.resize(mQseqSplitNum); mGwsQkvVec.resize(mQseqSplitNum); mLwsQkvVec.resize(mQseqSplitNum); - mKernel_clip_vec.resize(1); mGwsClipVec.resize(1); mLwsClipVec.resize(1); - - mTempQ.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDK) * batch * numHead})); - mTempK.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDK) * batch * numHead})); - mTempV.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDN) * batch * numHead})); - if(mHasMask) { - if(mIsAddMask) { - mTempMask.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch})); + mKernel_rearrange_vec.resize(1); + mGwsRearrgVec.resize(1); + mLwsRearrgVec.resize(1); + mKernel_mask_vec.resize(1); + mGwsMaskVec.resize(1); + mLwsMaskVec.resize(1); + mKernel_qk_vec.resize(mQseqSplitNum); + mGwsQkVec.resize(mQseqSplitNum); + mLwsQkVec.resize(mQseqSplitNum); + mKernel_softmax_vec.resize(mQseqSplitNum); + mGwsSoftMaxVec.resize(mQseqSplitNum); + mLwsSoftMaxVec.resize(mQseqSplitNum); + mKernel_trans_vec.resize(mQseqSplitNum); + mGwsTransVec.resize(mQseqSplitNum); + mLwsTransVec.resize(mQseqSplitNum); + mKernel_qkv_vec.resize(mQseqSplitNum); + mGwsQkvVec.resize(mQseqSplitNum); + mLwsQkvVec.resize(mQseqSplitNum); + mKernel_clip_vec.resize(1); + mGwsClipVec.resize(1); + mLwsClipVec.resize(1); + + mTempQ.reset( + Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDK) * batch * numHead})); + mTempK.reset(Tensor::createDevice( + {ROUND_UP(mKvSeqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDK) * batch * kvNumHead})); + mTempV.reset(Tensor::createDevice( + {ROUND_UP(mKvSeqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDN) * batch * kvNumHead})); + if (mHasMask) { + if (mIsAddMask) { + mTempMask.reset( + Tensor::createDevice({ROUND_UP(maskQlen, mAlignQ) * ROUND_UP(maskKvlen, mAlignKV) * batch})); } else { - mTempMask.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch})); + mTempMask.reset( + Tensor::createDevice({ROUND_UP(maskQlen, mAlignQ) * ROUND_UP(maskKvlen, mAlignKV) * batch})); } } - mTempQK.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum})); - mTempSoftMax.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum})); - mTempQKV.reset(Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDN) * batch * numHead})); - + mTempQK.reset(Tensor::createDevice( + {ROUND_UP(seqlen, mAlignQ) * ROUND_UP(mKvSeqlen, mAlignKV) * batch * numHead / mQseqSplitNum})); + mTempSoftMax.reset(Tensor::createDevice( + {ROUND_UP(seqlen, mAlignQ) * ROUND_UP(mKvSeqlen, mAlignKV) * batch * numHead / mQseqSplitNum})); + mTempQKV.reset( + Tensor::createDevice({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDN) * batch * numHead})); mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC); - if(mHasMask) { + if (mHasMask) { mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC); } mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); @@ -539,7 +590,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC); - if(mHasMask) { + if (mHasMask) { mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC); } mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); @@ -547,31 +598,34 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mTempQKV.get(), Backend::DYNAMIC); - // query: [batch, seqLenQ, headNum, headDim] -> mTempQ: [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ)] - // key: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempK: [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] - // value: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempV: [batch*headNum/group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDK] - // key & value -> pastKey & pastValue (copy) + // query: [batch, seqLenQ, headNum, headDim] -> mTempQ: [batch*headNum, ROUND_UP(headDim, mAlignHDK), + // ROUND_UP(seqLenQ, mAlignQ)] key: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempK: + // [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] value: [batch, seqLenKV/4, + // headNum/group, headDim, seqLenKV_4] -> mTempV: [batch*headNum/group, ROUND_UP(seqLenKV, mAlignKV), + // ROUND_UP(headDim, mAlignHDK] key & value -> pastKey & pastValue (copy) int seq_idx = 0; // rearrange qkv { std::set buildOption; - if((headDim % 4) != 0){ + if ((headDim % 4) != 0) { buildOption.emplace("-DHEADDIM_LEAVE"); } // generate cache for every option { auto option = buildOption; - auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), + inputs[0], outputs[0]); } { auto option = buildOption; option.emplace("-DSEQLEN_LEAVE"); - auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), + inputs[0], outputs[0]); } - if((seqlen % 4) != 0){ + if ((seqlen % 4) != 0 || (kvInputLen % 4) != 0) { buildOption.emplace("-DSEQLEN_LEAVE"); } - if(nullptr != mMeta) { + if (nullptr != mMeta) { buildOption.emplace("-DSAVE_KV"); } int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ); @@ -581,14 +635,16 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int head_dim_pack_v = ROUND_UP(headDim, mAlignHDN); int tile[4] = {mAlignQ, mAlignKV, mAlignHDK, mAlignHDN}; - int shape[4] = {seqlen, mKvSeqlen, numHead, headDim}; + int shape[4] = {seqlen, kvInputLen, numHead, headDim}; int param[4] = {group_size, batch, 0, 0}; - mKernel_rearrange_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_qkv", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange_vec[seq_idx])); + mKernel_rearrange_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_qkv", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange_vec[seq_idx])); - mGwsRearrgVec[seq_idx] = {static_cast(ALIMAX(UP_DIV(seq_len_pack_q, 4), UP_DIV(seq_len_pack_kv, 4))), \ - static_cast(ALIMAX(UP_DIV(head_dim_pack_qk, 4), UP_DIV(head_dim_pack_v, 4))), \ - static_cast(batch*numHead)}; + mGwsRearrgVec[seq_idx] = { + static_cast(ALIMAX(UP_DIV(seq_len_pack_q, 4), UP_DIV(seq_len_pack_kv, 4))), + static_cast(ALIMAX(UP_DIV(head_dim_pack_qk, 4), UP_DIV(head_dim_pack_v, 4))), + static_cast(batch * numHead)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -601,7 +657,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQ.get())); ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempK.get())); ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempV.get())); - if(nullptr != mMeta) { + if (nullptr != mMeta) { ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->key()); ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->value()); } @@ -611,34 +667,40 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mKeyValueMaxlen); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_qkv"); - mLwsRearrgVec[seq_idx] = localWS3DDefault(mGwsRearrgVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_qkv", mKernel_rearrange_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; - mGwsRearrgVec[seq_idx][0] = ROUND_UP(mGwsRearrgVec[seq_idx][0], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][0])); - mGwsRearrgVec[seq_idx][1] = ROUND_UP(mGwsRearrgVec[seq_idx][1], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][1])); - mGwsRearrgVec[seq_idx][2] = ROUND_UP(mGwsRearrgVec[seq_idx][2], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][2])); - if(nullptr != mMeta) { + mLwsRearrgVec[seq_idx] = + localWS3DDefault(mGwsRearrgVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_qkv", + mKernel_rearrange_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; + mGwsRearrgVec[seq_idx][0] = + ROUND_UP(mGwsRearrgVec[seq_idx][0], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][0])); + mGwsRearrgVec[seq_idx][1] = + ROUND_UP(mGwsRearrgVec[seq_idx][1], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][1])); + mGwsRearrgVec[seq_idx][2] = + ROUND_UP(mGwsRearrgVec[seq_idx][2], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][2])); + if (nullptr != mMeta) { mRgUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); mRgUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); } mRgUpdateInfo.update_kernel_args.push_back({0, 14, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], &mRgUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], + &mRgUpdateInfo); } // mask rearaange - if(mHasMask) - { + if (mHasMask) { std::set buildOption; - int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ); - int seq_len_pack_kv = ROUND_UP(mKvSeqlen, mAlignKV); - int shape[4] = {seqlen, mKvSeqlen, mAlignQ, mAlignKV}; + int seq_len_pack_q = ROUND_UP(maskQlen, mAlignQ); + int seq_len_pack_kv = ROUND_UP(maskKvlen, mAlignKV); + int shape[4] = {seqlen, maskKvlen, mAlignQ, mAlignKV}; - mKernel_mask_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_mask", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_mask_vec[seq_idx])); + mKernel_mask_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_mask", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_mask_vec[seq_idx])); - mGwsMaskVec[seq_idx] = {static_cast(UP_DIV(seq_len_pack_q, 4)), \ - static_cast(UP_DIV(seq_len_pack_kv, 4)), \ - static_cast(batch)}; + mGwsMaskVec[seq_idx] = {static_cast(UP_DIV(seq_len_pack_q, 4)), + static_cast(UP_DIV(seq_len_pack_kv, 4)), static_cast(batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -650,19 +712,23 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, shape); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask"); - mLwsMaskVec[seq_idx] = localWS3DDefault(mGwsMaskVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_mask", mKernel_mask_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLwsMaskVec[seq_idx] = + localWS3DDefault(mGwsMaskVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_mask", + mKernel_mask_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGwsMaskVec[seq_idx][0] = ROUND_UP(mGwsMaskVec[seq_idx][0], std::max((uint32_t)1, mLwsMaskVec[seq_idx][0])); mGwsMaskVec[seq_idx][1] = ROUND_UP(mGwsMaskVec[seq_idx][1], std::max((uint32_t)1, mLwsMaskVec[seq_idx][1])); mGwsMaskVec[seq_idx][2] = ROUND_UP(mGwsMaskVec[seq_idx][2], std::max((uint32_t)1, mLwsMaskVec[seq_idx][2])); mOpenCLBackend->recordKernel3d(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx]); } - for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + for (int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { // qk matmul { - // Q : [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, M] - // K : [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] -> [B, K, N] - // QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)] -> [B, M, N] + // Q : [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, + // M] K : [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] -> [B, K, N] QV: + // [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)] -> [B, M, + // N] int loop = batch * numHead; int e_pack = ROUND_UP(seqlen, mAlignQ); int e_pack_piece = e_pack / mQseqSplitNum; @@ -672,19 +738,27 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & std::set buildOptions; int biasType = 0; - std::vector bufferVec = {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get())}; - if(mHasMask) { + std::vector bufferVec = {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), + openCLBuffer(mTempQK.get())}; + if (mHasMask) { bufferVec.emplace_back(openCLBuffer(mTempMask.get())); + } else { + bufferVec.emplace_back(openCLBuffer(mTempQK.get())); } - if(mIsAddMask) { + if (mIsAddMask) { biasType = 2; - } else if(mHasMask) { - biasType = 5;// int value mask + } else if (mHasMask) { + biasType = 5; // int value mask } uint32_t layout = 14; // 10 means mix-precision, 4 means layout - auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)(biasType + 10*(group_size-1))}, bufferVec, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel()); - - int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; + auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, + (uint32_t)loop, (uint32_t)(biasType + 10 * (group_size - 1))}, + bufferVec, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), + mOpenCLBackend->getCLTuneLevel()); + + int KWG = param[0], KWI = param[1], MDIMA = param[2], MDIMC = param[3], MWG = param[4], NDIMB = param[5], + NDIMC = param[6], NWG = param[7], SA = param[8], SB = param[9], STRM = param[10], STRN = param[11], + VWM = param[12], VWN = param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -699,7 +773,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & buildOptions.emplace("-DSTRN=" + std::to_string(STRN)); buildOptions.emplace("-DVWM=" + std::to_string(VWM)); buildOptions.emplace("-DVWN=" + std::to_string(VWN)); - if(layout >= 4) { + if (layout >= 4) { buildOptions.emplace("-DOUTPUTMN"); } @@ -708,12 +782,12 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int localM = MDIMC; int localN = NDIMC; - if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { + if (mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { buildOptions.emplace("-DUSE_CL_MAD=1"); buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); } buildOptions.emplace("-DONLY_HAVE_ALPHA"); - if(biasType >= 1) { + if (biasType >= 1) { buildOptions.emplace("-DBIAS_TYPE=" + std::to_string(biasType)); } @@ -723,12 +797,14 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8"); buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16"); - mKernel_qk_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision()); + mKernel_qk_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision()); int out_per_thread_m = tileM / localM; int out_per_thread_n = tileN / localN; - mGwsQkVec[seq_idx] = {static_cast(e_pack_piece/out_per_thread_m), static_cast(h_pack/out_per_thread_n), static_cast(loop)}; + mGwsQkVec[seq_idx] = {static_cast(e_pack_piece / out_per_thread_m), + static_cast(h_pack / out_per_thread_n), static_cast(loop)}; mLwsQkVec[seq_idx] = {static_cast(localM), static_cast(localN), 1}; float alpha = scale; @@ -742,7 +818,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int stride[4] = {e_pack, h_pack, h_pack, h_pack}; int group[4] = {1, group_size, 1, loop}; - int idx = 0; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast(e_pack_piece)); ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast(h_pack)); @@ -751,8 +827,10 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, beta); ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQ.get())); ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempK.get())); - if(mHasMask) { + if (mHasMask) { ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempMask.get())); + } else { + ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get())); } ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get())); ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, batch_offset); @@ -769,18 +847,21 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)] // axis : 2 (last dim) int softmaxShape[4]; - softmaxShape[0] = batch*numHead; + softmaxShape[0] = batch * numHead; softmaxShape[1] = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum; softmaxShape[2] = ROUND_UP(mKvSeqlen, mAlignKV); - auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(256)); + auto MaxLocalSize = + std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(256)); int localSize = 64; std::set buildOption; buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - mKernel_softmax_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - mGwsSoftMaxVec[seq_idx] = {static_cast(localSize), static_cast(softmaxShape[1]), static_cast(softmaxShape[0])}; + mKernel_softmax_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + mGwsSoftMaxVec[seq_idx] = {static_cast(localSize), static_cast(softmaxShape[1]), + static_cast(softmaxShape[0])}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -794,7 +875,8 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & MNN_CHECK_CL_SUCCESS(ret, "setArg Attention softmax"); mLwsSoftMaxVec[seq_idx] = {static_cast(localSize), 1, 1}; - mOpenCLBackend->recordKernel3d(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx]); + mOpenCLBackend->recordKernel3d(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], + mLwsSoftMaxVec[seq_idx]); } { // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)] @@ -804,10 +886,12 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int transDimH = ROUND_UP(mKvSeqlen, mAlignKV); std::set buildOptions; - mKernel_trans_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans_vec[seq_idx])); + mKernel_trans_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + uint32_t maxWorkGroupSize = static_cast( + mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans_vec[seq_idx])); - mGwsTransVec[seq_idx] = {(uint32_t)transDimW/8, (uint32_t)transDimH/8, (uint32_t)(loop)}; + mGwsTransVec[seq_idx] = {(uint32_t)transDimW / 8, (uint32_t)transDimH / 8, (uint32_t)(loop)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -820,20 +904,28 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimW); ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimH); MNN_CHECK_CL_SUCCESS(ret, "setArg Attention transpose"); - mLwsTransVec[seq_idx] = localWS3DDefault(mGwsTransVec[seq_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "trans_3d_buf", mKernel_trans_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "self_attention_buf").first; - - mGwsTransVec[seq_idx][0] = ROUND_UP(mGwsTransVec[seq_idx][0], std::max((uint32_t)1, mLwsTransVec[seq_idx][0])); - mGwsTransVec[seq_idx][1] = ROUND_UP(mGwsTransVec[seq_idx][1], std::max((uint32_t)1, mLwsTransVec[seq_idx][1])); - mGwsTransVec[seq_idx][2] = ROUND_UP(mGwsTransVec[seq_idx][2], std::max((uint32_t)1, mLwsTransVec[seq_idx][2])); + mLwsTransVec[seq_idx] = + localWS3DDefault(mGwsTransVec[seq_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), + "trans_3d_buf", mKernel_trans_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), + "self_attention_buf") + .first; + + mGwsTransVec[seq_idx][0] = + ROUND_UP(mGwsTransVec[seq_idx][0], std::max((uint32_t)1, mLwsTransVec[seq_idx][0])); + mGwsTransVec[seq_idx][1] = + ROUND_UP(mGwsTransVec[seq_idx][1], std::max((uint32_t)1, mLwsTransVec[seq_idx][1])); + mGwsTransVec[seq_idx][2] = + ROUND_UP(mGwsTransVec[seq_idx][2], std::max((uint32_t)1, mLwsTransVec[seq_idx][2])); mOpenCLBackend->recordKernel3d(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx]); } // qk * value { - // Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, M] - // V : [Batch * numHead / group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDN)] -> [B, K, N] - // QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, N, M] + // Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> + // [B, K, M] V : [Batch * numHead / group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDN)] + // -> [B, K, N] QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ) / + // mQseqSplitNum] -> [B, N, M] int loop = batch * numHead; int e_pack = ROUND_UP(seqlen, mAlignQ); @@ -844,9 +936,14 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & std::set buildOptions; uint32_t layout = 0; - auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempQK.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel()); - - int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; + auto param = getGemmParams( + {(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, + {openCLBuffer(mTempQK.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, + mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel()); + + int KWG = param[0], KWI = param[1], MDIMA = param[2], MDIMC = param[3], MWG = param[4], NDIMB = param[5], + NDIMC = param[6], NWG = param[7], SA = param[8], SB = param[9], STRM = param[10], STRN = param[11], + VWM = param[12], VWN = param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -861,7 +958,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & buildOptions.emplace("-DSTRN=" + std::to_string(STRN)); buildOptions.emplace("-DVWM=" + std::to_string(VWM)); buildOptions.emplace("-DVWN=" + std::to_string(VWN)); - if(layout >= 4) { + if (layout >= 4) { buildOptions.emplace("-DOUTPUTMN"); } @@ -870,17 +967,19 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int localM = MDIMC; int localN = NDIMC; - if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { + if (mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { buildOptions.emplace("-DUSE_CL_MAD=1"); buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); } - mKernel_qkv_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision()); + mKernel_qkv_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision()); int out_per_thread_m = tileM / localM; int out_per_thread_n = tileN / localN; - mGwsQkvVec[seq_idx] = {static_cast(e_pack_piece/out_per_thread_m), static_cast(h_pack/out_per_thread_n), static_cast(loop)}; + mGwsQkvVec[seq_idx] = {static_cast(e_pack_piece / out_per_thread_m), + static_cast(h_pack / out_per_thread_n), static_cast(loop)}; mLwsQkvVec[seq_idx] = {static_cast(localM), static_cast(localN), 1}; float alpha = 1.0f; @@ -893,7 +992,7 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & int stride[4] = {e_pack_piece, h_pack, e_pack, h_pack}; int group[4] = {1, group_size, 1, loop}; - int idx = 0; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast(e_pack_piece)); ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast(h_pack)); @@ -919,10 +1018,12 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & // output: [batch, seqLenQ/4, headNum, headDim, seqLenQ_4] std::set buildOption; - mKernel_clip_vec[seq_idx] = runtime->buildKernel("attention_buf", "qkv_transpose_output", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_clip_vec[seq_idx])); + mKernel_clip_vec[seq_idx] = runtime->buildKernel("attention_buf", "qkv_transpose_output", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_clip_vec[seq_idx])); - mGwsClipVec[seq_idx] = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(headDim, 4)), static_cast(batch*numHead)}; + mGwsClipVec[seq_idx] = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(headDim, 4)), + static_cast(batch * numHead)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -937,7 +1038,10 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, numHead); ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, headDim); - mLwsClipVec[seq_idx] = localWS3DDefault(mGwsClipVec[seq_idx], maxWorkGroupSize, runtime, "qkv_transpose_output", mKernel_clip_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLwsClipVec[seq_idx] = + localWS3DDefault(mGwsClipVec[seq_idx], maxWorkGroupSize, runtime, "qkv_transpose_output", + mKernel_clip_vec[seq_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGwsClipVec[seq_idx][0] = ROUND_UP(mGwsClipVec[seq_idx][0], std::max((uint32_t)1, mLwsClipVec[seq_idx][0])); mGwsClipVec[seq_idx][1] = ROUND_UP(mGwsClipVec[seq_idx][1], std::max((uint32_t)1, mLwsClipVec[seq_idx][1])); mGwsClipVec[seq_idx][2] = ROUND_UP(mGwsClipVec[seq_idx][2], std::max((uint32_t)1, mLwsClipVec[seq_idx][2])); @@ -950,8 +1054,8 @@ ErrorCode AttentionBufExecution::longPrefillResize(const std::vector & return NO_ERROR; } -ErrorCode AttentionBufExecution::prefillResize(const std::vector &inputs, const std::vector &outputs){ - +ErrorCode AttentionBufExecution::prefillResize(const std::vector& inputs, + const std::vector& outputs) { auto runtime = mOpenCLBackend->getOpenCLRuntime(); auto query = inputs[0]; auto key = inputs[1]; @@ -960,23 +1064,24 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu int batch = shape[0]; int seqlen = shape[1]; + int kvInputLen = key->shape()[1]; int numHead = shape[2]; int kvNumHead = key->shape()[2]; int headDim = shape[3]; int groupSize = numHead / kvNumHead; - float scale = 1.0 / sqrt(headDim); + float scale = (mAttnScale == 0.0f) ? (1.0f / sqrt(headDim)) : mAttnScale; int maskKvlen = mKvSeqlen; int maskQlen = seqlen; - if(mHasMask) { + if (mHasMask) { auto mask = inputs[3]; auto mask_shape = mask->shape(); int dim = mask->dimensions(); MNN_ASSERT(dim >= 2); maskQlen = mask_shape[dim - 2]; - maskKvlen = mask_shape[dim - 1]; - if(mIsAddMask) { + maskKvlen = mask_shape[dim - 1]; + if (mIsAddMask) { mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); } else { mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); @@ -990,17 +1095,17 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); - if(mHasMask){ + if (mHasMask) { mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); } cl::Buffer keyBuffer, valueBuffer; - if(nullptr != mMeta) { + if (nullptr != mMeta) { keyBuffer = *mKVCacheCLManager->key(); valueBuffer = *mKVCacheCLManager->value(); } else { - mTempK.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch})); - mTempV.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch})); + mTempK.reset(Tensor::createDevice({ROUND_UP(kvInputLen, 4) * ROUND_UP(headDim, 4) * kvNumHead * batch})); + mTempV.reset(Tensor::createDevice({ROUND_UP(kvInputLen, 4) * ROUND_UP(headDim, 4) * kvNumHead * batch})); mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC); @@ -1011,7 +1116,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); - if(mHasMask){ + if (mHasMask) { mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); } @@ -1019,12 +1124,12 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu // rearrange query std::set buildOption; - mKernel_rearrangeQ = runtime->buildKernel("attention_buf", "rearrange_q", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeQ)); + mKernel_rearrangeQ = runtime->buildKernel("attention_buf", "rearrange_q", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeQ)); - mGlobalWorkSizeRearrgQ = {static_cast(UP_DIV(seqlen, 4)), \ - static_cast(UP_DIV(headDim, 4)), \ - static_cast(numHead*batch)}; + mGlobalWorkSizeRearrgQ = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(headDim, 4)), + static_cast(numHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1038,25 +1143,31 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_rearrangeQ->get().setArg(index++, numHead); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_q"); - mLocalWorkSizeRearrgQ = localWS3DDefault(mGlobalWorkSizeRearrgQ, maxWorkGroupSize, runtime, "rearrange_q", mKernel_rearrangeQ, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; - mGlobalWorkSizeRearrgQ[0] = ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0])); - mGlobalWorkSizeRearrgQ[1] = ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1])); - mGlobalWorkSizeRearrgQ[2] = ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2])); + mLocalWorkSizeRearrgQ = localWS3DDefault(mGlobalWorkSizeRearrgQ, maxWorkGroupSize, runtime, "rearrange_q", + mKernel_rearrangeQ, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; + mGlobalWorkSizeRearrgQ[0] = + ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0])); + mGlobalWorkSizeRearrgQ[1] = + ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1])); + mGlobalWorkSizeRearrgQ[2] = + ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2])); mRgQUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()}); mOpRecordUpdateInfo.emplace_back(&mRgQUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, &mRgQUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, + &mRgQUpdateInfo); } { // rearrange key std::set buildOption; buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); - mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange)); + mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange)); - mGlobalWorkSizeRearrg = {static_cast(UP_DIV(seqlen, 4)), \ - static_cast(UP_DIV(headDim, 4)), \ - static_cast(kvNumHead * batch)}; + mGlobalWorkSizeRearrg = {static_cast(UP_DIV(kvInputLen, 4)), + static_cast(UP_DIV(headDim, 4)), static_cast(kvNumHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1067,17 +1178,19 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_rearrange->get().setArg(index++, keyBuffer); ret |= mKernel_rearrange->get().setArg(index++, mPastKvSeqlen); ret |= mKernel_rearrange->get().setArg(index++, mKeyValueMaxlen); - ret |= mKernel_rearrange->get().setArg(index++, seqlen); + ret |= mKernel_rearrange->get().setArg(index++, kvInputLen); ret |= mKernel_rearrange->get().setArg(index++, kvNumHead); ret |= mKernel_rearrange->get().setArg(index++, numHead); ret |= mKernel_rearrange->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_k"); - mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", mKernel_rearrange, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", + mKernel_rearrange, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGlobalWorkSizeRearrg[0] = ROUND_UP(mGlobalWorkSizeRearrg[0], std::max((uint32_t)1, mLocalWorkSizeRearrg[0])); mGlobalWorkSizeRearrg[1] = ROUND_UP(mGlobalWorkSizeRearrg[1], std::max((uint32_t)1, mLocalWorkSizeRearrg[1])); mGlobalWorkSizeRearrg[2] = ROUND_UP(mGlobalWorkSizeRearrg[2], std::max((uint32_t)1, mLocalWorkSizeRearrg[2])); - if(nullptr != mMeta) { + if (nullptr != mMeta) { mRgUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); } mRgUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen}); @@ -1085,16 +1198,20 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo); mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo); } - if (mHasMask){ + if (mHasMask) { std::set buildOption; - if(mIsAddMask){ + if (mIsAddMask) { buildOption.emplace("-DADD_MASK"); - } else if(mHasMask) { + } else if (mHasMask) { buildOption.emplace("-DSET_MASK"); + } else { + buildOption.emplace("-DDEFAULT_MASK"); } - mKernel_rearrangeMask = runtime->buildKernel("attention_buf", "rearrange_mask_shortprefill", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - mGlobalWorkSizeRearrgM = {static_cast(UP_DIV(maskQlen, 4)), static_cast(UP_DIV(maskKvlen, 4)), static_cast(batch)}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeMask)); + mKernel_rearrangeMask = runtime->buildKernel("attention_buf", "rearrange_mask_shortprefill", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + mGlobalWorkSizeRearrgM = {static_cast(UP_DIV(maskQlen, 4)), + static_cast(UP_DIV(maskKvlen, 4)), static_cast(batch)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeMask)); uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[0]); @@ -1105,26 +1222,37 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_rearrangeMask->get().setArg(index++, maskQlen); ret |= mKernel_rearrangeMask->get().setArg(index++, maskKvlen); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask_shortprefill"); - mLocalWorkSizeRearrgM = localWS3DDefault(mGlobalWorkSizeRearrgM, maxWorkGroupSize, runtime, "rearrange_mask_shortprefill", mKernel_rearrangeMask, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; - mGlobalWorkSizeRearrgM[0] = ROUND_UP(mGlobalWorkSizeRearrgM[0], std::max((uint32_t)1, mLocalWorkSizeRearrgM[0])); - mGlobalWorkSizeRearrgM[1] = ROUND_UP(mGlobalWorkSizeRearrgM[1], std::max((uint32_t)1, mLocalWorkSizeRearrgM[1])); - mGlobalWorkSizeRearrgM[2] = ROUND_UP(mGlobalWorkSizeRearrgM[2], std::max((uint32_t)1, mLocalWorkSizeRearrgM[2])); + mLocalWorkSizeRearrgM = + localWS3DDefault(mGlobalWorkSizeRearrgM, maxWorkGroupSize, runtime, "rearrange_mask_shortprefill", + mKernel_rearrangeMask, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; + mGlobalWorkSizeRearrgM[0] = + ROUND_UP(mGlobalWorkSizeRearrgM[0], std::max((uint32_t)1, mLocalWorkSizeRearrgM[0])); + mGlobalWorkSizeRearrgM[1] = + ROUND_UP(mGlobalWorkSizeRearrgM[1], std::max((uint32_t)1, mLocalWorkSizeRearrgM[1])); + mGlobalWorkSizeRearrgM[2] = + ROUND_UP(mGlobalWorkSizeRearrgM[2], std::max((uint32_t)1, mLocalWorkSizeRearrgM[2])); mRgMUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()}); mOpRecordUpdateInfo.emplace_back(&mRgMUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, &mRgMUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, + &mRgMUpdateInfo); } { // matmul qk std::set buildOption; - if(mIsAddMask){ + if (mIsAddMask) { buildOption.emplace("-DADD_MASK"); - } else if(mHasMask) { + } else if (mHasMask) { buildOption.emplace("-DSET_MASK"); + } else { + buildOption.emplace("-DDEFAULT_MASK"); } buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(groupSize)); - mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask_prefill", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - mGlobalWorkSizeQk = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(mKvSeqlen, 4)), static_cast(numHead*batch)}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); + mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask_prefill", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + mGlobalWorkSizeQk = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(mKvSeqlen, 4)), + static_cast(numHead * batch)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1133,8 +1261,10 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[2]); ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQ.get())); ret |= mKernel_qk->get().setArg(index++, keyBuffer); - if(mHasMask) { + if (mHasMask) { ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempMask.get())); + } else { + ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get())); } ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get())); ret |= mKernel_qk->get().setArg(index++, scale); @@ -1146,24 +1276,27 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_qk->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask_prefill"); - mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask_prefill", mKernel_qk, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask_prefill", + mKernel_qk, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2])); mQkUpdateInfo.update_kernel_args.push_back({0, 1, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0}); mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()}); - if(nullptr != mMeta) { + if (nullptr != mMeta) { mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); } - if(mHasMask){ + if (mHasMask) { mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()}); mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKvSeqlen), &mKvSeqlen}); mQkUpdateInfo.update_kernel_args.push_back({0, 11, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); - }else{ + } else { mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); - mQkUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(mKvSeqlen), &mKvSeqlen}); - mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); + mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); + mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKvSeqlen), &mKvSeqlen}); + mQkUpdateInfo.update_kernel_args.push_back({0, 11, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); } mQkPrefillGlobal_size[0] = mGlobalWorkSizeQk[0]; mQkPrefillGlobal_size[1] = mGlobalWorkSizeQk[1]; @@ -1174,15 +1307,17 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu } { // softmax - int inside = ROUND_UP(seqlen, 4); + int inside = ROUND_UP(seqlen, 4); int outside = numHead * batch; int localSize = 64; std::set buildOption; buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_v4_buf", buildOption, mOpenCLBackend->getPrecision()); - mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(UP_DIV(inside, 4)), static_cast(outside)}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); + mKernel_softmax = + runtime->buildKernel("softmax_buf", "softmax_v4_buf", buildOption, mOpenCLBackend->getPrecision()); + mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(UP_DIV(inside, 4)), + static_cast(outside)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1197,29 +1332,37 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu MNN_CHECK_CL_SUCCESS(ret, "setArg softmax"); mLocalWorkSizeSoftMax = {static_cast(localSize), 1, 1}; - if(localSize == 1){ - mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax, mOpenCLBackend->getCLTuneLevel(), "softmax_buf").first; + if (localSize == 1) { + mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", + mKernel_softmax, mOpenCLBackend->getCLTuneLevel(), "softmax_buf") + .first; } - mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); - mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); - mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); + mGlobalWorkSizeSoftMax[0] = + ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); + mGlobalWorkSizeSoftMax[1] = + ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); + mGlobalWorkSizeSoftMax[2] = + ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); + mSoftMaxUpdateInfo.update_kernel_args.push_back( + {0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen}); mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + &mSoftMaxUpdateInfo); } { // rearrange value std::set buildOption; buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); - mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeV)); + mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeV)); - mGlobalWorkSizeRearrgV = {static_cast(UP_DIV(headDim, 4)), \ - static_cast(UP_DIV(seqlen, 4)), \ - static_cast(kvNumHead * batch)}; + mGlobalWorkSizeRearrgV = {static_cast(UP_DIV(headDim, 4)), + static_cast(UP_DIV(kvInputLen, 4)), + static_cast(kvNumHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1230,30 +1373,41 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_rearrangeV->get().setArg(index++, valueBuffer); ret |= mKernel_rearrangeV->get().setArg(index++, mPastKvSeqlen); ret |= mKernel_rearrangeV->get().setArg(index++, mKeyValueMaxlen); - ret |= mKernel_rearrangeV->get().setArg(index++, seqlen); + ret |= mKernel_rearrangeV->get().setArg(index++, kvInputLen); ret |= mKernel_rearrangeV->get().setArg(index++, kvNumHead); ret |= mKernel_rearrangeV->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_v"); - mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; - mGlobalWorkSizeRearrgV[0] = ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0])); - mGlobalWorkSizeRearrgV[1] = ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1])); - mGlobalWorkSizeRearrgV[2] = ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2])); - if(nullptr != mMeta) { + mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", + mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; + mGlobalWorkSizeRearrgV[0] = + ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0])); + mGlobalWorkSizeRearrgV[1] = + ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1])); + mGlobalWorkSizeRearrgV[2] = + ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2])); + if (nullptr != mMeta) { mRgVUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); } mRgVUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen}); mRgVUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); mOpRecordUpdateInfo.emplace_back(&mRgVUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, &mRgVUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + &mRgVUpdateInfo); } // qk * value { std::set buildOption; buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(groupSize)); - mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv_prefill", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qkv)); - mGlobalWorkSizeQkv = {static_cast(UP_DIV(headDim, 8)), static_cast(UP_DIV(seqlen, 4)), static_cast(numHead*batch)}; + if (mOutputC4) { + buildOption.emplace("-DATTENTION_C4"); + } + mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv_prefill", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qkv)); + mGlobalWorkSizeQkv = {static_cast(UP_DIV(headDim, 8)), static_cast(UP_DIV(seqlen, 4)), + static_cast(numHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1269,14 +1423,17 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_qkv->get().setArg(index++, numHead); ret |= mKernel_qkv->get().setArg(index++, kvNumHead); ret |= mKernel_qkv->get().setArg(index++, headDim); + ret |= mKernel_qkv->get().setArg(index++, batch); MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_prefill"); - mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv_prefill", mKernel_qkv, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv_prefill", + mKernel_qkv, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0])); mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1])); mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2])); mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); - if(nullptr != mMeta) { + if (nullptr != mMeta) { mQkvUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); } mQkvUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen}); @@ -1289,8 +1446,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu return NO_ERROR; } -ErrorCode AttentionBufExecution::decodeResize(const std::vector &inputs, const std::vector &outputs){ - +ErrorCode AttentionBufExecution::decodeResize(const std::vector& inputs, const std::vector& outputs) { auto runtime = mOpenCLBackend->getOpenCLRuntime(); auto query = inputs[0]; auto key = inputs[1]; @@ -1303,11 +1459,10 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input int kvNumHead = key->shape()[2]; int headDim = shape[3]; int group_size = numHead / kvNumHead; - float scale = 1.0 / sqrt(headDim); - + float scale = (mAttnScale == 0.0f) ? (1.0f / sqrt(headDim)) : mAttnScale; cl::Buffer keyBuffer, valueBuffer; - if(nullptr != mMeta) { + if (nullptr != mMeta) { keyBuffer = *mKVCacheCLManager->key(); valueBuffer = *mKVCacheCLManager->value(); } else { @@ -1331,12 +1486,12 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input // rearrange key std::set buildOption; - mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange)); + mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange)); - mGlobalWorkSizeRearrg = {static_cast(1), \ - static_cast(UP_DIV(headDim, 4)), \ - static_cast(kvNumHead * batch)}; + mGlobalWorkSizeRearrg = {static_cast(1), static_cast(UP_DIV(headDim, 4)), + static_cast(kvNumHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1353,16 +1508,19 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input ret |= mKernel_rearrange->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_k"); - mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", mKernel_rearrange, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", + mKernel_rearrange, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGlobalWorkSizeRearrg[0] = ROUND_UP(mGlobalWorkSizeRearrg[0], std::max((uint32_t)1, mLocalWorkSizeRearrg[0])); mGlobalWorkSizeRearrg[1] = ROUND_UP(mGlobalWorkSizeRearrg[1], std::max((uint32_t)1, mLocalWorkSizeRearrg[1])); mGlobalWorkSizeRearrg[2] = ROUND_UP(mGlobalWorkSizeRearrg[2], std::max((uint32_t)1, mLocalWorkSizeRearrg[2])); - if(nullptr != mMeta) { + if (nullptr != mMeta) { mRgUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); mRgUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen}); mRgUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + &mRgUpdateInfo); } else { mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg); } @@ -1371,9 +1529,10 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input // matmul qk std::set buildOption; buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size)); - mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_decode", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - mGlobalWorkSizeQk = {static_cast(UP_DIV(mKvSeqlen, 4)), static_cast(numHead)}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); + mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_decode", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + mGlobalWorkSizeQk = {static_cast(UP_DIV(mKvSeqlen, 4)), static_cast(numHead)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1389,10 +1548,12 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input ret |= mKernel_qk->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_decode"); - mLocalWorkSizeQk = localWS2DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_decode", mKernel_qk, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; + mLocalWorkSizeQk = localWS2DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_decode", + mKernel_qk, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); - if(nullptr != mMeta) { + if (nullptr != mMeta) { mQkUpdateInfo.update_kernel_args.push_back({0, 0, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0}); mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); @@ -1409,15 +1570,17 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input } { // softmax - int inside = 1; + int inside = 1; int outside = numHead; int localSize = 64; std::set buildOption; buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_in1_buf", buildOption, mOpenCLBackend->getPrecision()); - mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(inside), static_cast(outside)}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); + mKernel_softmax = + runtime->buildKernel("softmax_buf", "softmax_in1_buf", buildOption, mOpenCLBackend->getPrecision()); + mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(inside), + static_cast(outside)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1432,18 +1595,26 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input MNN_CHECK_CL_SUCCESS(ret, "setArg softmax"); mLocalWorkSizeSoftMax = {static_cast(localSize), 1, 1}; - if(localSize == 1){ - mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax, mOpenCLBackend->getCLTuneLevel(), "softmax_buf").first; + if (localSize == 1) { + mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", + mKernel_softmax, mOpenCLBackend->getCLTuneLevel(), "softmax_buf") + .first; } - mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); - mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); - mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); - if(nullptr != mMeta) { - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); + mGlobalWorkSizeSoftMax[0] = + ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); + mGlobalWorkSizeSoftMax[1] = + ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); + mGlobalWorkSizeSoftMax[2] = + ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); + if (nullptr != mMeta) { + mSoftMaxUpdateInfo.update_kernel_args.push_back( + {0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); + mSoftMaxUpdateInfo.update_kernel_args.push_back( + {0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen}); mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + &mSoftMaxUpdateInfo); } else { mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax); } @@ -1452,12 +1623,12 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input // rearrange value std::set buildOption; - mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeV)); + mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeV)); - mGlobalWorkSizeRearrgV = {static_cast(UP_DIV(headDim, 4)), \ - static_cast(1), \ - static_cast(kvNumHead * batch)}; + mGlobalWorkSizeRearrgV = {static_cast(UP_DIV(headDim, 4)), static_cast(1), + static_cast(kvNumHead * batch)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1473,16 +1644,22 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input ret |= mKernel_rearrangeV->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_v"); - mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel(), "attention_buf").first; - mGlobalWorkSizeRearrgV[0] = ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0])); - mGlobalWorkSizeRearrgV[1] = ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1])); - mGlobalWorkSizeRearrgV[2] = ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2])); - if(nullptr != mMeta) { + mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", + mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel(), "attention_buf") + .first; + mGlobalWorkSizeRearrgV[0] = + ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0])); + mGlobalWorkSizeRearrgV[1] = + ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1])); + mGlobalWorkSizeRearrgV[2] = + ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2])); + if (nullptr != mMeta) { mRgVUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); mRgVUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen}); mRgVUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); mOpRecordUpdateInfo.emplace_back(&mRgVUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, &mRgVUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + &mRgVUpdateInfo); } else { mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV); } @@ -1499,16 +1676,22 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input std::shared_ptr kernel[total_kernel * total_kernel]; std::vector globalWorkSize[total_kernel * total_kernel]; std::vector localWorkSize[total_kernel * total_kernel]; - std::pair min_cost(INT_MAX, 0);//(min_time, min_index) + std::pair min_cost(INT_MAX, 0); //(min_time, min_index) for (int i = 0; i < actual_kernel; i++) { - for(int j = 0; j < actual_kernel; j++){ + for (int j = 0; j < actual_kernel; j++) { int knl_idx = i * total_kernel + j; auto option = buildOption; option.emplace(unroll[j]); - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("attention_buf", kernelName[i], option, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); - globalWorkSize[knl_idx] = {static_cast(UP_DIV(headDim, itemC[i])), static_cast(numHead)}; + if (mOutputC4) { + option.emplace("-DATTENTION_C4"); + } + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "attention_buf", kernelName[i], option, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); + globalWorkSize[knl_idx] = {static_cast(UP_DIV(headDim, itemC[i])), + static_cast(numHead)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= kernel[knl_idx]->get().setArg(index++, globalWorkSize[knl_idx][0]); @@ -1523,19 +1706,25 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input ret |= kernel[knl_idx]->get().setArg(index++, headDim); MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_decode"); std::pair, int> retTune; - retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[i] + unroll[j], kernel[knl_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf"); - if(min_cost.first > retTune.second) { + retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, + mOpenCLBackend->getOpenCLRuntime(), kernelName[i] + unroll[j], + kernel[knl_idx], mOpenCLBackend->getCLTuneLevel(), "attention_buf"); + if (min_cost.first > retTune.second) { min_cost.first = retTune.second; min_cost.second = knl_idx; mLocalWorkSizeQkv = {retTune.first[0], retTune.first[1]}; } } } - int min_index = min_cost.second / 2; - int min_index_unroll = min_cost.second % 2; + int min_index = min_cost.second / 2; + int min_index_unroll = min_cost.second % 2; mGlobalWorkSizeQkv = {globalWorkSize[min_cost.second][0], globalWorkSize[min_cost.second][1]}; buildOption.emplace(unroll[min_index_unroll]); - mKernel_qkv = runtime->buildKernel("attention_buf", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + if (mOutputC4) { + buildOption.emplace("-DATTENTION_C4"); + } + mKernel_qkv = runtime->buildKernel("attention_buf", kernelName[min_index], buildOption, + mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -1553,8 +1742,9 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0])); mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1])); - if(nullptr != mMeta) { - mQkvUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); + if (nullptr != mMeta) { + mQkvUpdateInfo.update_kernel_args.push_back( + {0, 2, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); mQkvUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mKvSeqlen), &mKvSeqlen}); mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); @@ -1570,7 +1760,7 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input } // [Batch, q_seqlen, HeadNum, HeadDim] -> [Batch, kv_seqlen, HeadNum, HeadDim] -ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, const std::vector &outputs) { +ErrorCode AttentionBufExecution::onResize(const std::vector& inputs, const std::vector& outputs) { mOpenCLBackend->startRecord(mRecording); auto shape = inputs[0]->shape(); @@ -1579,12 +1769,12 @@ ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, c int numHead = shape[2]; int headDim = shape[3]; int kvNumHead = inputs[1]->shape()[2]; - if(nullptr != mMeta) { + if (nullptr != mMeta) { // if has kv_cache, default has mask -// MNN_ASSERT(inputs.size() > 3); + // MNN_ASSERT(inputs.size() > 3); } - mHasMask = inputs.size() > 3; - mIsDecode = seqlen == 1 && mMeta->add == 1; + mHasMask = inputs.size() > 3 && inputs[3]->dimensions() > 2; + mIsDecode = seqlen == 1 && (nullptr == mMeta || mMeta->add == 1); // reset updateArgs variable and kernel vector init(); @@ -1592,13 +1782,14 @@ ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, c handleKVCache(inputs, outputs); mLongPrefill = false; - if(mIsDecode) { + if (mIsDecode) { return decodeResize(inputs, outputs); } else { - if(mPastKvSeqlen == 0){ + if (mPastKvSeqlen == 0) { std::pair, uint32_t> tuneInfo; - std::string info = "attention_" + std::to_string(batch) + "_" + std::to_string(numHead) + "_" + std::to_string(headDim) + "_" + std::to_string(kvNumHead); - if(seqlen > 16){ + std::string info = "attention_" + std::to_string(batch) + "_" + std::to_string(numHead) + "_" + + std::to_string(headDim) + "_" + std::to_string(kvNumHead); + if (seqlen > 16) { if (getTunedInfo(info, {static_cast(seqlen)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getCLTuneLevel())) { mLongPrefill = tuneInfo.first[0]; @@ -1615,23 +1806,25 @@ ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, c longPrefillResize(inputs, outputs); auto longPrefillTime = getExecuteTime(); mLongPrefill = false; - if(longPrefillTime < shortPrefillTime){ + if (longPrefillTime < shortPrefillTime) { mLongPrefill = true; } - std::pair, uint32_t> tuneInfoTmp = std::make_pair, uint32_t>({mLongPrefill}, 0); - setTunedInfo(info, {static_cast(seqlen)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime(), "attention_buf"); + std::pair, uint32_t> tuneInfoTmp = + std::make_pair, uint32_t>({mLongPrefill}, 0); + setTunedInfo(info, {static_cast(seqlen)}, tuneInfoTmp, + mOpenCLBackend->getOpenCLRuntime(), "attention_buf"); init(); } else { - if(seqlen > 512){ + if (seqlen > 512) { mLongPrefill = true; } } } } } - if(mLongPrefill){ + if (mLongPrefill) { longPrefillResize(inputs, outputs); - }else{ + } else { prefillResize(inputs, outputs); } } @@ -1639,128 +1832,159 @@ ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, c return NO_ERROR; } -int AttentionBufExecution::getExecuteTime(){ +int AttentionBufExecution::getExecuteTime() { int executeTime = 0; auto runtime = mOpenCLBackend->getOpenCLRuntime(); - if(mLongPrefill) { + if (mLongPrefill) { int seq_idx = 0; cl::Event event0, event1, event2, event3, event4, event5, event6; - run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event0); + run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event0); executeTime += runtime->getEventTime(event0); - if(mHasMask) { - run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event1); + if (mHasMask) { + run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event1); executeTime += runtime->getEventTime(event1); } - for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { - run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event2); + for (int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event2); executeTime += runtime->getEventTime(event2); - run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event3); + run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event3); executeTime += runtime->getEventTime(event3); - run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event4); + run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event4); executeTime += runtime->getEventTime(event4); - run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event5); + run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event5); executeTime += runtime->getEventTime(event5); } seq_idx = 0; - run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event6); + run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event6); executeTime += runtime->getEventTime(event6); - } else{ + } else { cl::Event event0, event1, event2, event3, event4, event5, event6; - run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0); + run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, + mOpenCLBackend->getOpenCLRuntime(), &event0); executeTime += runtime->getEventTime(event0); - run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1); + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + mOpenCLBackend->getOpenCLRuntime(), &event1); executeTime += runtime->getEventTime(event1); - if(mHasMask) { - run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2); + if (mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, + mOpenCLBackend->getOpenCLRuntime(), &event2); executeTime += runtime->getEventTime(event2); } - run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3); + run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), + &event3); executeTime += runtime->getEventTime(event3); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + mOpenCLBackend->getOpenCLRuntime(), &event4); executeTime += runtime->getEventTime(event4); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + mOpenCLBackend->getOpenCLRuntime(), &event5); executeTime += runtime->getEventTime(event5); - run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6); + run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), + &event6); executeTime += runtime->getEventTime(event6); } return executeTime; } -ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { +ErrorCode AttentionBufExecution::onExecute(const std::vector& inputs, const std::vector& outputs) { #ifdef LOG_VERBOSE MNN_PRINT("start AttentionBufExecution onExecute !\n"); #endif - if(nullptr != mMeta){ + if (nullptr != mMeta) { // If allocKVCache already ran reallocKVCache(true) during resize phase, // skip it here to avoid double-executing Remove. For subsequent decode // iterations (no resize), mReallocDone is false so we still run it. if (mKVCacheCLManager->isReallocDone()) { mKVCacheCLManager->clearReallocDone(); } else { - auto shape = inputs[0]->shape(); - int seqlen = shape[1]; - mKVCacheCLManager->reallocKVCache(mMeta, seqlen); + int kvInputLen = inputs[1]->shape()[1]; + mKVCacheCLManager->reallocKVCache(mMeta, kvInputLen); } } UpdateArgs(inputs, outputs); #ifdef ENABLE_OPENCL_TIME_PROFILER - if(mLongPrefill) { + if (mLongPrefill) { int seq_idx = 0; cl::Event event0, event1, event2, event3, event4, event5, event6; - run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event0); + run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event0); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_qkv", event0}); - if(mHasMask) { - run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event1); + if (mHasMask) { + run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event1); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_mask", event1}); } - for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { - run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event2); + for (int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event2); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event2}); - run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event3); + run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event3); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event3}); - run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event4); + run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event4); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"transpose_softmax", event4}); - run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event5); + run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event5); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event5}); } seq_idx = 0; - run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event6); + run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime(), &event6); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_output", event6}); - } else{ - if(mIsDecode){ + } else { + if (mIsDecode) { cl::Event event0, event1, event2, event3, event4; - run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event0); + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + mOpenCLBackend->getOpenCLRuntime(), &event0); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_k", event0}); runKernel2D(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event1); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event1}); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event2); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + mOpenCLBackend->getOpenCLRuntime(), &event2); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event2}); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event3); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + mOpenCLBackend->getOpenCLRuntime(), &event3); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event3}); - runKernel2D(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event4); + runKernel2D(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), + &event4); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event4}); - }else{ + } else { cl::Event event0, event1, event2, event3, event4, event5, event6; - run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0); + run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, + mOpenCLBackend->getOpenCLRuntime(), &event0); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_q", event0}); - run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1); + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + mOpenCLBackend->getOpenCLRuntime(), &event1); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_k", event1}); - if(mHasMask) { - run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2); + if (mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, + mOpenCLBackend->getOpenCLRuntime(), &event2); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_mask_shortprefill", event2}); } - run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3); + run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), + &event3); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event3}); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + mOpenCLBackend->getOpenCLRuntime(), &event4); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event4}); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + mOpenCLBackend->getOpenCLRuntime(), &event5); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event5}); - run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6); + run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), + &event6); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event6}); } } #else - if(mOpenCLBackend->isUseRecordQueue()){ + if (mOpenCLBackend->isUseRecordQueue()) { mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo); #ifdef LOG_VERBOSE MNN_PRINT("End AttentionBufExecution onExecute... \n"); @@ -1768,37 +1992,51 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, return NO_ERROR; } - if(mLongPrefill) { + if (mLongPrefill) { int seq_idx = 0; - run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - if(mHasMask) { - run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); + if (mHasMask) { + run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); } - for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { - run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - + for (int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); } seq_idx = 0; - run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], mOpenCLBackend->getOpenCLRuntime()); - } else{ - if(mIsDecode){ - run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], + mOpenCLBackend->getOpenCLRuntime()); + } else { + if (mIsDecode) { + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + mOpenCLBackend->getOpenCLRuntime()); runKernel2D(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + mOpenCLBackend->getOpenCLRuntime()); runKernel2D(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime()); - }else{ - run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime()); - if(mHasMask) { - run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime()); + } else { + run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, + mOpenCLBackend->getOpenCLRuntime()); + if (mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, + mOpenCLBackend->getOpenCLRuntime()); } run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, + mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, + mOpenCLBackend->getOpenCLRuntime()); run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime()); } } @@ -1811,19 +2049,29 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, return NO_ERROR; } -AttentionBufExecution::AttentionBufExecution(const MNN::Op *op, Backend* backend, bool kv_cahce) : CommonExecution(backend, op) { +AttentionBufExecution::AttentionBufExecution(const MNN::Op* op, Backend* backend, bool outputC4) + : CommonExecution(backend, op) { mMeta = (KVMeta*)(backend->getMetaPtr()); + mOutputC4 = outputC4; + mAttnScale = op->main_as_AttentionParam()->attnScale(); mKVCacheCLManager.reset(new KVCacheCLManager(backend, nullptr != mMeta)); - mOpenCLBackend = static_cast(backend); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); + mOpenCLBackend = static_cast(backend); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); OPENCL_CHECK_KERNEL_CTOR(kernel); mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } -AttentionBufExecution::AttentionBufExecution(std::shared_ptr manager, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mKVCacheCLManager(manager) { +AttentionBufExecution::AttentionBufExecution(std::shared_ptr manager, const MNN::Op* op, + Backend* backend) + : CommonExecution(backend, op), mKVCacheCLManager(manager) { mMeta = (KVMeta*)(backend->getMetaPtr()); - mOpenCLBackend = static_cast(backend); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); + mOpenCLBackend = static_cast(backend); + auto param = op->main_as_AttentionParam(); + mOutputC4 = param->output_c4(); + mAttnScale = param->attnScale(); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); OPENCL_CHECK_KERNEL_CTOR(kernel); mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } @@ -1832,7 +2080,7 @@ bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) if (nullptr == dst) { return true; } - if (bn->getMetaPtr() == backend()->getMetaPtr()) { + if (bn->getMetaPtr() == mMeta && mMeta != nullptr) { *dst = new AttentionBufExecution(mKVCacheCLManager, op, bn); } else { *dst = new AttentionBufExecution(op, bn, true); @@ -1842,8 +2090,8 @@ bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) class AttentionBufCreator : public OpenCLBackend::Creator { public: - virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, - const MNN::Op *op, Backend *backend) const override { + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const override { for (int i = 0; i < inputs.size(); ++i) { TensorUtils::setTensorSupportPack(inputs[i], false); } @@ -1851,11 +2099,11 @@ class AttentionBufCreator : public OpenCLBackend::Creator { TensorUtils::setTensorSupportPack(outputs[i], false); } auto param = op->main_as_AttentionParam(); - OPENCL_CREATOR_CHECK(new AttentionBufExecution(op, backend, param->kv_cache())); + OPENCL_CREATOR_CHECK(new AttentionBufExecution(op, backend, param->output_c4())); } }; REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(AttentionBufCreator, OpType_Attention, BUFFER); } // namespace OpenCL } // namespace MNN -#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ \ No newline at end of file +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp index 3a38c62506..a840e7822a 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp @@ -19,34 +19,22 @@ namespace OpenCL { class KVCacheCLManager { public: - KVCacheCLManager(Backend *backend, bool kv_cache); + KVCacheCLManager(Backend* backend, bool kv_cache); ~KVCacheCLManager() = default; void allocKVCache(const KVMeta* meta, int seqlen); bool reallocKVCache(const KVMeta* meta, int seqlen, bool isExecute = true); - void setArgs(int numHead, int kvNumHead, int headDim){ + void setArgs(int numHead, int kvNumHead, int headDim) { mNumHead = numHead; mKvNumHead = kvNumHead; mHeadDim = headDim; } - int pastKvLength() { - return mPastLength; - } - void addKvLength(int seq_len){ - mPastLength += seq_len; - } - int maxLength() { - return mMaxLength; - } - int numHead() { - return mNumHead; - } - const cl::Buffer * key() { - return mPastKey.get(); - } - const cl::Buffer * value() { - return mPastValue.get(); - } + int pastKvLength() { return mPastLength; } + void addKvLength(int seq_len) { mPastLength += seq_len; } + int maxLength() { return mMaxLength; } + int numHead() { return mNumHead; } + const cl::Buffer* key() { return mPastKey.get(); } + const cl::Buffer* value() { return mPastValue.get(); } // Called after allocKVCache completes reallocKVCache in resize phase. // onExecute checks this to avoid double-executing realloc/Remove. @@ -59,40 +47,40 @@ class KVCacheCLManager { const int mExpandChunk = 64; std::shared_ptr mPastKey, mPastValue; int mPastLength = 0, mMaxLength = 0, mNumHead = 0, mKvNumHead = 0, mHeadDim = 0; - OpenCLBackend *mOpenCLBackend; + OpenCLBackend* mOpenCLBackend; int mByte = 4; }; class AttentionBufExecution : public CommonExecution { public: - AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache); - AttentionBufExecution(std::shared_ptr manager, const MNN::Op *op, Backend *backend); - ErrorCode longPrefillResize(const std::vector &inputs, const std::vector &outputs); - ErrorCode prefillResize(const std::vector &inputs, const std::vector &outputs); - ErrorCode decodeResize(const std::vector &inputs, const std::vector &outputs); + AttentionBufExecution(const MNN::Op* op, Backend* backend, bool outputC4); + AttentionBufExecution(std::shared_ptr manager, const MNN::Op* op, Backend* backend); + ErrorCode longPrefillResize(const std::vector& inputs, const std::vector& outputs); + ErrorCode prefillResize(const std::vector& inputs, const std::vector& outputs); + ErrorCode decodeResize(const std::vector& inputs, const std::vector& outputs); - ErrorCode UpdateArgs(const std::vector &inputs, const std::vector &outputs); + ErrorCode UpdateArgs(const std::vector& inputs, const std::vector& outputs); ErrorCode init(); int getExecuteTime(); virtual ~AttentionBufExecution() = default; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; + virtual ErrorCode onExecute(const std::vector& inputs, const std::vector& outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; private: - + bool mOutputC4 = false; + float mAttnScale = 0.0f; KVMeta* mMeta; int getLocalSize(int size, int maxGroupSize); bool mIsDecode = false; - void handleKVCache(const std::vector &inputs, const std::vector &outputs); + void handleKVCache(const std::vector& inputs, const std::vector& outputs); int mPastKvSeqlen = 0; int mKvSeqlen = 0; int mKeyValueMaxlen = 0; int mDecodeTmpMaxlen = 0; - uint32_t mMaxWorkGroupSize; - OpenCLBackend *mOpenCLBackend; + OpenCLBackend* mOpenCLBackend; RecordUpdateInfo mRgUpdateInfo; RecordUpdateInfo mRgQUpdateInfo; RecordUpdateInfo mRgMUpdateInfo; @@ -106,6 +94,7 @@ class AttentionBufExecution : public CommonExecution { std::vector mOpRecordUpdateInfo; std::shared_ptr mKVCacheCLManager; std::shared_ptr mTempQK, mTempSoftMax; + private: int mAlignQ, mAlignKV, mAlignHDK, mAlignHDN; bool mLongPrefill = false; @@ -114,6 +103,7 @@ class AttentionBufExecution : public CommonExecution { bool mIsAddMask = false; bool mNeedKvCache = true; bool mHasMask = false; + private: std::vector> mKernel_rearrange_vec; std::vector> mKernel_mask_vec; @@ -122,7 +112,7 @@ class AttentionBufExecution : public CommonExecution { std::vector> mKernel_qk_vec; std::vector> mKernel_softmax_vec; std::vector> mKernel_qkv_vec; - + std::vector> mGwsQkVec; std::vector> mLwsQkVec; std::vector> mGwsSoftMaxVec; @@ -137,6 +127,7 @@ class AttentionBufExecution : public CommonExecution { std::vector> mLwsTransVec; std::vector> mGwsClipVec; std::vector> mLwsClipVec; + private: std::shared_ptr mKernel_rearrangeQ; std::shared_ptr mKernel_rearrangeV; @@ -145,7 +136,7 @@ class AttentionBufExecution : public CommonExecution { std::shared_ptr mKernel_qk; std::shared_ptr mKernel_softmax; std::shared_ptr mKernel_qkv; - + std::vector mGlobalWorkSizeQk; std::vector mLocalWorkSizeQk; std::vector mGlobalWorkSizeSoftMax; @@ -160,9 +151,8 @@ class AttentionBufExecution : public CommonExecution { std::vector mLocalWorkSizeRearrg; std::vector mGlobalWorkSizeRearrgM; std::vector mLocalWorkSizeRearrgM; - }; } // namespace OpenCL } // namespace MNN #endif /* AttentionBufExecution_hpp */ -#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ \ No newline at end of file +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ diff --git a/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp b/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp index 3241ca39b1..f234cb095c 100644 --- a/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp @@ -13,9 +13,11 @@ namespace MNN { namespace OpenCL { -BinaryBufExecution::BinaryBufExecution(const std::vector &inputs, const std::string &compute, const MNN::Op *op, Backend *backend) +BinaryBufExecution::BinaryBufExecution(const std::vector& inputs, const std::string& compute, + const MNN::Op* op, Backend* backend) : CommonExecution(backend, op), mCompute(compute) { - if(op->type() == OpType_BinaryOp && op->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (inputs[0]->getType().code == halide_type_int || inputs[0]->getType().code == halide_type_uint)){ + if (op->type() == OpType_BinaryOp && op->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && + (inputs[0]->getType().code == halide_type_int || inputs[0]->getType().code == halide_type_uint)) { mBuildOptions.emplace("-DINT_COMPUTE_MOD"); } mBuildOptions.emplace("-DOPERATOR=" + compute); @@ -23,36 +25,37 @@ BinaryBufExecution::BinaryBufExecution(const std::vector &inputs, cons uint32_t BinaryBufExecution::realSize(const Tensor* tensor) { uint32_t num = 1; - for(int i = 0; i < tensor->dimensions(); i++) { + for (int i = 0; i < tensor->dimensions(); i++) { num *= tensor->length(i); } return num; } #ifdef MNN_SUPPORT_INTEL_SUBGROUP -ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inputs, const std::vector &outputs) { - auto openCLBackend = static_cast(backend()); - auto output = outputs[0]; - auto inputShape0 = tensorShapeFormat(inputs[0]); - auto inputShape1 = tensorShapeFormat(inputs[1]); - auto outputShape = tensorShapeFormat(output); - auto runTime = ((OpenCLBackend *)backend())->getOpenCLRuntime(); - int shape[4] = {outputShape[0], outputShape[1], outputShape[2], outputShape[3]}; +ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector& inputs, + const std::vector& outputs) { + auto openCLBackend = static_cast(backend()); + auto output = outputs[0]; + auto inputShape0 = tensorShapeFormat(inputs[0]); + auto inputShape1 = tensorShapeFormat(inputs[1]); + auto outputShape = tensorShapeFormat(output); + auto runTime = ((OpenCLBackend*)backend())->getOpenCLRuntime(); + int shape[4] = {outputShape[0], outputShape[1], outputShape[2], outputShape[3]}; int fullCount[2] = {1, 1}; int input0_c_pack = TensorUtils::getTensorChannelPack(inputs[0]); int input1_c_pack = TensorUtils::getTensorChannelPack(inputs[1]); int output_c_pack = TensorUtils::getTensorChannelPack(output); - + int activationType = 0; - if(mOp->type() == OpType_BinaryOp) { + if (mOp->type() == OpType_BinaryOp) { activationType = mOp->main_as_BinaryOp()->activationType(); } - auto &unit = mUnits[0]; + auto& unit = mUnits[0]; std::set buildOptions = mBuildOptions; - if(output->getType().code == halide_type_int) { - if(output->getType().bits == 8){ + if (output->getType().code == halide_type_int) { + if (output->getType().bits == 8) { buildOptions.emplace("-DINTEL_DATA=uchar"); buildOptions.emplace("-DAS_INPUT_DATA=as_char"); buildOptions.emplace("-DAS_INPUT_DATA4=as_char4"); @@ -60,7 +63,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu buildOptions.emplace("-DINTEL_SUB_GROUP_READ=intel_sub_group_block_read_uc"); buildOptions.emplace("-DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read_uc4"); buildOptions.emplace("-DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write_uc4"); - } else if(output->getType().bits == 32){ + } else if (output->getType().bits == 32) { buildOptions.emplace("-DINTEL_DATA=uint"); buildOptions.emplace("-DAS_INPUT_DATA=as_int"); buildOptions.emplace("-DAS_INPUT_DATA4=as_int4"); @@ -69,8 +72,8 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu buildOptions.emplace("-DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4"); buildOptions.emplace("-DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"); } - } else if(output->getType().code == halide_type_uint){ - if(output->getType().bits == 8){ + } else if (output->getType().code == halide_type_uint) { + if (output->getType().bits == 8) { buildOptions.emplace("-DINTEL_DATA=uchar"); buildOptions.emplace("-DAS_INPUT_DATA=as_uchar"); buildOptions.emplace("-DAS_INPUT_DATA4=as_uchar4"); @@ -78,7 +81,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu buildOptions.emplace("-DINTEL_SUB_GROUP_READ=intel_sub_group_block_read_uc"); buildOptions.emplace("-DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read_uc4"); buildOptions.emplace("-DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write_uc4"); - } else if(output->getType().bits == 32){ + } else if (output->getType().bits == 32) { buildOptions.emplace("-DINTEL_DATA=uint"); buildOptions.emplace("-DAS_INPUT_DATA=as_uint"); buildOptions.emplace("-DAS_INPUT_DATA4=as_uint4"); @@ -88,7 +91,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu buildOptions.emplace("-DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"); } } else { - if(openCLBackend->getPrecision() != BackendConfig::Precision_High){ + if (openCLBackend->getPrecision() != BackendConfig::Precision_High) { buildOptions.emplace("-DINTEL_DATA=ushort"); buildOptions.emplace("-DAS_INPUT_DATA=as_half"); buildOptions.emplace("-DAS_INPUT_DATA4=as_half4"); @@ -96,7 +99,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu buildOptions.emplace("-DINTEL_SUB_GROUP_READ=intel_sub_group_block_read_us"); buildOptions.emplace("-DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read_us4"); buildOptions.emplace("-DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write_us4"); - }else{ + } else { buildOptions.emplace("-DINTEL_DATA=uint"); buildOptions.emplace("-DAS_INPUT_DATA=as_float"); buildOptions.emplace("-DAS_INPUT_DATA4=as_float4"); @@ -107,8 +110,9 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu } } std::string kernelName = "binary_buf_c" + std::to_string(input0_c_pack) + "_c" + std::to_string(input1_c_pack) + - "_c" + std::to_string(output_c_pack); - unit.kernel = runTime->buildKernel("binary_subgroup_buf", kernelName, buildOptions, openCLBackend->getPrecision(), inputs[0], output); + "_c" + std::to_string(output_c_pack); + unit.kernel = runTime->buildKernel("binary_subgroup_buf", kernelName, buildOptions, openCLBackend->getPrecision(), + inputs[0], output); mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); fullCount[0] = realSize(inputs[0]) == 1 ? 0 : 1; @@ -117,14 +121,14 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu auto input0pad = TensorUtils::getDescribe(inputs[0])->mPads; auto input1pad = TensorUtils::getDescribe(inputs[1])->mPads; auto outputpad = TensorUtils::getDescribe(output)->mPads; - + uint32_t index = 0; cl_int ret = CL_SUCCESS; if (input0_c_pack == 16 && input1_c_pack == 16) { - mGlobalWorkSize = {(uint32_t)UP_DIV(outputShape[2], 4) * outputShape[1], - (uint32_t)ROUND_UP(outputShape[3], 16), (uint32_t)outputShape[0]}; + mGlobalWorkSize = {(uint32_t)UP_DIV(outputShape[2], 4) * outputShape[1], (uint32_t)ROUND_UP(outputShape[3], 16), + (uint32_t)outputShape[0]}; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {1, 16, 1}; + unit.localWorkSize = {1, 16, 1}; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); @@ -144,7 +148,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } else { mGlobalWorkSize = {(uint32_t)outputShape[2] * outputShape[1], (uint32_t)UP_DIV(outputShape[3], 4), - (uint32_t)outputShape[0]}; + (uint32_t)outputShape[0]}; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); @@ -163,37 +167,40 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu ret |= unit.kernel->get().setArg(index++, static_cast(outputpad.right)); MNN_CHECK_CL_SUCCESS(ret, "setArg BinaryBufExecution"); - mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel, openCLBackend->getCLTuneLevel(), "binary_subgroup_buf").first; + mLocalWorkSize = + localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, + unit.kernel, openCLBackend->getCLTuneLevel(), "binary_subgroup_buf") + .first; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } - + for (int i = 2; i < inputs.size(); ++i) { fullCount[0] = 1; fullCount[1] = realSize(inputs[i]) == 1 ? 0 : 1; - auto &unit = mUnits[i-1]; + auto& unit = mUnits[i - 1]; int input0_c_pack_tmp = TensorUtils::getTensorChannelPack(output); int input1_c_pack_tmp = TensorUtils::getTensorChannelPack(inputs[i]); int output_c_pack_tmp = TensorUtils::getTensorChannelPack(output); - std::string kernelNameTmp = "binary_buf_c" + std::to_string(input0_c_pack_tmp) + "_c" + std::to_string(input1_c_pack_tmp) + - "_c" + std::to_string(output_c_pack_tmp); - unit.kernel = runTime->buildKernel("binary_subgroup_buf", kernelNameTmp, buildOptions, openCLBackend->getPrecision(), inputs[i], output); + std::string kernelNameTmp = "binary_buf_c" + std::to_string(input0_c_pack_tmp) + "_c" + + std::to_string(input1_c_pack_tmp) + "_c" + std::to_string(output_c_pack_tmp); + unit.kernel = runTime->buildKernel("binary_subgroup_buf", kernelNameTmp, buildOptions, + openCLBackend->getPrecision(), inputs[i], output); mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); auto input0padtmp = TensorUtils::getDescribe(output)->mPads; auto input1padtmp = TensorUtils::getDescribe(inputs[i])->mPads; auto outputpadtmp = TensorUtils::getDescribe(output)->mPads; - uint32_t index = 0; if (input0_c_pack_tmp == 16 && input1_c_pack_tmp == 16) { - mGlobalWorkSize = {(uint32_t)UP_DIV(outputShape[2], 4) * outputShape[1], - (uint32_t)ROUND_UP(outputShape[3], 16), (uint32_t)outputShape[0]}; + mGlobalWorkSize = {(uint32_t)UP_DIV(outputShape[2], 4) * outputShape[1], + (uint32_t)ROUND_UP(outputShape[3], 16), (uint32_t)outputShape[0]}; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {1, 16, 1}; + unit.localWorkSize = {1, 16, 1}; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); @@ -213,7 +220,7 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } else { mGlobalWorkSize = {(uint32_t)outputShape[2] * outputShape[1], (uint32_t)UP_DIV(outputShape[3], 4), - (uint32_t)outputShape[0]}; + (uint32_t)outputShape[0]}; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); @@ -232,10 +239,13 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu ret |= unit.kernel->get().setArg(index++, static_cast(outputpadtmp.right)); MNN_CHECK_CL_SUCCESS(ret, "setArg BinaryBufExecution MultiInput"); - mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelNameTmp, unit.kernel, openCLBackend->getCLTuneLevel(), "binary_subgroup_buf").first; + mLocalWorkSize = + localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelNameTmp, + unit.kernel, openCLBackend->getCLTuneLevel(), "binary_subgroup_buf") + .first; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } } @@ -243,16 +253,17 @@ ErrorCode BinaryBufExecution::SubgroupOnResize(const std::vector &inpu } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ -ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { +ErrorCode BinaryBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { MNN_ASSERT(inputs.size() >= 2); mUnits.resize(inputs.size() - 1); - + auto openCLBackend = static_cast(backend()); auto output = outputs[0]; auto outputShape = tensorShapeFormat(output); - auto runTime = ((OpenCLBackend *)backend())->getOpenCLRuntime(); + auto runTime = ((OpenCLBackend*)backend())->getOpenCLRuntime(); #ifdef MNN_SUPPORT_INTEL_SUBGROUP - if (runTime->isSupportedIntelSubgroup() && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat) { + if (runTime->isSupportedIntelSubgroup() && + MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat) { return SubgroupOnResize(inputs, outputs); } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ @@ -260,32 +271,33 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons fullCount[0] = realSize(inputs[0]) == 1 ? 0 : 1; fullCount[1] = realSize(inputs[1]) == 1 ? 0 : 1; int totalSize = 0; - if(MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat){ + if (MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat) { totalSize = outputShape[0] * outputShape[1] * outputShape[2] * ROUND_UP(outputShape[3], 4); - }else{ + } else { totalSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; } - + int activationType = 0; - if(mOp->type() == OpType_BinaryOp) { + if (mOp->type() == OpType_BinaryOp) { activationType = mOp->main_as_BinaryOp()->activationType(); } - auto &unit = mUnits[0]; - + auto& unit = mUnits[0]; + std::set buildOptions = mBuildOptions; - if(totalSize % 4 != 0) { + if (totalSize % 4 != 0) { buildOptions.emplace("-DPACK_LEAVE"); } - if(fullCount[0] == 0) { + if (fullCount[0] == 0) { buildOptions.emplace("-DA_SINGLE"); } - if(fullCount[1] == 0) { + if (fullCount[1] == 0) { buildOptions.emplace("-DB_SINGLE"); } - unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, openCLBackend->getPrecision(), inputs[0], output); - mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, openCLBackend->getPrecision(), + inputs[0], output); + mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - mGlobalWorkSize = {(uint32_t)UP_DIV(totalSize, 4), (uint32_t)1}; + mGlobalWorkSize = {(uint32_t)UP_DIV(totalSize, 4), (uint32_t)1}; uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); @@ -299,16 +311,17 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons std::string name = "binary_buf"; mLocalWorkSize = {(uint32_t)16, (uint32_t)1}; - + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; openCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); for (int i = 2; i < inputs.size(); ++i) { fullCount[0] = 1; fullCount[1] = realSize(inputs[i]) == 1 ? 0 : 1; - auto &unit = mUnits[i-1]; - - unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, openCLBackend->getPrecision(), inputs[i], output); + auto& unit = mUnits[i - 1]; + + unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, openCLBackend->getPrecision(), + inputs[i], output); uint32_t index = 0; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); @@ -321,32 +334,36 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons MNN_CHECK_CL_SUCCESS(ret, "setArg BinaryBufExecution MultiInput"); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; openCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } - + return NO_ERROR; } class BinaryBufCreator : public OpenCLBackend::Creator { public: - virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, - const MNN::Op *op, Backend *backend) const override { + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const override { #ifdef MNN_SUPPORT_INTEL_SUBGROUP for (int i = 0; i < inputs.size(); ++i) { int channel = inputs[i]->channel(); - if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup() - && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[i])->dimensionFormat) { + if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup() && + MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[i])->dimensionFormat) { TensorUtils::setTensorChannelPack(inputs[i], 16); } } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ if (op->type() == OpType_Eltwise) { switch (op->main_as_Eltwise()->type()) { - case EltwiseType_SUM: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0+in1", op, backend)); - case EltwiseType_PROD: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0*in1", op, backend)); - case EltwiseType_SUB: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-in1", op, backend)); - case EltwiseType_MAXIMUM: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in0:in1", op, backend)); + case EltwiseType_SUM: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0+in1", op, backend)); + case EltwiseType_PROD: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0*in1", op, backend)); + case EltwiseType_SUB: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-in1", op, backend)); + case EltwiseType_MAXIMUM: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in0:in1", op, backend)); default: break; } @@ -357,24 +374,69 @@ class BinaryBufCreator : public OpenCLBackend::Creator { MNN_ASSERT(inputs.size() > 1); switch (op->main_as_BinaryOp()->opType()) { - case BinaryOpOperation_MUL: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0*in1", op, backend)); - case BinaryOpOperation_ADD: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0+in1", op, backend)); - case BinaryOpOperation_SUB: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-in1", op, backend)); - case BinaryOpOperation_REALDIV: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))", op, backend)); - case BinaryOpOperation_MINIMUM: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in1:in0", op, backend)); - case BinaryOpOperation_MAXIMUM: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in0:in1", op, backend)); - case BinaryOpOperation_GREATER: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-isgreater(in0,in1))", op, backend)); - case BinaryOpOperation_LESS: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-isless(in0,in1))", op, backend)); - case BinaryOpOperation_LESS_EQUAL: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-islessequal(in0,in1))", op, backend)); - case BinaryOpOperation_GREATER_EQUAL: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-isgreaterequal(in0,in1))", op, backend)); - case BinaryOpOperation_EQUAL: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-isequal(in0,in1))", op, backend)); - case BinaryOpOperation_FLOORDIV: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", op, backend)); - case BinaryOpOperation_FLOORMOD: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend)); - case BinaryOpOperation_POW: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "pow(in0,in1)", op, backend)); - case BinaryOpOperation_SquaredDifference: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "(in0-in1)*(in0-in1)", op, backend)); - case BinaryOpOperation_ATAN2: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "(in1==(float)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float)PI)))", op, backend)); - case BinaryOpOperation_NOTEQUAL: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend)); - case BinaryOpOperation_MOD: OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend)); + case BinaryOpOperation_MUL: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0*in1", op, backend)); + case BinaryOpOperation_MUL_SILU: + OPENCL_CREATOR_CHECK(new BinaryBufExecution( + inputs, "in0*(in1*native_recip((float4)1+native_exp(-in1)))", op, backend)); + case BinaryOpOperation_ADD: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0+in1", op, backend)); + case BinaryOpOperation_SUB: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0-in1", op, backend)); + case BinaryOpOperation_REALDIV: + OPENCL_CREATOR_CHECK(new BinaryBufExecution( + inputs, + "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))", op, + backend)); + case BinaryOpOperation_MINIMUM: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in1:in0", op, backend)); + case BinaryOpOperation_MAXIMUM: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "in0>in1?in0:in1", op, backend)); + case BinaryOpOperation_GREATER: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-isgreater(in0,in1))", op, backend)); + case BinaryOpOperation_LESS: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-isless(in0,in1))", op, backend)); + case BinaryOpOperation_LESS_EQUAL: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-islessequal(in0,in1))", op, backend)); + case BinaryOpOperation_GREATER_EQUAL: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-isgreaterequal(in0,in1))", op, backend)); + case BinaryOpOperation_EQUAL: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-isequal(in0,in1))", op, backend)); + case BinaryOpOperation_FLOORDIV: + OPENCL_CREATOR_CHECK(new BinaryBufExecution( + inputs, + "floor(sign(in1)*in0/" + "(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", + op, backend)); + case BinaryOpOperation_FLOORMOD: + OPENCL_CREATOR_CHECK(new BinaryBufExecution( + inputs, + "in0-floor(sign(in1)*in0/" + "(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", + op, backend)); + case BinaryOpOperation_POW: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "pow(in0,in1)", op, backend)); + case BinaryOpOperation_SquaredDifference: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, "(in0-in1)*(in0-in1)", op, backend)); + case BinaryOpOperation_ATAN2: + OPENCL_CREATOR_CHECK(new BinaryBufExecution(inputs, + "(in1==(float)0?(sign(in0)*(float4)(PI/2)):(atan(in0/" + "in1)+(in1>(float4)0?(float4)0:sign(in0)*(float)PI)))", + op, backend)); + case BinaryOpOperation_NOTEQUAL: + OPENCL_CREATOR_CHECK( + new BinaryBufExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend)); + case BinaryOpOperation_MOD: + OPENCL_CREATOR_CHECK(new BinaryBufExecution( + inputs, + "in0-floor(sign(in1)*in0/" + "(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", + op, backend)); default: break; } diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp index 509776b963..f24a78fe8e 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp @@ -6,6 +6,7 @@ #ifdef MNN_LOW_MEMORY #ifndef MNN_OPENCL_BUFFER_CLOSED #include "ConvBufLowMemoryExecution.hpp" +#include "SharedGatherBufExecution.hpp" // #define LOG_VERBOSE namespace MNN { namespace OpenCL { @@ -13,33 +14,34 @@ namespace OpenCL { #define PACK_CIN 4 // set mDequantScale mDequantOffset mNumQuantBit mFilterDataPtr from mConv2dParams -void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void *weight_ptr) { +void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void* weight_ptr) { auto quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true, weight_ptr); - if(quanCommon == nullptr){ + if (quanCommon == nullptr) { mValid = false; auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { staticMapAlloc->setRemove(true); } return; } // set mResource->mNumQuantBit - if(quanCommon->canUseInt2){ + if (quanCommon->canUseInt2) { mResource->mNumQuantBit = 2; - } else if(quanCommon->canUseInt3){ + } else if (quanCommon->canUseInt3) { mResource->mNumQuantBit = 3; - } else if(quanCommon->canUseInt4){ + } else if (quanCommon->canUseInt4) { mResource->mNumQuantBit = 4; - }else{ + } else { mResource->mNumQuantBit = 8; } if (mOp->main_as_Convolution2D()->common()->inputCount() > 0) { mResource->mInputChannel = mOp->main_as_Convolution2D()->common()->inputCount(); } else { - mResource->mInputChannel = quanCommon->weight.size() / (mResource->mKernelWidth * mResource->mKernelHeight * mResource->mOutputChannel); + mResource->mInputChannel = quanCommon->weight.size() / + (mResource->mKernelWidth * mResource->mKernelHeight * mResource->mOutputChannel); } // src of alpha in CPU - float * dequantAlpha = quanCommon->alpha.get(); + float* dequantAlpha = quanCommon->alpha.get(); int totalCount = quanCommon->alphaSize; int soSize = 1; if (quanCommon->asymmetric) { @@ -53,108 +55,116 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void *weight_ptr) { int numAlphaPack = ROUND_UP(numAlpha, 4); int fpBytes = mOpenCLBackend->fpBytes(); int buffer_size = mResource->mBlockSize * numAlphaPack * fpBytes * soSize + sizeof(float); - + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mDequantScaleOffsetBuffer = staticMapAlloc.get()->allocBuffer(buffer_size); - }else{ - mResource->mDequantScaleOffsetBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); + } else { + mResource->mDequantScaleOffsetBuffer.reset(new cl::Buffer( + mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); } // transfer data from src in cpu to dst in gpu cl_int resBias, resScaleOffset; float coef = 1.0; - - void * dequantScaleOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mResource->mDequantScaleOffsetBuffer.get(), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &resScaleOffset); - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap > 1){ - if(fpBytes == 2){ - float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + (numAlphaPack * mResource->mBlockSize * soSize)); + + void* dequantScaleOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *mResource->mDequantScaleOffsetBuffer.get(), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, + &resScaleOffset); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap > 1) { + if (fpBytes == 2) { + float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + + (numAlphaPack * mResource->mBlockSize * soSize)); coef = coefMapPtr[0]; - }else{ - coef = ((float *)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)]; + } else { + coef = ((float*)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)]; } - }else{ - if(fpBytes == 2) { + } else { + if (fpBytes == 2) { float max_data = 0.0f; - if (quanCommon->asymmetric){ + if (quanCommon->asymmetric) { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float s = fabsf(srcZ[2*j+0]); - float b = fabsf(srcZ[2*j+1]); + for (int j = 0; j < mResource->mBlockSize; ++j) { + float s = fabsf(srcZ[2 * j + 0]); + float b = fabsf(srcZ[2 * j + 1]); float temp = ALIMAX(s, b); - if(temp > max_data) { + if (temp > max_data) { max_data = temp; } } } - }else{ + } else { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize; - for(int j = 0; j < mResource->mBlockSize; ++j){ + for (int j = 0; j < mResource->mBlockSize; ++j) { float s = fabsf(srcZ[j]); - if(s > max_data) { + if (s > max_data) { max_data = s; } } } } - if(abs(max_data) >= 0.000001f){ + if (abs(max_data) >= 0.000001f) { coef = 1000.0f / max_data; } if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { if (quanCommon->asymmetric) { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float o = srcZ[2*j+0]; - float s = srcZ[2*j+1]; + for (int j = 0; j < mResource->mBlockSize; ++j) { + float o = srcZ[2 * j + 0]; + float s = srcZ[2 * j + 1]; // For int4, absorb -8 bias into offset: offset_new = offset - 8 * scale if (mResource->mNumQuantBit == 4) { o = o - 8.0f * s; } - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = (half_float::half)(s * coef); - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = (half_float::half)(o * coef); + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = + (half_float::half)(s * coef); + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = + (half_float::half)(o * coef); } } } else { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize; - for(int j = 0; j < mResource->mBlockSize; ++j){ - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = (half_float::half)(srcZ[j] * coef); + for (int j = 0; j < mResource->mBlockSize; ++j) { + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = + (half_float::half)(srcZ[j] * coef); } } } - float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + (numAlphaPack * mResource->mBlockSize * soSize)); + float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + + (numAlphaPack * mResource->mBlockSize * soSize)); coefMapPtr[0] = coef; } else { MNN_ERROR("Map error dequantBufferMap == nullptr \n"); MNN_ASSERT(false); } - } else{ + } else { if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { if (quanCommon->asymmetric) { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float o = srcZ[2*j+0]; - float s = srcZ[2*j+1]; + for (int j = 0; j < mResource->mBlockSize; ++j) { + float o = srcZ[2 * j + 0]; + float s = srcZ[2 * j + 1]; // For int4, absorb -8 bias into offset: offset_new = offset - 8 * scale if (mResource->mNumQuantBit == 4) { o = o - 8.0f * s; } - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = s * coef; - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = o * coef; + ((float*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = s * coef; + ((float*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = o * coef; } } } else { for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize; - for(int j = 0; j < mResource->mBlockSize; ++j){ - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = srcZ[j] * coef; + for (int j = 0; j < mResource->mBlockSize; ++j) { + ((float*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = srcZ[j] * coef; } } } - ((float *)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)] = coef; + ((float*)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)] = coef; } else { MNN_ERROR("Map error dequantBufferMap == nullptr \n"); MNN_ASSERT(false); @@ -162,9 +172,10 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void *weight_ptr) { } } mResource->mCoef = coef; - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mDequantScaleOffsetBuffer.get(), dequantScaleOffsetBufferMap); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject( + *mResource->mDequantScaleOffsetBuffer.get(), dequantScaleOffsetBufferMap); // set mFilterDataPtr - mFilterDataPtr = (void *)quanCommon->weight.get(); + mFilterDataPtr = (void*)quanCommon->weight.get(); } bool ConvBufLowMemoryExecution::convertToQuantWeight1x1Buffer(cl::Buffer input) { @@ -173,36 +184,39 @@ bool ConvBufLowMemoryExecution::convertToQuantWeight1x1Buffer(cl::Buffer input) #endif auto runtime = mOpenCLBackend->getOpenCLRuntime(); std::string kernelName = "conv2d_1x1_weight_quant_buffer"; - if(mResource->mUseImage){ + if (mResource->mUseImage) { kernelName = "conv2d_1x1_weight_quant_image"; } std::set buildOptions; if (mResource->mNumQuantBit == 8) { buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); - } else if (mResource->mNumQuantBit == 4){ + } else if (mResource->mNumQuantBit == 4) { // int4 case buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); - } else if (mResource->mNumQuantBit == 3){ + } else if (mResource->mNumQuantBit == 3) { buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT3"); - } else if (mResource->mNumQuantBit == 2){ + } else if (mResource->mNumQuantBit == 2) { buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT2"); - } else {/* More types to be supported. */} + } else { /* More types to be supported. */ + } - mBufferToConv1x1Kernel = runtime->buildKernelWithCache("buffer_convert_quant", kernelName, buildOptions, mOpenCLBackend->getPrecision()); + mBufferToConv1x1Kernel = + runtime->buildKernelWithCache("buffer_convert_quant", kernelName, buildOptions, mOpenCLBackend->getPrecision()); if (mBufferToConv1x1Kernel == nullptr) { return false; } auto kernel = mBufferToConv1x1Kernel->get(); - uint32_t gws[2] = {static_cast(UP_DIV(mResource->mInputChannel, PACK_CIN)), static_cast(UP_DIV(mResource->mOutputChannel, PACK_COUT))}; + uint32_t gws[2] = {static_cast(UP_DIV(mResource->mInputChannel, PACK_CIN)), + static_cast(UP_DIV(mResource->mOutputChannel, PACK_COUT))}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= kernel.setArg(idx++, gws[0]); ret |= kernel.setArg(idx++, gws[1]); ret |= kernel.setArg(idx++, input); - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= kernel.setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= kernel.setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= kernel.setArg(idx++, mResource->mInputChannel); @@ -221,8 +235,8 @@ bool ConvBufLowMemoryExecution::convertToQuantWeight1x1Buffer(cl::Buffer input) } res = runtime->commandQueue().enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1]), - cl::NDRange(lws[0], lws[1]), nullptr, &event); + cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1]), + cl::NDRange(lws[0], lws[1]), nullptr, &event); event.wait(); MNN_CHECK_CL_SUCCESS(res, "convertToQuantWeight1x1Buffer"); @@ -236,57 +250,61 @@ bool ConvBufLowMemoryExecution::convertToQuantWeight1x1Buffer(cl::Buffer input) // set mKernelBuffer for the 1x1 kernels void ConvBufLowMemoryExecution::set1x1WeightLowMemory() { bool preAllocGpuMem = mResource->mInputChannel != 0 && mResource->mConv2dParams->quanParameter(); - if(preAllocGpuMem){ + if (preAllocGpuMem) { mResource->mNumQuantBit = mResource->mConv2dParams->quanParameter()->aMaxOrBits(); - if(mResource->mNumQuantBit == 0){ + if (mResource->mNumQuantBit == 0) { // support old model for external weight file with int4/int8 quant mResource->mNumQuantBit = ConvolutionCommon::getQuantBitFromExternalFile(mOp); } - } else{ + } else { getInfoFromOpLowMemory(nullptr); - if(mValid == false){ + if (mValid == false) { return; } } cl_int res = CL_SUCCESS; - std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, PACK_COUT), ROUND_UP(mResource->mInputChannel, PACK_CIN), 1, 1})); + std::shared_ptr filterBuffer(Tensor::createDevice( + {ROUND_UP(mResource->mOutputChannel, PACK_COUT), ROUND_UP(mResource->mInputChannel, PACK_CIN), 1, 1})); const size_t orig_bytes = filterBuffer->usize() / sizeof(float); // OC_align * IC_align bytes (1B per weight) size_t staging_size = orig_bytes; size_t output_size = orig_bytes; size_t cpy_size = mResource->mOutputChannel * mResource->mInputChannel; int actual_packCin = PACK_CIN; // shared part for all cases - if (mResource->mNumQuantBit == 4){ + if (mResource->mNumQuantBit == 4) { // int4 case staging_size /= 2; output_size /= 2; cpy_size = UP_DIV(cpy_size, 2); - } else if(mResource->mNumQuantBit == 3){ + } else if (mResource->mNumQuantBit == 3) { // int3 case: 3/8 byte per weight in packed output, staging is 1B per weight output_size = (output_size * 3) / 8; actual_packCin = PACK_CIN * 2; // 8, forces image off for w3 (vload12 hard on image) - } else if(mResource->mNumQuantBit == 2){ + } else if (mResource->mNumQuantBit == 2) { // int2 case: 1/4 byte per weight in packed output, staging is 1B per weight output_size /= 4; actual_packCin = PACK_CIN * 2; // 8 - } else if(mResource->mNumQuantBit == 8){ + } else if (mResource->mNumQuantBit == 8) { actual_packCin /= 2; - } else {/* More types to be supported. */} - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, staging_size); - void *mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, staging_size, nullptr, nullptr, &res); - if(mapPtr != nullptr && res == CL_SUCCESS){ - if(preAllocGpuMem){ + } else { /* More types to be supported. */ + } + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1) { + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, staging_size); + void* mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + filterBufferCL, true, CL_MAP_WRITE, 0, staging_size, nullptr, nullptr, &res); + if (mapPtr != nullptr && res == CL_SUCCESS) { + if (preAllocGpuMem) { getInfoFromOpLowMemory(mapPtr); - if(mValid == false){ + if (mValid == false) { return; } // For 2/3bit forceQuant, ConvolutionCommon::load keeps the blob in a separate // allocation (mFilterDataPtr) instead of writing into mapPtr. Copy it now. - if(mResource->mNumQuantBit == 2 || mResource->mNumQuantBit == 3){ + if (mResource->mNumQuantBit == 2 || mResource->mNumQuantBit == 3) { ::memcpy(mapPtr, mFilterDataPtr, cpy_size); } - } else{ + } else { ::memcpy(mapPtr, mFilterDataPtr, cpy_size); } } else { @@ -295,61 +313,69 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory() { } mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, mapPtr); // Use Image load weights (only for 4bit/8bit; 2/3bit stick to buffer) - if(mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 8){ - if(UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384){ + if (mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 8) { + if (UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && + ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384) { mResource->mUseImage = true; } } auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); - if(mResource->mUseImage){ + if (mResource->mUseImage) { size_t w = UP_DIV(mResource->mInputChannel, actual_packCin); size_t h = UP_DIV(mResource->mOutputChannel, PACK_COUT); - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mKernelImage = staticMapAlloc.get()->allocImage(w, h, CL_SIGNED_INT32); - }else{ - mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); + } else { + mResource->mKernelImage.reset( + new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, + cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); } if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); } - }else{ - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + } else { + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mKernelBuffer = staticMapAlloc.get()->allocBuffer(output_size); - }else{ - mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, output_size)); + } else { + mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, output_size)); } } convertToQuantWeight1x1Buffer(filterBufferCL); - }else { - if(preAllocGpuMem){ + } else { + if (preAllocGpuMem) { getInfoFromOpLowMemory(nullptr); - if(mValid == false){ + if (mValid == false) { return; } } // Use Image load weights (only for 4bit/8bit; 2/3bit stick to buffer) - if(mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 8){ - if(UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384){ + if (mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 8) { + if (UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && + ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384) { mResource->mUseImage = true; } } auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); - if(mResource->mUseImage){ + if (mResource->mUseImage) { size_t w = UP_DIV(mResource->mInputChannel, actual_packCin); size_t h = UP_DIV(mResource->mOutputChannel, PACK_COUT); - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mKernelImage = staticMapAlloc.get()->allocImage(w, h, CL_SIGNED_INT32); - }else{ - mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); + } else { + mResource->mKernelImage.reset( + new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, + cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); } if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); } - }else{ - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + } else { + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mKernelBuffer = staticMapAlloc.get()->allocBuffer(output_size); - }else{ - mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, output_size)); + } else { + mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, output_size)); } } } @@ -357,39 +383,44 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory() { // set mFilter for the general kernels void ConvBufLowMemoryExecution::setGeneralWeightLowMemory() { bool preAllocGpuMem = mResource->mInputChannel != 0 && mResource->mConv2dParams->quanParameter(); - if(preAllocGpuMem){ + if (preAllocGpuMem) { mResource->mNumQuantBit = mResource->mConv2dParams->quanParameter()->aMaxOrBits(); - if(mResource->mNumQuantBit == 0){ + if (mResource->mNumQuantBit == 0) { // support old model for external weight file with int4/int8 quant mResource->mNumQuantBit = ConvolutionCommon::getQuantBitFromExternalFile(mOp); } - } else{ + } else { getInfoFromOpLowMemory(nullptr); - if(mValid == false){ + if (mValid == false) { return; } } - - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ - std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4), mResource->mInputChannel, mResource->mKernelWidth, mResource->mKernelHeight})); + + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1) { + std::shared_ptr filterBuffer( + Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4), mResource->mInputChannel, + mResource->mKernelWidth, mResource->mKernelHeight})); size_t buffer_size = filterBuffer->usize() / sizeof(float); - size_t cpy_size = mResource->mOutputChannel * mResource->mInputChannel * mResource->mKernelWidth * mResource->mKernelHeight; - if (mResource->mNumQuantBit == 4){ + size_t cpy_size = + mResource->mOutputChannel * mResource->mInputChannel * mResource->mKernelWidth * mResource->mKernelHeight; + if (mResource->mNumQuantBit == 4) { buffer_size /= 2; cpy_size = UP_DIV(cpy_size, 2); } - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); // map and pack data from filterDataPtr cl_int res; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(ptrCL != nullptr && res == CL_SUCCESS) { - if(preAllocGpuMem){ + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if (ptrCL != nullptr && res == CL_SUCCESS) { + if (preAllocGpuMem) { getInfoFromOpLowMemory(ptrCL); - if(mValid == false){ + if (mValid == false) { return; } - } else{ + } else { ::memcpy(ptrCL, mFilterDataPtr, cpy_size); } } else { @@ -398,17 +429,22 @@ void ConvBufLowMemoryExecution::setGeneralWeightLowMemory() { mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); if (mResource->mNumQuantBit == 8) { // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 4 * ROUND_UP(mResource->mInputChannel, 4)})); + mResource->mFilter.reset(Tensor::createDevice( + {1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, + 4 * ROUND_UP(mResource->mInputChannel, 4)})); if (!(mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC))) { mValid = false; return; } - } else if (mResource->mNumQuantBit == 4){ + } else if (mResource->mNumQuantBit == 4) { // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight // For int4 case, data stored in mFilter should be uint8_t, // while "Tensor::createDevice" occupies more memory than "Tensor::createDevice". - // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be supported. - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 2 * ROUND_UP(mResource->mInputChannel, 4)})); + // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be + // supported. + mResource->mFilter.reset(Tensor::createDevice( + {1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, + 2 * ROUND_UP(mResource->mInputChannel, 4)})); if (!(mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC))) { mValid = false; return; @@ -416,60 +452,70 @@ void ConvBufLowMemoryExecution::setGeneralWeightLowMemory() { } // convert to NC4HW4 MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), false, true, true, mResource->mNumQuantBit); - }else{ - if(preAllocGpuMem){ + bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), + mOpenCLBackend->getPrecision(), false, true, true, + mResource->mNumQuantBit); + } else { + if (preAllocGpuMem) { getInfoFromOpLowMemory(nullptr); - if(mValid == false){ + if (mValid == false) { return; } } if (mResource->mNumQuantBit == 8) { // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 4 * ROUND_UP(mResource->mInputChannel, 4)})); + mResource->mFilter.reset(Tensor::createDevice( + {1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, + 4 * ROUND_UP(mResource->mInputChannel, 4)})); if (!(mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC))) { mValid = false; return; } - } else if (mResource->mNumQuantBit == 4){ + } else if (mResource->mNumQuantBit == 4) { // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight // For int4 case, data stored in mFilter should be uint8_t, // while "Tensor::createDevice" occupies more memory than "Tensor::createDevice". - // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be supported. - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 2 * ROUND_UP(mResource->mInputChannel, 4)})); + // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be + // supported. + mResource->mFilter.reset(Tensor::createDevice( + {1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, + 2 * ROUND_UP(mResource->mInputChannel, 4)})); if (!(mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC))) { mValid = false; return; } } } - } // select the fastest kernel for the general cases by tuning -void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * output) { +void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor* input, Tensor* output) { mUnits.resize(1); - auto &unit = mUnits[0]; - std::vector inputShape = tensorShapeFormat(input); + auto& unit = mUnits[0]; + std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); - const int batch = outputShape.at(0); - const int height = outputShape.at(1); - const int width = outputShape.at(2); - const int outChannel = outputShape.at(3); - const int inputHeight = inputShape.at(1); - const int inputWidth = inputShape.at(2); + const int batch = outputShape.at(0); + const int height = outputShape.at(1); + const int width = outputShape.at(2); + const int outChannel = outputShape.at(3); + const int inputHeight = inputShape.at(1); + const int inputWidth = inputShape.at(2); const int inputChannels = inputShape.at(3); const int inputChannelBlocks = UP_DIV(inputChannels, 4); const int blockDim = mResource->mInputChannel / mResource->mBlockSize; - std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel) + "_" + std::to_string(mResource->mKernelHeight) + "_" + std::to_string(mResource->mKernelWidth) + "_" + std::to_string(mResource->mStrides[0]) + "_" + std::to_string(mResource->mStrides[1]) + "_" + std::to_string(mResource->mDilations[0]) + "_" + std::to_string(mResource->mDilations[1]); - int inputImageShape[2] = {inputHeight, inputWidth}; + std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel) + "_" + + std::to_string(mResource->mKernelHeight) + "_" + std::to_string(mResource->mKernelWidth) + "_" + + std::to_string(mResource->mStrides[0]) + "_" + std::to_string(mResource->mStrides[1]) + "_" + + std::to_string(mResource->mDilations[0]) + "_" + std::to_string(mResource->mDilations[1]); + int inputImageShape[2] = {inputHeight, inputWidth}; int outputImageShape[2] = {height, width}; - int kernelShape[2] = {mResource->mKernelHeight, mResource->mKernelWidth}; - int strideShape[2] = {mResource->mStrides[0], mResource->mStrides[1]}; - int paddingShape[2] = {mPaddings[0], mPaddings[1]}; - int dilationShape[2] = {mResource->mDilations[0], mResource->mDilations[1]}; + int kernelShape[2] = {mResource->mKernelHeight, mResource->mKernelWidth}; + int strideShape[2] = {mResource->mStrides[0], mResource->mStrides[1]}; + int paddingShape[2] = {mPaddings[0], mPaddings[1]}; + int dilationShape[2] = {mResource->mDilations[0], mResource->mDilations[1]}; // {"conv_2d_c4h1w2", "conv_2d_c4h1w1", "conv_2d_c8h1w1", "conv_2d_c4h1w4", "conv_2d_c8h2w1", "conv_2d_c4h4w1"}; const int total_kernel = 4; - std::string kernelName[total_kernel] = {"conv_2d_int_c4h1w1", "conv_2d_int_c4h1w2", "conv_2d_int_c4h1w4", "conv_2d_int_c8h1w4"}; + std::string kernelName[total_kernel] = {"conv_2d_int_c4h1w1", "conv_2d_int_c4h1w2", "conv_2d_int_c4h1w4", + "conv_2d_int_c8h1w4"}; int itemC[total_kernel] = {4, 4, 4, 8}; int itemH[total_kernel] = {1, 1, 1, 1}; int itemW[total_kernel] = {1, 2, 4, 4}; @@ -477,24 +523,29 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor std::shared_ptr kernel[total_kernel]; std::vector globalWorkSize[total_kernel]; std::vector localWorkSize[total_kernel]; - std::pair min_cost(INT_MAX, 0);//(min_time, min_index) + std::pair min_cost(INT_MAX, 0); //(min_time, min_index) // MNN_PRINT("Checking kernel %d.\n", knlCheck); for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { std::set buildOption = mResource->mBuildOptions; - if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + if (itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4) { buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - if((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0){ + if ((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0) { buildOption.emplace("-DBLOCK_LEAVE"); } - if(inputChannels % 4 != 0){ + if (inputChannels % 4 != 0) { buildOption.emplace("-DINPUT_CHANNEL_BOUNDARY_PROTECT"); } - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int_buf", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int_buf", kernelName[knl_idx], + buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); - globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; - uint32_t idx = 0; + globalWorkSize[knl_idx] = { + static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * + UP_DIV(outputShape.at(2), itemW[knl_idx])), + static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; + uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]); ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]); @@ -519,29 +570,33 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(mResource->mCoef)); MNN_CHECK_CL_SUCCESS(ret, "setArg ConvBufLowMemory Kernel Select"); std::pair, int> retTune; - retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx] + info, kernel[knl_idx], mOpenCLBackend->getCLTuneLevel(), "conv_2d_int_buf"); - if(min_cost.first > retTune.second) { + retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), + kernelName[knl_idx] + info, kernel[knl_idx], mOpenCLBackend->getCLTuneLevel(), + "conv_2d_int_buf"); + if (min_cost.first > retTune.second) { min_cost.first = retTune.second; min_cost.second = knl_idx; mLocalWorkSize = {retTune.first[0], retTune.first[1]}; } } - int min_index = min_cost.second; + int min_index = min_cost.second; mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; std::set buildOption = mResource->mBuildOptions; - if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + if (itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && + outputShape.at(3) % itemC[min_index] <= 4) { buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - if((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0){ + if ((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0) { buildOption.emplace("-DBLOCK_LEAVE"); } - if(inputChannels % 4 != 0){ + if (inputChannels % 4 != 0) { buildOption.emplace("-DINPUT_CHANNEL_BOUNDARY_PROTECT"); } - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int_buf", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision()); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int_buf", kernelName[min_index], buildOption, + mOpenCLBackend->getPrecision()); - uint32_t idx = 0; + uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); @@ -572,13 +627,13 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor } // weight inverse quantization, use xgemm opt -void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor * output) { +void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor* input, Tensor* output) { mUnits.resize(3); auto runtime = mOpenCLBackend->getOpenCLRuntime(); - std::vector inputShape = tensorShapeFormat(input); + std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); int channelPack = 2; - if(mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 3 || mResource->mNumQuantBit == 2){ + if (mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 3 || mResource->mNumQuantBit == 2) { channelPack = 4; } int area = inputShape.at(1) * inputShape.at(2); @@ -588,21 +643,21 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor int mAlignK = 4; int mAlignN = 16; int mAlignM = 64; - + // set M Align and N Align - if(mResource->mOutputChannel > 1024) { + if (mResource->mOutputChannel > 1024) { mAlignN = 128; - } else if(mResource->mOutputChannel > 512) { + } else if (mResource->mOutputChannel > 512) { mAlignN = 64; - } else if(mResource->mOutputChannel > 96) { + } else if (mResource->mOutputChannel > 96) { mAlignN = 32; } float ratio = 1.0 * M / 1024.0 * N / 1024.0 * K / 1024.0; - if(M > 1024 && ratio >= 1.0) { + if (M > 1024 && ratio >= 1.0) { mAlignM = 128; - } else if(M > 512 && ratio >= 0.1) { + } else if (M > 512 && ratio >= 0.1) { mAlignM = 64; - } else if(M > 96){ + } else if (M > 96) { mAlignM = 32; } else { mAlignM = 16; @@ -611,35 +666,39 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor int alignN = ROUND_UP(N, mAlignN); int alignK = ROUND_UP(K, mAlignK); int blockDim = mResource->mInputChannel / mResource->mBlockSize; - + // alloc temp bufer - mConvGemmWeightTensor.reset(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, mAlignN) * ROUND_UP(mResource->mInputChannel, std::max(mAlignK, channelPack))})); + mConvGemmWeightTensor.reset( + Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, mAlignN) * + ROUND_UP(mResource->mInputChannel, std::max(mAlignK, channelPack))})); mConvGemmInpTensor.reset(Tensor::createDevice({alignK * alignM})); mConvGemmOutTensor.reset(Tensor::createDevice({alignN * alignM})); mOpenCLBackend->onAcquireBuffer(mConvGemmWeightTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); - - //weight inverse quantization and rearrange + + // weight inverse quantization and rearrange { - auto &unit = mUnits[0]; + auto& unit = mUnits[0]; int outputChannelAlign = ROUND_UP(mResource->mOutputChannel, alignN); int outputChannel4Align = ROUND_UP(mResource->mOutputChannel, 4); int inputChannel4Align = ROUND_UP(mResource->mInputChannel, 4); std::set buildOption = mResource->mBuildOptions; - if(mResource->mUseImage){ + if (mResource->mUseImage) { buildOption.emplace("-DUSE_IMAGE"); } - mGlobalWorkSize = {static_cast(UP_DIV(mResource->mInputChannel, channelPack)), static_cast(UP_DIV(mResource->mOutputChannel, 8))}; - unit.kernel = runtime->buildKernel("gemm_conv1x1_buf", "inverse_quant_weight", buildOption, mOpenCLBackend->getPrecision()); + mGlobalWorkSize = {static_cast(UP_DIV(mResource->mInputChannel, channelPack)), + static_cast(UP_DIV(mResource->mOutputChannel, 8))}; + unit.kernel = runtime->buildKernel("gemm_conv1x1_buf", "inverse_quant_weight", buildOption, + mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -651,25 +710,29 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); ret |= unit.kernel->get().setArg(idx++, static_cast(mResource->mCoef)); MNN_CHECK_CL_SUCCESS(ret, "setArg inverse_quant_weight"); - - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "inverse_quant_weight", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf").first; + + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "inverse_quant_weight", + unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; } - + // rearange input { - auto &unit = mUnits[1]; + auto& unit = mUnits[1]; std::set buildOptions = mResource->mBuildOptions; - + int m_pack = 4; - mGlobalWorkSize = {static_cast(alignM/m_pack), static_cast(alignK/4)}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_pad", buildOptions, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = {static_cast(alignM / m_pack), static_cast(alignK / 4)}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_pad", buildOptions, + mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); int offset = 0; - int idx = 0; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); @@ -681,27 +744,32 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); MNN_CHECK_CL_SUCCESS(ret, "setArg transpose_pad"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_pad", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_buf").first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_pad", unit.kernel, + mOpenCLBackend->getCLTuneLevel(), "gemm_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; } - + // call gemm strassen { mStrassenComputor.reset(new StrassenMatrixComputor(backend(), 3)); - mStrassenComputor->onEncode(alignM, alignK, alignN, alignM, alignN, alignN, openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mConvGemmWeightTensor.get()), openCLBuffer(mConvGemmOutTensor.get()), false, openCLBuffer(mResource->mBias.get())); + mStrassenComputor->onEncode(alignM, alignK, alignN, alignM, alignN, alignN, + openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mConvGemmWeightTensor.get()), + openCLBuffer(mConvGemmOutTensor.get()), false, + openCLBuffer(mResource->mBias.get())); } - + // call output transpose { - auto &unit = mUnits[2]; + auto& unit = mUnits[2]; std::set buildOptions = mResource->mBuildOptions; int pack_m = 1; - if(M % 8 == 0) { + if (M % 8 == 0) { pack_m = 8; - } else if(M % 4 == 0) { + } else if (M % 4 == 0) { pack_m = 4; } buildOptions.emplace("-DM_VEC=" + std::to_string(pack_m)); @@ -718,7 +786,7 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor mGlobalWorkSize = {static_cast(UP_DIV(M, pack_m)), static_cast(UP_DIV(N, 4))}; int offset = 0; - int idx = 0; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); @@ -732,7 +800,9 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); MNN_CHECK_CL_SUCCESS(ret, "setArg transpose_bias"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_bias", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_buf").first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_bias", unit.kernel, + mOpenCLBackend->getCLTuneLevel(), "gemm_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; @@ -740,13 +810,13 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor mOpenCLBackend->onReleaseBuffer(mConvGemmWeightTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); - + return; } -void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * output) { +void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor* input, Tensor* output) { mUnits.resize(1); - auto &unit = mUnits[0]; - std::vector inputShape = tensorShapeFormat(input); + auto& unit = mUnits[0]; + std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); const int outChannel = outputShape.at(3); const int inputChannels = inputShape.at(3); @@ -761,16 +831,16 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); std::set buildOption = mResource->mBuildOptions; int inputChannelLeaves = 0; - if(mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 3 || mResource->mNumQuantBit == 2){ + if (mResource->mNumQuantBit == 4 || mResource->mNumQuantBit == 3 || mResource->mNumQuantBit == 2) { inputChannelLeaves = useLocalMem ? (inputChannels % 4) : (blockDim % 4); } else { inputChannelLeaves = useLocalMem ? (inputChannels % 2) : (blockDim % 2); } - if(outChannel % 8 != 0){ + if (outChannel % 8 != 0) { buildOption.emplace("-DOUTPUT_CHANNEL_LEAVES"); } buildOption.emplace("-DINPUT_CHANNEL_LEAVES_NUM=" + std::to_string(inputChannelLeaves)); - if(mResource->mUseImage){ + if (mResource->mUseImage) { buildOption.emplace("-DUSE_IMAGE"); } @@ -806,13 +876,15 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu } int local_size = useLocalMem ? 128 : 1; - if(useLocalMem && mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast){ + if (useLocalMem && mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast) { int min_time = INT_MAX; - for (int ksize = 8; ksize <= 256; ksize*=2) { + for (int ksize = 8; ksize <= 256; ksize *= 2) { auto option = buildOption; option.emplace("-DWGS=" + std::to_string(ksize)); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", + option, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); std::vector gws = {static_cast(ksize), static_cast(UP_DIV(outChannel, 8))}; std::vector lws = {static_cast(ksize), 1}; uint32_t idx = 0; @@ -825,9 +897,9 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu } else { ret |= kernel->get().setArg(idx++, openCLBuffer(input)); } - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -843,18 +915,21 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu ret |= kernel->get().setArg(idx++, static_cast(mResource->mCoef)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select"); std::pair, int> retTune; - int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, kernel, "gemv_conv1x1_buf"); - if(min_time > cost_time) { + int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), + "gemv_conv_c8_buf" + info, kernel, "gemv_conv1x1_buf"); + if (min_time > cost_time) { local_size = ksize; min_time = cost_time; } } } - + buildOption.emplace("-DWGS=" + std::to_string(local_size)); mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 8))}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, + mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); @@ -865,9 +940,9 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu } else { ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); } - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -882,20 +957,23 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); ret |= unit.kernel->get().setArg(idx++, static_cast(mResource->mCoef)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf"); - if(useLocalMem){ + if (useLocalMem) { mLocalWorkSize = {static_cast(local_size), 1}; - }else{ - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemv_conv1x1_buf").first; + } else { + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), + "gemv_conv_c8_buf" + info, unit.kernel, mOpenCLBackend->getCLTuneLevel(), + "gemv_conv1x1_buf") + .first; } mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; return; } -void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * output) { +void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor* input, Tensor* output) { mUnits.resize(1); - auto &unit = mUnits[0]; - std::vector inputShape = tensorShapeFormat(input); + auto& unit = mUnits[0]; + std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); const int outChannel = outputShape.at(3); const int inputChannels = inputShape.at(3); @@ -905,13 +983,13 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu const int outputChannelAlign = ROUND_UP(outChannel, 4); const int blockNum = mResource->mBlockSize; const int blockDim = mResource->mInputChannel / mResource->mBlockSize; - + int global_y = batch * width_height; std::string kernelName = "gemm_b4_c8"; std::set buildOption = mResource->mBuildOptions; int inputChannelLeaves = 0; int inputBatchLeaves = global_y % 4; - if(mResource->mNumQuantBit == 4){ + if (mResource->mNumQuantBit == 4) { inputChannelLeaves = blockDim % 4; kernelName += "_int4_buf"; } else { @@ -920,21 +998,22 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu } buildOption.emplace("-DINPUT_CHANNEL_LEAVES_NUM=" + std::to_string(inputChannelLeaves)); buildOption.emplace("-DINPUT_BATCH_LEAVES_NUM=" + std::to_string(inputBatchLeaves)); - if(mResource->mUseImage){ + if (mResource->mUseImage) { buildOption.emplace("-DUSE_IMAGE"); } // generate cache for every option for (int i = 0; i < 4; i++) { std::set option = mResource->mBuildOptions; - if(mResource->mUseImage){ + if (mResource->mUseImage) { option.emplace("-DUSE_IMAGE"); } option.emplace("-DINPUT_CHANNEL_LEAVES_NUM=" + std::to_string(inputChannelLeaves)); option.emplace("-DINPUT_BATCH_LEAVES_NUM=" + std::to_string(i)); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, option, mOpenCLBackend->getPrecision()); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, option, + mOpenCLBackend->getPrecision()); } std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); - if(global_y <= 16) { + if (global_y <= 16) { mUnits.resize(3); int outputChannelAlign8 = ROUND_UP(outChannel, 8); mConvGemmInpTensor.reset(Tensor::createDevice({inputChannelAlign * ROUND_UP(global_y, 4)})); @@ -943,14 +1022,17 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); - + { - //c4nhw4 -> nhwc - auto &unit = mUnits[0]; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_c4nhw4_to_nhwc", buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); - - mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), static_cast(UP_DIV(inputChannels, 4))}; + // c4nhw4 -> nhwc + auto& unit = mUnits[0]; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_c4nhw4_to_nhwc", + buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + + mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), + static_cast(UP_DIV(inputChannels, 4))}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -961,7 +1043,10 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelAlign)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_c4nhw4_to_nhwc"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_c4nhw4_to_nhwc", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf").first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), + "gemm_c4nhw4_to_nhwc", unit.kernel, mOpenCLBackend->getCLTuneLevel(), + "gemm_conv1x1_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; @@ -969,22 +1054,26 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu { const int inputChannelBlocks = UP_DIV(inputChannels, 4); const int outputChannelBlocks = UP_DIV(outChannel, 4); - auto &unit = mUnits[1]; + auto& unit = mUnits[1]; std::set buildOption = mResource->mBuildOptions; - if(mResource->mUseImage){ + if (mResource->mUseImage) { buildOption.emplace("-DUSE_IMAGE"); } buildOption.emplace("-DCOMPUTE_BATCH"); - + int local_size = 64; - if(mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast){ + if (mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast) { int min_time = INT_MAX; - for (int ksize = 16; ksize <= 256; ksize*=2) { + for (int ksize = 16; ksize <= 256; ksize *= 2) { auto option = buildOption; option.emplace("-DWGS=" + std::to_string(ksize)); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); - std::vector gws = {static_cast(ksize), static_cast(UP_DIV(outChannel, 8)), static_cast(UP_DIV(global_y, 4))}; + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel( + "gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); + std::vector gws = {static_cast(ksize), + static_cast(UP_DIV(outChannel, 8)), + static_cast(UP_DIV(global_y, 4))}; std::vector lws = {static_cast(ksize), 1, 1}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; @@ -992,9 +1081,9 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= kernel->get().setArg(idx++, static_cast(gws[1])); ret |= kernel->get().setArg(idx++, static_cast(gws[2])); ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -1010,26 +1099,31 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= kernel->get().setArg(idx++, static_cast(mResource->mCoef)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select"); std::pair, int> retTune; - int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info + "_batch", kernel, "gemv_conv1x1_buf"); - if(min_time > cost_time) { + int cost_time = + get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), + "gemv_conv_c8_buf" + info + "_batch", kernel, "gemv_conv1x1_buf"); + if (min_time > cost_time) { local_size = ksize; min_time = cost_time; } } } buildOption.emplace("-DWGS=" + std::to_string(local_size)); - mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 8)), static_cast(UP_DIV(global_y, 4))}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 8)), + static_cast(UP_DIV(global_y, 4))}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", + buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[2])); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -1050,10 +1144,13 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; } { - auto &unit = mUnits[2]; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_nhwc_to_c4nhw4", buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); - mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), static_cast(UP_DIV(outChannel, 4))}; + auto& unit = mUnits[2]; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_nhwc_to_c4nhw4", + buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), + static_cast(UP_DIV(outChannel, 4))}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -1063,7 +1160,10 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_nhwc_to_c4nhw4"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_nhwc_to_c4nhw4", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf").first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), + "gemm_nhwc_to_c4nhw4", unit.kernel, mOpenCLBackend->getCLTuneLevel(), + "gemm_conv1x1_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; @@ -1100,9 +1200,11 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu } } } - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, buildOption, mOpenCLBackend->getPrecision()); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); - + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, buildOption, + mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = + static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), static_cast(UP_DIV(outChannel, 8))}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; @@ -1113,9 +1215,9 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu } else { ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); } - if(mResource->mUseImage){ + if (mResource->mUseImage) { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); - }else{ + } else { ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); @@ -1128,13 +1230,18 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); ret |= unit.kernel->get().setArg(idx++, mResource->mCoef); MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv1x1_buf"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf").first; + mLocalWorkSize = + localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, + unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gemm_conv1x1_buf") + .first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; return; } -ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector &inputs, const std::vector &outputs, const MNN::Op *op, Backend *backend) +ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector& inputs, + const std::vector& outputs, const MNN::Op* op, + Backend* backend) : ConvBufCommonExecution(op->main_as_Convolution2D(), backend), CommonExecution(backend, op) { if (!mConvComValid) { mValid = false; @@ -1143,27 +1250,29 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector #ifdef LOG_VERBOSE MNN_PRINT("Start ConvBufLowMemoryExecution init !\n"); #endif - mOpenCLBackend = static_cast(backend); - const auto *conv2dParams = op->main_as_Convolution2D(); - const auto *conv2dCommonParams = conv2dParams->common(); - mResource->mConv2dParams = conv2dParams; - mResource->mConv2dCommonParams = conv2dCommonParams; - mResource->mStrides = {conv2dCommonParams->strideY(), conv2dCommonParams->strideX()}; - mResource->mDilations = {conv2dCommonParams->dilateY(), conv2dCommonParams->dilateX()}; + mOpenCLBackend = static_cast(backend); + const auto* conv2dParams = op->main_as_Convolution2D(); + const auto* conv2dCommonParams = conv2dParams->common(); + mResource->mConv2dParams = conv2dParams; + mResource->mConv2dCommonParams = conv2dCommonParams; + mResource->mStrides = {conv2dCommonParams->strideY(), conv2dCommonParams->strideX()}; + mResource->mDilations = {conv2dCommonParams->dilateY(), conv2dCommonParams->dilateX()}; auto padding = ConvolutionCommon::convolutionPad(inputs[0], outputs[0], conv2dCommonParams); - mPaddings[0] = padding.second;//padY - mPaddings[1] = padding.first;//padX + mPaddings[0] = padding.second; // padY + mPaddings[1] = padding.first; // padX - mResource->mKernelWidth = conv2dCommonParams->kernelX(); - mResource->mKernelHeight = conv2dCommonParams->kernelY(); + mResource->mKernelWidth = conv2dCommonParams->kernelX(); + mResource->mKernelHeight = conv2dCommonParams->kernelY(); mResource->mInputChannel = conv2dCommonParams->inputCount(); mResource->mOutputChannel = conv2dCommonParams->outputCount(); - - //select opt conv method - if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1 && conv2dCommonParams->padX() == 0 && conv2dCommonParams->padY() == 0 && conv2dCommonParams->dilateX() == 1 && conv2dCommonParams->dilateY() == 1) { + + // select opt conv method + if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && + mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1 && conv2dCommonParams->padX() == 0 && + conv2dCommonParams->padY() == 0 && conv2dCommonParams->dilateX() == 1 && conv2dCommonParams->dilateY() == 1) { set1x1WeightLowMemory(); mResource->mConv1x1Opt = true; - }else { + } else { // set mFilter for not 1x1 case setGeneralWeightLowMemory(); } @@ -1179,13 +1288,14 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector #endif } -ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(std::shared_ptr resource, const MNN::Op* op, Backend *backend) +ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(std::shared_ptr resource, const MNN::Op* op, + Backend* backend) : ConvBufCommonExecution(backend), CommonExecution(backend, op) { mResource = resource; - const auto *conv2dParams = op->main_as_Convolution2D(); - const auto *conv2dCommonParams = conv2dParams->common(); - mResource->mConv2dParams = conv2dParams; - mResource->mConv2dCommonParams = conv2dCommonParams; + const auto* conv2dParams = op->main_as_Convolution2D(); + const auto* conv2dCommonParams = conv2dParams->common(); + mResource->mConv2dParams = conv2dParams; + mResource->mConv2dCommonParams = conv2dCommonParams; } ConvBufLowMemoryExecution::~ConvBufLowMemoryExecution() { @@ -1206,38 +1316,46 @@ bool ConvBufLowMemoryExecution::onClone(Backend* bn, const Op* op, Execution** d if (nullptr == dst) { return true; } + if (op->type() == OpType_GatherV2) { + if (!SharedGatherBufExecution::validResource(mResource)) { + return false; + } + *dst = new SharedGatherBufExecution(mResource, op, bn); + return true; + } *dst = new ConvBufLowMemoryExecution(mResource, op, bn); return true; } -ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &inputs, const std::vector &outputs) { +ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector& inputs, const std::vector& outputs) { #ifdef LOG_VERBOSE MNN_PRINT("Start ConvBufLowMemoryExecution onResize !\n"); #endif auto runTime = mOpenCLBackend->getOpenCLRuntime(); mOpenCLBackend->startRecord(mRecording); mUnits.resize(1); - auto input = inputs[0]; + auto input = inputs[0]; auto output = outputs[0]; auto padding = ConvolutionCommon::convolutionPad(input, output, mResource->mConv2dCommonParams); - mPaddings[0] = padding.second;//padY - mPaddings[1] = padding.first;//padX + mPaddings[0] = padding.second; // padY + mPaddings[1] = padding.first; // padX // onclone default use conv1x1Opt, need reset std::vector outputShape = tensorShapeFormat(output); const int batch = outputShape.at(0) * outputShape.at(1) * outputShape.at(2); mUseFPWeight = false; if (mResource->mConv1x1Opt) { - if(batch == 1){ + if (batch == 1) { tuneGemvLowMemory(input, output); } else { // 2/3 bit have no dedicated GEMM kernel yet; always fall back to inverse-quant + FP gemm. - if(mResource->mNumQuantBit == 2 || mResource->mNumQuantBit == 3){ + if (mResource->mNumQuantBit == 2 || mResource->mNumQuantBit == 3) { mUseFPWeight = true; useFPWeightGemmLowMemory(input, output); } else { std::pair, uint32_t> tuneInfo; - std::string info = "convBufLowMemory_" + std::to_string(mResource->mInputChannel) + "_" + std::to_string(mResource->mOutputChannel); - if(batch > 16){ + std::string info = "convBufLowMemory_" + std::to_string(mResource->mInputChannel) + "_" + + std::to_string(mResource->mOutputChannel); + if (batch > 16) { if (getTunedInfo(info, {static_cast(batch)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getCLTuneLevel())) { mUseFPWeight = tuneInfo.first[0]; @@ -1252,21 +1370,23 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &input useFPWeightGemmLowMemory(input, output); auto longBatchTime = getExecuteTime(); mUseFPWeight = false; - if(longBatchTime < shortBatchTime){ + if (longBatchTime < shortBatchTime) { mUseFPWeight = true; } - std::pair, uint32_t> tuneInfoTmp = std::make_pair, uint32_t>({mUseFPWeight}, 0); - setTunedInfo(info, {static_cast(batch)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime(), "gemm_conv1x1_buf"); + std::pair, uint32_t> tuneInfoTmp = + std::make_pair, uint32_t>({mUseFPWeight}, 0); + setTunedInfo(info, {static_cast(batch)}, tuneInfoTmp, + mOpenCLBackend->getOpenCLRuntime(), "gemm_conv1x1_buf"); } else { - if(batch > 512){ + if (batch > 512) { mUseFPWeight = true; } } } } - if(mUseFPWeight){ + if (mUseFPWeight) { useFPWeightGemmLowMemory(input, output); - }else{ + } else { tuneGemmLowMemory(input, output); } } @@ -1274,15 +1394,16 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &input } else { tuneGeneralCaseLowMemory(input, output); } - for (auto &unit : mUnits) { + for (auto& unit : mUnits) { bool lws_null = true; for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) { - unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); - if(unit.localWorkSize.get()[i] != 0) { + unit.globalWorkSize.get()[i] = + ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); + if (unit.localWorkSize.get()[i] != 0) { lws_null = false; } } - if(lws_null){ + if (lws_null) { unit.localWorkSize = cl::NullRange; } } @@ -1293,67 +1414,57 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &input return NO_ERROR; } -int ConvBufLowMemoryExecution::getExecuteTime(){ - for (auto &unit : mUnits) { +int ConvBufLowMemoryExecution::getExecuteTime() { + for (auto& unit : mUnits) { bool lws_null = true; for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) { - unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); - if(unit.localWorkSize.get()[i] != 0) { + unit.globalWorkSize.get()[i] = + ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); + if (unit.localWorkSize.get()[i] != 0) { lws_null = false; } } - if(lws_null){ + if (lws_null) { unit.localWorkSize = cl::NullRange; } } int executeTime = 0; auto runtime = mOpenCLBackend->getOpenCLRuntime(); auto res = CL_SUCCESS; - if(mUseFPWeight){ + if (mUseFPWeight) { // arrange input and weight int i = 0; - for (; i < 2; ++i){ + for (; i < 2; ++i) { auto unit = mUnits[i]; cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); executeTime += runtime->getEventTime(event); } // call gemm execute executeTime += mStrassenComputor->getExecuteTime(); - + // rearrange output - for (; i < mUnits.size(); ++i){ + for (; i < mUnits.size(); ++i) { auto unit = mUnits[i]; cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); executeTime += runtime->getEventTime(event); } - }else{ - for (auto &unit : mUnits) { + } else { + for (auto& unit : mUnits) { cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); executeTime += runtime->getEventTime(event); } } return executeTime; } -ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { +ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector& inputs, + const std::vector& outputs) { #ifdef LOG_VERBOSE MNN_PRINT("Start ConvBufLowMemoryExecution onExecute !\n"); #endif @@ -1361,74 +1472,56 @@ ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector &inpu #ifdef ENABLE_OPENCL_TIME_PROFILER int idx = 0; #else - if(mOpenCLBackend->isUseRecordQueue()){ + if (mOpenCLBackend->isUseRecordQueue()) { mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo); return NO_ERROR; } #endif auto res = CL_SUCCESS; - if(mUseFPWeight){ + if (mUseFPWeight) { // arrange input and weight int i = 0; - for (; i < 2; ++i){ + for (; i < 2; ++i) { auto unit = mUnits[i]; - #ifdef ENABLE_OPENCL_TIME_PROFILER +#ifdef ENABLE_OPENCL_TIME_PROFILER cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); - #else - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize); - #endif +#else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize); +#endif MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); } // call gemm execute mStrassenComputor->onExecute(); - + // rearrange output - for (; i < mUnits.size(); ++i){ + for (; i < mUnits.size(); ++i) { auto unit = mUnits[i]; - #ifdef ENABLE_OPENCL_TIME_PROFILER +#ifdef ENABLE_OPENCL_TIME_PROFILER cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); - #else - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize); - #endif +#else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize); +#endif MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); } - }else{ - for (auto &unit : mUnits) { - #ifdef ENABLE_OPENCL_TIME_PROFILER + } else { + for (auto& unit : mUnits) { +#ifdef ENABLE_OPENCL_TIME_PROFILER cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, + unit.localWorkSize, nullptr, &event); runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); - #else - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, +#else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), cl::NullRange, unit.globalWorkSize, unit.localWorkSize); - #endif +#endif MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); } } @@ -1441,4 +1534,4 @@ ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector &inpu } // namespace OpenCL } // namespace MNN #endif /* MNN_OPENCL_BUFFER_CLOSED */ -#endif /* MNN_LOW_MEMORY */ \ No newline at end of file +#endif /* MNN_LOW_MEMORY */ diff --git a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp index f43482b63d..1dc23fc3ee 100644 --- a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp @@ -12,20 +12,22 @@ namespace MNN { namespace OpenCL { -LayerNormBufExecution::LayerNormBufExecution(const std::vector &inputs, const MNN::Op *op, Backend *backend) +LayerNormBufExecution::LayerNormBufExecution(const std::vector& inputs, const MNN::Op* op, Backend* backend) : CommonExecution(backend, op) { - mOpenCLBackend = static_cast(backend); - auto runtime = mOpenCLBackend->getOpenCLRuntime(); + mOpenCLBackend = static_cast(backend); + auto runtime = mOpenCLBackend->getOpenCLRuntime(); const auto* layer_norm_param = op->main_as_LayerNorm(); - mResource.reset(new LayernormResource); + mResource.reset(new LayernormResource); if (nullptr != layer_norm_param->axis()) { mResource->axis_size = layer_norm_param->axis()->size(); } mResource->epsilon_ = layer_norm_param->epsilon(); mResource->group_ = layer_norm_param->group(); mResource->RMSNorm = layer_norm_param->useRMSNorm(); - auto bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); - auto kernel = runtime->buildKernel("layernorm_buf", "layernorm_buf", {"-DLOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); + auto bufferUnitSize = + mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); + auto kernel = + runtime->buildKernel("layernorm_buf", "layernorm_buf", {"-DLOCAL_SIZE=512"}, mOpenCLBackend->getPrecision()); OPENCL_CHECK_KERNEL_CTOR(kernel); mResource->mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(kernel)); @@ -35,82 +37,94 @@ LayerNormBufExecution::LayerNormBufExecution(const std::vector &inputs MNN_ASSERT(layer_norm_param->gamma()->size() == layer_norm_param->beta()->size()); gammasize = layer_norm_param->gamma()->size(); } - mResource->has_gamma_beta_ = mResource->has_gamma_beta_ || (layer_norm_param->external() && layer_norm_param->external()->size() > 1 && layer_norm_param->external()->data()[1] > 0); + mResource->has_gamma_beta_ = + mResource->has_gamma_beta_ || (layer_norm_param->external() && layer_norm_param->external()->size() > 1 && + layer_norm_param->external()->data()[1] > 0); if (mResource->has_gamma_beta_ && gammasize == 0) { gammasize = layer_norm_param->external()->data()[1] / sizeof(float); } - + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); - if(mResource->has_gamma_beta_){ + if (mResource->has_gamma_beta_) { { auto error = CL_SUCCESS; int size = gammasize; - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mGammaBuffer = staticMapAlloc.get()->allocBuffer(ALIGN_UP4(size) * bufferUnitSize); - }else{ - mResource->mGammaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); + } else { + mResource->mGammaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + ALIGN_UP4(size) * bufferUnitSize)); } - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ - auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mGammaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1) { + auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *(mResource->mGammaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, + nullptr, &error); const float* gamma_data = layer_norm_param->gamma()->data(); - if(GammaPtrCL != nullptr && error == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) - { + if (GammaPtrCL != nullptr && error == CL_SUCCESS) { + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < size; i++) { ((half_float::half*)GammaPtrCL)[i] = (half_float::half)(gamma_data[i]); } - for(int i=size; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mGammaBuffer.get(), GammaPtrCL); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mGammaBuffer.get(), + GammaPtrCL); } } { auto error = CL_SUCCESS; int size = gammasize; - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr) { mResource->mBetaBuffer = staticMapAlloc.get()->allocBuffer(ALIGN_UP4(size) * bufferUnitSize); - }else{ - mResource->mBetaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); + } else { + mResource->mBetaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + ALIGN_UP4(size) * bufferUnitSize)); } - if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ - auto BetaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mBetaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1) { + auto BetaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *(mResource->mBetaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, + nullptr, &error); const float* beta_data = layer_norm_param->beta()->data(); - if(BetaPtrCL != nullptr && error == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) - { + if (BetaPtrCL != nullptr && error == CL_SUCCESS) { + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < size; i++) { ((half_float::half*)BetaPtrCL)[i] = (half_float::half)(beta_data[i]); } - for(int i=size; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mBetaBuffer.get(), BetaPtrCL); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mBetaBuffer.get(), + BetaPtrCL); } } } } -LayerNormBufExecution::LayerNormBufExecution(std::shared_ptr resource, const Op* op, Backend* backend): CommonExecution(backend, op) { +LayerNormBufExecution::LayerNormBufExecution(std::shared_ptr resource, const Op* op, + Backend* backend) + : CommonExecution(backend, op) { mResource = resource; - mOpenCLBackend = (OpenCLBackend *)backend; + mOpenCLBackend = (OpenCLBackend*)backend; } -bool LayerNormBufExecution::onClone(Backend *bn, const Op *op, Execution **dst) { +bool LayerNormBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) { if (!mValid) { return false; } @@ -121,23 +135,27 @@ bool LayerNormBufExecution::onClone(Backend *bn, const Op *op, Execution **dst) return true; } -int LayerNormBufExecution::getLocalSize(int size, int maxGroupSize){ +int LayerNormBufExecution::getLocalSize(int size, int maxGroupSize) { int local_size = 1; - while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){ + while (local_size * 2 <= maxGroupSize && local_size * 2 <= size) { local_size *= 2; } return local_size; } -ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { +ErrorCode LayerNormBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { mUnits.resize(1); - auto &unit = mUnits[0]; - Tensor *input = inputs[0]; - Tensor *output = outputs[0]; - auto runtime = ((OpenCLBackend *)backend())->getOpenCLRuntime(); - auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mResource->mMaxWorkGroupSize), (uint32_t)256); + auto& unit = mUnits[0]; + Tensor* input = inputs[0]; + Tensor* output = outputs[0]; + auto runtime = ((OpenCLBackend*)backend())->getOpenCLRuntime(); + auto MaxLocalSize = + std::min(std::min(runtime->getMaxWorkItemSizes()[0], mResource->mMaxWorkGroupSize), (uint32_t)256); - std::vector inputShape = tensorShapeFormat(input); + const auto layout = TensorUtils::getDescribe(input)->dimensionFormat; + bool isNC4HW4 = layout == MNN_DATA_FORMAT_NC4HW4; + + std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); int rank = inputs.at(0)->dimensions(); @@ -158,21 +176,46 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, c } inner_size /= mResource->group_; } - - int local_size = getLocalSize(inner_size / 4, MaxLocalSize); + + if (isNC4HW4) { + inner_size = inputs.at(0)->length(1); + outter_size = 1; + for (int i = 0; i < rank; i++) { + if (i != 1) { + outter_size *= inputs.at(0)->length(i); + } + } + } + + int local_size; + std::string kernelName; + if (isNC4HW4) { + int channelUnit = UP_DIV(inner_size, 4); + local_size = getLocalSize(channelUnit, MaxLocalSize); + if (inputs.size() == 2 && outputs.size() == 2) { + kernelName = "binary_layernorm_c4_buf"; + } else { + kernelName = "layernorm_c4_buf"; + } + } else { + local_size = getLocalSize(inner_size / 4, MaxLocalSize); + kernelName = "layernorm_buf"; + } + std::set buildOptions; buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - if(mResource->RMSNorm){ + if (mResource->RMSNorm) { buildOptions.emplace("-DRMSNORM"); } - if(mResource->has_gamma_beta_){ + if (mResource->has_gamma_beta_) { buildOptions.emplace("-DGAMMA_BETA"); } - if(inner_size % 4 != 0){ + if (!isNC4HW4 && inner_size % 4 != 0) { buildOptions.emplace("-DPACK_LEAVE"); } - - unit.kernel = runtime->buildKernel("layernorm_buf", "layernorm_buf", buildOptions, mOpenCLBackend->getPrecision()); + + unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions, mOpenCLBackend->getPrecision()); + OPENCL_CHECK_KERNEL(unit.kernel); mGWS = {static_cast(local_size), static_cast(outter_size)}; mLWS = {static_cast(local_size), 1}; @@ -180,10 +223,17 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, c cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGWS[0]); ret |= unit.kernel->get().setArg(idx++, mGWS[1]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + if (isNC4HW4 && inputs.size() == 2 && outputs.size() == 2) { + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[0])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[1])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[0])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[1])); + } else { + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + } ret |= unit.kernel->get().setArg(idx++, static_cast(inner_size)); - if(mResource->has_gamma_beta_){ + if (mResource->has_gamma_beta_) { ret |= unit.kernel->get().setArg(idx++, *mResource->mGammaBuffer.get()); ret |= unit.kernel->get().setArg(idx++, *mResource->mBetaBuffer.get()); } @@ -194,21 +244,19 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, c unit.localWorkSize = {mLWS[0], mLWS[1]}; return NO_ERROR; - } class LayerNormBufCreator : public OpenCLBackend::Creator { public: virtual ~LayerNormBufCreator() = default; - virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, - const MNN::Op *op, Backend *backend) const override { - for (int i = 0; i < inputs.size(); ++i) { + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const override { + for (int i = 0; i < inputs.size(); ++i) { TensorUtils::setTensorSupportPack(inputs[i], false); } for (int i = 0; i < outputs.size(); ++i) { TensorUtils::setTensorSupportPack(outputs[i], false); } - const auto* layer_norm_param = op->main_as_LayerNorm(); OPENCL_CREATOR_CHECK(new LayerNormBufExecution(inputs, op, backend)); } }; diff --git a/source/backend/opencl/execution/buffer/RopeBufExecution.cpp b/source/backend/opencl/execution/buffer/RopeBufExecution.cpp new file mode 100644 index 0000000000..e8bc94b97a --- /dev/null +++ b/source/backend/opencl/execution/buffer/RopeBufExecution.cpp @@ -0,0 +1,208 @@ +// +// RopeBufExecution.cpp +// MNN +// +// OpenCL buffer-path implementation of RoPE (Rotary Positional Embedding). +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#include "RopeBufExecution.hpp" +#include "MNN_generated.h" +#include "core/OpCommonUtils.hpp" +#include "core/TensorUtils.hpp" + +namespace MNN { +namespace OpenCL { + +RopeBufExecution::RopeBufExecution(const MNN::Op* op, Backend* backend) : CommonExecution(backend, op) { + mOpenCLBackend = static_cast(backend); + + if (nullptr != op && OpParameter_Extra == op->main_type()) { + auto extra = op->main_as_Extra(); + if (nullptr != extra && nullptr != extra->attr()) { + for (int i = 0; i < extra->attr()->size(); ++i) { + auto attr = extra->attr()->GetAs(i); + if (nullptr == attr || nullptr == attr->key()) { + continue; + } + if (attr->key()->str() == "rope_cut_head_dim") { + mRopeCutHeadDim = attr->i(); + continue; + } + if (attr->key()->str() == "q_norm") { + auto qLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + auto param = qLayernorm->main_as_LayerNorm(); + mQEps = param->epsilon(); + if (param->gamma()) { + int size = param->gamma()->size(); + mQGamma.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + ALIGN_UP4(size) * sizeof(float))); + auto error = CL_SUCCESS; + auto ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *mQGamma, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * sizeof(float), nullptr, nullptr, &error); + if (ptr != nullptr && error == CL_SUCCESS) { + ::memset(ptr, 0, ALIGN_UP4(size) * sizeof(float)); + ::memcpy(ptr, param->gamma()->data(), size * sizeof(float)); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mQGamma, ptr); + } + continue; + } + if (attr->key()->str() == "k_norm") { + auto kLayernorm = flatbuffers::GetRoot(attr->tensor()->int8s()->data()); + auto param = kLayernorm->main_as_LayerNorm(); + mKEps = param->epsilon(); + if (param->gamma()) { + int size = param->gamma()->size(); + mKGamma.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + ALIGN_UP4(size) * sizeof(float))); + auto error = CL_SUCCESS; + auto ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + *mKGamma, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * sizeof(float), nullptr, nullptr, &error); + if (ptr != nullptr && error == CL_SUCCESS) { + ::memset(ptr, 0, ALIGN_UP4(size) * sizeof(float)); + ::memcpy(ptr, param->gamma()->data(), size * sizeof(float)); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mKGamma, ptr); + } + continue; + } + } + } + } +} + +RopeBufExecution::RopeBufExecution(const MNN::Op* op, Backend* backend, int ropeCutHeadDim, + std::shared_ptr qGamma, float qEps, std::shared_ptr kGamma, + float kEps) + : CommonExecution(backend, op), + mRopeCutHeadDim(ropeCutHeadDim), + mQGamma(qGamma), + mKGamma(kGamma), + mQEps(qEps), + mKEps(kEps) { + mOpenCLBackend = static_cast(backend); +} + +bool RopeBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) { + if (nullptr == dst) { + return true; + } + *dst = new RopeBufExecution(op, bn, mRopeCutHeadDim, mQGamma, mQEps, mKGamma, mKEps); + return true; +} + +ErrorCode RopeBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { + MNN_ASSERT(inputs.size() == 6); + MNN_ASSERT(outputs.size() == 2); + + auto q = inputs[0]; + auto k = inputs[1]; + + int batch = q->length(0); + int seqLen = q->length(1); + int numHead = q->length(2); + int headDim = q->length(3); + int kvNumHead = k->length(2); + + int halfD = headDim / 2; + int ropeDim = mRopeCutHeadDim; + if (ropeDim <= 0 || ropeDim > headDim) { + ropeDim = headDim; + } + ropeDim = (ropeDim / 2) * 2; + int ropeHalfD = ropeDim / 2; + if (ropeHalfD > halfD) { + ropeHalfD = halfD; + } + + int outerSize = batch * seqLen; + int fullHead = numHead + kvNumHead; + + mUnits.resize(1); + auto& unit = mUnits[0]; + + auto runtime = mOpenCLBackend->getOpenCLRuntime(); + + std::set buildOptions; + if (mQGamma) { + buildOptions.emplace("-DQ_NORM"); + } + if (mKGamma) { + buildOptions.emplace("-DK_NORM"); + } + unit.kernel = runtime->buildKernel("rope_buf", "rope_buf", buildOptions, mOpenCLBackend->getPrecision()); + OPENCL_CHECK_KERNEL(unit.kernel); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + + if (mQGamma || mKGamma) { + mGlobalWorkSize = {1, static_cast(outerSize), static_cast(fullHead)}; + } else { + mGlobalWorkSize = {static_cast(halfD), static_cast(outerSize), + static_cast(fullHead)}; + } + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[0])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[1])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[2])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[3])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[4])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[5])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[0])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[1])); + ret |= unit.kernel->get().setArg(idx++, outerSize); + ret |= unit.kernel->get().setArg(idx++, halfD); + ret |= unit.kernel->get().setArg(idx++, ropeHalfD); + ret |= unit.kernel->get().setArg(idx++, headDim); + ret |= unit.kernel->get().setArg(idx++, numHead); + ret |= unit.kernel->get().setArg(idx++, kvNumHead); + if (mQGamma) { + ret |= unit.kernel->get().setArg(idx++, *mQGamma); + ret |= unit.kernel->get().setArg(idx++, mQEps); + } + if (mKGamma) { + ret |= unit.kernel->get().setArg(idx++, *mKGamma); + ret |= unit.kernel->get().setArg(idx++, mKEps); + } + MNN_CHECK_CL_SUCCESS(ret, "setArg RopeBufExecution"); + + mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runtime, "rope_buf", unit.kernel, + mOpenCLBackend->getCLTuneLevel(), "rope_buf") + .first; + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + + return NO_ERROR; +} + +class RopeBufCreator : public OpenCLBackend::Creator { +public: + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const override { + for (int i = 0; i < inputs.size(); ++i) { + TensorUtils::setTensorSupportPack(inputs[i], false); + } + for (int i = 0; i < outputs.size(); ++i) { + TensorUtils::setTensorSupportPack(outputs[i], false); + } + OPENCL_CREATOR_CHECK(new RopeBufExecution(op, backend)); + } +}; + +REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(RopeBufCreator, OpType_RoPE, BUFFER); + +} // namespace OpenCL +} // namespace MNN + +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ diff --git a/source/backend/opencl/execution/buffer/RopeBufExecution.hpp b/source/backend/opencl/execution/buffer/RopeBufExecution.hpp new file mode 100644 index 0000000000..5e26ce9b53 --- /dev/null +++ b/source/backend/opencl/execution/buffer/RopeBufExecution.hpp @@ -0,0 +1,44 @@ +// +// RopeBufExecution.hpp +// MNN +// +// OpenCL buffer-path implementation of RoPE (Rotary Positional Embedding). +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#ifndef RopeBufExecution_hpp +#define RopeBufExecution_hpp + +#include "backend/opencl/execution/image/CommonExecution.hpp" + +namespace MNN { +namespace OpenCL { + +class RopeBufExecution : public CommonExecution { +public: + RopeBufExecution(const MNN::Op* op, Backend* backend); + RopeBufExecution(const MNN::Op* op, Backend* backend, int ropeCutHeadDim, std::shared_ptr qGamma, + float qEps, std::shared_ptr kGamma, float kEps); + virtual ~RopeBufExecution() = default; + + virtual ErrorCode onEncode(const std::vector& inputs, const std::vector& outputs) override; + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; + +private: + OpenCLBackend* mOpenCLBackend = nullptr; + uint32_t mMaxWorkGroupSize = 0; + std::vector mGlobalWorkSize = {1, 1, 1}; + std::vector mLocalWorkSize = {1, 1, 1}; + int mRopeCutHeadDim = 0; + std::shared_ptr mQGamma; + std::shared_ptr mKGamma; + float mQEps = 0.0f; + float mKEps = 0.0f; +}; + +} // namespace OpenCL +} // namespace MNN + +#endif /* RopeBufExecution_hpp */ +#endif /* MNN_SUPPORT_TRANSFORMER_FUSE */ diff --git a/source/backend/opencl/execution/buffer/SharedGatherBufExecution.cpp b/source/backend/opencl/execution/buffer/SharedGatherBufExecution.cpp new file mode 100644 index 0000000000..c2e985b4a2 --- /dev/null +++ b/source/backend/opencl/execution/buffer/SharedGatherBufExecution.cpp @@ -0,0 +1,117 @@ +#ifndef MNN_OPENCL_BUFFER_CLOSED + +#include + +#include "backend/opencl/execution/buffer/SharedGatherBufExecution.hpp" +#include "core/Macro.h" +#include "core/TensorUtils.hpp" + +namespace MNN { +namespace OpenCL { + +SharedGatherBufExecution::SharedGatherBufExecution(std::shared_ptr resource, const Op* op, + Backend* backend) + : CommonExecution(backend, op), + mOpenCLBackend(static_cast(backend)), + mResource(std::move(resource)) {} + +bool SharedGatherBufExecution::validResource(const std::shared_ptr& resource) { + if (!resource.get() || !resource->mConv1x1Opt) { + return false; + } + if (!resource->mDequantScaleOffsetBuffer.get()) { + return false; + } + if (resource->mUseImage) { + if (!resource->mKernelImage.get()) { + return false; + } + } else if (!resource->mKernelBuffer.get()) { + return false; + } + return resource->mNumQuantBit == 4 || resource->mNumQuantBit == 8; +} + +bool SharedGatherBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) { + if (!mValid) { + return false; + } + if (nullptr == dst) { + return true; + } + *dst = new SharedGatherBufExecution(mResource, op, bn); + return true; +} + +ErrorCode SharedGatherBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { + mUnits.resize(1); + auto& unit = mUnits[0]; + auto runtime = mOpenCLBackend->getOpenCLRuntime(); + + MNN_ASSERT(inputs.size() == 1); + MNN_ASSERT(outputs.size() == 1); + auto indices = inputs[0]; + auto output = outputs[0]; + + if (!validResource(mResource)) { + return NOT_SUPPORT; + } + + const int selectSize = indices->elementSize(); + const int ic = output->length(output->dimensions() - 1); + const int oc = mResource->mOutputChannel; + if (selectSize <= 0 || ic <= 0 || oc <= 0) { + return NOT_SUPPORT; + } + const int blockSize = mResource->mBlockSize; + if (blockSize <= 0) { + return NOT_SUPPORT; + } + + std::set buildOptions; + if (mResource->mNumQuantBit == 8) { + buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); + } else { + buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); + } + if (mResource->mBuildOptions.find("-DASYMMETRIC") != mResource->mBuildOptions.end()) { + buildOptions.emplace("-DASYMMETRIC"); + } + const char* kernelName = mResource->mUseImage ? "shared_gather_quant_image" : "shared_gather_quant_buffer"; + unit.kernel = + runtime->buildKernel("shared_gather_buf", kernelName, buildOptions, mOpenCLBackend->getPrecision(), indices, + output); + OPENCL_CHECK_KERNEL(unit.kernel); + + mGWS = {(uint32_t)selectSize, (uint32_t)UP_DIV(ic, 4)}; + mLWS = {16, 16}; + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, (int)mGWS[0]); + ret |= unit.kernel->get().setArg(idx++, (int)mGWS[1]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + if (mResource->mUseImage) { + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); + } else { + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); + } + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(indices)); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); + ret |= unit.kernel->get().setArg(idx++, ic); + ret |= unit.kernel->get().setArg(idx++, oc); + ret |= unit.kernel->get().setArg(idx++, ic / blockSize); + ret |= unit.kernel->get().setArg(idx++, mResource->mCoef); + MNN_CHECK_CL_SUCCESS(ret, "setArg SharedGatherBufExecution"); + + mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); + unit.globalWorkSize = {mGWS[0], mGWS[1]}; + unit.localWorkSize = {mLWS[0], mLWS[1]}; + + return NO_ERROR; +} + +} // namespace OpenCL +} // namespace MNN + +#endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/SharedGatherBufExecution.hpp b/source/backend/opencl/execution/buffer/SharedGatherBufExecution.hpp new file mode 100644 index 0000000000..bdeaeee278 --- /dev/null +++ b/source/backend/opencl/execution/buffer/SharedGatherBufExecution.hpp @@ -0,0 +1,31 @@ +#ifndef MNN_OPENCL_BUFFER_CLOSED +#ifndef SharedGatherBufExecution_hpp +#define SharedGatherBufExecution_hpp + +#include "backend/opencl/execution/buffer/ConvBufExecution.hpp" +#include "backend/opencl/execution/image/CommonExecution.hpp" + +namespace MNN { +namespace OpenCL { + +class SharedGatherBufExecution : public CommonExecution { +public: + SharedGatherBufExecution(std::shared_ptr resource, const Op* op, Backend* backend); + virtual ~SharedGatherBufExecution() = default; + + static bool validResource(const std::shared_ptr& resource); + virtual ErrorCode onEncode(const std::vector& inputs, const std::vector& outputs) override; + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; + +private: + OpenCLBackend* mOpenCLBackend = nullptr; + std::shared_ptr mResource; + std::vector mGWS{0, 0}; + std::vector mLWS{0, 0}; +}; + +} // namespace OpenCL +} // namespace MNN + +#endif /* SharedGatherBufExecution_hpp */ +#endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/cl/attention_buf.cl b/source/backend/opencl/execution/cl/attention_buf.cl index 67495c5673..0cfbb23280 100644 --- a/source/backend/opencl/execution/cl/attention_buf.cl +++ b/source/backend/opencl/execution/cl/attention_buf.cl @@ -493,6 +493,8 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS __global const FLOAT* mask, #elif defined(SET_MASK) __global const int* mask, // [1 1 query_seq_len mask_key_seq_len] + #else + __global const FLOAT* mask, #endif __global FLOAT *qk, // [batch head_num kv_seq_length query_seq_len_4] __private const float scale, @@ -573,6 +575,34 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS out1 = (mask1 == (float4)0) ? (float4)(-FLT_MAX) : out1; out2 = (mask2 == (float4)0) ? (float4)(-FLT_MAX) : out2; out3 = (mask3 == (float4)0) ? (float4)(-FLT_MAX) : out3; + #elif defined(DEFAULT_MASK) + { + int kv_valid_offset = key_seq_len - query_seq_len; + int k0 = y4 + 0; + int k1 = y4 + 1; + int k2 = y4 + 2; + int k3 = y4 + 3; + int q0 = x4 + 0; + int q1 = x4 + 1; + int q2 = x4 + 2; + int q3 = x4 + 3; + if (k0 > kv_valid_offset + q0) { out0.s0 = -FLT_MAX; } + if (k1 > kv_valid_offset + q0) { out1.s0 = -FLT_MAX; } + if (k2 > kv_valid_offset + q0) { out2.s0 = -FLT_MAX; } + if (k3 > kv_valid_offset + q0) { out3.s0 = -FLT_MAX; } + if (k0 > kv_valid_offset + q1) { out0.s1 = -FLT_MAX; } + if (k1 > kv_valid_offset + q1) { out1.s1 = -FLT_MAX; } + if (k2 > kv_valid_offset + q1) { out2.s1 = -FLT_MAX; } + if (k3 > kv_valid_offset + q1) { out3.s1 = -FLT_MAX; } + if (k0 > kv_valid_offset + q2) { out0.s2 = -FLT_MAX; } + if (k1 > kv_valid_offset + q2) { out1.s2 = -FLT_MAX; } + if (k2 > kv_valid_offset + q2) { out2.s2 = -FLT_MAX; } + if (k3 > kv_valid_offset + q2) { out3.s2 = -FLT_MAX; } + if (k0 > kv_valid_offset + q3) { out0.s3 = -FLT_MAX; } + if (k1 > kv_valid_offset + q3) { out1.s3 = -FLT_MAX; } + if (k2 > kv_valid_offset + q3) { out2.s3 = -FLT_MAX; } + if (k3 > kv_valid_offset + q3) { out3.s3 = -FLT_MAX; } + } #endif } @@ -644,7 +674,8 @@ __kernel void matmul_qkv_prefill(GLOBAL_SIZE_3_DIMS __private const int max_len, __private const int head_num, __private const int kv_head_num, - __private const int head_dim) { + __private const int head_dim, + __private const int batch) { const int x = get_global_id(0); // head_dim const int y = get_global_id(1); // query_seq_len @@ -704,6 +735,24 @@ __kernel void matmul_qkv_prefill(GLOBAL_SIZE_3_DIMS out3 = mad((COMPUTE_FLOAT8)qk_vec.s3, past_vec, out3); } +#ifdef ATTENTION_C4 + int output_offset = (z * head_dim + x8) * query_seq_len * batch + (b * query_seq_len + y4) * 4; + const int stride = query_seq_len * batch * 4; + vstore4(CONVERT_FLOAT4(out0.lo), 0, output + output_offset); + vstore4(CONVERT_FLOAT4(out0.hi), 0, output + output_offset + stride); + if(y4 + 1 >= query_seq_len) return; + output_offset += 4; + vstore4(CONVERT_FLOAT4(out1.lo), 0, output + output_offset); + vstore4(CONVERT_FLOAT4(out1.hi), 0, output + output_offset + stride); + if(y4 + 2 >= query_seq_len) return; + output_offset += 4; + vstore4(CONVERT_FLOAT4(out2.lo), 0, output + output_offset); + vstore4(CONVERT_FLOAT4(out2.hi), 0, output + output_offset + stride); + if(y4 + 3 >= query_seq_len) return; + output_offset += 4; + vstore4(CONVERT_FLOAT4(out3.lo), 0, output + output_offset); + vstore4(CONVERT_FLOAT4(out3.hi), 0, output + output_offset + stride); +#else const int output_offset = ((b * query_seq_len + y4) * head_num + z) * head_dim + x8; const int stride = head_num * head_dim; vstore8(CONVERT_FLOAT8(out0), 0, output + output_offset); @@ -713,6 +762,7 @@ __kernel void matmul_qkv_prefill(GLOBAL_SIZE_3_DIMS vstore8(CONVERT_FLOAT8(out2), 0, output + output_offset + stride + stride); if(y4 + 3 >= query_seq_len) return; vstore8(CONVERT_FLOAT8(out3), 0, output + output_offset + stride + stride + stride); +#endif } @@ -865,4 +915,3 @@ __kernel void matmul_qkv_decode_b4(GLOBAL_SIZE_2_DIMS const int output_offset = y * head_dim + x4; vstore4(CONVERT_FLOAT4(out0), 0, output + output_offset); } - diff --git a/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp index 94fea8f0d8..78a89ea80c 100644 --- a/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp @@ -435,6 +435,8 @@ const char* attention_buf = " __global const FLOAT* mask,\n" " #elif defined(SET_MASK)\n" " __global const int* mask,// [1 1 query_seq_len mask_key_seq_len]\n" +" #else\n" +" __global const FLOAT* mask,\n" " #endif\n" " __global FLOAT *qk,// [batch head_num kv_seq_length query_seq_len_4]\n" " __private const float scale,\n" @@ -514,6 +516,34 @@ const char* attention_buf = " out1=(mask1 == (float4)0) ? (float4)(-FLT_MAX) : out1;\n" " out2=(mask2 == (float4)0) ? (float4)(-FLT_MAX) : out2;\n" " out3=(mask3 == (float4)0) ? (float4)(-FLT_MAX) : out3;\n" +" #elif defined(DEFAULT_MASK)\n" +" {\n" +" int kv_valid_offset=key_seq_len-query_seq_len;\n" +" int k0=y4+0;\n" +" int k1=y4+1;\n" +" int k2=y4+2;\n" +" int k3=y4+3;\n" +" int q0=x4+0;\n" +" int q1=x4+1;\n" +" int q2=x4+2;\n" +" int q3=x4+3;\n" +" if (k0>kv_valid_offset+q0) { out0.s0=-FLT_MAX; }\n" +" if (k1>kv_valid_offset+q0) { out1.s0=-FLT_MAX; }\n" +" if (k2>kv_valid_offset+q0) { out2.s0=-FLT_MAX; }\n" +" if (k3>kv_valid_offset+q0) { out3.s0=-FLT_MAX; }\n" +" if (k0>kv_valid_offset+q1) { out0.s1=-FLT_MAX; }\n" +" if (k1>kv_valid_offset+q1) { out1.s1=-FLT_MAX; }\n" +" if (k2>kv_valid_offset+q1) { out2.s1=-FLT_MAX; }\n" +" if (k3>kv_valid_offset+q1) { out3.s1=-FLT_MAX; }\n" +" if (k0>kv_valid_offset+q2) { out0.s2=-FLT_MAX; }\n" +" if (k1>kv_valid_offset+q2) { out1.s2=-FLT_MAX; }\n" +" if (k2>kv_valid_offset+q2) { out2.s2=-FLT_MAX; }\n" +" if (k3>kv_valid_offset+q2) { out3.s2=-FLT_MAX; }\n" +" if (k0>kv_valid_offset+q3) { out0.s3=-FLT_MAX; }\n" +" if (k1>kv_valid_offset+q3) { out1.s3=-FLT_MAX; }\n" +" if (k2>kv_valid_offset+q3) { out2.s3=-FLT_MAX; }\n" +" if (k3>kv_valid_offset+q3) { out3.s3=-FLT_MAX; }\n" +" }\n" " #endif\n" " }\n" " \n" @@ -583,7 +613,8 @@ const char* attention_buf = " __private const int max_len,\n" " __private const int head_num,\n" " __private const int kv_head_num,\n" -" __private const int head_dim) {\n" +" __private const int head_dim,\n" +" __private const int batch) {\n" " \n" " const int x=get_global_id(0); // head_dim\n" " const int y=get_global_id(1); // query_seq_len\n" @@ -643,6 +674,24 @@ const char* attention_buf = " out3=mad((COMPUTE_FLOAT8)qk_vec.s3,past_vec,out3);\n" " }\n" " \n" +"#ifdef ATTENTION_C4\n" +" int output_offset=(z*head_dim+x8)*query_seq_len*batch+(b*query_seq_len+y4)*4;\n" +" const int stride=query_seq_len*batch*4;\n" +" vstore4(CONVERT_FLOAT4(out0.lo),0,output+output_offset);\n" +" vstore4(CONVERT_FLOAT4(out0.hi),0,output+output_offset+stride);\n" +" if(y4+1 >= query_seq_len) return;\n" +" output_offset += 4;\n" +" vstore4(CONVERT_FLOAT4(out1.lo),0,output+output_offset);\n" +" vstore4(CONVERT_FLOAT4(out1.hi),0,output+output_offset+stride);\n" +" if(y4+2 >= query_seq_len) return;\n" +" output_offset += 4;\n" +" vstore4(CONVERT_FLOAT4(out2.lo),0,output+output_offset);\n" +" vstore4(CONVERT_FLOAT4(out2.hi),0,output+output_offset+stride);\n" +" if(y4+3 >= query_seq_len) return;\n" +" output_offset += 4;\n" +" vstore4(CONVERT_FLOAT4(out3.lo),0,output+output_offset);\n" +" vstore4(CONVERT_FLOAT4(out3.hi),0,output+output_offset+stride);\n" +"#else\n" " const int output_offset=((b*query_seq_len+y4)*head_num+z)*head_dim+x8;\n" " const int stride=head_num*head_dim;\n" " vstore8(CONVERT_FLOAT8(out0),0,output+output_offset);\n" @@ -652,6 +701,7 @@ const char* attention_buf = " vstore8(CONVERT_FLOAT8(out2),0,output+output_offset+stride+stride);\n" " if(y4+3 >= query_seq_len) return;\n" " vstore8(CONVERT_FLOAT8(out3),0,output+output_offset+stride+stride+stride);\n" +"#endif\n" "}\n" "__kernel void matmul_qkv_decode_b8(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *qk,// qk [1 head_num qk_seq_len 1]\n" diff --git a/source/backend/opencl/execution/cl/layernorm_buf.cl b/source/backend/opencl/execution/cl/layernorm_buf.cl index 09a6ce6a42..73713e56c2 100644 --- a/source/backend/opencl/execution/cl/layernorm_buf.cl +++ b/source/backend/opencl/execution/cl/layernorm_buf.cl @@ -2,6 +2,242 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif +__kernel void layernorm_c4_buf(__private int global_dim0, __private int global_dim1, + __global const FLOAT4 * input, + __global FLOAT4 * output, + __private const int inside, +#ifdef GAMMA_BETA + __global const FLOAT4 *gamma, + __global const FLOAT4 *beta, +#endif + __private float epsilon){ + int2 pos = (int2)(get_global_id(0), get_global_id(1)); +#if LOCAL_SIZE > 1 + float4 local sum_mnn[LOCAL_SIZE]; + #ifndef RMSNORM + float4 local sum_mean_mnn[LOCAL_SIZE]; + #endif + if (pos.x < global_dim0 && pos.y < global_dim1) { + const int lid = get_local_id(0); + const int batch = global_dim1; + const int channelUnit = inside / 4; + + float4 in_sum = 0; + int index = lid; + #ifdef RMSNORM + float4 mean = (float4)0; + #else + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input[idx]); + in_sum += in; + } + sum_mean_mnn[lid] = in_sum; + + barrier(CLK_LOCAL_MEM_FENCE); + for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + if (lid < i) + sum_mean_mnn[lid] = sum_mean_mnn[lid] + sum_mean_mnn[lid + i]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + float sum_all = sum_mean_mnn[0].x + sum_mean_mnn[0].y + sum_mean_mnn[0].z + sum_mean_mnn[0].w; + float4 mean = (float4)(sum_all / inside); + #endif + + in_sum = 0; + index = lid; + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input[idx]); + in_sum += (in - mean) * (in - mean); + } + sum_mnn[lid] = in_sum; + barrier(CLK_LOCAL_MEM_FENCE); + for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + if (lid < i) + sum_mnn[lid] = sum_mnn[lid] + sum_mnn[lid + i]; + barrier(CLK_LOCAL_MEM_FENCE); + } + float square_sum_all = sum_mnn[0].x + sum_mnn[0].y + sum_mnn[0].z + sum_mnn[0].w; + float4 square_sum = (float4)(square_sum_all / inside); + float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); + index = lid; + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input[idx]); + #ifdef GAMMA_BETA + float4 out = (in - mean) * value * convert_float4(gamma[index]) + convert_float4(beta[index]); + #else + float4 out = (in - mean) * value; + #endif + output[idx] = CONVERT_FLOAT4(out); + } + } +#else + if (pos.x < global_dim0 && pos.y < global_dim1) { + const int batch = global_dim1; + const int channelUnit = inside / 4; + + float4 in_sum = 0; + #ifdef RMSNORM + float4 mean = (float4)0; + #else + for(int index = 0; index < channelUnit; index++){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input[idx]); + in_sum += in; + } + float sum_all = in_sum.x + in_sum.y + in_sum.z + in_sum.w; + float4 mean = (float4)(sum_all / inside); + #endif + + in_sum = 0; + for(int index = 0; index < channelUnit; index++){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input[idx]); + in_sum += (in - mean) * (in - mean); + } + float square_sum_all = in_sum.x + in_sum.y + in_sum.z + in_sum.w; + float4 square_sum = (float4)(square_sum_all / inside); + float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); + int idx = pos.x * batch + pos.y; + float4 in = convert_float4(input[idx]); + #ifdef GAMMA_BETA + float4 out = (in - mean) * value * convert_float4(gamma[pos.x]) + convert_float4(beta[pos.x]); + #else + float4 out = (in - mean) * value; + #endif + output[idx] = CONVERT_FLOAT4(out); + } +#endif +} + +__kernel void binary_layernorm_c4_buf(__private int global_dim0, __private int global_dim1, + __global const FLOAT4 * input0, + __global const FLOAT4 * input1, + __global FLOAT4 * output0, + __global FLOAT4 * output1, + __private const int inside, +#ifdef GAMMA_BETA + __global const FLOAT4 *gamma, + __global const FLOAT4 *beta, +#endif + __private float epsilon){ + int2 pos = (int2)(get_global_id(0), get_global_id(1)); +#if LOCAL_SIZE > 1 + float4 local sum_mnn[LOCAL_SIZE]; + #ifndef RMSNORM + float4 local sum_mean_mnn[LOCAL_SIZE]; + #endif + if (pos.x < global_dim0 && pos.y < global_dim1) { + const int lid = get_local_id(0); + const int batch = global_dim1; + const int channelUnit = inside / 4; + + float4 in_sum = 0; + int index = lid; + #ifdef RMSNORM + float4 mean = (float4)0; + #else + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input0[idx]) + convert_float4(input1[idx]); + output0[idx] = CONVERT_FLOAT4(in); + in_sum += in; + } + sum_mean_mnn[lid] = in_sum; + + barrier(CLK_LOCAL_MEM_FENCE); + for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + if (lid < i) + sum_mean_mnn[lid] = sum_mean_mnn[lid] + sum_mean_mnn[lid + i]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + float sum_all = sum_mean_mnn[0].x + sum_mean_mnn[0].y + sum_mean_mnn[0].z + sum_mean_mnn[0].w; + float4 mean = (float4)(sum_all / inside); + #endif + + in_sum = 0; + index = lid; + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + #ifdef RMSNORM + float4 in = convert_float4(input0[idx]) + convert_float4(input1[idx]); + output0[idx] = CONVERT_FLOAT4(in); + #else + float4 in = convert_float4(output0[idx]); + #endif + in_sum += (in - mean) * (in - mean); + } + sum_mnn[lid] = in_sum; + barrier(CLK_LOCAL_MEM_FENCE); + for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + if (lid < i) + sum_mnn[lid] = sum_mnn[lid] + sum_mnn[lid + i]; + barrier(CLK_LOCAL_MEM_FENCE); + } + float square_sum_all = sum_mnn[0].x + sum_mnn[0].y + sum_mnn[0].z + sum_mnn[0].w; + float4 square_sum = (float4)(square_sum_all / inside); + float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); + index = lid; + for(; index < channelUnit; index+=LOCAL_SIZE){ + int idx = index * batch + pos.y; + float4 in = convert_float4(output0[idx]); + #ifdef GAMMA_BETA + float4 out = (in - mean) * value * convert_float4(gamma[index]) + convert_float4(beta[index]); + #else + float4 out = (in - mean) * value; + #endif + output1[idx] = CONVERT_FLOAT4(out); + } + } +#else + if (pos.x < global_dim0 && pos.y < global_dim1) { + const int batch = global_dim1; + const int channelUnit = inside / 4; + + float4 in_sum = 0; + #ifdef RMSNORM + float4 mean = (float4)0; + #else + for(int index = 0; index < channelUnit; index++){ + int idx = index * batch + pos.y; + float4 in = convert_float4(input0[idx]) + convert_float4(input1[idx]); + output0[idx] = CONVERT_FLOAT4(in); + in_sum += in; + } + float sum_all = in_sum.x + in_sum.y + in_sum.z + in_sum.w; + float4 mean = (float4)(sum_all / inside); + #endif + + in_sum = 0; + for(int index = 0; index < channelUnit; index++){ + int idx = index * batch + pos.y; + #ifdef RMSNORM + float4 in = convert_float4(input0[idx]) + convert_float4(input1[idx]); + output0[idx] = CONVERT_FLOAT4(in); + #else + float4 in = convert_float4(output0[idx]); + #endif + in_sum += (in - mean) * (in - mean); + } + float square_sum_all = in_sum.x + in_sum.y + in_sum.z + in_sum.w; + float4 square_sum = (float4)(square_sum_all / inside); + float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); + int idx = pos.x * batch + pos.y; + float4 in = convert_float4(output0[idx]); + #ifdef GAMMA_BETA + float4 out = (in - mean) * value * convert_float4(gamma[pos.x]) + convert_float4(beta[pos.x]); + #else + float4 out = (in - mean) * value; + #endif + output1[idx] = CONVERT_FLOAT4(out); + } +#endif +} + __kernel void layernorm_buf(__private int global_dim0, __private int global_dim1, __global const FLOAT * input, __global FLOAT * output, diff --git a/source/backend/opencl/execution/cl/layernorm_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/layernorm_buf_mnn_cl.cpp index 731c70b3d4..57941b5078 100644 --- a/source/backend/opencl/execution/cl/layernorm_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/layernorm_buf_mnn_cl.cpp @@ -5,6 +5,228 @@ const char* layernorm_buf = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" +"__kernel void layernorm_c4_buf(__private int global_dim0,__private int global_dim1,\n" +" __global const FLOAT4*input,\n" +" __global FLOAT4*output,\n" +" __private const int inside,\n" +"#ifdef GAMMA_BETA\n" +" __global const FLOAT4 *gamma,\n" +" __global const FLOAT4 *beta,\n" +"#endif\n" +" __private float epsilon){\n" +" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" +"#if LOCAL_SIZE>1\n" +" float4 local sum_mnn[LOCAL_SIZE];\n" +" #ifndef RMSNORM\n" +" float4 local sum_mean_mnn[LOCAL_SIZE];\n" +" #endif\n" +" if (pos.x0; i /= 2){\n" +" if (lid0; i /= 2){\n" +" if (lid1\n" +" float4 local sum_mnn[LOCAL_SIZE];\n" +" #ifndef RMSNORM\n" +" float4 local sum_mean_mnn[LOCAL_SIZE];\n" +" #endif\n" +" if (pos.x0; i /= 2){\n" +" if (lid0; i /= 2){\n" +" if (lid OpenCLProgramMap = { #endif #endif {"nearest", nearest}, +#ifndef MNN_OPENCL_BUFFER_CLOSED + {"rope_buf", rope_buf}, +#endif #ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_INTEL_SUBGROUP {"pooling_subgroup_buf", pooling_subgroup_buf}, @@ -298,6 +307,9 @@ const std::map OpenCLProgramMap = { #endif #ifndef MNN_OPENCL_BUFFER_CLOSED {"matmul_buf", matmul_buf}, +#endif +#ifndef MNN_OPENCL_BUFFER_CLOSED + {"shared_gather_buf", shared_gather_buf}, #endif {"pooling", pooling}, #ifndef MNN_OPENCL_BUFFER_CLOSED @@ -373,6 +385,7 @@ const std::map OpenCLProgramMd5Map = { {"binary_subgroup_buf", "8444f988543cd4a4d9b124442f02f999"}, {"depthwise_conv2d_subgroup_buf", "3e37457e72b7e629655aa04bd03e559e"}, {"nearest", "e8b2081c5e50ae6d370989f816cda543"}, + {"rope_buf", "bc211ff80619392e567e1bb9b1a2c80f"}, {"pooling_subgroup_buf", "9c935c0caabe2ee20822fcfd7722472e"}, {"pooling_buf", "806c95095431e361be2af7f4e9eae65e"}, {"winogradTransformSource2_5_1", "f0ee12556faf4fe0222e2a4e64c53c5c"}, @@ -391,7 +404,7 @@ const std::map OpenCLProgramMd5Map = { {"loop", "4849a55cd99f0ebab72a10527455341f"}, {"argmax_buf", "ae4a1ae3461b2758609022ac7569b11b"}, {"buffer_convert_subgroup_buf", "d968b717e537464a7fa08e742c9a0319"}, - {"attention_buf", "7d05b22865927ca19dae5762ba6f1df9"}, + {"attention_buf", "0863fdf099277eda2ce579ba3b014ed8"}, {"groupnorm_buf", "7f4b041b77ba98165ab624d94444f327"}, {"unary_subgroup_buf", "31e3768f899da6da45084f617b13c282"}, {"gemm", "5729018147348682e02762ed5ec14d0c"}, @@ -400,12 +413,13 @@ const std::map OpenCLProgramMd5Map = { {"range", "97feaf25d837a325382c162ad77ae0ca"}, {"scale_buf", "9176b8e86fd4d326e7fa14640ce13b48"}, {"matmul_buf", "b66faece7f0591d49c289e5227d9f680"}, + {"shared_gather_buf", "74b2cbe87698151c5c3cf718fa279cd4"}, {"pooling", "900d1388836badea36a7e06ad7763b0d"}, {"conv_2d_buf", "2faa0378ab0d702419a92ecc2073851a"}, {"gemm_int", "4e64d43a8ca423a9d0dc68dcfcd64c06"}, {"buffer_to_image", "bad95040692206db84b5a1bcc0b6f248"}, {"winogradTransformDest2_3_1", "f2aaa52d652565e70a44868d4f6028e9"}, - {"layernorm_buf", "5f6b88b29da72f51bdc85064b5663bb2"}, + {"layernorm_buf", "971da88e7c1f885d58a81a59a31a88a2"}, {"softmax_buf", "12052d403f3fa0cdfea2559296e88e6c"}, {"conv_2d_c16_subgroup_buf", "81f9027f323b6890d08d49dab10a15e4"}, {"input_transe_buf", "c80482cd531add8582edc242bcbfa947"}, diff --git a/source/backend/opencl/execution/cl/rope_buf.cl b/source/backend/opencl/execution/cl/rope_buf.cl new file mode 100644 index 0000000000..fc19fb161e --- /dev/null +++ b/source/backend/opencl/execution/cl/rope_buf.cl @@ -0,0 +1,131 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define GLOBAL_SIZE_3_DIMS \ + __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, + +#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ + return; \ + } + +__kernel void rope_buf(GLOBAL_SIZE_3_DIMS + __global const FLOAT *q, + __global const FLOAT *k, + __global const FLOAT *cosEven, + __global const FLOAT *cosOdd, + __global const FLOAT *sinEven, + __global const FLOAT *sinOdd, + __global FLOAT *q_out, + __global FLOAT *k_out, + __private const int outerSize, + __private const int halfD, + __private const int ropeHalfD, + __private const int headDim, + __private const int numHead, + __private const int kvNumHead +#ifdef Q_NORM + , __global const float *qGamma + , __private const float qEps +#endif +#ifdef K_NORM + , __global const float *kGamma + , __private const float kEps +#endif + ) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int z = get_global_id(2); + DEAL_NON_UNIFORM_DIM3(x, y, z); + + const int fullHead = numHead + kvNumHead; +#if defined(Q_NORM) || defined(K_NORM) + if (x >= 1 || y >= outerSize || z >= fullHead) { + return; + } +#else + if (x >= halfD || y >= outerSize || z >= fullHead) { + return; + } +#endif + + const int D = headDim; + bool isQ = (z < numHead); + __global const FLOAT *in_ptr = isQ ? (q + (y * numHead + z) * D) : (k + (y * kvNumHead + z - numHead) * D); + __global FLOAT *out_ptr = isQ ? (q_out + (y * numHead + z) * D) : (k_out + (y * kvNumHead + z - numHead) * D); + + float var = 0.0f; +#ifdef Q_NORM + if (isQ) { + for (int i = 0; i < D; ++i) { + float val = (float)in_ptr[i]; + var += val * val; + } + var = 1.0f / sqrt(var / D + qEps); + } +#endif +#ifdef K_NORM + if (!isQ) { + for (int i = 0; i < D; ++i) { + float val = (float)in_ptr[i]; + var += val * val; + } + var = 1.0f / sqrt(var / D + kEps); + } +#endif + +#if defined(Q_NORM) || defined(K_NORM) + for (int i = 0; i < halfD; ++i) { + const int cosIndex = y * halfD + i; + FLOAT cEven = cosEven[cosIndex]; + FLOAT cOdd = cosOdd[cosIndex]; + FLOAT sEven = sinEven[cosIndex]; + FLOAT sOdd = sinOdd[cosIndex]; + + FLOAT evenVal = in_ptr[i]; + FLOAT oddVal = in_ptr[i + halfD]; +#ifdef Q_NORM + if (isQ) { + evenVal = (FLOAT)((float)evenVal * var * qGamma[i]); + oddVal = (FLOAT)((float)oddVal * var * qGamma[i + halfD]); + } +#endif +#ifdef K_NORM + if (!isQ) { + evenVal = (FLOAT)((float)evenVal * var * kGamma[i]); + oddVal = (FLOAT)((float)oddVal * var * kGamma[i + halfD]); + } +#endif + + if (i < ropeHalfD) { + FLOAT v0 = evenVal * cEven - oddVal * sEven; + FLOAT v1 = oddVal * cOdd + evenVal * sOdd; + out_ptr[i] = v0; + out_ptr[i + halfD] = v1; + } else { + out_ptr[i] = evenVal; + out_ptr[i + halfD] = oddVal; + } + } +#else + const int cosIndex = y * halfD + x; + FLOAT cEven = cosEven[cosIndex]; + FLOAT cOdd = cosOdd[cosIndex]; + FLOAT sEven = sinEven[cosIndex]; + FLOAT sOdd = sinOdd[cosIndex]; + + FLOAT evenVal = in_ptr[x]; + FLOAT oddVal = in_ptr[x + halfD]; + + if (x < ropeHalfD) { + FLOAT v0 = evenVal * cEven - oddVal * sEven; + FLOAT v1 = oddVal * cOdd + evenVal * sOdd; + out_ptr[x] = v0; + out_ptr[x + halfD] = v1; + } else { + out_ptr[x] = evenVal; + out_ptr[x + halfD] = oddVal; + } +#endif +} diff --git a/source/backend/opencl/execution/cl/rope_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/rope_buf_mnn_cl.cpp new file mode 100644 index 0000000000..f817d63268 --- /dev/null +++ b/source/backend/opencl/execution/cl/rope_buf_mnn_cl.cpp @@ -0,0 +1,123 @@ +#include "opencl_source_map.hpp" +namespace MNN { +#ifndef MNN_OPENCL_BUFFER_CLOSED +const char* rope_buf = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"__kernel void rope_buf(GLOBAL_SIZE_3_DIMS\n" +" __global const FLOAT *q,\n" +" __global const FLOAT *k,\n" +" __global const FLOAT *cosEven,\n" +" __global const FLOAT *cosOdd,\n" +" __global const FLOAT *sinEven,\n" +" __global const FLOAT *sinOdd,\n" +" __global FLOAT *q_out,\n" +" __global FLOAT *k_out,\n" +" __private const int outerSize,\n" +" __private const int halfD,\n" +" __private const int ropeHalfD,\n" +" __private const int headDim,\n" +" __private const int numHead,\n" +" __private const int kvNumHead\n" +"#ifdef Q_NORM\n" +" ,__global const float *qGamma\n" +" ,__private const float qEps\n" +"#endif\n" +"#ifdef K_NORM\n" +" ,__global const float *kGamma\n" +" ,__private const float kEps\n" +"#endif\n" +" ) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" const int z=get_global_id(2);\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" +" const int fullHead=numHead+kvNumHead;\n" +"#if defined(Q_NORM) || defined(K_NORM)\n" +" if (x >= 1 || y >= outerSize || z >= fullHead) {\n" +" return;\n" +" }\n" +"#else\n" +" if (x >= halfD || y >= outerSize || z >= fullHead) {\n" +" return;\n" +" }\n" +"#endif\n" +" const int D=headDim;\n" +" bool isQ=(z= global_size_dim0 || (input2) >= global_size_dim1) { \ + return; \ + } + +__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void shared_gather_quant_buffer( + GLOBAL_SIZE_2_DIMS + __global OUTPUT_TYPE* output, +#ifdef USE_LOW_BIT_WEIGHT_INT8 + __global const char* weight, +#elif defined(USE_LOW_BIT_WEIGHT_INT4) + __global const uchar* weight, +#else + __global const FLOAT* weight, +#endif + __global const int* indices, + __global const FLOAT* dequantScaleOffset, + __private const int ic, + __private const int oc, + __private const int blockSize, + __private const float coef +) { + const int select_idx = get_global_id(0); + const int k4 = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(select_idx, k4); + + const int base_ic = k4 << 2; + if (base_ic >= ic) { + return; + } + + const int ocIndex = indices[select_idx]; + if (ocIndex < 0 || ocIndex >= oc) { + return; + } + + const int icC4 = (ic + 3) >> 2; + const int out_c_idx = ocIndex >> 2; + const int oc_in4 = ocIndex & 3; + const int ocBlock = ocIndex >> 3; + const int oc_in8 = ocIndex & 7; + const int dstChannelC4 = ((oc + 3) >> 2) << 2; + const int tileIndex = ocBlock * icC4 + k4; + +#ifdef USE_LOW_BIT_WEIGHT_INT8 + const int weightTileStride = 32; + const int weightBase = tileIndex * weightTileStride; +#elif defined(USE_LOW_BIT_WEIGHT_INT4) + const int weightTileStride = 16; + const int weightBase = tileIndex * weightTileStride; +#else + const int weightTileStride = 0; + const int weightBase = 0; +#endif + + const int outBase = select_idx * ic + base_ic; + COMPUTE_FLOAT4 out4 = (COMPUTE_FLOAT4)(0, 0, 0, 0); + + for (int i = 0; i < 4; ++i) { + const int icIndex = base_ic + i; + if (icIndex >= ic) { + break; + } + + const int blockIndex = icIndex / blockSize; + const int channelIndex = (out_c_idx << 2) + oc_in4; + int scaleIndex = blockIndex * dstChannelC4 + channelIndex; + +#ifdef ASYMMETRIC + scaleIndex = scaleIndex * 2; + FLOAT sRaw = dequantScaleOffset[scaleIndex + 0]; + FLOAT bRaw = dequantScaleOffset[scaleIndex + 1]; + COMPUTE_FLOAT scale = (COMPUTE_FLOAT)(convert_float(sRaw) / coef); + COMPUTE_FLOAT offset = (COMPUTE_FLOAT)(convert_float(bRaw) / coef); +#else + FLOAT sRaw = dequantScaleOffset[scaleIndex]; + COMPUTE_FLOAT scale = (COMPUTE_FLOAT)(convert_float(sRaw) / coef); + COMPUTE_FLOAT offset = (COMPUTE_FLOAT)0; +#endif + + COMPUTE_FLOAT wVal = (COMPUTE_FLOAT)0; +#ifdef USE_LOW_BIT_WEIGHT_INT8 + const int byteIndex = weightBase + i * 8 + oc_in8; + char qw = weight[byteIndex]; + wVal = (COMPUTE_FLOAT)qw; +#elif defined(USE_LOW_BIT_WEIGHT_INT4) + const int byteIndex = weightBase + i * 4 + (oc_in8 >> 1); + uchar packed = weight[byteIndex]; + int nibble = (oc_in8 & 1) == 0 ? ((packed >> 4) & 0x0F) : (packed & 0x0F); +#ifdef ASYMMETRIC + wVal = (COMPUTE_FLOAT)nibble; +#else + wVal = (COMPUTE_FLOAT)((int)nibble - 8); +#endif +#else + const int byteIndex = weightBase + i * 8 + oc_in8; + wVal = (COMPUTE_FLOAT)weight[byteIndex]; +#endif + + COMPUTE_FLOAT v = mad(wVal, scale, offset); + if (i == 0) { + out4.s0 = v; + } else if (i == 1) { + out4.s1 = v; + } else if (i == 2) { + out4.s2 = v; + } else { + out4.s3 = v; + } + } + + OUTPUT_TYPE4 outVec = CONVERT_OUTPUT4(out4); + if (base_ic + 3 < ic) { + vstore4(outVec, 0, output + outBase); + } else { + OUTPUT_TYPE* outPtr = (OUTPUT_TYPE*)(&outVec); + const int remain = ic - base_ic; + for (int i = 0; i < remain; ++i) { + output[outBase + i] = outPtr[i]; + } + } +} + +__kernel void shared_gather_quant_image( + GLOBAL_SIZE_2_DIMS + __global OUTPUT_TYPE* output, + __read_only image2d_t weight, + __global const int* indices, + __global const FLOAT* dequantScaleOffset, + __private const int ic, + __private const int oc, + __private const int blockSize, + __private const float coef +) { + const int select_idx = get_global_id(0); + const int k4 = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(select_idx, k4); + + const int base_ic = k4 << 2; + if (base_ic >= ic) { + return; + } + + const int ocIndex = indices[select_idx]; + if (ocIndex < 0 || ocIndex >= oc) { + return; + } + + const int out_c_idx = ocIndex >> 2; + const int oc_in4 = ocIndex & 3; + const int ocBlock = ocIndex >> 3; + const int oc_in8 = ocIndex & 7; + const int dstChannelC4 = ((oc + 3) >> 2) << 2; + const int outBase = select_idx * ic + base_ic; + COMPUTE_FLOAT4 out4 = (COMPUTE_FLOAT4)(0, 0, 0, 0); + +#ifdef USE_LOW_BIT_WEIGHT_INT4 + const uchar16 weightBytes = as_uchar16(read_imagei(weight, SAMPLER, (int2)(k4, ocBlock))); +#endif + + for (int i = 0; i < 4; ++i) { + const int icIndex = base_ic + i; + if (icIndex >= ic) { + break; + } + + const int blockIndex = icIndex / blockSize; + const int channelIndex = (out_c_idx << 2) + oc_in4; + int scaleIndex = blockIndex * dstChannelC4 + channelIndex; + +#ifdef ASYMMETRIC + scaleIndex = scaleIndex * 2; + FLOAT sRaw = dequantScaleOffset[scaleIndex + 0]; + FLOAT bRaw = dequantScaleOffset[scaleIndex + 1]; + COMPUTE_FLOAT scale = (COMPUTE_FLOAT)(convert_float(sRaw) / coef); + COMPUTE_FLOAT offset = (COMPUTE_FLOAT)(convert_float(bRaw) / coef); +#else + FLOAT sRaw = dequantScaleOffset[scaleIndex]; + COMPUTE_FLOAT scale = (COMPUTE_FLOAT)(convert_float(sRaw) / coef); + COMPUTE_FLOAT offset = (COMPUTE_FLOAT)0; +#endif + + COMPUTE_FLOAT wVal = (COMPUTE_FLOAT)0; +#ifdef USE_LOW_BIT_WEIGHT_INT8 + const int imageX = (k4 << 1) + (i >> 1); + const char16 weightBytes = as_char16(read_imagei(weight, SAMPLER, (int2)(imageX, ocBlock))); + char qw = weightBytes[(i & 1) * 8 + oc_in8]; + wVal = (COMPUTE_FLOAT)qw; +#elif defined(USE_LOW_BIT_WEIGHT_INT4) + uchar packed = weightBytes[i * 4 + (oc_in8 >> 1)]; + int nibble = (oc_in8 & 1) == 0 ? ((packed >> 4) & 0x0F) : (packed & 0x0F); +#ifdef ASYMMETRIC + wVal = (COMPUTE_FLOAT)nibble; +#else + wVal = (COMPUTE_FLOAT)((int)nibble - 8); +#endif +#else + const int imageX = (k4 << 1) + (i >> 1); + const char16 weightBytes = as_char16(read_imagei(weight, SAMPLER, (int2)(imageX, ocBlock))); + wVal = (COMPUTE_FLOAT)weightBytes[(i & 1) * 8 + oc_in8]; +#endif + + COMPUTE_FLOAT v = mad(wVal, scale, offset); + if (i == 0) { + out4.s0 = v; + } else if (i == 1) { + out4.s1 = v; + } else if (i == 2) { + out4.s2 = v; + } else { + out4.s3 = v; + } + } + + OUTPUT_TYPE4 outVec = CONVERT_OUTPUT4(out4); + if (base_ic + 3 < ic) { + vstore4(outVec, 0, output + outBase); + } else { + OUTPUT_TYPE* outPtr = (OUTPUT_TYPE*)(&outVec); + const int remain = ic - base_ic; + for (int i = 0; i < remain; ++i) { + output[outBase + i] = outPtr[i]; + } + } +} diff --git a/source/backend/opencl/execution/cl/shared_gather_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/shared_gather_buf_mnn_cl.cpp new file mode 100644 index 0000000000..3b1910243a --- /dev/null +++ b/source/backend/opencl/execution/cl/shared_gather_buf_mnn_cl.cpp @@ -0,0 +1,211 @@ +#include "opencl_source_map.hpp" +namespace MNN { +#ifndef MNN_OPENCL_BUFFER_CLOSED +const char* shared_gather_buf = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if ((input1) >= global_size_dim0 || (input2) >= global_size_dim1) { "" return; "" }\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"__kernel void shared_gather_quant_buffer(\n" +" GLOBAL_SIZE_2_DIMS\n" +" __global OUTPUT_TYPE* output,\n" +"#ifdef USE_LOW_BIT_WEIGHT_INT8\n" +" __global const char* weight,\n" +"#elif defined(USE_LOW_BIT_WEIGHT_INT4)\n" +" __global const uchar* weight,\n" +"#else\n" +" __global const FLOAT* weight,\n" +"#endif\n" +" __global const int* indices,\n" +" __global const FLOAT* dequantScaleOffset,\n" +" __private const int ic,\n" +" __private const int oc,\n" +" __private const int blockSize,\n" +" __private const float coef\n" +") {\n" +" const int select_idx=get_global_id(0);\n" +" const int k4=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(select_idx,k4);\n" +" const int base_ic=k4 << 2;\n" +" if (base_ic >= ic) {\n" +" return;\n" +" }\n" +" const int ocIndex=indices[select_idx];\n" +" if (ocIndex<0 || ocIndex >= oc) {\n" +" return;\n" +" }\n" +" const int icC4=(ic+3) >> 2;\n" +" const int out_c_idx=ocIndex >> 2;\n" +" const int oc_in4=ocIndex & 3;\n" +" const int ocBlock=ocIndex >> 3;\n" +" const int oc_in8=ocIndex & 7;\n" +" const int dstChannelC4=((oc+3) >> 2) << 2;\n" +" const int tileIndex=ocBlock*icC4+k4;\n" +"#ifdef USE_LOW_BIT_WEIGHT_INT8\n" +" const int weightTileStride=32;\n" +" const int weightBase=tileIndex*weightTileStride;\n" +"#elif defined(USE_LOW_BIT_WEIGHT_INT4)\n" +" const int weightTileStride=16;\n" +" const int weightBase=tileIndex*weightTileStride;\n" +"#else\n" +" const int weightTileStride=0;\n" +" const int weightBase=0;\n" +"#endif\n" +" const int outBase=select_idx*ic+base_ic;\n" +" COMPUTE_FLOAT4 out4=(COMPUTE_FLOAT4)(0,0,0,0);\n" +" for (int i=0; i<4; ++i) {\n" +" const int icIndex=base_ic+i;\n" +" if (icIndex >= ic) {\n" +" break;\n" +" }\n" +" const int blockIndex=icIndex/blockSize;\n" +" const int channelIndex=(out_c_idx << 2)+oc_in4;\n" +" int scaleIndex=blockIndex*dstChannelC4+channelIndex;\n" +"#ifdef ASYMMETRIC\n" +" scaleIndex=scaleIndex*2;\n" +" FLOAT sRaw=dequantScaleOffset[scaleIndex+0];\n" +" FLOAT bRaw=dequantScaleOffset[scaleIndex+1];\n" +" COMPUTE_FLOAT scale=(COMPUTE_FLOAT)(convert_float(sRaw)/coef);\n" +" COMPUTE_FLOAT offset=(COMPUTE_FLOAT)(convert_float(bRaw)/coef);\n" +"#else\n" +" FLOAT sRaw=dequantScaleOffset[scaleIndex];\n" +" COMPUTE_FLOAT scale=(COMPUTE_FLOAT)(convert_float(sRaw)/coef);\n" +" COMPUTE_FLOAT offset=(COMPUTE_FLOAT)0;\n" +"#endif\n" +" COMPUTE_FLOAT wVal=(COMPUTE_FLOAT)0;\n" +"#ifdef USE_LOW_BIT_WEIGHT_INT8\n" +" const int byteIndex=weightBase+i*8+oc_in8;\n" +" char qw=weight[byteIndex];\n" +" wVal=(COMPUTE_FLOAT)qw;\n" +"#elif defined(USE_LOW_BIT_WEIGHT_INT4)\n" +" const int byteIndex=weightBase+i*4+(oc_in8 >> 1);\n" +" uchar packed=weight[byteIndex];\n" +" int nibble=(oc_in8 & 1) == 0 ? ((packed >> 4) & 0x0F) : (packed & 0x0F);\n" +"#ifdef ASYMMETRIC\n" +" wVal=(COMPUTE_FLOAT)nibble;\n" +"#else\n" +" wVal=(COMPUTE_FLOAT)((int)nibble-8);\n" +"#endif\n" +"#else\n" +" const int byteIndex=weightBase+i*8+oc_in8;\n" +" wVal=(COMPUTE_FLOAT)weight[byteIndex];\n" +"#endif\n" +" COMPUTE_FLOAT v=mad(wVal,scale,offset);\n" +" if (i == 0) {\n" +" out4.s0=v;\n" +" } else if (i == 1) {\n" +" out4.s1=v;\n" +" } else if (i == 2) {\n" +" out4.s2=v;\n" +" } else {\n" +" out4.s3=v;\n" +" }\n" +" }\n" +" OUTPUT_TYPE4 outVec=CONVERT_OUTPUT4(out4);\n" +" if (base_ic+3= ic) {\n" +" return;\n" +" }\n" +" const int ocIndex=indices[select_idx];\n" +" if (ocIndex<0 || ocIndex >= oc) {\n" +" return;\n" +" }\n" +" const int out_c_idx=ocIndex >> 2;\n" +" const int oc_in4=ocIndex & 3;\n" +" const int ocBlock=ocIndex >> 3;\n" +" const int oc_in8=ocIndex & 7;\n" +" const int dstChannelC4=((oc+3) >> 2) << 2;\n" +" const int outBase=select_idx*ic+base_ic;\n" +" COMPUTE_FLOAT4 out4=(COMPUTE_FLOAT4)(0,0,0,0);\n" +"#ifdef USE_LOW_BIT_WEIGHT_INT4\n" +" const uchar16 weightBytes=as_uchar16(read_imagei(weight,SAMPLER,(int2)(k4,ocBlock)));\n" +"#endif\n" +" for (int i=0; i<4; ++i) {\n" +" const int icIndex=base_ic+i;\n" +" if (icIndex >= ic) {\n" +" break;\n" +" }\n" +" const int blockIndex=icIndex/blockSize;\n" +" const int channelIndex=(out_c_idx << 2)+oc_in4;\n" +" int scaleIndex=blockIndex*dstChannelC4+channelIndex;\n" +"#ifdef ASYMMETRIC\n" +" scaleIndex=scaleIndex*2;\n" +" FLOAT sRaw=dequantScaleOffset[scaleIndex+0];\n" +" FLOAT bRaw=dequantScaleOffset[scaleIndex+1];\n" +" COMPUTE_FLOAT scale=(COMPUTE_FLOAT)(convert_float(sRaw)/coef);\n" +" COMPUTE_FLOAT offset=(COMPUTE_FLOAT)(convert_float(bRaw)/coef);\n" +"#else\n" +" FLOAT sRaw=dequantScaleOffset[scaleIndex];\n" +" COMPUTE_FLOAT scale=(COMPUTE_FLOAT)(convert_float(sRaw)/coef);\n" +" COMPUTE_FLOAT offset=(COMPUTE_FLOAT)0;\n" +"#endif\n" +" COMPUTE_FLOAT wVal=(COMPUTE_FLOAT)0;\n" +"#ifdef USE_LOW_BIT_WEIGHT_INT8\n" +" const int imageX=(k4 << 1)+(i >> 1);\n" +" const char16 weightBytes=as_char16(read_imagei(weight,SAMPLER,(int2)(imageX,ocBlock)));\n" +" char qw=weightBytes[(i & 1)*8+oc_in8];\n" +" wVal=(COMPUTE_FLOAT)qw;\n" +"#elif defined(USE_LOW_BIT_WEIGHT_INT4)\n" +" uchar packed=weightBytes[i*4+(oc_in8 >> 1)];\n" +" int nibble=(oc_in8 & 1) == 0 ? ((packed >> 4) & 0x0F) : (packed & 0x0F);\n" +"#ifdef ASYMMETRIC\n" +" wVal=(COMPUTE_FLOAT)nibble;\n" +"#else\n" +" wVal=(COMPUTE_FLOAT)((int)nibble-8);\n" +"#endif\n" +"#else\n" +" const int imageX=(k4 << 1)+(i >> 1);\n" +" const char16 weightBytes=as_char16(read_imagei(weight,SAMPLER,(int2)(imageX,ocBlock)));\n" +" wVal=(COMPUTE_FLOAT)weightBytes[(i & 1)*8+oc_in8];\n" +"#endif\n" +" COMPUTE_FLOAT v=mad(wVal,scale,offset);\n" +" if (i == 0) {\n" +" out4.s0=v;\n" +" } else if (i == 1) {\n" +" out4.s1=v;\n" +" } else if (i == 2) {\n" +" out4.s2=v;\n" +" } else {\n" +" out4.s3=v;\n" +" }\n" +" }\n" +" OUTPUT_TYPE4 outVec=CONVERT_OUTPUT4(out4);\n" +" if (base_ic+3& inputs, const std::vector& outputs, Context& context, CommandBuffer& res) const override { + if (inputs.size() == 1) { + std::shared_ptr cmdP(new Command); + auto& cmd = *cmdP; + cmd.op = op; + cmd.inputs = inputs; + cmd.outputs = outputs; + res.command.emplace_back(std::move(cmdP)); + return true; + } _computeGather(inputs, outputs, context, res, op); return true; } virtual bool onRecompute(const Op* op, const std::vector& inputs, const std::vector& outputs, Context& context, CommandBuffer& cmd) const override { - if (cmd.command.size() != 1) { + if (cmd.command.size() != 1 || inputs.size() == 1) { return false; } int axis = 0; @@ -362,23 +371,25 @@ class GeometryGatherND : public GeometryComputer { paramSize = dimCount; } // recompute reshape - auto des = TensorUtils::getDescribeOrigin(reshapeIndice.get()); - des->offset = 0; - auto nativeDes = TensorUtils::getDescribe(reshapeIndice.get()); - nativeDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; - nativeDes->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, mSliceN * indiceNd)}; + auto desOrigin = TensorUtils::getDescribeOrigin(reshapeIndice.get()); + desOrigin->mem = nullptr; + auto des = TensorUtils::getDescribe(reshapeIndice.get()); + desOrigin->offset = 0; + des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + des->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, mSliceN * indiceNd)}; // recompute broadcast - des = TensorUtils::getDescribeOrigin(broadcastStride.get()); - des->offset = 0; - nativeDes = TensorUtils::getDescribe(broadcastStride.get()); - nativeDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; - nativeDes->regions[0].origin = constStride.get(); - nativeDes->regions[0].size[0] = 1; - nativeDes->regions[0].size[1] = mSliceN; - nativeDes->regions[0].size[2] = indiceNd; - nativeDes->regions[0].dst.stride[0] = indiceNd*mSliceN; - nativeDes->regions[0].dst.stride[1] = indiceNd; - nativeDes->regions[0].dst.stride[2] = 1; + desOrigin = TensorUtils::getDescribeOrigin(broadcastStride.get()); + desOrigin->mem = nullptr; + des = TensorUtils::getDescribe(broadcastStride.get()); + desOrigin->offset = 0; + des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + des->regions[0].origin = constStride.get(); + des->regions[0].size[0] = 1; + des->regions[0].size[1] = mSliceN; + des->regions[0].size[2] = indiceNd; + des->regions[0].dst.stride[0] = indiceNd*mSliceN; + des->regions[0].dst.stride[1] = indiceNd; + des->regions[0].dst.stride[2] = 1; // recompute loop auto loopCmd = cmd.command[cmd.command.size() - 1]; auto param = loopCmd->op->main_as_LoopParam(); diff --git a/source/geometry/GeometryLayernorm.cpp b/source/geometry/GeometryLayernorm.cpp index 120a95c91f..5dd92d0fed 100644 --- a/source/geometry/GeometryLayernorm.cpp +++ b/source/geometry/GeometryLayernorm.cpp @@ -16,18 +16,18 @@ class GeometryLayerNorm : public GeometryComputer { virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, Context& context, CommandBuffer& res) const override { /* Target: Ensure reduce dimensions must be a sequence subset [-rank,...,rank-1] */ - MNN_ASSERT(1 == outputs.size()); - MNN_ASSERT(1 == inputs.size()); auto layernorm = op->main_as_LayerNorm(); - if (!layernorm->axis()) { + if (!layernorm->axis() || op->defaultDimentionFormat() == MNN_DATA_FORMAT_NC4HW4) { std::shared_ptr cmdP(new Command); auto& cmd = *cmdP; cmd.op = op; - cmd.inputs = {inputs[0]}; + cmd.inputs = inputs; cmd.outputs = std::move(outputs); res.command.emplace_back(std::move(cmdP)); return true; } + MNN_ASSERT(1 == outputs.size()); + MNN_ASSERT(1 == inputs.size()); auto reduceDims = layernorm->axis()->data(); int reduceDimensionCount = layernorm->axis()->size(); auto inputShape = inputs[0]->shape(); diff --git a/source/shape/ShapeAttention.cpp b/source/shape/ShapeAttention.cpp index 75efa2321f..14225a8f9a 100644 --- a/source/shape/ShapeAttention.cpp +++ b/source/shape/ShapeAttention.cpp @@ -12,6 +12,16 @@ namespace MNN { #ifdef MNN_SUPPORT_TRANSFORMER_FUSE +class RoPESizeComputer : public SizeComputer { + virtual bool onComputeSize(const MNN::Op* op, const std::vector& inputs, + const std::vector& outputs) const override { + MNN_ASSERT(inputs.size() == 6); + MNN_ASSERT(outputs.size() == 2); + TensorUtils::copyShape(inputs[0], outputs[0], true); + TensorUtils::copyShape(inputs[1], outputs[1], true); + return true; + } +}; class FmhaV2SizeComputer : public SizeComputer { virtual bool onComputeSize(const MNN::Op* op, const std::vector& inputs, @@ -63,12 +73,22 @@ class AttentionSizeComputer : public SizeComputer { const std::vector& outputs) const override { auto input = inputs[0], output = outputs[0]; MNN_ASSERT(input->buffer().dimensions == 4); - output->buffer().dim[0].extent = input->buffer().dim[0].extent; - output->buffer().dim[1].extent = input->buffer().dim[1].extent; - output->buffer().dim[2].extent = input->buffer().dim[2].extent * input->buffer().dim[3].extent; - output->buffer().dimensions = 3; - output->buffer().type = input->buffer().type; - TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat; + if (op->main_as_AttentionParam()->output_c4()) { + output->buffer().dim[0].extent = input->buffer().dim[0].extent * input->buffer().dim[1].extent; + output->buffer().dim[1].extent = input->buffer().dim[2].extent * input->buffer().dim[3].extent; + output->buffer().dim[2].extent = 1; + output->buffer().dim[3].extent = 1; + output->buffer().dimensions = 4; + output->buffer().type = input->buffer().type; + TensorUtils::getDescribe(output)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4; + } else { + output->buffer().dim[0].extent = input->buffer().dim[0].extent; + output->buffer().dim[1].extent = input->buffer().dim[1].extent; + output->buffer().dim[2].extent = input->buffer().dim[2].extent * input->buffer().dim[3].extent; + output->buffer().dimensions = 3; + output->buffer().type = input->buffer().type; + TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat; + } return true; } virtual float onComputeFlops(const MNN::Op* op, const std::vector& inputs, const std::vector& outputs) const override { @@ -127,9 +147,9 @@ class LinearAttentionSizeComputer : public SizeComputer { REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(FmhaV2SizeComputer, OpType_FmhaV2); REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(FmhcaSizeComputer, OpType_Fmhca); +REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(RoPESizeComputer, OpType_RoPE); REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(AttentionSizeComputer, OpType_Attention); REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(LinearAttentionSizeComputer, OpType_LinearAttention); #endif } // namespace MNN - diff --git a/source/shape/ShapeGatherV2.cpp b/source/shape/ShapeGatherV2.cpp index 44a252df4e..c857854ff7 100644 --- a/source/shape/ShapeGatherV2.cpp +++ b/source/shape/ShapeGatherV2.cpp @@ -14,8 +14,32 @@ namespace MNN { class GatherV2Computer : public SizeComputer { virtual bool onComputeSize(const MNN::Op* op, const std::vector& inputs, const std::vector& outputs) const override { - auto params = inputs[0]; - auto indices = inputs[1]; + Tensor* indices = nullptr; + int32_t paramShape[MNN_MAX_TENSOR_DIM]; + int32_t paramDim = 0; + if (inputs.size() == 1) { + if (nullptr == op->main_as_Input()) { + MNN_ERROR("One input GatherV2 should has blob parameter\n"); + return false; + } + indices = inputs[0]; + auto blob = op->main_as_Input(); + outputs[0]->setType(blob->dtype()); + if (nullptr != blob->dims()) { + paramDim = blob->dims()->size(); + ::memcpy(paramShape, blob->dims()->data(), paramDim * sizeof(int)); + } + TensorUtils::getDescribe(outputs[0])->dimensionFormat = blob->dformat(); + } else { + auto params = inputs[0]; + indices = inputs[1]; + paramDim = params->dimensions(); + for (int i = 0; i < params->dimensions(); ++i) { + paramShape[i] = params->length(i); + } + outputs[0]->buffer().type = params->buffer().type; + TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; + } if (indices->getType().code != halide_type_int) { return false; } @@ -28,41 +52,39 @@ class GatherV2Computer : public SizeComputer { axis = op->main_as_Axis()->axis(); } - if( axis <= -params->buffer().dimensions || axis >= params->buffer().dimensions) { + if( axis <= -paramDim || axis >= paramDim) { return false; } if (axis < 0) { - axis = params->buffer().dimensions + axis; + axis = paramDim + axis; } - const int gather_dim_size = params->buffer().dim[axis].extent; + const int gather_dim_size = paramShape[axis]; MNN_ASSERT(gather_dim_size <= std::numeric_limits::max()); - const int numDimensions = params->buffer().dimensions + indices->buffer().dimensions - 1; + const int numDimensions = paramDim + indices->buffer().dimensions - 1; MNN_ASSERT(axis <= numDimensions); std::vector result_shape; for (int i = 0; i < axis; i++) { - result_shape.push_back(params->buffer().dim[i].extent); + result_shape.push_back(paramShape[i]); } for (int i = 0; i < indices->buffer().dimensions; i++) { result_shape.push_back(indices->buffer().dim[i].extent); } - for (int i = axis + 1; i < params->buffer().dimensions; i++) { - result_shape.push_back(params->buffer().dim[i].extent); + for (int i = axis + 1; i < paramDim; i++) { + result_shape.push_back(paramShape[i]); } outputs[0]->buffer().dimensions = (int)result_shape.size(); - outputs[0]->buffer().type = params->buffer().type; for (int i = 0; i < result_shape.size(); i++) { outputs[0]->buffer().dim[i].extent = result_shape.at(i); } - TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; return true; } }; diff --git a/source/shape/ShapeRegister.cpp b/source/shape/ShapeRegister.cpp index c4b30e1b9d..a2c237074a 100644 --- a/source/shape/ShapeRegister.cpp +++ b/source/shape/ShapeRegister.cpp @@ -121,6 +121,7 @@ extern void ___SplitGeLUSizeComputer__OpType_SplitGeLU__(); extern void ___SeqLen2SpatialSizeComputer__OpType_SeqLen2Spatial__(); extern void ___FmhaV2SizeComputer__OpType_FmhaV2__(); extern void ___FmhcaSizeComputer__OpType_Fmhca__(); +extern void ___RoPESizeComputer__OpType_RoPE__(); extern void ___AttentionSizeComputer__OpType_Attention__(); extern void ___LinearAttentionSizeComputer__OpType_LinearAttention__(); #endif @@ -245,6 +246,7 @@ ___SplitGeLUSizeComputer__OpType_SplitGeLU__(); ___SeqLen2SpatialSizeComputer__OpType_SeqLen2Spatial__(); ___FmhaV2SizeComputer__OpType_FmhaV2__(); ___FmhcaSizeComputer__OpType_Fmhca__(); +___RoPESizeComputer__OpType_RoPE__(); ___AttentionSizeComputer__OpType_Attention__(); ___LinearAttentionSizeComputer__OpType_LinearAttention__(); #endif diff --git a/test/op/AttentionTest.cpp b/test/op/AttentionTest.cpp index 845e65dff5..e9655adf2c 100644 --- a/test/op/AttentionTest.cpp +++ b/test/op/AttentionTest.cpp @@ -28,7 +28,7 @@ const int pastLength = 101; #define GENERATE_TOKENS 128 static KVMeta gMeta; -static std::shared_ptr _makeAttentionModule(int attentionMode = 8) { +static std::shared_ptr _makeAttentionModule(int attentionMode = 8, bool outputC4 = false) { auto Q = _Input(); auto K = _Input(); auto V = _Input(); @@ -38,6 +38,7 @@ static std::shared_ptr _makeAttentionModule(int attentionMode = 8) { attention->main.type = MNN::OpParameter_AttentionParam; attention->main.value = new MNN::AttentionParamT; attention->main.AsAttentionParam()->kv_cache = true; + attention->main.AsAttentionParam()->output_c4 = outputC4; auto o = Variable::create(Expr::create(attention.get(), {Q, K, V, mask})); auto buffer = Variable::save({o}); MNN::ScheduleConfig config; @@ -582,5 +583,52 @@ SpeedAttentionTest() = default; }; MNNTestSuiteRegister(AttentionTest, "op/attention"); + +class AttentionC4Test : public AttentionTest { +public: + AttentionC4Test() = default; + virtual ~AttentionC4Test() = default; + + bool compareC4Result(int seqLen) { + const float* resultPtr = Output->readMap(); + const int hidden = NumHead * HeadDim; + std::vector actual(seqLen * hidden); + std::vector expected(seqLen * hidden); + for (int i = 0; i < seqLen; ++i) { + for (int h = 0; h < NumHead; ++h) { + for (int d = 0; d < HeadDim; ++d) { + int c = h * HeadDim + d; + int c4Index = (c % 4) + 4 * i + 4 * seqLen * (c / 4); + int logicalIndex = i * hidden + c; + actual[logicalIndex] = resultPtr[c4Index]; + expected[logicalIndex] = expected_result[i][h][d]; + } + } + } + if (!checkVectorByRelativeError(actual.data(), expected.data(), actual.size(), 0.02f)) { + MNN_ERROR("AttentionC4Test failed!\n"); + return false; + } + return true; + } + + virtual bool run(int precision) { + srand(2024); + const int seqLen = 10; + std::shared_ptr naiveAttention(new NaiveAttention); + generateInput(seqLen, precision); + generateMask(seqLen, seqLen); + expected_result = naiveAttention->onExecute(query, key, value, mask, seqLen); + gMeta.previous = 0; + gMeta.remove = 0; + gMeta.add = seqLen; + auto attn = _makeAttentionModule(8, true); + Output = attn->onForward({Query, Key, Value, Mask})[0]; + gMeta.sync(); + return compareC4Result(seqLen); + } +}; + +MNNTestSuiteRegister(AttentionC4Test, "op/attention_c4"); MNNTestSuiteRegister(SpeedAttentionTest, "speed/attention"); #endif diff --git a/test/op/BinaryOPTest.cpp b/test/op/BinaryOPTest.cpp index 44ec301a53..55f0d88f51 100644 --- a/test/op/BinaryOPTest.cpp +++ b/test/op/BinaryOPTest.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "MNNTestSuite.h" #include "TestUtils.h" #include "MNN_generated.h" @@ -208,6 +209,21 @@ class MultiplyInt8Test : public BinaryTestCommon { } }; +class MultiplySiluTest : public BinaryTestCommon { +public: + virtual ~MultiplySiluTest() = default; + virtual bool run(int precision) { + std::vector x = {-2.0f, -1.0f, 0.5f, 1.5f, 2.0f, -3.0f}; + std::vector y = {-1.5f, -0.5f, 0.25f, 1.0f, 2.0f, 3.0f}; + std::vector expected(x.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = x[i] * (y[i] / (1.0f + std::exp(-y[i]))); + } + return test(MNN::Express::_MulSilu, "MultiplySiluTest", 0.01f, x, y, expected, {2, 3}, {2, 3}, + {2, 3}); + } +}; + class DivideTest : public BinaryTestCommon { public: virtual ~DivideTest() = default; @@ -714,6 +730,7 @@ MNNTestSuiteRegister(BinaryBroadcastShapeTest, "op/binary/broadcastShapeTest"); MNNTestSuiteRegister(AddTest, "op/binary/add"); MNNTestSuiteRegister(SubtractTest, "op/binary/subtract"); MNNTestSuiteRegister(MultiplyTest, "op/binary/multiply"); +MNNTestSuiteRegister(MultiplySiluTest, "op/binary/mulsilu"); MNNTestSuiteRegister(DivideTest, "op/binary/divide"); MNNTestSuiteRegister(PowTest, "op/binary/pow"); MNNTestSuiteRegister(MinimumTest, "op/binary/minimum"); @@ -752,4 +769,3 @@ MNNTestSuiteRegister(SquaredDifferenceInt8Test, "op/binary/sqdInt8"); MNNTestSuiteRegister(AddC4Test, "op/binary/addC4"); MNNTestSuiteRegister(AddBroastTest, "op/binary/AddBroast"); - diff --git a/test/op/LayerNormTest.cpp b/test/op/LayerNormTest.cpp index b6913edb2b..c13fa111cf 100644 --- a/test/op/LayerNormTest.cpp +++ b/test/op/LayerNormTest.cpp @@ -141,6 +141,53 @@ static bool testKernel (std::vector inputdata, std::vector targetd } return true; } + +static int nc4hw4Offset(int n, int c, int plane, int batch) { + return (c % 4) + 4 * plane * n + 4 * plane * batch * (c / 4); +} + +static void computeChannelLayerNorm(const std::vector& input, std::vector& output, int batch, int channel, + const std::vector& gamma, const std::vector& beta) { + output.resize(batch * channel); + for (int n = 0; n < batch; ++n) { + float mean = 0.0f; + for (int c = 0; c < channel; ++c) { + mean += input[n * channel + c]; + } + mean /= channel; + float variance = 0.0f; + for (int c = 0; c < channel; ++c) { + float v = input[n * channel + c] - mean; + variance += v * v; + } + variance /= channel; + float inv = 1.0f / std::sqrt(variance + eps); + for (int c = 0; c < channel; ++c) { + float v = (input[n * channel + c] - mean) * inv; + if (!gamma.empty()) { + v = v * gamma[c] + (beta.empty() ? 0.0f : beta[c]); + } + output[n * channel + c] = v; + } + } +} + +static bool checkNC4HW4Logical(VARP output, const std::vector& expected, int batch, int channel, + const char* testName) { + auto ptr = output->readMap(); + std::vector actual(batch * channel); + for (int n = 0; n < batch; ++n) { + for (int c = 0; c < channel; ++c) { + actual[n * channel + c] = ptr[nc4hw4Offset(n, c, 1, batch)]; + } + } + if (!checkVector(actual.data(), expected.data(), batch * channel, 0.02f)) { + MNN_ERROR("%s failed!\n", testName); + return false; + } + return true; +} + class LayerNormTest : public MNNTestCase { public: virtual ~LayerNormTest() = default; @@ -241,3 +288,92 @@ class LayerNormTest : public MNNTestCase { } }; MNNTestSuiteRegister(LayerNormTest, "op/layernorm"); + +class LayerNormC4Test : public MNNTestCase { +public: + virtual ~LayerNormC4Test() = default; + virtual bool run(int precision) { + const int batch = 2; + const int channel = 8; + const int physicalSize = batch * UP_DIV(channel, 4) * 4; + std::vector logical = {-1.0f, 0.5f, 2.0f, -0.5f, 1.5f, 0.25f, -1.25f, 0.75f, + 3.0f, 1.0f, -2.0f, 0.0f, 4.0f, -1.5f, 2.5f, -0.25f}; + std::vector packed(physicalSize, 0.0f); + for (int n = 0; n < batch; ++n) { + for (int c = 0; c < channel; ++c) { + packed[nc4hw4Offset(n, c, 1, batch)] = logical[n * channel + c]; + } + } + std::vector gamma = {0.5f, 0.75f, 1.0f, 1.25f, 1.5f, 0.8f, 1.2f, 0.6f}; + std::vector beta = {0.1f, -0.2f, 0.3f, -0.4f, 0.5f, 0.05f, -0.15f, 0.25f}; + auto input = _Input({batch, channel, 1, 1}, NC4HW4); + ::memcpy(input->writeMap(), packed.data(), packed.size() * sizeof(float)); + input->unMap(); + std::unique_ptr op(new OpT); + op->main.type = OpParameter_LayerNorm; + op->type = OpType_LayerNorm; + op->defaultDimentionFormat = MNN_DATA_FORMAT_NC4HW4; + op->main.value = new LayerNormT; + op->main.AsLayerNorm()->gamma = gamma; + op->main.AsLayerNorm()->beta = beta; + op->main.AsLayerNorm()->epsilon = eps; + op->main.AsLayerNorm()->axis = {1}; + auto output = Variable::create(Expr::create(std::move(op), {input})); + std::vector expected; + computeChannelLayerNorm(logical, expected, batch, channel, gamma, beta); + return checkNC4HW4Logical(output, expected, batch, channel, "LayerNormC4Test"); + } +}; + +class BinaryLayerNormC4Test : public MNNTestCase { +public: + virtual ~BinaryLayerNormC4Test() = default; + virtual bool run(int precision) { + const int batch = 2; + const int channel = 8; + const int physicalSize = batch * UP_DIV(channel, 4) * 4; + std::vector logical0 = {-1.0f, 0.5f, 2.0f, -0.5f, 1.5f, 0.25f, -1.25f, 0.75f, + 3.0f, 1.0f, -2.0f, 0.0f, 4.0f, -1.5f, 2.5f, -0.25f}; + std::vector logical1 = {0.25f, -0.5f, 0.75f, 1.0f, -1.25f, 0.5f, -0.75f, 1.25f, + -0.75f, 0.5f, 1.25f, -1.0f, 0.25f, 0.75f, -0.5f, 1.5f}; + std::vector packed0(physicalSize, 0.0f), packed1(physicalSize, 0.0f), sumLogical(batch * channel); + for (int n = 0; n < batch; ++n) { + for (int c = 0; c < channel; ++c) { + int logicalIndex = n * channel + c; + int packedIndex = nc4hw4Offset(n, c, 1, batch); + packed0[packedIndex] = logical0[logicalIndex]; + packed1[packedIndex] = logical1[logicalIndex]; + sumLogical[logicalIndex] = logical0[logicalIndex] + logical1[logicalIndex]; + } + } + std::vector gamma = {1.0f, 0.5f, 1.5f, 0.75f, 1.25f, 0.8f, 1.2f, 0.6f}; + std::vector beta(channel, 0.0f); + auto input0 = _Input({batch, channel, 1, 1}, NC4HW4); + auto input1 = _Input({batch, channel, 1, 1}, NC4HW4); + ::memcpy(input0->writeMap(), packed0.data(), packed0.size() * sizeof(float)); + ::memcpy(input1->writeMap(), packed1.data(), packed1.size() * sizeof(float)); + input0->unMap(); + input1->unMap(); + + std::unique_ptr op(new OpT); + op->main.type = OpParameter_LayerNorm; + op->type = OpType_LayerNorm; + op->defaultDimentionFormat = MNN_DATA_FORMAT_NC4HW4; + op->main.value = new LayerNormT; + op->main.AsLayerNorm()->gamma = gamma; + op->main.AsLayerNorm()->beta = beta; + op->main.AsLayerNorm()->epsilon = eps; + op->main.AsLayerNorm()->axis = {1}; + auto expr = Expr::create(std::move(op), {input0, input1}, 2); + auto sumOutput = Variable::create(expr, 0); + auto normOutput = Variable::create(expr, 1); + + std::vector expectedNorm; + computeChannelLayerNorm(sumLogical, expectedNorm, batch, channel, gamma, beta); + return checkNC4HW4Logical(sumOutput, sumLogical, batch, channel, "BinaryLayerNormC4SumTest") && + checkNC4HW4Logical(normOutput, expectedNorm, batch, channel, "BinaryLayerNormC4NormTest"); + } +}; + +MNNTestSuiteRegister(LayerNormC4Test, "op/layernorm/c4"); +MNNTestSuiteRegister(BinaryLayerNormC4Test, "op/layernorm/c4_binary"); diff --git a/test/op/RoPETest.cpp b/test/op/RoPETest.cpp new file mode 100644 index 0000000000..fcd156d69b --- /dev/null +++ b/test/op/RoPETest.cpp @@ -0,0 +1,117 @@ +// +// RoPETest.cpp +// MNNTests +// +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#include +#include +#include +#include +#include "MNNTestSuite.h" +#include "TestUtils.h" + +using namespace MNN; +using namespace MNN::Express; + +static EXPRP _RoPEExpr(VARP q, VARP k, VARP cosEven, VARP cosOdd, VARP sinEven, VARP sinOdd, int ropeCutHeadDim) { + std::unique_ptr op(new OpT); + op->type = OpType_RoPE; + op->main.type = OpParameter_Extra; + op->main.value = new ExtraT; + std::unique_ptr attr(new AttributeT); + attr->key = "rope_cut_head_dim"; + attr->i = ropeCutHeadDim; + op->main.AsExtra()->attr.emplace_back(std::move(attr)); + return Expr::create(std::move(op), {q, k, cosEven, cosOdd, sinEven, sinOdd}, 2); +} + +static void computeRopeExpected(const std::vector& input, std::vector& output, + const std::vector& cosEven, const std::vector& cosOdd, + const std::vector& sinEven, const std::vector& sinOdd, int outer, + int head, int headDim, int ropeCutHeadDim) { + output = input; + int halfDim = headDim / 2; + int ropeHalfDim = std::min(ropeCutHeadDim / 2, halfDim); + for (int o = 0; o < outer; ++o) { + for (int h = 0; h < head; ++h) { + for (int i = 0; i < ropeHalfDim; ++i) { + int base = (o * head + h) * headDim; + int trig = o * halfDim + i; + float evenVal = input[base + i]; + float oddVal = input[base + i + halfDim]; + output[base + i] = evenVal * cosEven[trig] - oddVal * sinEven[trig]; + output[base + i + halfDim] = oddVal * cosOdd[trig] + evenVal * sinOdd[trig]; + } + } + } +} + +class RoPETest : public MNNTestCase { +public: + virtual ~RoPETest() = default; + virtual bool run(int precision) { + const int batch = 1; + const int seqLen = 2; + const int qHead = 2; + const int kHead = 1; + const int headDim = 6; + const int halfDim = headDim / 2; + const int outer = batch * seqLen; + const int ropeCutHeadDim = 4; + + std::vector qData = { + 0.2f, -0.4f, 0.6f, 1.0f, -1.2f, 1.4f, + -0.7f, 0.9f, -1.1f, 0.3f, 0.5f, -0.8f, + 1.2f, -1.0f, 0.4f, -0.6f, 0.8f, -0.2f, + -1.4f, 1.1f, -0.9f, 0.7f, -0.5f, 0.3f + }; + std::vector kData = { + -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, + 1.0f, -1.1f, 1.2f, -1.3f, 1.4f, -1.5f + }; + std::vector cosEven = {0.90f, 0.80f, 0.70f, 0.60f, 0.50f, 0.40f}; + std::vector cosOdd = {0.91f, 0.81f, 0.71f, 0.61f, 0.51f, 0.41f}; + std::vector sinEven = {0.10f, -0.20f, 0.30f, -0.40f, 0.50f, -0.60f}; + std::vector sinOdd = {0.11f, -0.21f, 0.31f, -0.41f, 0.51f, -0.61f}; + + auto q = _Input({batch, seqLen, qHead, headDim}, NCHW); + auto k = _Input({batch, seqLen, kHead, headDim}, NCHW); + auto c0 = _Input({batch, seqLen, halfDim}, NCHW); + auto c1 = _Input({batch, seqLen, halfDim}, NCHW); + auto s0 = _Input({batch, seqLen, halfDim}, NCHW); + auto s1 = _Input({batch, seqLen, halfDim}, NCHW); + ::memcpy(q->writeMap(), qData.data(), qData.size() * sizeof(float)); + ::memcpy(k->writeMap(), kData.data(), kData.size() * sizeof(float)); + ::memcpy(c0->writeMap(), cosEven.data(), cosEven.size() * sizeof(float)); + ::memcpy(c1->writeMap(), cosOdd.data(), cosOdd.size() * sizeof(float)); + ::memcpy(s0->writeMap(), sinEven.data(), sinEven.size() * sizeof(float)); + ::memcpy(s1->writeMap(), sinOdd.data(), sinOdd.size() * sizeof(float)); + q->unMap(); + k->unMap(); + c0->unMap(); + c1->unMap(); + s0->unMap(); + s1->unMap(); + + auto expr = _RoPEExpr(q, k, c0, c1, s0, s1, ropeCutHeadDim); + auto qOut = Variable::create(expr, 0); + auto kOut = Variable::create(expr, 1); + std::vector qExpected, kExpected; + computeRopeExpected(qData, qExpected, cosEven, cosOdd, sinEven, sinOdd, outer, qHead, headDim, ropeCutHeadDim); + computeRopeExpected(kData, kExpected, cosEven, cosOdd, sinEven, sinOdd, outer, kHead, headDim, ropeCutHeadDim); + if (!checkVector(qOut->readMap(), qExpected.data(), qExpected.size(), 0.01f) || + !checkVector(kOut->readMap(), kExpected.data(), kExpected.size(), 0.01f)) { + MNN_ERROR("RoPETest failed!\n"); + return false; + } + return true; + } +}; + +MNNTestSuiteRegister(RoPETest, "op/rope"); + +#endif diff --git a/test/op/SharedGatherTest.cpp b/test/op/SharedGatherTest.cpp new file mode 100644 index 0000000000..3731e21d14 --- /dev/null +++ b/test/op/SharedGatherTest.cpp @@ -0,0 +1,160 @@ +// +// SharedGatherTest.cpp +// MNNTests +// +// Copyright © 2018, Alibaba Group Holding Limited +// + +// SharedGather requires the int8/low-memory weight quant path. The +// corresponding executor (DenseConvInt8TiledExecutor::onClone -> SharedGather) +// is only selected by ConvolutionFloatFactory when MNN_LOW_MEMORY is enabled. +// Without it, the conv base falls back to DenseConvolutionTiledExecutor, which +// cannot serve SharedGather, so this test cannot validate the feature and must +// be skipped to avoid spurious failures in non-low-memory CI builds. +#ifdef MNN_LOW_MEMORY + +#include +#include +#include +#include +#include "MNNTestSuite.h" +#include "TestUtils.h" +#include "core/IDSTEncoder.hpp" + +using namespace MNN; +using namespace MNN::Express; + +static std::shared_ptr makeSharedGatherRuntime() { + auto status = MNNTestSuite::get()->pStaus; + BackendConfig backendConfig; + backendConfig.precision = static_cast(status.precision); + backendConfig.memory = BackendConfig::Memory_Low; + ScheduleConfig config; + config.type = static_cast(status.forwardType); + config.backendConfig = &backendConfig; + config.numThread = status.thread > 0 ? status.thread : 1; + return std::shared_ptr(Executor::RuntimeManager::createRuntimeManager(config), + Executor::RuntimeManager::destroy); +} + +static VARP makeSharedConv(VARP input, const std::vector& weight, int ic, int oc) { + std::unique_ptr conv(new OpT); + conv->type = OpType_Convolution; + conv->name = "shared_weight"; + conv->main.type = OpParameter_Convolution2D; + conv->main.value = new Convolution2DT; + auto conv2D = conv->main.AsConvolution2D(); + conv2D->common.reset(new Convolution2DCommonT); + conv2D->common->kernelX = 1; + conv2D->common->kernelY = 1; + conv2D->common->strideX = 1; + conv2D->common->strideY = 1; + conv2D->common->dilateX = 1; + conv2D->common->dilateY = 1; + conv2D->common->group = 1; + conv2D->common->inputCount = ic; + conv2D->common->outputCount = oc; + conv2D->bias.resize(oc, 0.0f); + + std::vector scale; + scale.reserve(oc * 2); + for (int o = 0; o < oc; ++o) { + scale.emplace_back(-1.0f + 0.05f * o); + scale.emplace_back(0.125f + 0.01f * o); + } + IDSTEncoder::EncodeOptions options; + options.bits = 4; + conv2D->quanParameter = IDSTEncoder::encode(weight.data(), scale, ic, oc, true, nullptr, -8, options); + + auto expr = Expr::create(std::move(conv), {input}); + expr->setName("shared_weight"); + auto output = Variable::create(expr); + output->setName("shared_weight"); + return output; +} + +static VARP makeSharedGather(VARP indices, int ic, int oc) { + std::unique_ptr gather(new OpT); + gather->type = OpType_GatherV2; + gather->name = "shared_weight"; + gather->main.type = OpParameter_Input; + gather->main.value = new InputT; + gather->main.AsInput()->dims = {oc, ic}; + gather->main.AsInput()->dtype = DataType_DT_FLOAT; + gather->main.AsInput()->dformat = MNN_DATA_FORMAT_NCHW; + auto expr = Expr::create(std::move(gather), {indices}); + expr->setName("shared_weight"); + auto output = Variable::create(expr); + output->setName("shared_weight"); + return output; +} + +class SharedGatherTest : public MNNTestCase { +public: + virtual ~SharedGatherTest() = default; + virtual bool run(int precision) { + const int ic = 64; + const int oc = 8; + std::vector weight(oc * ic); + for (int o = 0; o < oc; ++o) { + float minValue = -1.0f + 0.05f * o; + float step = 0.125f + 0.01f * o; + for (int c = 0; c < ic; ++c) { + weight[o * ic + c] = minValue + step * ((c + o) % 16); + } + } + + auto baseInput = _Input({1, ic, 1, 1}, NCHW); + ::memset(baseInput->writeMap(), 0, ic * sizeof(float)); + baseInput->unMap(); + auto convOutput = makeSharedConv(baseInput, weight, ic, oc); + auto baseBuffer = Variable::save({convOutput}); + + int indicesData[] = {3, 0, 7, 2}; + auto indicesInput = _Input({4}, NCHW, halide_type_of()); + indicesInput->setName("x"); + auto gatherOutput = makeSharedGather(indicesInput, ic, oc); + auto gatherBuffer = Variable::save({gatherOutput}); + + auto runtime = makeSharedGatherRuntime(); + Module::Config baseConfig; + baseConfig.rearrange = true; + std::shared_ptr base(Module::load({}, {}, (const uint8_t*)baseBuffer.data(), baseBuffer.size(), runtime, + &baseConfig)); + if (!base) { + MNN_ERROR("SharedGatherTest load base module failed!\n"); + return false; + } + Module::Config gatherConfig; + gatherConfig.rearrange = true; + gatherConfig.base = base.get(); + std::shared_ptr gather(Module::load({}, {}, (const uint8_t*)gatherBuffer.data(), gatherBuffer.size(), + runtime, &gatherConfig)); + if (!gather) { + MNN_ERROR("SharedGatherTest load gather module failed!\n"); + return false; + } + auto runtimeIndices = _Input({4}, NCHW, halide_type_of()); + ::memcpy(runtimeIndices->writeMap(), indicesData, sizeof(indicesData)); + runtimeIndices->unMap(); + auto outputs = gather->onForward({runtimeIndices}); + if (outputs.empty() || outputs[0] == nullptr) { + MNN_ERROR("SharedGatherTest forward failed!\n"); + return false; + } + auto output = outputs[0]; + std::vector expected(4 * ic); + for (int i = 0; i < 4; ++i) { + ::memcpy(expected.data() + i * ic, weight.data() + indicesData[i] * ic, ic * sizeof(float)); + } + if (!checkVector(output->readMap(), expected.data(), expected.size(), 0.02f)) { + MNN_ERROR("SharedGatherTest failed!\n"); + return false; + } + return true; + } +}; + +MNNTestSuiteRegister(SharedGatherTest, "op/shared_gather"); + +#endif // MNN_LOW_MEMORY diff --git a/tools/converter/CMakeLists.txt b/tools/converter/CMakeLists.txt index 1e757c9f82..1c0d9dd4a4 100644 --- a/tools/converter/CMakeLists.txt +++ b/tools/converter/CMakeLists.txt @@ -20,10 +20,12 @@ IF(MNN_BUILD_CONVERTER) SET(MNN_CONVERTER_BACKENDS_OBJECTS "") include_directories(${CMAKE_CURRENT_LIST_DIR}/include) include_directories(${CMAKE_CURRENT_LIST_DIR}/source/tflite/schema) + include_directories(${CMAKE_CURRENT_LIST_DIR}/source/safetensors) include_directories(${CMAKE_CURRENT_BINARY_DIR}) include(${CMAKE_CURRENT_LIST_DIR}/source/compression/CMakeLists.txt) include(${CMAKE_CURRENT_LIST_DIR}/source/tensorflow/CMakeLists.txt) include(${CMAKE_CURRENT_LIST_DIR}/source/onnx/CMakeLists.txt) + include(${CMAKE_CURRENT_LIST_DIR}/source/safetensors/CMakeLists.txt) include(${CMAKE_CURRENT_LIST_DIR}/source/caffe/CMakeLists.txt) include(${CMAKE_CURRENT_LIST_DIR}/source/MNN/CMakeLists.txt) include(${CMAKE_CURRENT_LIST_DIR}/source/optimizer/CMakeLists.txt) diff --git a/tools/converter/include/config.hpp b/tools/converter/include/config.hpp index 5f2c9931d6..f165a9b211 100644 --- a/tools/converter/include/config.hpp +++ b/tools/converter/include/config.hpp @@ -11,6 +11,7 @@ #include #include #include +#include struct PostTreatContext; class MNN_PUBLIC modelConfig { public: @@ -23,13 +24,14 @@ class MNN_PUBLIC modelConfig { saveHalfFloat(false){ } ~ modelConfig (); - enum MODEL_SOURCE { TENSORFLOW = 0, CAFFE, ONNX, MNN, TFLITE, TORCH, JSON, MAX_SOURCE }; + enum MODEL_SOURCE { TENSORFLOW = 0, CAFFE, ONNX, MNN, TFLITE, TORCH, JSON, SAFETENSORS, MAX_SOURCE }; // MNN model path std::string MNNModel; // if model is tensorflow, this value is NULL; std::string prototxtFile; // tensorflow pb, or caffe model + std::vector modelFiles; std::string modelFile; // bizCode std::string bizCode; diff --git a/tools/converter/source/common/RemoveParams.cpp b/tools/converter/source/common/RemoveParams.cpp index 0344cff8d1..badd328297 100644 --- a/tools/converter/source/common/RemoveParams.cpp +++ b/tools/converter/source/common/RemoveParams.cpp @@ -70,6 +70,8 @@ void RemoveAndStoreParam(std::unique_ptr& op, std::ofstream* fs, int64 } break; } +// It will some case cause error +#ifdef MNN_SUPPORT_CONST_EXTERNAL case MNN::OpParameter_Blob: { auto param = op->main.AsBlob(); size_t totalSize = 1; @@ -100,6 +102,7 @@ void RemoveAndStoreParam(std::unique_ptr& op, std::ofstream* fs, int64 } break; } +#endif default: break; } diff --git a/tools/converter/source/common/cli.cpp b/tools/converter/source/common/cli.cpp index d2dd9ab4d4..e1639ce2cd 100644 --- a/tools/converter/source/common/cli.cpp +++ b/tools/converter/source/common/cli.cpp @@ -39,6 +39,7 @@ #include #include #include "core/MemoryFormater.h" +#include "../safetensors/SafetensorConverter.hpp" modelConfig::~modelConfig() { if (nullptr != compressInfo) { delete compressInfo; @@ -148,12 +149,12 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv "Convert Other Model Format To MNN Model\n")( std::make_pair("v", "version"), "show current version")(std::make_pair("f", "framework"), #ifdef MNN_BUILD_TORCH - "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,TORCH,JSON]", + "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,TORCH,JSON,ST]", #else - "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,JSON]", + "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,JSON,ST]", #endif cxxopts::value())( - "modelFile", "tensorflow Pb or caffeModel, ex: *.pb,*caffemodel", cxxopts::value())( + std::make_pair("i", "modelFile"), "tensorflow Pb or caffeModel, ex: *.pb,*caffemodel", cxxopts::value>())( "batch", "if model input's batch is not set, set as the batch size you set", cxxopts::value())( "keepInputFormat", "keep input dimension format or not, default: true", cxxopts::value())( "optimizeLevel", @@ -161,7 +162,7 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv "every input case is right, 2: normally right but some case may be wrong, default 1", cxxopts::value())("optimizePrefer", "graph optimize option, 0 for normal, 1 for smalleset, 2 for fastest", cxxopts::value())("prototxt", "only used for caffe, ex: *.prototxt", - cxxopts::value())("MNNModel", "MNN model, ex: *.mnn", + cxxopts::value())(std::make_pair("o", "MNNModel"), "MNN model, ex: *.mnn", cxxopts::value())( "fp16", "save Conv's weight/bias in half_float data type")( "benchmarkModel", @@ -259,6 +260,8 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv #endif } else if ("JSON" == frameWork) { modelPath.model = modelConfig::JSON; + } else if ("ST" == frameWork) { + modelPath.model = modelConfig::SAFETENSORS; } else { std::cout << "Framework Input ERROR or Not Support This Model Type Now!" << std::endl; return false; @@ -283,7 +286,13 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv // model file path if (result.count("modelFile")) { - const std::string modelFile = result["modelFile"].as(); + auto files = result["modelFile"].as>(); + modelPath.modelFiles = files; + if (files.empty()) { + DLOG(INFO) << "modelFile Not set Invalid, use --modelFile to set!"; + return false; + } + const std::string modelFile = files[0]; if (CommonKit::FileIsExist(modelFile)) { modelPath.modelFile = modelFile; } else { @@ -562,6 +571,24 @@ bool Cli::convertModel(modelConfig& modelPath) { dumpModelInfo(modelPath.modelFile.c_str()); return true; } + if (modelPath.model == modelConfig::SAFETENSORS) { + std::cout << "Create Converter with config: " << modelPath.modelFiles[0] << std::endl; + MNN::SafeTensors::Converter converter(modelPath.modelFiles[0]); + auto models = converter.listModels(); + for (int i = 1; i < modelPath.modelFiles.size(); ++i) { + std::cout << "Load Safetensors " << modelPath.modelFiles[i] << std::endl; + converter.loadSafeTensors(modelPath.modelFiles[i]); + } + for (auto& name : models) { + auto newConfig = modelPath; + newConfig.MNNModel = modelPath.MNNModel + "/"; + if (!converter.convert(name, newConfig)) { + std::cout << "Convert " << name << " Failed" << std::endl; + return false; + } + } + return true; + } std::cout << "Start to Convert Other Model Format To MNN Model..., target version: " << modelPath.targetVersion << std::endl; std::unique_ptr netT = std::unique_ptr(new MNN::NetT()); int parseRes = 1; diff --git a/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp b/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp index 1e7b98afcd..cf2b8f950c 100644 --- a/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp +++ b/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp @@ -20,6 +20,7 @@ static void _setInputFormat(std::vector& tensorFormat, int inde enum FormatSetType { NC4HW4_SINGLE, // only first input / output is nc4hw4 NC4HW4_FULL, // all nc4hw4 + NC4HW4_OUTPUT, // output nc4hw4 COMPABILIT_SINGLE, // only first input / output is compability COMPABILIT_FULL, // all format should be same ORIGIN @@ -38,6 +39,18 @@ static FormatSetType _getFormatType(const OpT* op, MNN_DATA_FORMAT originFormat) case MNN::OpType_PReLU: case MNN::OpType_Dilation2D: return NC4HW4_SINGLE; + case MNN::OpType_Attention: + if (op->main.AsAttentionParam()->output_c4) { + return NC4HW4_OUTPUT; + } else { + return ORIGIN; + } + case MNN::OpType_LayerNorm: + if (op->defaultDimentionFormat == MNN_DATA_FORMAT_NC4HW4) { + return NC4HW4_FULL; + } else { + return ORIGIN; + } case MNN::OpType_ConvInt8: case MNN::OpType_Pooling: case MNN::OpType_Pooling3D: @@ -136,6 +149,8 @@ static MNN_DATA_FORMAT _getRequireFormat(FormatSetType type, int inputIndex, MNN return originFormat; case NC4HW4_FULL: return MNN_DATA_FORMAT_NC4HW4; + case NC4HW4_OUTPUT: + return originFormat; case NC4HW4_SINGLE: if (inputIndex == 0) { return MNN_DATA_FORMAT_NC4HW4; @@ -214,6 +229,16 @@ static bool _computeTensorFormat(std::vector& tensorFormat, std } return true; } + case NC4HW4_OUTPUT: + { + for (int i=0; iinputIndexes.size(); ++i) { + _setInputFormat(tensorFormat, op->inputIndexes[i], originFormat); + } + for (int i=0; ioutputIndexes.size(); ++i) { + tensorFormat[op->outputIndexes[i]] = MNN_DATA_FORMAT_NC4HW4; + } + return true; + } case COMPABILIT_SINGLE: { for (int i=1; iinputIndexes.size(); ++i) { diff --git a/tools/converter/source/safetensors/CMakeLists.txt b/tools/converter/source/safetensors/CMakeLists.txt new file mode 100644 index 0000000000..45678c3711 --- /dev/null +++ b/tools/converter/source/safetensors/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB SafeTensors_SRC CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/*.cpp + ${CMAKE_CURRENT_LIST_DIR}/*.c + ${CMAKE_CURRENT_LIST_DIR}/*.cc + ${CMAKE_CURRENT_LIST_DIR}/*.h + ${CMAKE_CURRENT_LIST_DIR}/*.hpp +) + +add_library(MNNConverterSafeTensors OBJECT ${SafeTensors_SRC}) + +list(APPEND MNN_CONVERTER_BACKENDS_OBJECTS $) +list(APPEND MNN_CONVERTER_BACKENDS_TARGETS MNNConverterSafeTensors) diff --git a/tools/converter/source/safetensors/HuggingFaceQwen3.cpp b/tools/converter/source/safetensors/HuggingFaceQwen3.cpp new file mode 100644 index 0000000000..5686566c51 --- /dev/null +++ b/tools/converter/source/safetensors/HuggingFaceQwen3.cpp @@ -0,0 +1,401 @@ +#include +#include + +#include +#include "MNN_generated.h" + +#include "../optimizer/Global.hpp" +#include "SafetensorConverter.hpp" +#include "SafetensorModelRegistry.hpp" +#include "SafetensorUtils.hpp" +#include "WorkflowJson.hpp" +#include "HuggingFaceQwen3.hpp" + +using namespace MNN::Express; +using namespace MNN::Express::SafeTensorUtils; + +namespace MNN { +namespace SafeTensors { + +static VARP _linear2d(VARP x4d, VARP weightOI, VARP bias = nullptr) { + auto wInfo = weightOI->getInfo(); + if (nullptr == wInfo || wInfo->dim.size() < 2) { + return nullptr; + } + + const int outDim = wInfo->dim[0]; + const int inDim = wInfo->dim[1]; + if (inDim <= 0 || outDim <= 0) { + return nullptr; + } + + if (nullptr == weightOI->readMap()) { + weightOI = _Cast(weightOI); + weightOI.fix(VARP::CONSTANT); + } + + std::vector weightData(weightOI->getInfo()->size); + ::memcpy(weightData.data(), weightOI->readMap(), weightData.size() * sizeof(float)); + + std::vector biasData(outDim, 0.0f); + if (nullptr != bias) { + if (nullptr == bias->readMap()) { + bias = _Cast(bias); + bias.fix(VARP::CONSTANT); + } + ::memcpy(biasData.data(), bias->readMap(), outDim * sizeof(float)); + } + + return _Conv(std::move(weightData), std::move(biasData), x4d, {inDim, outDim}, {1, 1}); +} + +class HuggingFaceQwen3 { +public: + HuggingFaceQwen3(const Converter* converter) : mConverter(converter) {} + + struct BlockInfo { + int hiddenSize = 0; + int headDim = 0; + int numberHead = 0; + int ropeCutHeadDim = 0; + + VARP cosEven; + VARP cosOdd; + VARP sinEven; + VARP sinOdd; + + VARP shapeQKV; + }; + + std::pair makeBlock(VARP hiddenState, VARP add, const BlockInfo& info, VARP mask, int blockIndex) { + auto blockPrefix = std::string("model.layers.") + std::to_string(blockIndex) + "."; + auto attnPrefix = blockPrefix + "self_attn."; + auto mlpPrefix = blockPrefix + "mlp."; + + auto setName = [](const VARP& v, const std::string& name) { + if (nullptr != v.get()) { + v->setName(name); + } + }; + + auto load = [this](const std::string& name) { + if (mConverter->hasTensor(name)) { + return mConverter->loadTensor(name, false); + } + return MNN::Express::VARP(nullptr); + }; + + const int hiddenSize = info.hiddenSize; + const float ln_eps = 1.0e-6f; + const bool useC4Opt = true; + + // RMSNorm + QKV + auto ln1Weight = load(blockPrefix + "input_layernorm.weight"); + + VARP hiddenStateNorm; + if (nullptr != add.get()) { + auto res = _BinaryLayerNorm(hiddenState, add, {ln1Weight, nullptr, ln_eps, true, hiddenSize, true}); + hiddenStateNorm = res.second; + hiddenState = res.first; + } else { + hiddenStateNorm = _TransformerLayerNorm(hiddenState, {ln1Weight, nullptr, ln_eps, true, hiddenSize, true}); + } + setName(hiddenStateNorm, blockPrefix + "input_layernorm.out"); + auto hiddenStateNorm4d = hiddenStateNorm; + + auto qWeight = load(attnPrefix + "q_proj.weight"); + auto kWeight = load(attnPrefix + "k_proj.weight"); + auto vWeight = load(attnPrefix + "v_proj.weight"); + auto oWeight = load(attnPrefix + "o_proj.weight"); + + VARP qBias; + VARP kBias; + VARP vBias; + if (mConverter->hasTensor(attnPrefix + "q_proj.bias")) { + qBias = load(attnPrefix + "q_proj.bias"); + } + if (mConverter->hasTensor(attnPrefix + "k_proj.bias")) { + kBias = load(attnPrefix + "k_proj.bias"); + } + if (mConverter->hasTensor(attnPrefix + "v_proj.bias")) { + vBias = load(attnPrefix + "v_proj.bias"); + } + + auto qWeightInfo = qWeight->getInfo(); + auto kWeightInfo = kWeight->getInfo(); + auto vWeightInfo = vWeight->getInfo(); + if (nullptr == qWeightInfo || nullptr == kWeightInfo || nullptr == vWeightInfo) { + return {nullptr, nullptr}; + } + + const int queryHiddenSize = qWeightInfo->dim[0]; + int headDim = info.headDim; + int numHeads = info.numberHead > 0 ? info.numberHead : (queryHiddenSize / headDim); + if (headDim <= 0 || numHeads <= 0 || numHeads * headDim != queryHiddenSize) { + return {nullptr, nullptr}; + } + const int attnOutSize = headDim * numHeads; + + auto q = _linear2d(hiddenStateNorm4d, qWeight, qBias); + setName(q, attnPrefix + "q_proj.out"); + auto k = _linear2d(hiddenStateNorm4d, kWeight, kBias); + setName(k, attnPrefix + "k_proj.out"); + auto v = _linear2d(hiddenStateNorm4d, vWeight, vBias); + setName(v, attnPrefix + "v_proj.out"); + + auto shapeqkv = info.shapeQKV; + + q = _Reshape(q, shapeqkv); + setName(q, attnPrefix + "q_proj.out_reshape"); + RopeInfo ropeParam; + ropeParam.cutHeadDim = info.ropeCutHeadDim; + + if (mConverter->hasTensor(attnPrefix + "q_norm.weight")) { + auto qNorm = load(attnPrefix + "q_norm.weight"); + ropeParam.qNorm = {qNorm, nullptr, ln_eps, true}; + } + + k = _Reshape(k, shapeqkv); + setName(k, attnPrefix + "k_proj.out_reshape"); + if (mConverter->hasTensor(attnPrefix + "k_norm.weight")) { + auto kNorm = load(attnPrefix + "k_norm.weight"); + ropeParam.kNorm = {kNorm, nullptr, ln_eps, true}; + } + + v = _Reshape(v, shapeqkv); + setName(v, attnPrefix + "v_proj.out_reshape"); + + // RoPE + { + auto ropeOutputs = _TransformerRoPE(q, k, info.cosEven, info.cosOdd, info.sinEven, info.sinOdd, ropeParam); + q = ropeOutputs[0]; + k = ropeOutputs[1]; + setName(q, attnPrefix + "q_after_rope"); + setName(k, attnPrefix + "k_after_rope"); + } + + auto attn = _GPT2Attention(numHeads, headDim, q, k, v, nullptr, nullptr, nullptr, nullptr, mask, useC4Opt); + setName(attn, attnPrefix + "attention.out"); + + if (attnOutSize != numHeads * headDim) { + return {nullptr, nullptr}; + } + auto o = _linear2d(attn, oWeight); + setName(o, attnPrefix + "o_proj.out"); + + // RMSNorm + MLP + auto ln2Weight = load(blockPrefix + "post_attention_layernorm.weight"); + auto fuseLayerNorm = _BinaryLayerNorm(hiddenState, o, {ln2Weight, nullptr, ln_eps, true, hiddenSize, true}); + hiddenStateNorm = fuseLayerNorm.second; + hiddenState = fuseLayerNorm.first; + setName(hiddenState, blockPrefix + "resid1"); + setName(hiddenStateNorm, blockPrefix + "post_attention_layernorm.out"); + hiddenStateNorm4d = hiddenStateNorm; + + auto gateWeight = load(mlpPrefix + "gate_proj.weight"); + auto upWeight = load(mlpPrefix + "up_proj.weight"); + auto downWeight = load(mlpPrefix + "down_proj.weight"); + + auto gate = _linear2d(hiddenStateNorm4d, gateWeight); + setName(gate, mlpPrefix + "gate_proj.out"); + auto up = _linear2d(hiddenStateNorm4d, upWeight); + setName(up, mlpPrefix + "up_proj.out"); + + auto ffn = _MulSilu(up, gate); + setName(ffn, mlpPrefix + "mul_silu.out"); + + ffn = _linear2d(ffn, downWeight); + setName(ffn, mlpPrefix + "down_proj.out"); + + return {hiddenState, ffn}; + } + +private: + const Converter* mConverter = nullptr; +}; + +void HuggingFaceQwen3Convert(const Converter* converter, MNN::NetT* dst, const HuggingFaceQwen3Config& config) { + if (nullptr == converter || nullptr == dst) { + return; + } + + HuggingFaceQwen3 qwen3(converter); + + int blockSize = config.blockNumber; + if (blockSize <= 0) { + const int maxBlockSize = 256; + for (int blockIndex = 0; blockIndex < maxBlockSize; ++blockIndex) { + auto prefix = std::string("model.layers.") + std::to_string(blockIndex) + ".self_attn.q_proj.weight"; + if (!converter->hasTensor(prefix)) { + blockSize = blockIndex; + break; + } + } + } + + int hiddenSize = config.hiddenSize; + int headDim = config.headDim; + int numHead = config.numHead; + + if (hiddenSize <= 0 || headDim <= 0 || numHead <= 0) { + auto qWeight0 = converter->loadTensor("model.layers.0.self_attn.q_proj.weight"); + auto kWeight0 = converter->loadTensor("model.layers.0.self_attn.k_proj.weight"); + if (nullptr != qWeight0.get() && nullptr != qWeight0->getInfo() && qWeight0->getInfo()->dim.size() >= 2) { + const int queryHiddenSize = qWeight0->getInfo()->dim[0]; + const int inputHiddenSize = qWeight0->getInfo()->dim[1]; + if (hiddenSize <= 0) { + hiddenSize = inputHiddenSize; + } + + if (numHead > 0 && headDim <= 0 && queryHiddenSize % numHead == 0) { + headDim = queryHiddenSize / numHead; + } else if (headDim > 0 && numHead <= 0 && queryHiddenSize % headDim == 0) { + numHead = queryHiddenSize / headDim; + } else if (headDim <= 0 && numHead <= 0) { + int kvHiddenSize = 0; + if (nullptr != kWeight0.get() && nullptr != kWeight0->getInfo() && kWeight0->getInfo()->dim.size() >= 2) { + kvHiddenSize = kWeight0->getInfo()->dim[0]; + } + static const int candidates[] = {128, 96, 80, 72, 64, 48, 40, 32}; + for (int cand : candidates) { + if (cand <= 0) { + continue; + } + if (queryHiddenSize % cand != 0) { + continue; + } + if (kvHiddenSize > 0 && kvHiddenSize % cand != 0) { + continue; + } + int candHead = queryHiddenSize / cand; + if (candHead > 0 && candHead <= 64) { + headDim = cand; + numHead = candHead; + break; + } + } + } + } + } + + HuggingFaceQwen3::BlockInfo blockInfo; + blockInfo.hiddenSize = hiddenSize > 0 ? hiddenSize : 1024; + blockInfo.headDim = headDim > 0 ? headDim : 128; + blockInfo.numberHead = numHead > 0 ? numHead : 16; + blockInfo.ropeCutHeadDim = config.ropeCutHeadDim; + + auto embed = _Input({1, -1, blockInfo.hiddenSize}, NCHW, halide_type_of()); + embed->setName("input_embedding"); + + auto position = _Input({1, -1}, NCHW, halide_type_of()); + position->setName("position_ids"); + + auto mask = _Input({}, NCHW, halide_type_of()); + mask->setName("mask"); + + auto one = _Unsqueeze(_Scalar(1), {0}); + auto negone = _Unsqueeze(_Scalar(-1), {0}); + auto shapeHiddenState = _Shape(embed, true); + auto seqLenVar = _Slice(shapeHiddenState, _Unsqueeze(_Scalar(1), {0}), one); + auto batchLenVar = _Slice(shapeHiddenState, _Unsqueeze(_Scalar(0), {0}), one); + + auto headDimVar = _Unsqueeze(_Scalar(blockInfo.headDim), {0}); + + const int posEmbEnd = config.maxPositionEmbeddings > 0 ? config.maxPositionEmbeddings : 32768; + const float ropeTheta = config.ropeTheta > 0.0f ? config.ropeTheta : 100000.0f; + auto posEmb = _PrecomputePosEmbedding(blockInfo.headDim, posEmbEnd, ropeTheta); + posEmb.fix(VARP::CONSTANT); + posEmb->setName("precompute_posemb"); + + posEmb = _GatherV2(posEmb, position, _Scalar(1)); + auto cosAndsin = _Split(posEmb, {2}, 0); + + blockInfo.shapeQKV = _Concat({batchLenVar, seqLenVar, negone, headDimVar}, 0); + blockInfo.shapeQKV->setName("shape_qkv"); + + blockInfo.cosEven = _Squeeze(cosAndsin[0], {0}); + blockInfo.cosOdd = _Squeeze(cosAndsin[0], {0}); + blockInfo.sinEven = _Squeeze(cosAndsin[1], {0}); + blockInfo.sinOdd = _Squeeze(cosAndsin[1], {0}); + + auto hiddenState = _Reshape(embed, {-1, hiddenSize, 1, 1}); + hiddenState = _Convert(hiddenState, NC4HW4); + VARP add = nullptr; + for (int blockIndex = 0; blockIndex < blockSize; ++blockIndex) { + auto res = qwen3.makeBlock(hiddenState, add, blockInfo, mask, blockIndex); + hiddenState = res.first; + add = res.second; + hiddenState->setName("block" + std::to_string(blockIndex)); + } + + // Final RMSNorm + if (add.get() != nullptr) { + hiddenState = _Add(hiddenState, add); + } + auto normWeight = converter->loadTensor("model.norm.weight"); + hiddenState = _TransformerLayerNorm(hiddenState, {normWeight, nullptr, 1.0e-6f, true, blockInfo.hiddenSize, true}); + hiddenState = _Reshape(hiddenState, shapeHiddenState); + hiddenState->setName("hidden_state"); + + std::vector outputs = {hiddenState}; + std::vector outputNames = {"hidden_state"}; + if (config.outputLastHiddenState) { + auto lastHiddenState = _MakeLastHiddenStateOutput(hiddenState, blockInfo.hiddenSize); + outputs.emplace_back(lastHiddenState); + outputNames.emplace_back("last_hidden_state"); + } + + Variable::save(outputs, dst); + dst->sourceType = NetSource_ONNX; + dst->outputName = std::move(outputNames); +} + +namespace { +static bool _convertHuggingFaceDecoderModel(const Converter* converter, const rapidjson::Value* model, modelConfig& modelPath) { + if (nullptr == converter) { + return false; + } + + auto netT = std::unique_ptr(new MNN::NetT); + HuggingFaceQwen3Config config; + + if (nullptr != model && model->IsObject()) { + auto blocks = WorkflowJson::getArray(*model, "blocks"); + if (nullptr != blocks) { + for (auto& block : blocks->GetArray()) { + if (!block.IsObject()) { + continue; + } + auto type = WorkflowJson::getString(block, "type"); + if (type == "QwenTransformer") { + config.hiddenSize = WorkflowJson::getInt(block, "hiddenSize", config.hiddenSize); + config.headDim = WorkflowJson::getInt(block, "headDim", config.headDim); + config.numHead = WorkflowJson::getInt(block, "numHead", config.numHead); + config.kvNumHead = WorkflowJson::getInt(block, "kvNumHead", config.kvNumHead); + config.blockNumber = WorkflowJson::getInt(block, "number", config.blockNumber); + config.maxPositionEmbeddings = WorkflowJson::getInt(block, "maxPositionEmbeddings", config.maxPositionEmbeddings); + config.maxPositionEmbeddings = WorkflowJson::getInt(block, "max_position_embeddings", config.maxPositionEmbeddings); + // backward compatible field name (legacy) + config.maxPositionEmbeddings = WorkflowJson::getInt(block, "bit", config.maxPositionEmbeddings); + config.ropeTheta = WorkflowJson::getFloat(block, "ropeTheta", config.ropeTheta); + config.ropeTheta = WorkflowJson::getFloat(block, "rope_theta", config.ropeTheta); + config.ropeCutHeadDim = WorkflowJson::getInt(block, "ropeCutHeadDim", config.ropeCutHeadDim); + config.ropeCutHeadDim = WorkflowJson::getInt(block, "rope_cut_head_dim", config.ropeCutHeadDim); + break; + } + } + } + } + + auto path = modelPath.MNNModel; + modelPath.MNNModel = path + "decoder.mnn"; + HuggingFaceQwen3Convert(converter, netT.get(), config); + optimizeAndWrite(modelPath, netT); + return true; +} + +REGISTER_SAFETENSOR_MODEL_BUILDER("hf_decoder", _convertHuggingFaceDecoderModel); +} // namespace + +} // namespace SafeTensors +} // namespace MNN diff --git a/tools/converter/source/safetensors/HuggingFaceQwen3.hpp b/tools/converter/source/safetensors/HuggingFaceQwen3.hpp new file mode 100644 index 0000000000..29d0d1387c --- /dev/null +++ b/tools/converter/source/safetensors/HuggingFaceQwen3.hpp @@ -0,0 +1,26 @@ +#ifndef HuggingFaceQwen3_hpp +#define HuggingFaceQwen3_hpp + +#include "SafetensorConverter.hpp" + +namespace MNN { +namespace SafeTensors { + +struct HuggingFaceQwen3Config { + int hiddenSize = 0; + int headDim = 0; + int numHead = 0; + int kvNumHead = 0; + int blockNumber = 0; + int maxPositionEmbeddings = 0; + float ropeTheta = 0.0f; + int ropeCutHeadDim = 0; + bool outputLastHiddenState = true; +}; + +void HuggingFaceQwen3Convert(const Converter* converter, MNN::NetT* dst, const HuggingFaceQwen3Config& config); + +} // namespace SafeTensors +} // namespace MNN + +#endif diff --git a/tools/converter/source/safetensors/Logit.cpp b/tools/converter/source/safetensors/Logit.cpp new file mode 100644 index 0000000000..970586505a --- /dev/null +++ b/tools/converter/source/safetensors/Logit.cpp @@ -0,0 +1,513 @@ +#include +#include +#include +#include +#include +#include "MNN_generated.h" +#include + +#include "SafetensorConverter.hpp" +#include "Logit.hpp" +#include "SafetensorModelRegistry.hpp" +#include "SafetensorUtils.hpp" +#include "WorkflowJson.hpp" + +using namespace MNN::Express; +using namespace MNN::Express::SafeTensorUtils; + +namespace MNN { +namespace SafeTensors { + +static inline void _setNameIfEmpty(const VARP& v, const std::string& name) { + if (nullptr != v.get() && v->name().empty()) { + v->setName(name); + } +} + +static VARP _linear2d(VARP x4d, VARP weightOI, VARP bias = nullptr) { + auto wInfo = weightOI->getInfo(); + if (nullptr == wInfo || wInfo->dim.size() < 2) { + return nullptr; + } + + const int outDim = wInfo->dim[0]; + const int inDim = wInfo->dim[1]; + if (inDim <= 0 || outDim <= 0) { + return nullptr; + } + + std::vector weightData(weightOI->getInfo()->size); + ::memcpy(weightData.data(), weightOI->readMap(), weightData.size() * sizeof(float)); + std::vector biasData(outDim, 0.0f); + if (nullptr != bias) { + ::memcpy(biasData.data(), bias->readMap(), outDim * sizeof(float)); + } + + return _Conv(std::move(weightData), std::move(biasData), x4d, {inDim, outDim}, {1, 1}); +} + +// Deep-copy `src` into `dst` via flatbuffers Pack/UnPack and strip Convolution2D +// weight payloads — the runtime reuses the original logit's quantized weights. +static void _cloneLogitNet(const MNN::NetT* src, MNN::NetT* dst) { + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(MNN::CreateNet(fbb, src)); + std::unique_ptr cloned(flatbuffers::GetRoot(fbb.GetBufferPointer())->UnPack()); + *dst = std::move(*cloned); + for (auto& op : dst->oplists) { + if (op && op->main.type == OpParameter_Convolution2D) { + auto conv = op->main.AsConvolution2D(); + conv->weight.clear(); + conv->bias.clear(); + conv->quanParameter.reset(); + conv->external.clear(); + } + } +} + +// Locate logits tensor index — prefer outputName mapping, fall back to the last +// op with an output index. +static int _findLogitsIndex(const MNN::NetT* net) { + if (!net->outputName.empty()) { + const auto& outName = net->outputName[0]; + for (int i = 0; i < (int)net->tensorName.size(); ++i) { + if (net->tensorName[i] == outName) return i; + } + } + for (int i = (int)net->oplists.size() - 1; i >= 0; --i) { + if (net->oplists[i] && !net->oplists[i]->outputIndexes.empty()) { + return net->oplists[i]->outputIndexes[0]; + } + } + return -1; +} + +static int _addTensor(MNN::NetT* net, const std::string& name) { + int idx = (int)net->tensorName.size(); + net->tensorName.push_back(name); + return idx; +} + +static int _appendConstInt(MNN::NetT* net, const std::string& opName, const std::string& tensorName, int value) { + int idx = _addTensor(net, tensorName); + std::unique_ptr op(new OpT); + op->type = OpType_Const; + op->main.type = OpParameter_Blob; + op->main.value = new BlobT; + auto blob = op->main.AsBlob(); + blob->dataFormat = MNN_DATA_FORMAT_NCHW; + blob->dataType = DataType_DT_INT32; + blob->dims = {1}; + blob->int32s = {value}; + op->name = opName; + op->outputIndexes = {idx}; + net->oplists.emplace_back(std::move(op)); + return idx; +} + +static int _appendSoftmaxOp(MNN::NetT* net, int inputIdx, const std::string& opName, const std::string& tensorName, int axis = -1) { + int idx = _addTensor(net, tensorName); + std::unique_ptr op(new OpT); + op->type = OpType_Softmax; + op->main.type = OpParameter_Axis; + op->main.value = new AxisT; + op->main.AsAxis()->axis = axis; + op->name = opName; + op->inputIndexes = {inputIdx}; + op->outputIndexes = {idx}; + net->oplists.emplace_back(std::move(op)); + return idx; +} + +// Returns {valuesIdx, indicesIdx}. Default largest=true (no main parameter). +static std::pair _appendTopK2Op(MNN::NetT* net, int inputIdx, int kIdx, + const std::string& opName, + const std::string& valuesName, + const std::string& indicesName) { + int valuesIdx = _addTensor(net, valuesName); + int indicesIdx = _addTensor(net, indicesName); + std::unique_ptr op(new OpT); + op->type = OpType_TopKV2; + op->name = opName; + op->inputIndexes = {inputIdx, kIdx}; + op->outputIndexes = {valuesIdx, indicesIdx}; + net->oplists.emplace_back(std::move(op)); + return {valuesIdx, indicesIdx}; +} + +static int _appendUnaryOp(MNN::NetT* net, int inputIdx, UnaryOpOperation kind, + const std::string& opName, const std::string& tensorName) { + int idx = _addTensor(net, tensorName); + std::unique_ptr op(new OpT); + op->type = OpType_UnaryOp; + op->main.type = OpParameter_UnaryOp; + op->main.value = new UnaryOpT; + op->main.AsUnaryOp()->opType = kind; + op->name = opName; + op->inputIndexes = {inputIdx}; + op->outputIndexes = {idx}; + net->oplists.emplace_back(std::move(op)); + return idx; +} + +// Clone `logit` into `dst` and resolve the logits tensor index. Returns -1 on +// failure (also logs the error tagged with `fnTag`). +static int _cloneAndFindLogits(const MNN::NetT* logit, MNN::NetT* dst, const char* fnTag) { + if (nullptr == logit || nullptr == dst) { + return -1; + } + _cloneLogitNet(logit, dst); + int idx = _findLogitsIndex(dst); + if (idx < 0) { + MNN_ERROR("%s: 未找到 logits 输出\n", fnTag); + } + return idx; +} + +void LogitConvert(const Converter* converter, MNN::NetT* dst, const LogitConfig& config) { + if (nullptr == converter || nullptr == dst) { + return; + } + + auto weight = converter->loadTensor(config.wteWeightName); + if (nullptr == weight.get() && config.wteWeightName.size() > 7 && config.wteWeightName.substr(0, 7) == "module.") { + weight = converter->loadTensor(config.wteWeightName.substr(7)); + } + if (nullptr == weight.get() || nullptr == weight->getInfo() || weight->getInfo()->dim.size() < 2) { + MNN_ERROR("LogitConvert: missing/invalid %s\n", config.wteWeightName.c_str()); + return; + } + + const int d0 = weight->getInfo()->dim[0]; + const int d1 = weight->getInfo()->dim[1]; + if (d0 <= 0 || d1 <= 0) { + MNN_ERROR("LogitConvert: invalid wte weight shape\n"); + return; + } + + int hiddenSize = config.hiddenSize; + int vocabSize = 0; + bool needTranspose = false; + + if (hiddenSize > 0) { + if (d1 == hiddenSize) { + vocabSize = d0; + } else if (d0 == hiddenSize) { + vocabSize = d1; + needTranspose = true; + } else { + // Fallback: assume [vocab, hidden] + hiddenSize = d1; + vocabSize = d0; + } + } else { + // Fallback heuristics: vocab is usually larger than hidden + if (d0 >= d1) { + vocabSize = d0; + hiddenSize = d1; + } else { + vocabSize = d1; + hiddenSize = d0; + needTranspose = true; + } + } + + if (needTranspose) { + weight = _Transpose(weight, {1, 0}); // -> [vocab, hidden] + } + + // Input hidden state: [B, S, H] + auto hiddenState = _Input({1, -1, hiddenSize}, NCHW, halide_type_of()); + hiddenState->setName(config.inputName); + + auto shapeHidden = _Shape(hiddenState, true); + auto one = _Unsqueeze(_Scalar(1), {0}); + auto batchVar = _Slice(shapeHidden, _Unsqueeze(_Scalar(0), {0}), one); + auto seqVar = _Slice(shapeHidden, _Unsqueeze(_Scalar(1), {0}), one); + + auto hidden2d = _Reshape(hiddenState, {-1, hiddenSize, 1, 1}); + + // Optional bias + VARP bias = nullptr; + auto prefix = config.wteWeightName; + const std::string suffix = ".weight"; + if (prefix.size() > suffix.size() && prefix.compare(prefix.size() - suffix.size(), suffix.size(), suffix) == 0) { + prefix.resize(prefix.size() - suffix.size()); + } + auto biasName = prefix + ".bias"; + if (converter->hasTensor(biasName)) { + bias = converter->loadTensor(biasName); + } else if (biasName.size() > 7 && biasName.substr(0, 7) == "module." && converter->hasTensor(biasName.substr(7))) { + bias = converter->loadTensor(biasName.substr(7)); + } + + // Quantized path if weight_qscale exists + VARP logits2d = nullptr; + auto wScaleName = config.wteWeightName + "_qscale"; + std::string textScaleName = config.wteWeightName; + if (textScaleName.find(".weight") != std::string::npos) { + textScaleName.replace(textScaleName.find(".weight"), 7, ".text_embedding.weight_qscale"); + } + + auto loadScale = [&](const std::string& name) -> VARP { + if (converter->hasTensor(name)) { + return converter->loadTensor(name); + } + if (name.size() > 7 && name.substr(0, 7) == "module." && converter->hasTensor(name.substr(7))) { + return converter->loadTensor(name.substr(7)); + } + return nullptr; + }; + + VARP wScale = loadScale(wScaleName); + if (nullptr == wScale.get()) { + wScale = loadScale(textScaleName); + } + + if (wScale.get() != nullptr) { + logits2d = _QConvolution1x1(hiddenSize, hidden2d, nullptr, nullptr, weight, wScale, nullptr, bias, vocabSize); + } else { + // Float path + if (weight->getInfo()->type.code != halide_type_float) { + MNN_ERROR("LogitConvert: wte weight is not float and no qscale found\n"); + return; + } + logits2d = _linear2d(hidden2d, weight, bias); + } + + if (nullptr == logits2d.get()) { + MNN_ERROR("LogitConvert: build logits failed\n"); + return; + } + _setNameIfEmpty(logits2d, prefix + ".out2d"); + + auto vocabVar = _Unsqueeze(_Scalar(vocabSize), {0}); + auto logits3d = _Reshape(logits2d, _Concat({batchVar, seqVar, vocabVar}, 0)); + logits3d->setName(config.outputName); + + Variable::save({logits3d}, dst); + dst->sourceType = NetSource_ONNX; + dst->outputName = {config.outputName}; + +} + +void MakeTieEmbedding(const Converter* converter, const MNN::NetT* src, MNN::NetT* dst) { + if (nullptr == converter || nullptr == src || nullptr == dst) { + return; + } + + std::string sharedName; + int ic = 0; + int oc = 0; + + for (auto& op : src->oplists) { + if (nullptr == op) { + continue; + } + if (op->type != OpType_Convolution) { + continue; + } + auto conv = op->main.AsConvolution2D(); + if (nullptr == conv || nullptr == conv->common) { + continue; + } + if (conv->common->inputCount <= 0 || conv->common->outputCount <= 0) { + continue; + } + // Keep the last conv as shared weight provider (align with makeSharedGather.py). + sharedName = op->name; + ic = conv->common->inputCount; + oc = conv->common->outputCount; + } + + if (sharedName.empty() || ic <= 0 || oc <= 0) { + MNN_ERROR("MakeTieEmbedding: can't find valid convolution in src\n"); + return; + } + + // Indices input. + auto input = _Input({-1}, NCHW, halide_type_of()); + input->setName("x"); + + // GatherV2 with OpParameter_Input main is a special form that lets runtime reuse + // the quantized weights from base model's convolution execution. + std::unique_ptr gather(new OpT); + gather->type = OpType_GatherV2; + gather->main.type = OpParameter_Input; + gather->main.value = new InputT; + gather->main.AsInput()->dims = {oc, ic}; + gather->main.AsInput()->dtype = DataType_DT_FLOAT; + gather->main.AsInput()->dformat = MNN_DATA_FORMAT_NCHW; + + auto gatherExpr = Expr::create(gather.get(), {input}); + gatherExpr->setName(sharedName); + + auto output = Variable::create(gatherExpr); + output->setName(sharedName); + + Variable::save({output}, dst); + dst->sourceType = NetSource_ONNX; + dst->outputName = {sharedName}; +} + + +// Clone `logit` into `dst` and append a TopKV2 producing top-K indices. +void MakeTopKV(const Converter* /*converter*/, const MNN::NetT* logit, MNN::NetT* dst, int K) { + int logitsIdx = _cloneAndFindLogits(logit, dst, "MakeTopKV"); + if (logitsIdx < 0) return; + + int kIdx = _appendConstInt(dst, "const_topk_k", "topk_k", K); + auto vi = _appendTopK2Op(dst, logitsIdx, kIdx, "TopKV2", "topk_values", "topk_indices"); + dst->outputName = {dst->tensorName[vi.second]}; +} + +// Clone `logit` into `dst` and append a Softmax (axis=-1) as the new output. +void MakeSoftmax(const Converter* /*converter*/, const MNN::NetT* logit, MNN::NetT* dst) { + int logitsIdx = _cloneAndFindLogits(logit, dst, "MakeSoftmax"); + if (logitsIdx < 0) return; + + int smIdx = _appendSoftmaxOp(dst, logitsIdx, "LogitSoftmax", "logit_softmax"); + dst->outputName = {dst->tensorName[smIdx]}; +} + +// Clone `logit` into `dst`, append Softmax → TopKV2 → Log(values) for beam search. +// Outputs {log(values), indices}. +void MakeBeamTopKV(const Converter* /*converter*/, const MNN::NetT* logit, MNN::NetT* dst, int K) { + int logitsIdx = _cloneAndFindLogits(logit, dst, "MakeBeamTopKV"); + if (logitsIdx < 0) return; + + int smIdx = _appendSoftmaxOp(dst, logitsIdx, "BeamSoftmax", "beam_softmax"); + int kIdx = _appendConstInt(dst, "const_beam_topk_k", "beam_topk_k", K); + auto vi = _appendTopK2Op(dst, smIdx, kIdx, "BeamTopKV2", "beam_topk_values", "beam_topk_indices"); + int logIdx = _appendUnaryOp(dst, vi.first, UnaryOpOperation_LOG, "BeamTopKV2_Log", "beam_topk_log_values"); + dst->outputName = {dst->tensorName[logIdx], dst->tensorName[vi.second]}; +} + +namespace { + +// Parse the K parameter, accepting either an int array or a single int (legacy). +// Always returns a non-empty list (defaults to {1}). +static std::vector _parseKList(const rapidjson::Value& block) { + std::vector kList; + if (auto kArr = WorkflowJson::getArray(block, "K")) { + for (auto& kv : kArr->GetArray()) { + if (kv.IsInt() && kv.GetInt() > 0) kList.push_back(kv.GetInt()); + } + } else { + int K = WorkflowJson::getInt(block, "K", 1); + if (K > 0) kList.push_back(K); + } + if (kList.empty()) kList.push_back(1); + return kList; +} + +// Build and write a per-K logit variant for each entry in `kList`. External +// data is force-disabled for variants since they reuse the base logit's +// quantized weights. +template +static void _saveLogitVariants(const Converter* converter, MNN::NetT* logitNet, + modelConfig& modelPath, const std::string& path, + const std::vector& kList, + const std::string& filePrefix, Builder build) { + auto originExternal = modelPath.saveExternalData; + modelPath.saveExternalData = false; + for (int K : kList) { + auto net = std::unique_ptr(new MNN::NetT); + build(converter, logitNet, net.get(), K); + modelPath.MNNModel = path + filePrefix + std::to_string(K) + ".mnn"; + optimizeAndWrite(modelPath, net); + } + modelPath.saveExternalData = originExternal; +} + +static bool _convertLogitModel(const Converter* converter, const rapidjson::Value* model, modelConfig& modelPath) { + if (nullptr == converter) { + return false; + } + + LogitConfig config; + auto path = modelPath.MNNModel; + + auto logitNet = std::unique_ptr(new MNN::NetT); + std::unique_ptr embeddingNet; + + auto ensureLogits = [&]() { + if (logitNet->oplists.empty()) { + LogitConvert(converter, logitNet.get(), config); + } + }; + + if (nullptr != model && model->IsObject()) { + auto blocks = WorkflowJson::getArray(*model, "blocks"); + if (nullptr != blocks) { + for (auto& block : blocks->GetArray()) { + if (!block.IsObject()) { + continue; + } + + auto prefix = WorkflowJson::getString(block, "prefix"); + if (!prefix.empty()) { + auto wteWeightName = prefix; + if (wteWeightName.find(".weight") == std::string::npos) { + auto candidate = wteWeightName + ".weight"; + if (converter->hasTensor(candidate)) { + wteWeightName = candidate; + } + } + if (converter->hasTensor(wteWeightName)) { + config.wteWeightName = wteWeightName; + } + } + + auto type = WorkflowJson::getString(block, "type"); + if (type == "InnerProduct") { + LogitConvert(converter, logitNet.get(), config); + continue; + } + if (type == "TieEmbedding") { + embeddingNet = std::unique_ptr(new MNN::NetT); + MakeTieEmbedding(converter, logitNet.get(), embeddingNet.get()); + continue; + } + if (type == "TopKV") { + ensureLogits(); + _saveLogitVariants(converter, logitNet.get(), modelPath, path, + _parseKList(block), "logit_topkv_", MakeTopKV); + continue; + } + if (type == "Softmax") { + ensureLogits(); + auto originExternal = modelPath.saveExternalData; + modelPath.saveExternalData = false; + auto softmaxNet = std::unique_ptr(new MNN::NetT); + MakeSoftmax(converter, logitNet.get(), softmaxNet.get()); + modelPath.MNNModel = path + "logit_softmax.mnn"; + optimizeAndWrite(modelPath, softmaxNet); + modelPath.saveExternalData = originExternal; + continue; + } + if (type == "BeamTopKV") { + ensureLogits(); + _saveLogitVariants(converter, logitNet.get(), modelPath, path, + _parseKList(block), "logit_beam_", MakeBeamTopKV); + continue; + } + } + } + } + + modelPath.MNNModel = path + "logit.mnn"; + optimizeAndWrite(modelPath, logitNet); + + if (nullptr != embeddingNet.get()) { + modelPath.MNNModel = path + "embed.mnn"; + optimizeAndWrite(modelPath, embeddingNet); + } + return true; +} + +REGISTER_SAFETENSOR_MODEL_BUILDER("logit", _convertLogitModel); + +} // namespace + +} // namespace SafeTensors +} // namespace MNN diff --git a/tools/converter/source/safetensors/Logit.hpp b/tools/converter/source/safetensors/Logit.hpp new file mode 100644 index 0000000000..d946f97226 --- /dev/null +++ b/tools/converter/source/safetensors/Logit.hpp @@ -0,0 +1,30 @@ +#ifndef Logit_hpp +#define Logit_hpp + +#include +#include "SafetensorConverter.hpp" + +namespace MNN { +namespace SafeTensors { + +struct LogitConfig { + // Default to Qwen/GPT2 word embedding matrix + std::string wteWeightName = "module.gpt2.transformer.wte.weight"; + + // Optional: if provided, will be used to disambiguate weight layout + int hiddenSize = 0; + + std::string inputName = "hidden_state"; + std::string outputName = "output"; +}; + +void LogitConvert(const Converter* converter, MNN::NetT* dst, const LogitConfig& config); +void MakeTieEmbedding(const Converter* converter, const MNN::NetT* src, MNN::NetT* dst); +void MakeTopKV(const Converter* converter, const MNN::NetT* logit, MNN::NetT* dst, int K); +void MakeSoftmax(const Converter* converter, const MNN::NetT* logit, MNN::NetT* dst); +void MakeBeamTopKV(const Converter* converter, const MNN::NetT* logit, MNN::NetT* dst, int K); + +} // namespace SafeTensors +} // namespace MNN + +#endif diff --git a/tools/converter/source/safetensors/SafetensorConverter.cpp b/tools/converter/source/safetensors/SafetensorConverter.cpp new file mode 100644 index 0000000000..ab0a15768e --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorConverter.cpp @@ -0,0 +1,214 @@ +#include +#include +#include + +#include +#include + +#include + +#include "SafetensorConverter.hpp" +#include "SafetensorModelRegistry.hpp" +#include "WorkflowJson.hpp" + +#include "../common/CommonUtils.hpp" + +#define SAFETENSORS_CPP_IMPLEMENTATION +#include "safetensors.hh" +namespace MNN { +namespace SafeTensors { + +static halide_type_t _convertSafeTensorDType(safetensors::dtype dtype) { + switch (dtype) { + case safetensors::kBOOL: + // Safetensors stores BOOL as 1 byte. MNN's 1-bit bool type is not widely supported. + return halide_type_of(); + case safetensors::kUINT8: + return halide_type_of(); + case safetensors::kINT8: + return halide_type_of(); + case safetensors::kINT16: + return halide_type_of(); + case safetensors::kUINT16: + return halide_type_of(); + case safetensors::kINT32: + return halide_type_of(); + case safetensors::kUINT32: + return halide_type_of(); + case safetensors::kINT64: + return halide_type_of(); + case safetensors::kUINT64: + return halide_type_of(); + case safetensors::kFLOAT16: + return halide_type_t(halide_type_float, 16); + case safetensors::kBFLOAT16: + return halide_type_t(halide_type_bfloat, 16); + case safetensors::kFLOAT32: + return halide_type_of(); + case safetensors::kFLOAT64: + return halide_type_of(); + default: + break; + } + return halide_type_of(); +} + +struct Converter::Content { + rapidjson::Document mWorkFlow; + safetensors::safetensors_t mSt; +}; + +Converter::Converter(const std::string& jsonFile) { + mMain = new Content; + + std::ifstream fileNames(jsonFile); + std::ostringstream output; + output << fileNames.rdbuf(); + auto outputStr = output.str(); + + mMain->mWorkFlow.Parse(outputStr.c_str()); + if (mMain->mWorkFlow.HasParseError() || !mMain->mWorkFlow.IsObject()) { + MNN_ERROR("Invalid json\n"); + mMain->mWorkFlow.SetObject(); + return; + } +} + +Converter::~ Converter() { + delete mMain; +} + +std::vector Converter::listModels() const { + std::vector res; + if (nullptr == mMain) { + return res; + } + auto models = WorkflowJson::getArray(mMain->mWorkFlow, "models"); + if (nullptr == models) { + return res; + } + for (auto& model : models->GetArray()) { + if (!model.IsObject()) { + continue; + } + auto name = WorkflowJson::getString(model, "name"); + if (name.empty()) { + continue; + } + res.emplace_back(std::move(name)); + } + return res; +} +void Converter::loadSafeTensors(const std::string& safeTensorFile) { + std::string warn, err; + auto ret = safetensors::mmap_from_file(safeTensorFile, &mMain->mSt, &warn, &err); + if (warn.size()) { + FUNC_PRINT_ALL(warn.c_str(), s); + } + if (!ret) { + FUNC_PRINT_ALL(err.c_str(), s); + return; + } +} +bool Converter::convert(const std::string& name, modelConfig& modelPath) { + auto builder = SafetensorModelRegistry::get()->find(name); + if (nullptr == builder) { + MNN_ERROR("SafetensorConverter: unsupported model %s\n", name.c_str()); + return false; + } + + const rapidjson::Value* model = nullptr; + if (mMain != nullptr && mMain->mWorkFlow.IsObject()) { + auto models = WorkflowJson::getArray(mMain->mWorkFlow, "models"); + if (nullptr != models) { + for (auto& item : models->GetArray()) { + if (!item.IsObject()) { + continue; + } + auto modelName = WorkflowJson::getString(item, "name"); + MNN_PRINT("Checking model config for: %s (target: %s)\n", modelName.c_str(), name.c_str()); + if (!modelName.empty() && modelName == name) { + model = &item; + int weightQuantBits = WorkflowJson::getInt(item, "weightQuantBits", -1); + if (weightQuantBits >= 0) { + MNN_PRINT("Override weightQuantBits to %d for model %s\n", weightQuantBits, name.c_str()); + modelPath.weightQuantBits = weightQuantBits; + } + break; + } + } + } + } + + return builder(this, model, modelPath); +} +bool Converter::hasTensor(const std::string& name) const { + safetensors::tensor_t t; + if (mMain->mSt.tensors.at(name, &t)) { + return true; + } + return false; +} + +MNN::Express::VARP Converter::loadTensor(const std::string& name, bool print) const { + safetensors::tensor_t t; + bool find = mMain->mSt.tensors.at(name, &t); + if (!find) { + if (print) { + FUNC_PRINT_ALL(name.c_str(), s); + } + return nullptr; + } + + const uint8_t* dataBufferAddr = nullptr; + size_t dataBufferSize = 0; + if (mMain->mSt.mmaped) { + dataBufferAddr = mMain->mSt.databuffer_addr; + dataBufferSize = mMain->mSt.databuffer_size; + } else { + dataBufferAddr = mMain->mSt.storage.data(); + dataBufferSize = mMain->mSt.storage.size(); + } + if (nullptr == dataBufferAddr || dataBufferSize == 0) { + MNN_ERROR("Safetensors databuffer is empty, please call loadSafeTensors first\n"); + return nullptr; + } + + const size_t offsetBegin = t.data_offsets[0]; + const size_t offsetEnd = t.data_offsets[1]; + if (offsetBegin > offsetEnd || offsetEnd > dataBufferSize) { + MNN_ERROR("Invalid tensor offsets for %s: [%zu, %zu), databuffer=%zu\n", name.c_str(), offsetBegin, offsetEnd, dataBufferSize); + return nullptr; + } + + const size_t nitems = safetensors::get_shape_size(t); + const size_t itemBytes = safetensors::get_dtype_bytes(t.dtype); + const size_t expectedBytes = nitems * itemBytes; + const size_t actualBytes = offsetEnd - offsetBegin; + if (expectedBytes != actualBytes) { + MNN_ERROR("Invalid tensor %s: expected %zu bytes(%zu*%zu), got %zu\n", name.c_str(), expectedBytes, nitems, itemBytes, actualBytes); + return nullptr; + } + + MNN::Express::INTS shape; + shape.reserve(t.shape.size()); + for (auto dim : t.shape) { + if (dim > static_cast(std::numeric_limits::max())) { + MNN_ERROR("Tensor %s has too large shape dim: %zu\n", name.c_str(), dim); + return nullptr; + } + shape.emplace_back(static_cast(dim)); + } + + auto tensorStart = dataBufferAddr + offsetBegin; + auto dtype = _convertSafeTensorDType(t.dtype); + auto var = MNN::Express::_Const(tensorStart, std::move(shape), MNN::Express::NCHW, dtype); + if (t.dtype == safetensors::kBFLOAT16) { + var = MNN::Express::_Cast(var); + var.fix(MNN::Express::VARP::CONSTANT); + } + return var; +} + +}; +}; diff --git a/tools/converter/source/safetensors/SafetensorConverter.hpp b/tools/converter/source/safetensors/SafetensorConverter.hpp new file mode 100644 index 0000000000..8305372060 --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorConverter.hpp @@ -0,0 +1,27 @@ +#ifndef SafetensorConverter_hpp +#define SafetensorConverter_hpp +#include +#include +#include +#include +#include "config.hpp" +namespace MNN { +namespace SafeTensors { +class MNN_PUBLIC Converter { +public: + Converter(const std::string& jsonFile); + ~ Converter(); + std::vector listModels() const; + void loadSafeTensors(const std::string& safeTensorFile); + bool convert(const std::string& name, modelConfig& modelPath); + MNN::Express::VARP loadTensor(const std::string& name, bool printNotFound = true) const; + bool hasTensor(const std::string& name) const; + struct Content; +private: + Content* mMain = nullptr; +}; +}; +}; + +#endif + diff --git a/tools/converter/source/safetensors/SafetensorModelRegistry.cpp b/tools/converter/source/safetensors/SafetensorModelRegistry.cpp new file mode 100644 index 0000000000..4f9478da12 --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorModelRegistry.cpp @@ -0,0 +1,83 @@ +#include "SafetensorModelRegistry.hpp" + +#include + +#include + +#include "MNN_generated.h" +#include "PostConverter.hpp" +#include "writeFb.hpp" +#include "../common/CommonUtils.hpp" + +namespace MNN { +namespace SafeTensors { + +struct SafetensorModelRegistry::Impl { + std::unordered_map builders; +}; + +SafetensorModelRegistry* SafetensorModelRegistry::get() { + static SafetensorModelRegistry gRegistry; + if (!gRegistry.mImpl) { + gRegistry.mImpl.reset(new Impl); + } + return &gRegistry; +} + +void SafetensorModelRegistry::insert(const std::string& name, ModelBuilder builder) { + if (name.empty() || builder == nullptr) { + return; + } + if (!mImpl) { + mImpl.reset(new Impl); + } + auto iter = mImpl->builders.find(name); + if (iter != mImpl->builders.end()) { + MNN_PRINT("SafetensorModelRegistry: override builder for %s\n", name.c_str()); + } + mImpl->builders[name] = builder; +} + +ModelBuilder SafetensorModelRegistry::find(const std::string& name) const { + if (!mImpl) { + return nullptr; + } + auto iter = mImpl->builders.find(name); + if (iter == mImpl->builders.end()) { + return nullptr; + } + return iter->second; +} + +SafetensorModelRegister::SafetensorModelRegister(const char* name, ModelBuilder builder) { + if (nullptr == name || builder == nullptr) { + return; + } + SafetensorModelRegistry::get()->insert(name, builder); +} + +MNN_PUBLIC void optimizeAndWrite(modelConfig& modelPath, std::unique_ptr& netT) { + if (nullptr == netT.get()) { + return; + } + + std::unique_ptr metaOp(new MNN::OpT); + metaOp->type = MNN::OpType_Extra; + metaOp->main.value = new MNN::ExtraT; + metaOp->main.type = MNN::OpParameter_Extra; + metaOp->main.AsExtra()->type = "Meta"; + metaOp->main.AsExtra()->engine = "MNN"; + + std::vector expectedPass; + CommonKit::loadCompress(modelPath); + + std::unique_ptr newNet = optimizeNet(netT, modelPath.forTraining, modelPath, expectedPass); + if (nullptr != newNet) { + (void)writeFb(newNet, modelPath, std::move(metaOp)); + } else { + MNN_ERROR("SafetensorModelRegistry: optimizeNet failed, skip writing %s\n", modelPath.MNNModel.c_str()); + } +} + +} // namespace SafeTensors +} // namespace MNN diff --git a/tools/converter/source/safetensors/SafetensorModelRegistry.hpp b/tools/converter/source/safetensors/SafetensorModelRegistry.hpp new file mode 100644 index 0000000000..de1f46194b --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorModelRegistry.hpp @@ -0,0 +1,53 @@ +#ifndef SafetensorModelRegistry_hpp +#define SafetensorModelRegistry_hpp + +#include +#include + +#include "config.hpp" + +#include + +#include + +namespace MNN { +struct NetT; + +namespace SafeTensors { + +class Converter; + +using ModelBuilder = bool (*)(const Converter* converter, const rapidjson::Value* model, modelConfig& modelPath); + +class MNN_PUBLIC SafetensorModelRegistry { +public: + static SafetensorModelRegistry* get(); + + void insert(const std::string& name, ModelBuilder builder); + ModelBuilder find(const std::string& name) const; + +private: + SafetensorModelRegistry() = default; + + struct Impl; + std::unique_ptr mImpl; +}; + +class MNN_PUBLIC SafetensorModelRegister { +public: + SafetensorModelRegister(const char* name, ModelBuilder builder); +}; + +MNN_PUBLIC void optimizeAndWrite(modelConfig& modelPath, std::unique_ptr& netT); + +} // namespace SafeTensors +} // namespace MNN + +#define MNN_SAFETENSOR_JOIN_INNER(x, y) x##y +#define MNN_SAFETENSOR_JOIN(x, y) MNN_SAFETENSOR_JOIN_INNER(x, y) + +#define REGISTER_SAFETENSOR_MODEL_BUILDER(modelName, builderFunc) \ + static ::MNN::SafeTensors::SafetensorModelRegister \ + MNN_SAFETENSOR_JOIN(__mnn_safetensor_model_register_, __COUNTER__)(modelName, builderFunc) + +#endif diff --git a/tools/converter/source/safetensors/SafetensorUtils.cpp b/tools/converter/source/safetensors/SafetensorUtils.cpp new file mode 100644 index 0000000000..1bd642f39b --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorUtils.cpp @@ -0,0 +1,437 @@ +#include "SafetensorUtils.hpp" + +#include +#include +#include + +#include + +#include "MNN_generated.h" +#include "core/IDSTEncoder.hpp" + +namespace MNN { +namespace Express { +namespace SafeTensorUtils { + +VARP _MakeLastHiddenStateOutput(VARP hiddenState, int hiddenSize) { + if (nullptr == hiddenState.get()) { + return nullptr; + } + std::vector sizes = {1, 1, hiddenSize}; + auto sizeVar = _Const(sizes.data(), {3}, NCHW, halide_type_of()); + std::vector begins = {0, -1, 0}; + auto beginVar = _Const(begins.data(), {3}, NCHW, halide_type_of()); + auto output = _Slice(hiddenState, beginVar, sizeVar); + output->setName("last_hidden_state"); + return output; +} + +VARP _GPT2Attention(int numHead, int headDim, VARP q, VARP k, VARP v, VARP qk_scale_q, VARP qk_scale_k, + VARP sv_scale_s, VARP sv_scale_v, VARP mask, bool supportC4Opt, float attnScale) { + std::unique_ptr op(new OpT); + op->type = OpType_Attention; + op->main.value = new AttentionParamT; + op->main.type = OpParameter_AttentionParam; + op->main.AsAttentionParam()->kv_cache = true; + op->main.AsAttentionParam()->attnScale = attnScale; + bool supportC4 = (headDim % 16 == 0) && supportC4Opt; + op->main.AsAttentionParam()->output_c4 = supportC4; + if (nullptr != qk_scale_q || nullptr != qk_scale_k) { + op->main.AsAttentionParam()->mhq_quant.resize(4); + for (int i = 0; i < 4; ++i) { + op->main.AsAttentionParam()->mhq_quant[i].reset(new TensorQuantInfoT); + op->main.AsAttentionParam()->mhq_quant[i]->scale = 0.0f; + } + auto& mhqQuant = op->main.AsAttentionParam()->mhq_quant; + if (nullptr != qk_scale_q) { + mhqQuant[0]->scale = qk_scale_q->readMap()[0]; + } + if (nullptr != qk_scale_k) { + mhqQuant[1]->scale = qk_scale_k->readMap()[0]; + } + if (nullptr != sv_scale_s) { + mhqQuant[2]->scale = sv_scale_s->readMap()[0]; + } + if (nullptr != sv_scale_v) { + mhqQuant[3]->scale = sv_scale_v->readMap()[0]; + } + } + VARP output; + if (nullptr != mask.get()) { + output = Variable::create(Expr::create(op.get(), {q, k, v, mask})); + } else { + output = Variable::create(Expr::create(op.get(), {q, k, v})); + } + if (!supportC4) { + output = _Reshape(output, {-1, numHead * headDim, 1, 1}); + } + return output; +} + +static void _splitBufToArray(const uint8_t* buf, uint8_t* arr, size_t arrLen, size_t needBits) { + unsigned char mask = (1 << needBits) - 1; + unsigned char* tmp = (unsigned char*)buf; + int offset = 0; + for (size_t i = 0; i < arrLen; ++i) { + unsigned char idx = 0; + long shift = 8 - needBits - offset % 8; + if (shift < 0) { + idx = (tmp[offset / 8] << (0 - shift)) & mask; + idx |= (tmp[(offset / 8) + 1] >> (8 + shift)) & mask; + } else { + idx = (tmp[offset / 8] >> shift) & mask; + } + offset += needBits; + if (offset % 8 == 0) { + tmp += offset / 8; + offset = 0; + } + arr[i] = idx; + } +} + +VARP _QConvolution1x1(int inputCount, VARP input, VARP inputScale, VARP inputZero, VARP weight, VARP wscale, + VARP wzeropoint, VARP bias, int outputCount, bool scaleInputCount, int weightBit) { + std::unique_ptr conv(new OpT); + conv->type = OpType_Convolution; + conv->main.type = OpParameter_Convolution2D; + conv->main.value = new Convolution2DT; + auto parm = conv->main.AsConvolution2D(); + parm->common.reset(new Convolution2DCommonT); + if (outputCount > 0) { + parm->common->outputCount = outputCount; + } else { + parm->common->outputCount = (int)bias->getInfo()->size; + outputCount = parm->common->outputCount; + } + auto weightSize = weight->getInfo()->size; + auto weightInputCount = weightSize / parm->common->outputCount; + if (0 == weightBit) { + weightBit = 8 * (int)weightInputCount / inputCount; + } + MNN_ASSERT(weightBit <= 8); + if (nullptr == wscale.get()) { + parm->weight.resize(weightSize); + auto ptr = weight->readMap(); + if (nullptr == ptr) { + MNN_ERROR("_QConvolution1x1: weight->readMap() is nullptr!\n"); + return nullptr; + } + ::memcpy(parm->weight.data(), ptr, weightSize * sizeof(float)); + parm->common->inputCount = inputCount; + parm->bias.resize(parm->common->outputCount); + if (nullptr != bias) { + auto bptr = bias->readMap(); + if (nullptr == bptr) { + MNN_ERROR("_QConvolution1x1: bias->readMap() is nullptr!\n"); + return nullptr; + } + ::memcpy(parm->bias.data(), bptr, bias->getInfo()->size * sizeof(float)); + } else { + ::memset(parm->bias.data(), 0, parm->bias.size() * sizeof(float)); + } + return Variable::create(Expr::create(conv.get(), {input})); + } + + if (scaleInputCount) { + std::vector scales(inputCount); + auto scalePtr = wscale->readMap(); + if (nullptr == scalePtr) { + MNN_ERROR("_QConvolution1x1: wscale->readMap() is nullptr!\n"); + return nullptr; + } + ::memcpy(scales.data(), scalePtr, inputCount * sizeof(float)); + std::vector emptyBias; + input = _Scale(input, inputCount, std::move(scales), std::move(emptyBias)); + wscale = _Const(1.0f, {outputCount}, NCHW); + } + if (wscale->getInfo()->size == 1 && parm->common->outputCount > 1) { + auto scalePtr = wscale->readMap(); + if (nullptr == scalePtr) { + MNN_ERROR("_QConvolution1x1: scalar wscale->readMap() is nullptr!\n"); + return nullptr; + } + std::vector scales(parm->common->outputCount, scalePtr[0]); + wscale = _Const(scales.data(), {parm->common->outputCount}, NCHW); + } + auto scaleSize = wscale->getInfo()->size; + if (parm->common->outputCount > scaleSize) { + MNN_ERROR("scaleSize %zu <= outputCount %d\n", scaleSize, parm->common->outputCount); + return nullptr; + } + + parm->common->inputCount = inputCount; + parm->bias.resize(parm->common->outputCount); + if (nullptr != bias) { + auto bptr = bias->readMap(); + if (nullptr == bptr) { + MNN_ERROR("_QConvolution1x1 quant: bias->readMap() is nullptr!\n"); + return nullptr; + } + ::memcpy(parm->bias.data(), bptr, bias->getInfo()->size * sizeof(float)); + } else { + ::memset(parm->bias.data(), 0, parm->bias.size() * sizeof(float)); + } + if (nullptr != inputScale) { + auto scale = inputScale->readMap()[0]; + float zeroPoint = 0.0f; + if (nullptr != inputZero) { + zeroPoint = inputZero->readMap()[0]; + } + input->writeScaleMap(scale, zeroPoint); + } + + int n = parm->common->outputCount; + int k = parm->common->inputCount; + std::vector weightInt8(n * k); + if (4 == weightBit) { + int kDiv8 = k / 8; + auto weightSrcInt8 = weight->readMap(); + for (int i = 0; i < kDiv8; ++i) { + for (int u = 0; u < 4; ++u) { + for (int v = 0; v < n; ++v) { + auto packed = weightSrcInt8[(i * 4 + u) * n + v]; + int8_t item1 = packed >> 4; + int8_t item0 = packed - item1 * 16; + if (item0 >= 8) { + item0 -= 16; + } + MNN_ASSERT(item1 <= 7 && item1 >= -8); + weightInt8[v * k + i * 8 + u] = item0; + weightInt8[v * k + i * 8 + u + 4] = item1; + } + } + } + } else if (weightBit == 8) { + ::memcpy(weightInt8.data(), weight->readMap(), n * k); + } else { + auto weightSrcUInt8 = weight->readMap(); + auto weightUInt8 = (uint8_t*)weightInt8.data(); + _splitBufToArray(weightSrcUInt8, weightUInt8, n * k, weightBit); + int offset = 1 << (weightBit - 1); + for (int i = 0; i < n * k; ++i) { + weightInt8[i] = (int)weightUInt8[i] - offset; + } + } + + int dstWeightBit = weightBit; + if (8 == weightBit) { + int maxV = -256; + int minV = 256; + for (int v = 0; v < n * k; ++v) { + auto q = weightInt8[v]; + if (q > maxV) { + maxV = q; + } + if (q < minV) { + minV = q; + } + } + int targetBit = 0; + if (maxV >= 0) { + targetBit = (int)ceil(log(maxV + 1) / log(2)) + 1; + } + if (minV < 0) { + auto d1 = (int)ceil(log(-minV) / log(2)) + 1; + if (d1 > targetBit) { + targetBit = d1; + } + } + dstWeightBit = targetBit; + } + if (dstWeightBit > 4) { + dstWeightBit = 8; + } else if (dstWeightBit > 1) { + dstWeightBit = 4; + } else { + dstWeightBit = 1; + } + + std::vector scale; + bool async = false; + if (nullptr != wzeropoint) { + if (_ReduceMax(_Abs(_Cast(wzeropoint)))->readMap()[0] >= 1e-11f) { + async = true; + } + } + if (async) { + scale.resize(2 * scaleSize); + if (wzeropoint->getInfo()->type.code == halide_type_float) { + auto zeroPoint = wzeropoint->readMap(); + auto scalePtr = wscale->readMap(); + for (int i = 0; i < scaleSize; ++i) { + scale[2 * i + 1] = scalePtr[i]; + scale[2 * i + 0] = zeroPoint[i]; + } + } else { + auto zeroPoint = wzeropoint->readMap(); + auto scalePtr = wscale->readMap(); + for (int i = 0; i < scaleSize; ++i) { + scale[2 * i + 1] = scalePtr[i]; + scale[2 * i + 0] = -scalePtr[i] * zeroPoint[i]; + } + } + } else { + scale.resize(scaleSize); + ::memcpy(scale.data(), wscale->readMap(), scaleSize * sizeof(float)); + } + auto kernelSize = n * k / scaleSize; + parm->quanParameter = IDSTEncoder::encode(nullptr, scale, kernelSize, scaleSize, async, weightInt8.data(), 1, + {dstWeightBit, false}); + return Variable::create(Expr::create(conv.get(), {input})); +} + +static std::unique_ptr _makeLayerNorm(const LayerNormInfo& info) { + auto inputDim = info.hiddenSize; + if (0 == inputDim) { + if (nullptr != info.inputLayerNormWeight && nullptr != info.inputLayerNormWeight->getInfo()) { + inputDim = (int)info.inputLayerNormWeight->getInfo()->size; + } else { + MNN_ERROR("_TransformerLayerNorm: hiddenSize is 0 and inputLayerNormWeight is missing!\n"); + } + } + std::unique_ptr layerNorm(new OpT); + layerNorm->type = OpType_LayerNorm; + layerNorm->main.value = new LayerNormT; + layerNorm->main.type = OpParameter_LayerNorm; + layerNorm->main.AsLayerNorm()->axis = {-1}; + layerNorm->main.AsLayerNorm()->group = 1; + layerNorm->main.AsLayerNorm()->epsilon = info.ln_eps; + layerNorm->main.AsLayerNorm()->useRMSNorm = info.useRMSNorm; + if (info.useC4) { + layerNorm->defaultDimentionFormat = MNN_DATA_FORMAT_NC4HW4; + } + if (nullptr != info.inputLayerNormWeight) { + layerNorm->main.AsLayerNorm()->beta.resize(inputDim); + layerNorm->main.AsLayerNorm()->gamma.resize(inputDim); + if (nullptr != info.inputLayerNormBias) { + ::memcpy(layerNorm->main.AsLayerNorm()->beta.data(), info.inputLayerNormBias->readMap(), + inputDim * sizeof(float)); + } else { + ::memset(layerNorm->main.AsLayerNorm()->beta.data(), 0, inputDim * sizeof(float)); + } + ::memcpy(layerNorm->main.AsLayerNorm()->gamma.data(), info.inputLayerNormWeight->readMap(), + inputDim * sizeof(float)); + } + return layerNorm; +} + +std::pair _BinaryLayerNorm(VARP r0, VARP r1, const LayerNormInfo& info) { + std::unique_ptr layerNorm = _makeLayerNorm(info); + auto expr = Expr::create(layerNorm.get(), {r0, r1}, 2); + return {Variable::create(expr, 0), Variable::create(expr, 1)}; +} + +VARP _TransformerLayerNorm(VARP hiddenState, const LayerNormInfo& info) { + std::unique_ptr layerNorm = _makeLayerNorm(info); + return Variable::create(Expr::create(layerNorm.get(), {hiddenState})); +} + +static void _fillRopeTable(float* dst, const std::vector& cosTable, const std::vector& sinTable, + int end, int halfDim) { + const int tableSize = end * halfDim; + for (int t = 0; t < end; ++t) { + for (int i = 0; i < halfDim; ++i) { + const int evenIndex = (2 * i) % halfDim; + const int oddIndex = (2 * i + 1) % halfDim; + const int srcEven = t * halfDim + evenIndex; + const int srcOdd = t * halfDim + oddIndex; + const int dstIndex = t * halfDim + i; + dst[dstIndex] = cosTable[srcEven]; + dst[tableSize + dstIndex] = cosTable[srcOdd]; + dst[2 * tableSize + dstIndex] = sinTable[srcEven]; + dst[3 * tableSize + dstIndex] = sinTable[srcOdd]; + } + } +} + +VARP _PrecomputePosEmbedding(int dim, int end, float theta, bool interleaved) { + if (dim % 2 != 0 || end <= 0 || theta <= 0.0f) { + return nullptr; + } + + const int halfDim = dim / 2; + const int tableSize = end * halfDim; + std::vector cosTable(tableSize); + std::vector sinTable(tableSize); + for (int t = 0; t < end; ++t) { + for (int i = 0; i < halfDim; ++i) { + const float exponent = static_cast(2 * i) / static_cast(dim); + const float invFreq = 1.0f / std::pow(theta, exponent); + const float freq = static_cast(t) * invFreq; + const int offset = t * halfDim + i; + cosTable[offset] = std::cos(freq); + sinTable[offset] = std::sin(freq); + } + } + if (!interleaved) { + std::vector freqsCis(2 * tableSize); + ::memcpy(freqsCis.data(), cosTable.data(), tableSize * sizeof(float)); + ::memcpy(freqsCis.data() + tableSize, sinTable.data(), tableSize * sizeof(float)); + auto res = _Const(freqsCis.data(), {2, end, halfDim}, NCHW, halide_type_of()); + res.fix(VARP::CONSTANT); + return res; + } + + std::vector ropeTables(4 * tableSize); + _fillRopeTable(ropeTables.data(), cosTable, sinTable, end, halfDim); + auto res = _Const(ropeTables.data(), {4, end, halfDim}, NCHW, halide_type_of()); + res.fix(VARP::CONSTANT); + return res; +} + +VARPS _TransformerRoPE(VARP q, VARP k, VARP cosEven, VARP cosOdd, VARP sinEven, VARP sinOdd, const RopeInfo& info) { + std::unique_ptr qnorm; + std::unique_ptr knorm; + if (nullptr != info.qNorm.inputLayerNormWeight.get()) { + qnorm = _makeLayerNorm(info.qNorm); + } + if (nullptr != info.kNorm.inputLayerNormWeight.get()) { + knorm = _makeLayerNorm(info.kNorm); + } + + std::unique_ptr ropeOp(new OpT); + ropeOp->type = OpType_RoPE; + ExtraT* extra = nullptr; + if (info.cutHeadDim > 0 || nullptr != qnorm || nullptr != knorm) { + ropeOp->main.type = OpParameter_Extra; + extra = new ExtraT; + extra->type = "RoPE"; + extra->engine = "MNN"; + ropeOp->main.value = extra; + } + if (nullptr != qnorm) { + std::unique_ptr attr(new AttributeT); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Op::Pack(builder, qnorm.get())); + attr->key = "q_norm"; + attr->tensor.reset(new BlobT); + attr->tensor->dataType = DataType_DT_INT8; + attr->tensor->int8s.resize(builder.GetSize()); + ::memcpy(attr->tensor->int8s.data(), builder.GetBufferPointer(), builder.GetSize()); + extra->attr.emplace_back(std::move(attr)); + } + if (nullptr != knorm) { + std::unique_ptr attr(new AttributeT); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Op::Pack(builder, knorm.get())); + attr->key = "k_norm"; + attr->tensor.reset(new BlobT); + attr->tensor->dataType = DataType_DT_INT8; + attr->tensor->int8s.resize(builder.GetSize()); + ::memcpy(attr->tensor->int8s.data(), builder.GetBufferPointer(), builder.GetSize()); + extra->attr.emplace_back(std::move(attr)); + } + if (info.cutHeadDim > 0) { + std::unique_ptr attr(new AttributeT); + attr->key = "rope_cut_head_dim"; + attr->i = info.cutHeadDim; + extra->attr.emplace_back(std::move(attr)); + } + auto ropeExpr = Expr::create(ropeOp.get(), {q, k, cosEven, cosOdd, sinEven, sinOdd}, 2); + return {Variable::create(ropeExpr, 0), Variable::create(ropeExpr, 1)}; +} + +} // namespace SafeTensorUtils +} // namespace Express +} // namespace MNN diff --git a/tools/converter/source/safetensors/SafetensorUtils.hpp b/tools/converter/source/safetensors/SafetensorUtils.hpp new file mode 100644 index 0000000000..a400ecfc96 --- /dev/null +++ b/tools/converter/source/safetensors/SafetensorUtils.hpp @@ -0,0 +1,50 @@ +#ifndef SafetensorUtils_hpp +#define SafetensorUtils_hpp + +#include + +#include +#include + +namespace MNN { +namespace Express { +namespace SafeTensorUtils { + +struct LayerNormInfo { + VARP inputLayerNormWeight; + VARP inputLayerNormBias; + float ln_eps = 0.0f; + bool useRMSNorm = false; + int hiddenSize = 0; + bool useC4 = false; + + LayerNormInfo() = default; + LayerNormInfo(VARP weight, VARP bias, float eps, bool rms, int hidden = 0, bool c4 = false) + : inputLayerNormWeight(weight), inputLayerNormBias(bias), ln_eps(eps), useRMSNorm(rms), hiddenSize(hidden), useC4(c4) { + } +}; + +struct RopeInfo { + LayerNormInfo qNorm; + LayerNormInfo kNorm; + int cutHeadDim = 0; +}; + +MNN_PUBLIC VARP _QConvolution1x1(int inputCount, VARP input, VARP inputScale, VARP inputZero, VARP weight, + VARP wscale, VARP wzeropoint, VARP bias, int outputcount = 0, + bool scaleInputCount = false, int weightBits = 0); +MNN_PUBLIC VARP _TransformerLayerNorm(VARP hiddenState, const LayerNormInfo& info); +MNN_PUBLIC std::pair _BinaryLayerNorm(VARP r0, VARP r1, const LayerNormInfo& info); +MNN_PUBLIC VARP _GPT2Attention(int numHead, int headDim, VARP q, VARP k, VARP v, VARP qk_scale_q, VARP qk_scale_k, + VARP sv_scale_s, VARP sv_scale_v, VARP mask, bool supportC4Opt = false, + float attnScale = 0.0f); +MNN_PUBLIC VARP _PrecomputePosEmbedding(int dim, int end, float theta = 1000000.0f, bool interleaved = false); +MNN_PUBLIC VARPS _TransformerRoPE(VARP q, VARP k, VARP cosEven, VARP cosOdd, VARP sinEven, VARP sinOdd, + const RopeInfo& info); +MNN_PUBLIC VARP _MakeLastHiddenStateOutput(VARP hiddenState, int hiddenSize); + +} // namespace SafeTensorUtils +} // namespace Express +} // namespace MNN + +#endif diff --git a/tools/converter/source/safetensors/WorkflowJson.hpp b/tools/converter/source/safetensors/WorkflowJson.hpp new file mode 100644 index 0000000000..6d54d9e120 --- /dev/null +++ b/tools/converter/source/safetensors/WorkflowJson.hpp @@ -0,0 +1,124 @@ +#ifndef WorkflowJson_hpp +#define WorkflowJson_hpp + +#include + +#include + +namespace MNN { +namespace SafeTensors { +namespace WorkflowJson { + +inline const rapidjson::Value* _findMember(const rapidjson::Value& obj, const char* key) { + if (!obj.IsObject() || nullptr == key) { + return nullptr; + } + auto it = obj.FindMember(key); + if (it == obj.MemberEnd()) { + return nullptr; + } + return &it->value; +} + +inline std::string getString(const rapidjson::Value& obj, const char* key, const std::string& defaultValue = "") { + auto v = _findMember(obj, key); + if (nullptr == v || !v->IsString()) { + return defaultValue; + } + return v->GetString(); +} + +inline bool getBool(const rapidjson::Value& obj, const char* key, bool defaultValue = false) { + auto v = _findMember(obj, key); + if (nullptr == v) { + return defaultValue; + } + return v->GetBool(); +} +inline int getInt(const rapidjson::Value& obj, const char* key, int defaultValue = 0) { + auto v = _findMember(obj, key); + if (nullptr == v || !v->IsInt()) { + return defaultValue; + } + return v->GetInt(); +} + +inline float getFloat(const rapidjson::Value& obj, const char* key, float defaultValue = 0.0f) { + auto v = _findMember(obj, key); + if (nullptr == v) { + return defaultValue; + } + if (v->IsFloat()) { + return v->GetFloat(); + } + if (v->IsDouble()) { + return static_cast(v->GetDouble()); + } + return defaultValue; +} + +inline const rapidjson::Value* getArray(const rapidjson::Value& obj, const char* key) { + auto v = _findMember(obj, key); + if (nullptr == v || !v->IsArray()) { + return nullptr; + } + return v; +} + +inline bool firstArrayStringEquals(const rapidjson::Value& obj, const char* key, const char* expected) { + if (nullptr == expected) { + return false; + } + auto v = getArray(obj, key); + if (nullptr == v || v->Empty()) { + return false; + } + auto& first = (*v)[0]; + if (!first.IsString()) { + return false; + } + return first.GetString() == std::string(expected); +} + +inline bool arrayStringContains(const rapidjson::Value& obj, const char* key, const char* expected) { + if (nullptr == expected) { + return false; + } + auto v = getArray(obj, key); + if (nullptr == v) { + return false; + } + const std::string target(expected); + for (auto& item : v->GetArray()) { + if (item.IsString() && item.GetString() == target) { + return true; + } + } + return false; +} + +inline const rapidjson::Value* findFirstBlockByType(const rapidjson::Value& model, const char* type) { + if (nullptr == type) { + return nullptr; + } + auto blocks = getArray(model, "blocks"); + if (nullptr == blocks) { + return nullptr; + } + for (auto& item : blocks->GetArray()) { + if (!item.IsObject()) { + continue; + } + auto t = _findMember(item, "type"); + if (nullptr != t && t->IsString() && t->GetString() == std::string(type)) { + return &item; + } + } + return nullptr; +} + +} // namespace WorkflowJson +} // namespace SafeTensors +} // namespace MNN + +#endif diff --git a/tools/converter/source/safetensors/safetensors.hh b/tools/converter/source/safetensors/safetensors.hh new file mode 100644 index 0000000000..e0add85f6e --- /dev/null +++ b/tools/converter/source/safetensors/safetensors.hh @@ -0,0 +1,4865 @@ +// SPDX-License-Identifier: MIT +// Copyright 2023 - Present, Syoyo Fujita. +// Inspired from: +// https://gist.github.com/Narsil/5d6bf307995158ad2c4994f323967284 +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __ANDROID__ +#ifdef SAFETENSORS_CPP_ANDROID_LOAD_FROM_ASSETS +#include +#endif + +#ifdef SAFETENSORS_CPP_IMPLEMENTATION +AAssetManager *asset_manager = nullptr; +#else +extern AAssetManager *asset_manager; +#endif +#endif + + +namespace safetensors { + +constexpr size_t kMaxDim = + 8; // must be equal to SAFETENSORS_C_MAX_DIM in `safetensors-c.h` + +enum dtype { + kBOOL, + kUINT8, + kINT8, + kINT16, + kUINT16, + kFLOAT16, + kBFLOAT16, + kINT32, + kUINT32, + kFLOAT32, + kFLOAT64, + kINT64, + kUINT64, +}; + +namespace minijson { + +// Simple C++ implementation of Python's OrderedDict like dictonary +// (preserves key insertion order) +// Modified for JSON: +// - No duplicated key allowed + +template +class ordered_dict { + public: + bool at(const size_t idx, T *dst) const { + if (idx >= _keys.size()) { + return false; + } + + if (!_m.count(_keys[idx])) { + // This should not happen though. + return false; + } + + (*dst) = _m.at(_keys[idx]); + + return true; + } + + bool count(const std::string &key) const { return _m.count(key); } + + void insert(const std::string &key, const T &value) { + if (_m.count(key)) { + // overwrite existing value + } else { + _keys.push_back(key); + } + + _m[key] = value; + } + + void insert(const std::string &key, T &&value) { + if (_m.count(key)) { + // overwrite existing value + } else { + _keys.push_back(key); + } + + _m[key] = std::move(value); + } + + bool at(const std::string &key, T *dst) const { + if (!_m.count(key)) { + // This should not happen though. + return false; + } + + (*dst) = _m.at(key); + + return true; + } + + const std::vector &keys() const { return _keys; } + + size_t size() const { return _m.size(); } + + bool erase(const std::string &key) { + // simple linear search + for (size_t i = 0; i < _keys.size(); i++) { + if (_keys[i] == key) { + _keys.erase(_keys.begin() + i); + _m.erase(key); + return true; + } + } + + return false; + } + + private: + std::vector _keys; + std::map _m; +}; + +} // namespace minijson + +template +using ordered_dict = minijson::ordered_dict; + +struct tensor_t { + safetensors::dtype dtype; + std::vector shape; + std::array data_offsets; +}; + +struct safetensors_t { + // we need ordered dict(preserves the order of key insertion) + // as done in Python's OrderedDict, since JSON data may not be sorted by its key string. + ordered_dict tensors; + ordered_dict metadata; + std::vector storage; // empty when mmap'ed + size_t header_size{0}; // JSON size + + bool mmaped{false}; + + // + // Following members are set when mmaped. + // + const uint8_t *mmap_addr{nullptr}; + size_t mmap_size{0}; + const uint8_t *databuffer_addr{nullptr}; // [mmap_addr + header_size + 8] + size_t databuffer_size{0}; // mmap_size - header_size - 8 + // opaque pointer to safetensors_file and safetensors_mmap + void *st_file{nullptr}; + void *st_mmap{nullptr}; + + ~safetensors_t(); +}; + +// +// Load safetensors from file. +// databuffer is copied to `safetensors_t::storage`. +// +// @param[in] filename Filepath. Assume UTF-8 filepath. +// @param[out] st safetensors data. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +bool load_from_file(const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err); + +// +// Load safetensors data from memory. +// databuffer is copied to `safetensors_t::storage`. +// +// @param[in] addr Memory address of safetensors data. +// @param[in] nbytes The size in bytes. +// @param[in] filename Filename of corresponding memory data. Can be empty. +// @param[out] st safetensors data. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +// +bool load_from_memory(const uint8_t *addr, const size_t nbytes, + const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err); + +// +// Load safetensors with memory mapping(i.e. zero-copy). +// databuffer is not copied to `safetensors_t` object, thus the app must hold +// file during `safetensor_t` object is live. +// +// @param[in] filename Filepath. Assume UTF-8 filepath. +// @param[out] st safetensors data. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +bool mmap_from_file(const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err); + +// +// Load safetensors from mmaped region. +// databuffer is not copied to `safetensors_t` object, thus the app must not +// free/unmap `addr` during `safetensor_t` object is live. +// +// @param[in] addr mmaped memory address of safetensors data. +// @param[in] nbytes mmap bytes. +// @param[in] filename Filename of corresponding memory data. Can be empty. +// @param[out] st safetensors data. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +bool mmap_from_memory(const uint8_t *arr, const size_t nbytes, + const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err); + +// +// Save safetensors to file. +// +// @param[in] st safetensors data. +// @param[in] filename Filepath. Assume UTF-8 filepath. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +bool save_to_file(const safetensors_t &st, const std::string &filename, + std::string *warn, std::string *err); + +// +// Save safetensors to memory. +// +// @param[in] st safetensors data. +// @param[out] data_out Serialized safetensor data. +// @param[out] warn Warning message buffer(can be nullptr if you don't need +// warning message) +// @param[out] err Error message buffer(can be nullptr if you don't need error +// message) +// +// @return true upon success. `err` will be filled when false. +bool save_to_memory(const std::string &filename, std::vector *data_out, + std::string *warn, std::string *err); + +// +// Utility functions +// + +// Returns shape[0] * shape[1] * ... +// Empty Tensor(any shape[i] is 0) returns 0. +// Zero-rank tensor([]) return 1. +size_t get_shape_size(const tensor_t &t); + +// Returns dtype size in bytes. +size_t get_dtype_bytes(const safetensors::dtype dtype); +std::string get_dtype_str(const safetensors::dtype dtype); + +// Validate data_offsets of all tensors in safetensors_t. +bool validate_data_offsets(const safetensors_t &st, std::string &err); + +uint16_t float_to_bfloat16(float x); +float bfloat16_to_float(uint16_t x); + +uint16_t float_to_fp16(float x); +float fp16_to_float(uint16_t x); + +} // namespace safetensors + +#if defined(SAFETENSORS_CPP_IMPLEMENTATION) + +#include +#include +#include + +#ifdef __has_include +#if __has_include() +#include +#if defined(_POSIX_MAPPED_FILES) +#include +#endif +#if defined(_POSIX_MEMLOCK_RANGE) +#include +#endif +#endif +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include // for _fseeki64 +#include +#endif + +#if !defined(MINIJSON_IMPLEMENTATION) +#define MINIJSON_IMPLEMENTATION +#endif + +// minijson: https://github.com/syoyo/minijson + +/* + * JSON parser: C++ oriented JSON parser. + */ + +#include +#include +#include +#include +#include + +//#define __MINIJSON_LIBERAL + +// We recommended to use simdjson from_chars. +// Using strtod() is a fallback +#if defined(MINIJSON_USE_STRTOD) +// Use stdlib's strtod +#include +#else + +namespace minijson { +namespace simdjson { +namespace internal { + +double from_chars(const char *first) noexcept; +double from_chars(const char *first, const char *end) noexcept; + +char *to_chars(char *first, const char *last, double value); + +} // namespace internal +} // namespace simdjson +} // namesapce minijson + +#endif + +namespace minijson { + +namespace detail { + +double from_chars(const char *p); +const char *my_strchr(const char *p, int ch); + +} // namespace detail + +namespace detail { + +// +// Usage: +// - set_input() +// - scan_string() +// - success: use `token_buffer` string +// - error: use `error_message` +// +struct string_parser { + // input string must be UTF-8 + void set_input(const std::string &s) { _input = s; } + + bool scan_string(); + + void reset() { + if (_input.size()) { + current = _input[0]; + } else { + current = '\0'; + } + curr_idx = 0; + token_buffer.clear(); + } + + // fetch next token. + unsigned char get() { + if ((curr_idx + 1) < _input.size()) { + curr_idx++; + current = _input[curr_idx]; + return current; + } + current = '\0'; + return current; + } + + bool eof() { + if (_input.empty()) { + return true; + } + + if (curr_idx >= _input.size()) { + return true; + } + + return false; + } + + void add(const unsigned char c) { token_buffer += c; } + + void add(const int i) { + // use lower 8bit + token_buffer += static_cast(i & 0xff); + } + + int get_codepoint(); + + bool next_byte_in_range(const std::initializer_list ranges); + + std::string error_message; + std::string token_buffer; // output + + unsigned char current{'\0'}; + size_t curr_idx{0}; + std::string _input; +}; + +} // namespace detail + +typedef enum { + unknown_type, + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type, +} type; + +typedef enum { + no_error, + undefined_error, + invalid_token_error, + unknown_type_error, + memory_allocation_error, + corrupted_json_error, + duplicated_key_error, +} error; + +class value; + +typedef bool boolean; +typedef double number; +typedef std::string string; +typedef safetensors::ordered_dict object; +typedef std::vector array; +typedef struct { +} null_t; + +// null_t null; + +template +struct TypeTraits; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 0; } +}; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 1; } +}; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 2; } +}; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 3; } +}; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 4; } +}; + +template <> +struct TypeTraits { + static constexpr uint32_t type_id() { return 5; } +}; + +class value { + private: + type t; + union { + null_t n; + boolean b; + number d; + std::string *s; + array *a; + object *o; + } u; + + void _free_u() { + if (t == string_type) { + delete this->u.s; + this->u.s = nullptr; + } + if (t == array_type) { + delete this->u.a; + this->u.a = nullptr; + } + if (t == object_type) { + delete this->u.o; + this->u.o = nullptr; + } + } + + public: + value() : t(unknown_type), u() {} + value(null_t n) : t(null_type), u() { u.n = n; } + value(boolean b) : t(boolean_type), u() { u.b = b; } + value(number d) : t(boolean_type), u() { u.d = d; } + value(const char *s) : t(string_type), u() { u.s = new std::string(s); } + value(const std::string &s) : t(string_type), u() { + u.s = new std::string(s); + } + value(const array &a) : t(array_type), u() { u.a = new array(a); } + value(const object &o) : t(object_type), u() { u.o = new object(o); } + value(const value &v) : t(v.t), u() { + if (t == array_type) { + u.a = new array(); + *u.a = *v.u.a; + } else if (t == object_type) { + u.o = new object(); + *u.o = *v.u.o; + } else if (t == string_type) { + u.s = new std::string(); + *u.s = *v.u.s; + } else + u.d = v.u.d; + } + ~value() { _free_u(); } + + template + bool is() const { + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == null_type) + return true; + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == boolean_type) + return true; + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == number_type) + return true; + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == string_type) + return true; + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == array_type) + return true; + if (TypeTraits::type_id() == TypeTraits::type_id() && + t == object_type) + return true; + return false; + } + + template + const T *as() const { + if ((t == array_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.a); + } + + if ((t == object_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.o); + } + + if ((t == string_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.s); + } + + if ((t == null_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.n); + } + + if ((t == boolean_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.b); + } + + if ((t == number_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.d); + } + + return nullptr; + } + + template + T *as() { + if ((t == array_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.a); + } + + if ((t == object_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.o); + } + + if ((t == string_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(u.s); + } + + if ((t == null_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.n); + } + + if ((t == boolean_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.b); + } + + if ((t == number_type) && + (TypeTraits::type_id() == TypeTraits::type_id())) { + return reinterpret_cast(&u.d); + } + + return nullptr; + } + + null_t &operator=(null_t &n) { + t = null_type; + u.n = n; + return u.n; + } + boolean &operator=(boolean b) { + t = boolean_type; + u.b = b; + return u.b; + } + number &operator=(number d) { + t = number_type; + u.d = d; + return u.d; + } + const std::string &operator=(const char *s) { + _free_u(); + t = string_type; + u.s = new std::string(s); + return *u.s; + } + const std::string &operator=(const std::string &s) { + _free_u(); + t = string_type; + u.s = new std::string(s); + return *u.s; + } + const object &operator=(const object &o) { + _free_u(); + t = object_type; + u.o = new object(o); + return *u.o; + } + const array &operator=(const array &a) { + _free_u(); + t = array_type; + u.a = new array(a); + return *u.a; + } + const value &operator=(const value &v) { + _free_u(); + t = v.t; + if (t == array_type) { + u.a = new array(*v.u.a); + } else if (t == object_type) { + u.o = new object(*v.u.o); + } else if (t == string_type) { + u.s = new std::string(*v.u.s); + } else + u.d = v.u.d; + return *this; + } + + std::string type_name() const { + if (t == array_type) { + return "array"; + } + + if (t == object_type) { + return "object"; + } + + if (t == string_type) { + return "string"; + } + + if (t == null_type) { + return "null"; + } + + if (t == boolean_type) { + return "boolean"; + } + + if (t == number_type) { + return "number"; + } + + return "[[invalid]]"; + } + + std::string str(const char *p) const { + std::stringstream ss; + ss << '"'; + while (*p) { + if (*p == '\n') { + ss << "\\n"; + } else if (*p == '\r') { + ss << "\\r"; + } else if (*p == '\t') { + ss << "\\t"; + } else if (detail::my_strchr("\"", *p)) { + ss << "\\" << *p; + } else { + ss << *p; + } + p++; + } + ss << '"'; + return ss.str(); + } + + std::string str() const { + std::stringstream ss; + if (t == unknown_type) { + ss << "undefined"; + } else if (t == null_type) { + ss << "null"; + } else if (t == boolean_type) { + ss << (u.b ? "true" : "false"); + } else if (t == number_type) { + ss << double(u.d); + } else if (t == string_type) { + ss << str(u.s->c_str()); + } else if (const array *pa = as()) { + array::const_iterator i; + ss << "["; + // array a = get(); + for (i = pa->begin(); i != pa->end(); i++) { + if (i != pa->begin()) ss << ", "; + ss << i->str(); + } + ss << "]"; + } else if (auto po = as()) { + // object::const_iterator i; + ss << "{"; + // object o = get(); + for (size_t i = 0; i < po->size(); i++) { + if (i > 0) ss << ", "; + ss << "\"" << po->keys()[i] << "\""; + + value v; + if (po->at(i, &v)) { + ss << ": " << v.str(); + } else { + // TODO: report error + ss << ": null"; + } + } + ss << "}"; + } + return ss.str(); + } +}; + +#define MINIJSON_SKIP(i) \ + while (*i && detail::my_strchr("\r\n \t", *i)) { \ + i++; \ + } + +template +inline error parse_object(Iter &i, value &v) { + object o; + i++; + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } + if (*i != '\x7d') { + while (*i) { + value vk, vv; + error e = parse_string(i, vk); + if (e != no_error) return e; + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } + if (*i != ':') return invalid_token_error; + i++; + e = parse_any(i, vv); + if (e != no_error) return e; + + auto ps = vk.as(); + if (!ps) { + return unknown_type_error; + } + + if (o.count(*ps)) { + return duplicated_key_error; + } + o.insert(*ps, vv); + + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } + if (*i == '\x7d') break; + if (*i != ',') return invalid_token_error; + i++; + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } +#ifdef __MINIJSON_LIBERAL + if (*i == '\x7d') break; +#endif + } + } + v = value(o); + i++; + return no_error; +} + +template +inline error parse_array(Iter &i, value &v) { + array a; + i++; + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } + if (*i != ']') { + while (*i) { + value va; + error e = parse_any(i, va); + if (e != no_error) return e; + a.push_back(va); + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } + if (*i == ']') break; + if (*i != ',') return invalid_token_error; + i++; + MINIJSON_SKIP(i) + if (!(*i)) { + return corrupted_json_error; + } +#ifdef __MINIJSON_LIBERAL + if (*i == '\x7d') break; +#endif + } + } + v = value(a); + i++; + return no_error; +} + +template +inline error parse_null(Iter &i, value &v) { + Iter p = i; + if (*i == 'n' && *(i + 1) == 'u' && *(i + 2) == 'l' && *(i + 3) == 'l') { + i += 4; + v = null_t(); + } + if (*i && nullptr == detail::my_strchr(":,\x7d]\r\n ", *i)) { + i = p; + return undefined_error; + } + return no_error; +} + +template +inline error parse_boolean(Iter &i, value &v) { + Iter p = i; + if (*i == 't' && *(i + 1) == 'r' && *(i + 2) == 'u' && *(i + 3) == 'e') { + i += 4; + v = static_cast(true); + } else if (*i == 'f' && *(i + 1) == 'a' && *(i + 2) == 'l' && + *(i + 3) == 's' && *(i + 4) == 'e') { + i += 5; + v = static_cast(false); + } + if (*i && nullptr == detail::my_strchr(":,\x7d]\r\n ", *i)) { + i = p; + return undefined_error; + } + return no_error; +} + +template +inline error parse_number(Iter &i, value &v) { + Iter p = i; + + if (*i == '-') { + i++; + } + +#define MINIJSON_IS_NUM(x) ('0' <= x && x <= '9') +#define MINIJSON_IS_ALNUM(x) \ + (('0' <= x && x <= '9') || ('a' <= x && x <= 'f') || ('A' <= x && x <= 'F')) + if (*i == '0' && *(i + 1) == 'x' && MINIJSON_IS_ALNUM(*(i + 2))) { + i += 3; + while (MINIJSON_IS_ALNUM(*i)) i++; + v = static_cast(detail::from_chars(p)); + } else { + while (MINIJSON_IS_NUM(*i)) i++; + if (*i == '.') { + i++; + if (!MINIJSON_IS_NUM(*i)) { + i = p; + return invalid_token_error; + } + while (MINIJSON_IS_NUM(*i)) i++; + } + if (*i == 'e') { + i++; + if (!MINIJSON_IS_NUM(*i)) { + i = p; + return invalid_token_error; + } + while (MINIJSON_IS_NUM(*i)) i++; + } + v = static_cast(detail::from_chars(p)); + } + if (*i && nullptr == detail::my_strchr(":,\x7d]\r\n ", *i)) { + i = p; + return invalid_token_error; + } + return no_error; +} + +template +inline error parse_string(Iter &i, value &v) { + if (*i != '"') return invalid_token_error; + + Iter s = i; + char t = *i++; // = '"' + Iter p = i; + +#if 0 + std::stringstream ss; + while (*i && *i != t) { + if (*i == '\\' && *(i + 1)) { + i++; + if (*i == 'n') + ss << "\n"; + else if (*i == 'r') + ss << "\r"; + else if (*i == 't') + ss << "\t"; + else + ss << *i; + } else { + ss << *i; + } + i++; + } +#else + // read until '"' + while (*i && *i != t) { + if (*i == '\\' && *(i + 1)) { + i++; + } + i++; + } + +#endif + if (!*i) return invalid_token_error; + if (i < p) { + return corrupted_json_error; + } + +#if 0 + v = std::string(p, size_t(i - p)); + + i++; + if (*i && nullptr == detail::my_strchr(":,\x7d]\r\n ", *i)) { + i = p; + return invalid_token_error; + } + +#else + + i++; + if (*i && nullptr == detail::my_strchr(":,\x7d]\r\n ", *i)) { + i = p; + return invalid_token_error; + } + + // include first and last '"' char + std::string buf(s, size_t(i - s)); + + detail::string_parser str_parser; + str_parser.set_input(buf); + + if (!str_parser.scan_string()) { + // TODO: error message + // str_parser.error_message; + return invalid_token_error; + } else { + v = str_parser.token_buffer; + } + +#endif + + return no_error; +} + +template +inline error parse_any(Iter &i, value &v) { + MINIJSON_SKIP(i) + if (*i == '\x7b') return parse_object(i, v); + if (*i == '[') return parse_array(i, v); + if (*i == 't' || *i == 'f') return parse_boolean(i, v); + if (*i == 'n') return parse_null(i, v); + if ((*i == '-') || ('0' <= *i && *i <= '9')) return parse_number(i, v); + if (*i == '"') return parse_string(i, v); + return invalid_token_error; +} + +template +inline error parse(Iter &i, value &v) { + return parse_any(i, v); +} + +#undef MINIJSON_SKIP + +inline const char *errstr(error e) { + const char *s = "unknown error"; + switch (e) { + case no_error: { + s = "no error"; + break; + } + case undefined_error: { + s = "undefined"; + break; + } + case invalid_token_error: { + s = "invalid token"; + break; + } + case unknown_type_error: { + s = "unknown type"; + break; + } + case memory_allocation_error: { + s = "memory allocation error"; + break; + } + case corrupted_json_error: { + s = "input is corrupted"; + break; + } + case duplicated_key_error: { + s = "duplicated key found"; + break; + } + // default: return "unknown error"; + } + + return s; +} + +} // namespace minijson + +#if defined(MINIJSON_IMPLEMENTATION) + +namespace minijson { + +namespace detail { + +// clang-format off +// +// From json.hpp --------------------------------------------------------- +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#if 1 + #define JSON_HEDLEY_UNLIKELY(cond) (cond) + #define JSON_HEDLEY_LIKELY(cond) (cond) + + /*! + @brief get codepoint from 4 hex characters following `\u` + + For input "\u c1 c2 c3 c4" the codepoint is: + (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4 + = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0) + + Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f' + must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The + conversion is done by subtracting the offset (0x30, 0x37, and 0x57) + between the ASCII value of the character and the desired integer value. + + @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or + non-hex character) + */ + int string_parser::get_codepoint() + { + // this function only makes sense after reading `\u` + //JSON_ASSERT(current == 'u'); + if (current != 'u') { + return -1; + } + int codepoint = 0; + + const auto factors = { 12u, 8u, 4u, 0u }; + for (const auto factor : factors) + { + get(); + + if (current >= '0' && current <= '9') + { + codepoint += static_cast((static_cast(current) - 0x30u) << factor); + } + else if (current >= 'A' && current <= 'F') + { + codepoint += static_cast((static_cast(current) - 0x37u) << factor); + } + else if (current >= 'a' && current <= 'f') + { + codepoint += static_cast((static_cast(current) - 0x57u) << factor); + } + else + { + return -1; + } + } + + if (0x0000 <= codepoint && codepoint <= 0xFFFF) { + } else { + return -1; + } + return codepoint; + } + + /*! + @brief check if the next byte(s) are inside a given range + + Adds the current byte and, for each passed range, reads a new byte and + checks if it is inside the range. If a violation was detected, set up an + error message and return false. Otherwise, return true. + + @param[in] ranges list of integers; interpreted as list of pairs of + inclusive lower and upper bound, respectively + + @pre The passed list @a ranges must have 2, 4, or 6 elements; that is, + 1, 2, or 3 pairs. This precondition is enforced by an assertion. + + @return true if and only if no range violation was detected + */ + bool string_parser::next_byte_in_range(const std::initializer_list ranges) + { + if (ranges.size() == 2 || ranges.size() == 4 || ranges.size() == 6) { + } else { + return false; + } + + add(current); + + for (auto range = ranges.begin(); range != ranges.end(); ++range) + { + get(); + if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) // NOLINT(bugprone-inc-dec-in-conditions) + { + add(current); + } + else + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return false; + } + } + + return true; + } + /*! + @brief scan a string literal + + This function scans a string according to Sect. 7 of RFC 8259. While + scanning, bytes are escaped and copied into buffer token_buffer. Then the + function returns successfully, token_buffer is *not* null-terminated (as it + may contain \0 bytes), and token_buffer.size() is the number of bytes in the + string. + + @return true if string could be successfully scanned, + false otherwise + + @note In case of errors, variable error_message contains a textual + description. + */ + bool string_parser::scan_string() + { + // reset token_buffer (ignore opening quote) + reset(); + + // we entered the function by reading an open quote + //JSON_ASSERT(current == '\"'); + if (current != '\"') { + error_message = "first character must be '\"'"; + return false; + } + + + while (!eof()) + { + + // get next character + switch (get()) + { + + // closing quote + case '\"': + { + return true; + } + + // escapes + case '\\': + { + switch (get()) + { + // quotation mark + case '\"': + add('\"'); + break; + // reverse solidus + case '\\': + add('\\'); + break; + // solidus + case '/': + add('/'); + break; + // backspace + case 'b': + add('\b'); + break; + // form feed + case 'f': + add('\f'); + break; + // line feed + case 'n': + add('\n'); + break; + // carriage return + case 'r': + add('\r'); + break; + // tab + case 't': + add('\t'); + break; + + // unicode escapes + case 'u': + { + const int codepoint1 = get_codepoint(); + int codepoint = codepoint1; // start with codepoint1 + + if (JSON_HEDLEY_UNLIKELY(codepoint1 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return false; + } + + // check if code point is a high surrogate + if (0xD800 <= codepoint1 && codepoint1 <= 0xDBFF) + { + // expect next \uxxxx entry + if (JSON_HEDLEY_LIKELY(get() == '\\' && get() == 'u')) + { + const int codepoint2 = get_codepoint(); + + if (JSON_HEDLEY_UNLIKELY(codepoint2 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return false; + } + + // check if codepoint2 is a low surrogate + if (JSON_HEDLEY_LIKELY(0xDC00 <= codepoint2 && codepoint2 <= 0xDFFF)) + { + // overwrite codepoint + codepoint = static_cast( + // high surrogate occupies the most significant 22 bits + (static_cast(codepoint1) << 10u) + // low surrogate occupies the least significant 15 bits + + static_cast(codepoint2) + // there is still the 0xD800, 0xDC00 and 0x10000 noise + // in the result, so we have to subtract with: + // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00 + - 0x35FDC00u); + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return false; + } + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return false; + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(0xDC00 <= codepoint1 && codepoint1 <= 0xDFFF)) + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF"; + return false; + } + } + + // result of the above calculation yields a proper codepoint + //JSON_ASSERT(0x00 <= codepoint && codepoint <= 0x10FFFF); + if (0x00 <= codepoint && codepoint <= 0x10FFFF) { + } else { + error_message = "invalid string: invalid codepoint"; + return false; + } + + // translate codepoint into bytes + if (codepoint < 0x80) + { + // 1-byte characters: 0xxxxxxx (ASCII) + add(static_cast(codepoint)); + } + else if (codepoint <= 0x7FF) + { + // 2-byte characters: 110xxxxx 10xxxxxx + add(static_cast(0xC0u | (static_cast(codepoint) >> 6u))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else if (codepoint <= 0xFFFF) + { + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + add(static_cast(0xE0u | (static_cast(codepoint) >> 12u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else + { + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + add(static_cast(0xF0u | (static_cast(codepoint) >> 18u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 12u) & 0x3Fu))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + + break; + } + + // other characters after escape + default: + error_message = "invalid string: forbidden character after backslash"; + return false; + } + + break; + } + + // invalid control characters + case 0x00: + { + error_message = "invalid string: control character U+0000 (NUL) must be escaped to \\u0000"; + return false; + } + + case 0x01: + { + error_message = "invalid string: control character U+0001 (SOH) must be escaped to \\u0001"; + return false; + } + + case 0x02: + { + error_message = "invalid string: control character U+0002 (STX) must be escaped to \\u0002"; + return false; + } + + case 0x03: + { + error_message = "invalid string: control character U+0003 (ETX) must be escaped to \\u0003"; + return false; + } + + case 0x04: + { + error_message = "invalid string: control character U+0004 (EOT) must be escaped to \\u0004"; + return false; + } + + case 0x05: + { + error_message = "invalid string: control character U+0005 (ENQ) must be escaped to \\u0005"; + return false; + } + + case 0x06: + { + error_message = "invalid string: control character U+0006 (ACK) must be escaped to \\u0006"; + return false; + } + + case 0x07: + { + error_message = "invalid string: control character U+0007 (BEL) must be escaped to \\u0007"; + return false; + } + + case 0x08: + { + error_message = "invalid string: control character U+0008 (BS) must be escaped to \\u0008 or \\b"; + return false; + } + + case 0x09: + { + error_message = "invalid string: control character U+0009 (HT) must be escaped to \\u0009 or \\t"; + return false; + } + + case 0x0A: + { + error_message = "invalid string: control character U+000A (LF) must be escaped to \\u000A or \\n"; + return false; + } + + case 0x0B: + { + error_message = "invalid string: control character U+000B (VT) must be escaped to \\u000B"; + return false; + } + + case 0x0C: + { + error_message = "invalid string: control character U+000C (FF) must be escaped to \\u000C or \\f"; + return false; + } + + case 0x0D: + { + error_message = "invalid string: control character U+000D (CR) must be escaped to \\u000D or \\r"; + return false; + } + + case 0x0E: + { + error_message = "invalid string: control character U+000E (SO) must be escaped to \\u000E"; + return false; + } + + case 0x0F: + { + error_message = "invalid string: control character U+000F (SI) must be escaped to \\u000F"; + return false; + } + + case 0x10: + { + error_message = "invalid string: control character U+0010 (DLE) must be escaped to \\u0010"; + return false; + } + + case 0x11: + { + error_message = "invalid string: control character U+0011 (DC1) must be escaped to \\u0011"; + return false; + } + + case 0x12: + { + error_message = "invalid string: control character U+0012 (DC2) must be escaped to \\u0012"; + return false; + } + + case 0x13: + { + error_message = "invalid string: control character U+0013 (DC3) must be escaped to \\u0013"; + return false; + } + + case 0x14: + { + error_message = "invalid string: control character U+0014 (DC4) must be escaped to \\u0014"; + return false; + } + + case 0x15: + { + error_message = "invalid string: control character U+0015 (NAK) must be escaped to \\u0015"; + return false; + } + + case 0x16: + { + error_message = "invalid string: control character U+0016 (SYN) must be escaped to \\u0016"; + return false; + } + + case 0x17: + { + error_message = "invalid string: control character U+0017 (ETB) must be escaped to \\u0017"; + return false; + } + + case 0x18: + { + error_message = "invalid string: control character U+0018 (CAN) must be escaped to \\u0018"; + return false; + } + + case 0x19: + { + error_message = "invalid string: control character U+0019 (EM) must be escaped to \\u0019"; + return false; + } + + case 0x1A: + { + error_message = "invalid string: control character U+001A (SUB) must be escaped to \\u001A"; + return false; + } + + case 0x1B: + { + error_message = "invalid string: control character U+001B (ESC) must be escaped to \\u001B"; + return false; + } + + case 0x1C: + { + error_message = "invalid string: control character U+001C (FS) must be escaped to \\u001C"; + return false; + } + + case 0x1D: + { + error_message = "invalid string: control character U+001D (GS) must be escaped to \\u001D"; + return false; + } + + case 0x1E: + { + error_message = "invalid string: control character U+001E (RS) must be escaped to \\u001E"; + return false; + } + + case 0x1F: + { + error_message = "invalid string: control character U+001F (US) must be escaped to \\u001F"; + return false; + } + + // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace)) + case 0x20: + case 0x21: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3A: + case 0x3B: + case 0x3C: + case 0x3D: + case 0x3E: + case 0x3F: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5A: + case 0x5B: + case 0x5D: + case 0x5E: + case 0x5F: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7A: + case 0x7B: + case 0x7C: + case 0x7D: + case 0x7E: + case 0x7F: + { + add(current); + break; + } + + // U+0080..U+07FF: bytes C2..DF 80..BF + case 0xC2: + case 0xC3: + case 0xC4: + case 0xC5: + case 0xC6: + case 0xC7: + case 0xC8: + case 0xC9: + case 0xCA: + case 0xCB: + case 0xCC: + case 0xCD: + case 0xCE: + case 0xCF: + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + case 0xD8: + case 0xD9: + case 0xDA: + case 0xDB: + case 0xDC: + case 0xDD: + case 0xDE: + case 0xDF: + { + if (JSON_HEDLEY_UNLIKELY(!next_byte_in_range({0x80, 0xBF}))) + { + return false; + } + break; + } + + // U+0800..U+0FFF: bytes E0 A0..BF 80..BF + case 0xE0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF + // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xEE: + case 0xEF: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // U+D000..U+D7FF: bytes ED 80..9F 80..BF + case 0xED: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x9F, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF + case 0xF0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF + case 0xF1: + case 0xF2: + case 0xF3: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF + case 0xF4: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF})))) + { + return false; + } + break; + } + + // remaining bytes (80..C1 and F5..FF) are ill-formed + default: + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return false; + } + } + } + + error_message = "invalid string: missing closing quote"; + return false; + } +#endif +// end json.hpp +// clang-format on + +} // namespace detail + +namespace detail { + +double from_chars(const char *p) { +#if defined(MINIJSON_USE_STRTOD) + return strtod(p, nullptr); +#else + return simdjson::internal::from_chars(p); +#endif +} + +const char *my_strchr(const char *p, int ch) { + char c; + + constexpr uint64_t kMaxCount = 1024ull * 1024ull; // up to 1M chars + + uint64_t cnt{0}; + + c = ch; + for (;; ++p, cnt++) { + if (cnt > kMaxCount) { + return nullptr; + } + + if (*p == c) { + return (p); + } + if (*p == '\0') { + return (nullptr); + } + } +} + +} // namespace detail +} // namespace minijson + +#if !defined(MINIJSON_USE_STRTOD) + +#include +#include + +namespace minijson { +namespace simdjson { +namespace internal { + +/** + * The code in the internal::from_chars function is meant to handle the + *floating-point number parsing when we have more than 19 digits in the decimal + *mantissa. This should only be seen in adversarial scenarios: we do not expect + *production systems to even produce such floating-point numbers. + * + * The parser is based on work by Nigel Tao (at + *https://github.com/google/wuffs/) who credits Ken Thompson for the design (via + *a reference to the Go source code). See + * https://github.com/google/wuffs/blob/aa46859ea40c72516deffa1b146121952d6dfd3b/internal/cgen/base/floatconv-submodule-data.c + * https://github.com/google/wuffs/blob/46cd8105f47ca07ae2ba8e6a7818ef9c0df6c152/internal/cgen/base/floatconv-submodule-code.c + * It is probably not very fast but it is a fallback that should almost never be + * called in real life. Google Wuffs is published under APL 2.0. + **/ + +namespace { +constexpr uint32_t max_digits = 768; +constexpr int32_t decimal_point_range = 2047; +} // namespace + +struct adjusted_mantissa { + uint64_t mantissa; + int power2; + adjusted_mantissa() : mantissa(0), power2(0) {} +}; + +struct decimal { + uint32_t num_digits; + int32_t decimal_point; + bool negative; + bool truncated; + uint8_t digits[max_digits]; +}; + +template +struct binary_format { + static constexpr int mantissa_explicit_bits(); + static constexpr int minimum_exponent(); + static constexpr int infinite_power(); + static constexpr int sign_index(); +}; + +template <> +constexpr int binary_format::mantissa_explicit_bits() { + return 52; +} + +template <> +constexpr int binary_format::minimum_exponent() { + return -1023; +} +template <> +constexpr int binary_format::infinite_power() { + return 0x7FF; +} + +template <> +constexpr int binary_format::sign_index() { + return 63; +} + +inline bool is_integer(char c) noexcept { return (c >= '0' && c <= '9'); } + +// This should always succeed since it follows a call to parse_number. +static decimal parse_decimal(const char *&p) noexcept { + decimal answer; + answer.num_digits = 0; + answer.decimal_point = 0; + answer.truncated = false; + answer.negative = (*p == '-'); + if ((*p == '-') || (*p == '+')) { + ++p; + } + + while (*p == '0') { + ++p; + } + while (is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + if (*p == '.') { + ++p; + const char *first_after_period = p; + // if we have not yet encountered a zero, we have to skip it as well + if (answer.num_digits == 0) { + // skip zeros + while (*p == '0') { + ++p; + } + } + while (is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + answer.decimal_point = int32_t(first_after_period - p); + } + if (answer.num_digits > 0) { + const char *preverse = p - 1; + int32_t trailing_zeros = 0; + while ((*preverse == '0') || (*preverse == '.')) { + if (*preverse == '0') { + trailing_zeros++; + } + --preverse; + } + answer.decimal_point += int32_t(answer.num_digits); + answer.num_digits -= uint32_t(trailing_zeros); + } + if (answer.num_digits > max_digits) { + answer.num_digits = max_digits; + answer.truncated = true; + } + if (('e' == *p) || ('E' == *p)) { + ++p; + bool neg_exp = false; + if ('-' == *p) { + neg_exp = true; + ++p; + } else if ('+' == *p) { + ++p; + } + int32_t exp_number = 0; // exponential part + while (is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + answer.decimal_point += (neg_exp ? -exp_number : exp_number); + } + return answer; +} + +// This should always succeed since it follows a call to parse_number. +// Will not read at or beyond the "end" pointer. +static decimal parse_decimal(const char *&p, const char *end) noexcept { + decimal answer; + answer.num_digits = 0; + answer.decimal_point = 0; + answer.truncated = false; + if (p == end) { + return answer; + } // should never happen + answer.negative = (*p == '-'); + if ((*p == '-') || (*p == '+')) { + ++p; + } + + while ((p != end) && (*p == '0')) { + ++p; + } + while ((p != end) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + if ((p != end) && (*p == '.')) { + ++p; + if (p == end) { + return answer; + } // should never happen + const char *first_after_period = p; + // if we have not yet encountered a zero, we have to skip it as well + if (answer.num_digits == 0) { + // skip zeros + while (*p == '0') { + ++p; + } + } + while ((p != end) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + answer.decimal_point = int32_t(first_after_period - p); + } + if (answer.num_digits > 0) { + const char *preverse = p - 1; + int32_t trailing_zeros = 0; + while ((*preverse == '0') || (*preverse == '.')) { + if (*preverse == '0') { + trailing_zeros++; + } + --preverse; + } + answer.decimal_point += int32_t(answer.num_digits); + answer.num_digits -= uint32_t(trailing_zeros); + } + if (answer.num_digits > max_digits) { + answer.num_digits = max_digits; + answer.truncated = true; + } + if ((p != end) && (('e' == *p) || ('E' == *p))) { + ++p; + if (p == end) { + return answer; + } // should never happen + bool neg_exp = false; + if ('-' == *p) { + neg_exp = true; + ++p; + } else if ('+' == *p) { + ++p; + } + int32_t exp_number = 0; // exponential part + while ((p != end) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + answer.decimal_point += (neg_exp ? -exp_number : exp_number); + } + return answer; +} + +namespace { + +// remove all final zeroes +inline void trim(decimal &h) { + while ((h.num_digits > 0) && (h.digits[h.num_digits - 1] == 0)) { + h.num_digits--; + } +} + +uint32_t number_of_digits_decimal_left_shift(decimal &h, uint32_t shift) { + shift &= 63; + const static uint16_t number_of_digits_decimal_left_shift_table[65] = { + 0x0000, 0x0800, 0x0801, 0x0803, 0x1006, 0x1009, 0x100D, 0x1812, 0x1817, + 0x181D, 0x2024, 0x202B, 0x2033, 0x203C, 0x2846, 0x2850, 0x285B, 0x3067, + 0x3073, 0x3080, 0x388E, 0x389C, 0x38AB, 0x38BB, 0x40CC, 0x40DD, 0x40EF, + 0x4902, 0x4915, 0x4929, 0x513E, 0x5153, 0x5169, 0x5180, 0x5998, 0x59B0, + 0x59C9, 0x61E3, 0x61FD, 0x6218, 0x6A34, 0x6A50, 0x6A6D, 0x6A8B, 0x72AA, + 0x72C9, 0x72E9, 0x7B0A, 0x7B2B, 0x7B4D, 0x8370, 0x8393, 0x83B7, 0x83DC, + 0x8C02, 0x8C28, 0x8C4F, 0x9477, 0x949F, 0x94C8, 0x9CF2, 0x051C, 0x051C, + 0x051C, 0x051C, + }; + uint32_t x_a = number_of_digits_decimal_left_shift_table[shift]; + uint32_t x_b = number_of_digits_decimal_left_shift_table[shift + 1]; + uint32_t num_new_digits = x_a >> 11; + uint32_t pow5_a = 0x7FF & x_a; + uint32_t pow5_b = 0x7FF & x_b; + const static uint8_t + number_of_digits_decimal_left_shift_table_powers_of_5[0x051C] = { + 5, 2, 5, 1, 2, 5, 6, 2, 5, 3, 1, 2, 5, 1, 5, 6, 2, 5, 7, 8, 1, 2, 5, + 3, 9, 0, 6, 2, 5, 1, 9, 5, 3, 1, 2, 5, 9, 7, 6, 5, 6, 2, 5, 4, 8, 8, + 2, 8, 1, 2, 5, 2, 4, 4, 1, 4, 0, 6, 2, 5, 1, 2, 2, 0, 7, 0, 3, 1, 2, + 5, 6, 1, 0, 3, 5, 1, 5, 6, 2, 5, 3, 0, 5, 1, 7, 5, 7, 8, 1, 2, 5, 1, + 5, 2, 5, 8, 7, 8, 9, 0, 6, 2, 5, 7, 6, 2, 9, 3, 9, 4, 5, 3, 1, 2, 5, + 3, 8, 1, 4, 6, 9, 7, 2, 6, 5, 6, 2, 5, 1, 9, 0, 7, 3, 4, 8, 6, 3, 2, + 8, 1, 2, 5, 9, 5, 3, 6, 7, 4, 3, 1, 6, 4, 0, 6, 2, 5, 4, 7, 6, 8, 3, + 7, 1, 5, 8, 2, 0, 3, 1, 2, 5, 2, 3, 8, 4, 1, 8, 5, 7, 9, 1, 0, 1, 5, + 6, 2, 5, 1, 1, 9, 2, 0, 9, 2, 8, 9, 5, 5, 0, 7, 8, 1, 2, 5, 5, 9, 6, + 0, 4, 6, 4, 4, 7, 7, 5, 3, 9, 0, 6, 2, 5, 2, 9, 8, 0, 2, 3, 2, 2, 3, + 8, 7, 6, 9, 5, 3, 1, 2, 5, 1, 4, 9, 0, 1, 1, 6, 1, 1, 9, 3, 8, 4, 7, + 6, 5, 6, 2, 5, 7, 4, 5, 0, 5, 8, 0, 5, 9, 6, 9, 2, 3, 8, 2, 8, 1, 2, + 5, 3, 7, 2, 5, 2, 9, 0, 2, 9, 8, 4, 6, 1, 9, 1, 4, 0, 6, 2, 5, 1, 8, + 6, 2, 6, 4, 5, 1, 4, 9, 2, 3, 0, 9, 5, 7, 0, 3, 1, 2, 5, 9, 3, 1, 3, + 2, 2, 5, 7, 4, 6, 1, 5, 4, 7, 8, 5, 1, 5, 6, 2, 5, 4, 6, 5, 6, 6, 1, + 2, 8, 7, 3, 0, 7, 7, 3, 9, 2, 5, 7, 8, 1, 2, 5, 2, 3, 2, 8, 3, 0, 6, + 4, 3, 6, 5, 3, 8, 6, 9, 6, 2, 8, 9, 0, 6, 2, 5, 1, 1, 6, 4, 1, 5, 3, + 2, 1, 8, 2, 6, 9, 3, 4, 8, 1, 4, 4, 5, 3, 1, 2, 5, 5, 8, 2, 0, 7, 6, + 6, 0, 9, 1, 3, 4, 6, 7, 4, 0, 7, 2, 2, 6, 5, 6, 2, 5, 2, 9, 1, 0, 3, + 8, 3, 0, 4, 5, 6, 7, 3, 3, 7, 0, 3, 6, 1, 3, 2, 8, 1, 2, 5, 1, 4, 5, + 5, 1, 9, 1, 5, 2, 2, 8, 3, 6, 6, 8, 5, 1, 8, 0, 6, 6, 4, 0, 6, 2, 5, + 7, 2, 7, 5, 9, 5, 7, 6, 1, 4, 1, 8, 3, 4, 2, 5, 9, 0, 3, 3, 2, 0, 3, + 1, 2, 5, 3, 6, 3, 7, 9, 7, 8, 8, 0, 7, 0, 9, 1, 7, 1, 2, 9, 5, 1, 6, + 6, 0, 1, 5, 6, 2, 5, 1, 8, 1, 8, 9, 8, 9, 4, 0, 3, 5, 4, 5, 8, 5, 6, + 4, 7, 5, 8, 3, 0, 0, 7, 8, 1, 2, 5, 9, 0, 9, 4, 9, 4, 7, 0, 1, 7, 7, + 2, 9, 2, 8, 2, 3, 7, 9, 1, 5, 0, 3, 9, 0, 6, 2, 5, 4, 5, 4, 7, 4, 7, + 3, 5, 0, 8, 8, 6, 4, 6, 4, 1, 1, 8, 9, 5, 7, 5, 1, 9, 5, 3, 1, 2, 5, + 2, 2, 7, 3, 7, 3, 6, 7, 5, 4, 4, 3, 2, 3, 2, 0, 5, 9, 4, 7, 8, 7, 5, + 9, 7, 6, 5, 6, 2, 5, 1, 1, 3, 6, 8, 6, 8, 3, 7, 7, 2, 1, 6, 1, 6, 0, + 2, 9, 7, 3, 9, 3, 7, 9, 8, 8, 2, 8, 1, 2, 5, 5, 6, 8, 4, 3, 4, 1, 8, + 8, 6, 0, 8, 0, 8, 0, 1, 4, 8, 6, 9, 6, 8, 9, 9, 4, 1, 4, 0, 6, 2, 5, + 2, 8, 4, 2, 1, 7, 0, 9, 4, 3, 0, 4, 0, 4, 0, 0, 7, 4, 3, 4, 8, 4, 4, + 9, 7, 0, 7, 0, 3, 1, 2, 5, 1, 4, 2, 1, 0, 8, 5, 4, 7, 1, 5, 2, 0, 2, + 0, 0, 3, 7, 1, 7, 4, 2, 2, 4, 8, 5, 3, 5, 1, 5, 6, 2, 5, 7, 1, 0, 5, + 4, 2, 7, 3, 5, 7, 6, 0, 1, 0, 0, 1, 8, 5, 8, 7, 1, 1, 2, 4, 2, 6, 7, + 5, 7, 8, 1, 2, 5, 3, 5, 5, 2, 7, 1, 3, 6, 7, 8, 8, 0, 0, 5, 0, 0, 9, + 2, 9, 3, 5, 5, 6, 2, 1, 3, 3, 7, 8, 9, 0, 6, 2, 5, 1, 7, 7, 6, 3, 5, + 6, 8, 3, 9, 4, 0, 0, 2, 5, 0, 4, 6, 4, 6, 7, 7, 8, 1, 0, 6, 6, 8, 9, + 4, 5, 3, 1, 2, 5, 8, 8, 8, 1, 7, 8, 4, 1, 9, 7, 0, 0, 1, 2, 5, 2, 3, + 2, 3, 3, 8, 9, 0, 5, 3, 3, 4, 4, 7, 2, 6, 5, 6, 2, 5, 4, 4, 4, 0, 8, + 9, 2, 0, 9, 8, 5, 0, 0, 6, 2, 6, 1, 6, 1, 6, 9, 4, 5, 2, 6, 6, 7, 2, + 3, 6, 3, 2, 8, 1, 2, 5, 2, 2, 2, 0, 4, 4, 6, 0, 4, 9, 2, 5, 0, 3, 1, + 3, 0, 8, 0, 8, 4, 7, 2, 6, 3, 3, 3, 6, 1, 8, 1, 6, 4, 0, 6, 2, 5, 1, + 1, 1, 0, 2, 2, 3, 0, 2, 4, 6, 2, 5, 1, 5, 6, 5, 4, 0, 4, 2, 3, 6, 3, + 1, 6, 6, 8, 0, 9, 0, 8, 2, 0, 3, 1, 2, 5, 5, 5, 5, 1, 1, 1, 5, 1, 2, + 3, 1, 2, 5, 7, 8, 2, 7, 0, 2, 1, 1, 8, 1, 5, 8, 3, 4, 0, 4, 5, 4, 1, + 0, 1, 5, 6, 2, 5, 2, 7, 7, 5, 5, 5, 7, 5, 6, 1, 5, 6, 2, 8, 9, 1, 3, + 5, 1, 0, 5, 9, 0, 7, 9, 1, 7, 0, 2, 2, 7, 0, 5, 0, 7, 8, 1, 2, 5, 1, + 3, 8, 7, 7, 7, 8, 7, 8, 0, 7, 8, 1, 4, 4, 5, 6, 7, 5, 5, 2, 9, 5, 3, + 9, 5, 8, 5, 1, 1, 3, 5, 2, 5, 3, 9, 0, 6, 2, 5, 6, 9, 3, 8, 8, 9, 3, + 9, 0, 3, 9, 0, 7, 2, 2, 8, 3, 7, 7, 6, 4, 7, 6, 9, 7, 9, 2, 5, 5, 6, + 7, 6, 2, 6, 9, 5, 3, 1, 2, 5, 3, 4, 6, 9, 4, 4, 6, 9, 5, 1, 9, 5, 3, + 6, 1, 4, 1, 8, 8, 8, 2, 3, 8, 4, 8, 9, 6, 2, 7, 8, 3, 8, 1, 3, 4, 7, + 6, 5, 6, 2, 5, 1, 7, 3, 4, 7, 2, 3, 4, 7, 5, 9, 7, 6, 8, 0, 7, 0, 9, + 4, 4, 1, 1, 9, 2, 4, 4, 8, 1, 3, 9, 1, 9, 0, 6, 7, 3, 8, 2, 8, 1, 2, + 5, 8, 6, 7, 3, 6, 1, 7, 3, 7, 9, 8, 8, 4, 0, 3, 5, 4, 7, 2, 0, 5, 9, + 6, 2, 2, 4, 0, 6, 9, 5, 9, 5, 3, 3, 6, 9, 1, 4, 0, 6, 2, 5, + }; + const uint8_t *pow5 = + &number_of_digits_decimal_left_shift_table_powers_of_5[pow5_a]; + uint32_t i = 0; + uint32_t n = pow5_b - pow5_a; + for (; i < n; i++) { + if (i >= h.num_digits) { + return num_new_digits - 1; + } else if (h.digits[i] == pow5[i]) { + continue; + } else if (h.digits[i] < pow5[i]) { + return num_new_digits - 1; + } else { + return num_new_digits; + } + } + return num_new_digits; +} + +} // end of anonymous namespace + +static uint64_t round(decimal &h) { + if ((h.num_digits == 0) || (h.decimal_point < 0)) { + return 0; + } else if (h.decimal_point > 18) { + return UINT64_MAX; + } + // at this point, we know that h.decimal_point >= 0 + uint32_t dp = uint32_t(h.decimal_point); + uint64_t n = 0; + for (uint32_t i = 0; i < dp; i++) { + n = (10 * n) + ((i < h.num_digits) ? h.digits[i] : 0); + } + bool round_up = false; + if (dp < h.num_digits) { + round_up = h.digits[dp] >= 5; // normally, we round up + // but we may need to round to even! + if ((h.digits[dp] == 5) && (dp + 1 == h.num_digits)) { + round_up = h.truncated || ((dp > 0) && (1 & h.digits[dp - 1])); + } + } + if (round_up) { + n++; + } + return n; +} + +// computes h * 2^-shift +static void decimal_left_shift(decimal &h, uint32_t shift) { + if (h.num_digits == 0) { + return; + } + uint32_t num_new_digits = number_of_digits_decimal_left_shift(h, shift); + int32_t read_index = int32_t(h.num_digits - 1); + uint32_t write_index = h.num_digits - 1 + num_new_digits; + uint64_t n = 0; + + while (read_index >= 0) { + n += uint64_t(h.digits[read_index]) << shift; + uint64_t quotient = n / 10; + uint64_t remainder = n - (10 * quotient); + if (write_index < max_digits) { + h.digits[write_index] = uint8_t(remainder); + } else if (remainder > 0) { + h.truncated = true; + } + n = quotient; + write_index--; + read_index--; + } + while (n > 0) { + uint64_t quotient = n / 10; + uint64_t remainder = n - (10 * quotient); + if (write_index < max_digits) { + h.digits[write_index] = uint8_t(remainder); + } else if (remainder > 0) { + h.truncated = true; + } + n = quotient; + write_index--; + } + h.num_digits += num_new_digits; + if (h.num_digits > max_digits) { + h.num_digits = max_digits; + } + h.decimal_point += int32_t(num_new_digits); + trim(h); +} + +// computes h * 2^shift +static void decimal_right_shift(decimal &h, uint32_t shift) { + uint32_t read_index = 0; + uint32_t write_index = 0; + + uint64_t n = 0; + + while ((n >> shift) == 0) { + if (read_index < h.num_digits) { + n = (10 * n) + h.digits[read_index++]; + } else if (n == 0) { + return; + } else { + while ((n >> shift) == 0) { + n = 10 * n; + read_index++; + } + break; + } + } + h.decimal_point -= int32_t(read_index - 1); + if (h.decimal_point < -decimal_point_range) { // it is zero + h.num_digits = 0; + h.decimal_point = 0; + h.negative = false; + h.truncated = false; + return; + } + uint64_t mask = (uint64_t(1) << shift) - 1; + while (read_index < h.num_digits) { + uint8_t new_digit = uint8_t(n >> shift); + n = (10 * (n & mask)) + h.digits[read_index++]; + h.digits[write_index++] = new_digit; + } + while (n > 0) { + uint8_t new_digit = uint8_t(n >> shift); + n = 10 * (n & mask); + if (write_index < max_digits) { + h.digits[write_index++] = new_digit; + } else if (new_digit > 0) { + h.truncated = true; + } + } + h.num_digits = write_index; + trim(h); +} + +template +adjusted_mantissa compute_float(decimal &d) { + adjusted_mantissa answer; + if (d.num_digits == 0) { + // should be zero + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } + // At this point, going further, we can assume that d.num_digits > 0. + // We want to guard against excessive decimal point values because + // they can result in long running times. Indeed, we do + // shifts by at most 60 bits. We have that log(10**400)/log(2**60) ~= 22 + // which is fine, but log(10**299995)/log(2**60) ~= 16609 which is not + // fine (runs for a long time). + // + if (d.decimal_point < -324) { + // We have something smaller than 1e-324 which is always zero + // in binary64 and binary32. + // It should be zero. + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } else if (d.decimal_point >= 310) { + // We have something at least as large as 0.1e310 which is + // always infinite. + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + + static const uint32_t max_shift = 60; + static const uint32_t num_powers = 19; + static const uint8_t powers[19] = { + 0, 3, 6, 9, 13, 16, 19, 23, 26, 29, // + 33, 36, 39, 43, 46, 49, 53, 56, 59, // + }; + int32_t exp2 = 0; + while (d.decimal_point > 0) { + uint32_t n = uint32_t(d.decimal_point); + uint32_t shift = (n < num_powers) ? powers[n] : max_shift; + decimal_right_shift(d, shift); + if (d.decimal_point < -decimal_point_range) { + // should be zero + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } + exp2 += int32_t(shift); + } + // We shift left toward [1/2 ... 1]. + while (d.decimal_point <= 0) { + uint32_t shift; + if (d.decimal_point == 0) { + if (d.digits[0] >= 5) { + break; + } + shift = (d.digits[0] < 2) ? 2 : 1; + } else { + uint32_t n = uint32_t(-d.decimal_point); + shift = (n < num_powers) ? powers[n] : max_shift; + } + decimal_left_shift(d, shift); + if (d.decimal_point > decimal_point_range) { + // we want to get infinity: + answer.power2 = 0xFF; + answer.mantissa = 0; + return answer; + } + exp2 -= int32_t(shift); + } + // We are now in the range [1/2 ... 1] but the binary format uses [1 ... 2]. + exp2--; + constexpr int32_t minimum_exponent = binary::minimum_exponent(); + while ((minimum_exponent + 1) > exp2) { + uint32_t n = uint32_t((minimum_exponent + 1) - exp2); + if (n > max_shift) { + n = max_shift; + } + decimal_right_shift(d, n); + exp2 += int32_t(n); + } + if ((exp2 - minimum_exponent) >= binary::infinite_power()) { + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + + const int mantissa_size_in_bits = binary::mantissa_explicit_bits() + 1; + decimal_left_shift(d, mantissa_size_in_bits); + + uint64_t mantissa = round(d); + // It is possible that we have an overflow, in which case we need + // to shift back. + if (mantissa >= (uint64_t(1) << mantissa_size_in_bits)) { + decimal_right_shift(d, 1); + exp2 += 1; + mantissa = round(d); + if ((exp2 - minimum_exponent) >= binary::infinite_power()) { + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + } + answer.power2 = exp2 - binary::minimum_exponent(); + if (mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) { + answer.power2--; + } + answer.mantissa = + mantissa & ((uint64_t(1) << binary::mantissa_explicit_bits()) - 1); + return answer; +} + +template +adjusted_mantissa parse_long_mantissa(const char *first) { + decimal d = parse_decimal(first); + return compute_float(d); +} + +template +adjusted_mantissa parse_long_mantissa(const char *first, const char *end) { + decimal d = parse_decimal(first, end); + return compute_float(d); +} + +double from_chars(const char *first) noexcept { + bool negative = first[0] == '-'; + if (negative) { + first++; + } + adjusted_mantissa am = parse_long_mantissa>(first); + uint64_t word = am.mantissa; + word |= uint64_t(am.power2) + << binary_format::mantissa_explicit_bits(); + word = negative ? word | (uint64_t(1) << binary_format::sign_index()) + : word; + double value; + std::memcpy(&value, &word, sizeof(double)); + return value; +} + +double from_chars(const char *first, const char *end) noexcept { + bool negative = first[0] == '-'; + if (negative) { + first++; + } + adjusted_mantissa am = parse_long_mantissa>(first, end); + uint64_t word = am.mantissa; + word |= uint64_t(am.power2) + << binary_format::mantissa_explicit_bits(); + word = negative ? word | (uint64_t(1) << binary_format::sign_index()) + : word; + double value; + std::memcpy(&value, &word, sizeof(double)); + return value; +} + +} // namespace internal +} // namespace simdjson +} // namespace minijson + +namespace minijson { +namespace simdjson { +namespace internal { +/*! +implements the Grisu2 algorithm for binary to decimal floating-point +conversion. +Adapted from JSON for Modern C++ + +This implementation is a slightly modified version of the reference +implementation which may be obtained from +http://florian.loitsch.com/publications (bench.tar.gz). +The code is distributed under the MIT license, Copyright (c) 2009 Florian +Loitsch. For a detailed description of the algorithm see: [1] Loitsch, "Printing +Floating-Point Numbers Quickly and Accurately with Integers", Proceedings of the +ACM SIGPLAN 2010 Conference on Programming Language Design and Implementation, +PLDI 2010 [2] Burger, Dybvig, "Printing Floating-Point Numbers Quickly and +Accurately", Proceedings of the ACM SIGPLAN 1996 Conference on Programming +Language Design and Implementation, PLDI 1996 +*/ +namespace dtoa_impl { + +template +Target reinterpret_bits(const Source source) { + static_assert(sizeof(Target) == sizeof(Source), "size mismatch"); + + Target target; + std::memcpy(&target, &source, sizeof(Source)); + return target; +} + +struct diyfp // f * 2^e +{ + static constexpr int kPrecision = 64; // = q + + std::uint64_t f = 0; + int e = 0; + + constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {} + + /*! + @brief returns x - y + @pre x.e == y.e and x.f >= y.f + */ + static diyfp sub(const diyfp &x, const diyfp &y) noexcept { + return {x.f - y.f, x.e}; + } + + /*! + @brief returns x * y + @note The result is rounded. (Only the upper q bits are returned.) + */ + static diyfp mul(const diyfp &x, const diyfp &y) noexcept { + static_assert(kPrecision == 64, "internal error"); + + // Computes: + // f = round((x.f * y.f) / 2^q) + // e = x.e + y.e + q + + // Emulate the 64-bit * 64-bit multiplication: + // + // p = u * v + // = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi) + // = (u_lo v_lo ) + 2^32 ((u_lo v_hi ) + (u_hi v_lo )) + + // 2^64 (u_hi v_hi ) = (p0 ) + 2^32 ((p1 ) + (p2 )) + // + 2^64 (p3 ) = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + + // 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3 ) = + // (p0_lo ) + 2^32 (p0_hi + p1_lo + p2_lo ) + 2^64 (p1_hi + + // p2_hi + p3) = (p0_lo ) + 2^32 (Q ) + 2^64 (H ) = (p0_lo ) + + // 2^32 (Q_lo + 2^32 Q_hi ) + 2^64 (H ) + // + // (Since Q might be larger than 2^32 - 1) + // + // = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H) + // + // (Q_hi + H does not overflow a 64-bit int) + // + // = p_lo + 2^64 p_hi + + const std::uint64_t u_lo = x.f & 0xFFFFFFFFu; + const std::uint64_t u_hi = x.f >> 32u; + const std::uint64_t v_lo = y.f & 0xFFFFFFFFu; + const std::uint64_t v_hi = y.f >> 32u; + + const std::uint64_t p0 = u_lo * v_lo; + const std::uint64_t p1 = u_lo * v_hi; + const std::uint64_t p2 = u_hi * v_lo; + const std::uint64_t p3 = u_hi * v_hi; + + const std::uint64_t p0_hi = p0 >> 32u; + const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu; + const std::uint64_t p1_hi = p1 >> 32u; + const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu; + const std::uint64_t p2_hi = p2 >> 32u; + + std::uint64_t Q = p0_hi + p1_lo + p2_lo; + + // The full product might now be computed as + // + // p_hi = p3 + p2_hi + p1_hi + (Q >> 32) + // p_lo = p0_lo + (Q << 32) + // + // But in this particular case here, the full p_lo is not required. + // Effectively we only need to add the highest bit in p_lo to p_hi (and + // Q_hi + 1 does not overflow). + + Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up + + const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u); + + return {h, x.e + y.e + 64}; + } + + /*! + @brief normalize x such that the significand is >= 2^(q-1) + @pre x.f != 0 + */ + static diyfp normalize(diyfp x) noexcept { + while ((x.f >> 63u) == 0) { + x.f <<= 1u; + x.e--; + } + + return x; + } + + /*! + @brief normalize x such that the result has the exponent E + @pre e >= x.e and the upper e - x.e bits of x.f must be zero. + */ + static diyfp normalize_to(const diyfp &x, + const int target_exponent) noexcept { + const int delta = x.e - target_exponent; + + return {x.f << delta, target_exponent}; + } +}; + +struct boundaries { + diyfp w; + diyfp minus; + diyfp plus; +}; + +/*! +Compute the (normalized) diyfp representing the input number 'value' and its +boundaries. +@pre value must be finite and positive +*/ +template +boundaries compute_boundaries(FloatType value) { + // Convert the IEEE representation into a diyfp. + // + // If v is denormal: + // value = 0.F * 2^(1 - bias) = ( F) * 2^(1 - bias - (p-1)) + // If v is normalized: + // value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1)) + + static_assert(std::numeric_limits::is_iec559, + "internal error: dtoa_short requires an IEEE-754 " + "floating-point implementation"); + + constexpr int kPrecision = + std::numeric_limits::digits; // = p (includes the hidden bit) + constexpr int kBias = + std::numeric_limits::max_exponent - 1 + (kPrecision - 1); + constexpr int kMinExp = 1 - kBias; + constexpr std::uint64_t kHiddenBit = std::uint64_t{1} + << (kPrecision - 1); // = 2^(p-1) + + using bits_type = typename std::conditional::type; + + const std::uint64_t bits = reinterpret_bits(value); + const std::uint64_t E = bits >> (kPrecision - 1); + const std::uint64_t F = bits & (kHiddenBit - 1); + + const bool is_denormal = E == 0; + const diyfp v = is_denormal + ? diyfp(F, kMinExp) + : diyfp(F + kHiddenBit, static_cast(E) - kBias); + + // Compute the boundaries m- and m+ of the floating-point value + // v = f * 2^e. + // + // Determine v- and v+, the floating-point predecessor and successor if v, + // respectively. + // + // v- = v - 2^e if f != 2^(p-1) or e == e_min (A) + // = v - 2^(e-1) if f == 2^(p-1) and e > e_min (B) + // + // v+ = v + 2^e + // + // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_ + // between m- and m+ round to v, regardless of how the input rounding + // algorithm breaks ties. + // + // ---+-------------+-------------+-------------+-------------+--- (A) + // v- m- v m+ v+ + // + // -----------------+------+------+-------------+-------------+--- (B) + // v- m- v m+ v+ + + const bool lower_boundary_is_closer = F == 0 && E > 1; + const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1); + const diyfp m_minus = lower_boundary_is_closer + ? diyfp(4 * v.f - 1, v.e - 2) // (B) + : diyfp(2 * v.f - 1, v.e - 1); // (A) + + // Determine the normalized w+ = m+. + const diyfp w_plus = diyfp::normalize(m_plus); + + // Determine w- = m- such that e_(w-) = e_(w+). + const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e); + + return {diyfp::normalize(v), w_minus, w_plus}; +} + +// Given normalized diyfp w, Grisu needs to find a (normalized) cached +// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies +// within a certain range [alpha, gamma] (Definition 3.2 from [1]) +// +// alpha <= e = e_c + e_w + q <= gamma +// +// or +// +// f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q +// <= f_c * f_w * 2^gamma +// +// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies +// +// 2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma +// +// or +// +// 2^(q - 2 + alpha) <= c * w < 2^(q + gamma) +// +// The choice of (alpha,gamma) determines the size of the table and the form of +// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well +// in practice: +// +// The idea is to cut the number c * w = f * 2^e into two parts, which can be +// processed independently: An integral part p1, and a fractional part p2: +// +// f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e +// = (f div 2^-e) + (f mod 2^-e) * 2^e +// = p1 + p2 * 2^e +// +// The conversion of p1 into decimal form requires a series of divisions and +// modulos by (a power of) 10. These operations are faster for 32-bit than for +// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be +// achieved by choosing +// +// -e >= 32 or e <= -32 := gamma +// +// In order to convert the fractional part +// +// p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ... +// +// into decimal form, the fraction is repeatedly multiplied by 10 and the digits +// d[-i] are extracted in order: +// +// (10 * p2) div 2^-e = d[-1] +// (10 * p2) mod 2^-e = d[-2] / 10^1 + ... +// +// The multiplication by 10 must not overflow. It is sufficient to choose +// +// 10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64. +// +// Since p2 = f mod 2^-e < 2^-e, +// +// -e <= 60 or e >= -60 := alpha + +constexpr int kAlpha = -60; +constexpr int kGamma = -32; + +struct cached_power // c = f * 2^e ~= 10^k +{ + std::uint64_t f; + int e; + int k; +}; + +/*! +For a normalized diyfp w = f * 2^e, this function returns a (normalized) cached +power-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c +satisfies (Definition 3.2 from [1]) + alpha <= e_c + e + q <= gamma. +*/ +inline cached_power get_cached_power_for_binary_exponent(int e) { + // Now + // + // alpha <= e_c + e + q <= gamma (1) + // ==> f_c * 2^alpha <= c * 2^e * 2^q + // + // and since the c's are normalized, 2^(q-1) <= f_c, + // + // ==> 2^(q - 1 + alpha) <= c * 2^(e + q) + // ==> 2^(alpha - e - 1) <= c + // + // If c were an exact power of ten, i.e. c = 10^k, one may determine k as + // + // k = ceil( log_10( 2^(alpha - e - 1) ) ) + // = ceil( (alpha - e - 1) * log_10(2) ) + // + // From the paper: + // "In theory the result of the procedure could be wrong since c is rounded, + // and the computation itself is approximated [...]. In practice, however, + // this simple function is sufficient." + // + // For IEEE double precision floating-point numbers converted into + // normalized diyfp's w = f * 2^e, with q = 64, + // + // e >= -1022 (min IEEE exponent) + // -52 (p - 1) + // -52 (p - 1, possibly normalize denormal IEEE numbers) + // -11 (normalize the diyfp) + // = -1137 + // + // and + // + // e <= +1023 (max IEEE exponent) + // -52 (p - 1) + // -11 (normalize the diyfp) + // = 960 + // + // This binary exponent range [-1137,960] results in a decimal exponent + // range [-307,324]. One does not need to store a cached power for each + // k in this range. For each such k it suffices to find a cached power + // such that the exponent of the product lies in [alpha,gamma]. + // This implies that the difference of the decimal exponents of adjacent + // table entries must be less than or equal to + // + // floor( (gamma - alpha) * log_10(2) ) = 8. + // + // (A smaller distance gamma-alpha would require a larger table.) + + // NB: + // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34. + + constexpr int kCachedPowersMinDecExp = -300; + constexpr int kCachedPowersDecStep = 8; + + static constexpr std::array kCachedPowers = {{ + {0xAB70FE17C79AC6CA, -1060, -300}, {0xFF77B1FCBEBCDC4F, -1034, -292}, + {0xBE5691EF416BD60C, -1007, -284}, {0x8DD01FAD907FFC3C, -980, -276}, + {0xD3515C2831559A83, -954, -268}, {0x9D71AC8FADA6C9B5, -927, -260}, + {0xEA9C227723EE8BCB, -901, -252}, {0xAECC49914078536D, -874, -244}, + {0x823C12795DB6CE57, -847, -236}, {0xC21094364DFB5637, -821, -228}, + {0x9096EA6F3848984F, -794, -220}, {0xD77485CB25823AC7, -768, -212}, + {0xA086CFCD97BF97F4, -741, -204}, {0xEF340A98172AACE5, -715, -196}, + {0xB23867FB2A35B28E, -688, -188}, {0x84C8D4DFD2C63F3B, -661, -180}, + {0xC5DD44271AD3CDBA, -635, -172}, {0x936B9FCEBB25C996, -608, -164}, + {0xDBAC6C247D62A584, -582, -156}, {0xA3AB66580D5FDAF6, -555, -148}, + {0xF3E2F893DEC3F126, -529, -140}, {0xB5B5ADA8AAFF80B8, -502, -132}, + {0x87625F056C7C4A8B, -475, -124}, {0xC9BCFF6034C13053, -449, -116}, + {0x964E858C91BA2655, -422, -108}, {0xDFF9772470297EBD, -396, -100}, + {0xA6DFBD9FB8E5B88F, -369, -92}, {0xF8A95FCF88747D94, -343, -84}, + {0xB94470938FA89BCF, -316, -76}, {0x8A08F0F8BF0F156B, -289, -68}, + {0xCDB02555653131B6, -263, -60}, {0x993FE2C6D07B7FAC, -236, -52}, + {0xE45C10C42A2B3B06, -210, -44}, {0xAA242499697392D3, -183, -36}, + {0xFD87B5F28300CA0E, -157, -28}, {0xBCE5086492111AEB, -130, -20}, + {0x8CBCCC096F5088CC, -103, -12}, {0xD1B71758E219652C, -77, -4}, + {0x9C40000000000000, -50, 4}, {0xE8D4A51000000000, -24, 12}, + {0xAD78EBC5AC620000, 3, 20}, {0x813F3978F8940984, 30, 28}, + {0xC097CE7BC90715B3, 56, 36}, {0x8F7E32CE7BEA5C70, 83, 44}, + {0xD5D238A4ABE98068, 109, 52}, {0x9F4F2726179A2245, 136, 60}, + {0xED63A231D4C4FB27, 162, 68}, {0xB0DE65388CC8ADA8, 189, 76}, + {0x83C7088E1AAB65DB, 216, 84}, {0xC45D1DF942711D9A, 242, 92}, + {0x924D692CA61BE758, 269, 100}, {0xDA01EE641A708DEA, 295, 108}, + {0xA26DA3999AEF774A, 322, 116}, {0xF209787BB47D6B85, 348, 124}, + {0xB454E4A179DD1877, 375, 132}, {0x865B86925B9BC5C2, 402, 140}, + {0xC83553C5C8965D3D, 428, 148}, {0x952AB45CFA97A0B3, 455, 156}, + {0xDE469FBD99A05FE3, 481, 164}, {0xA59BC234DB398C25, 508, 172}, + {0xF6C69A72A3989F5C, 534, 180}, {0xB7DCBF5354E9BECE, 561, 188}, + {0x88FCF317F22241E2, 588, 196}, {0xCC20CE9BD35C78A5, 614, 204}, + {0x98165AF37B2153DF, 641, 212}, {0xE2A0B5DC971F303A, 667, 220}, + {0xA8D9D1535CE3B396, 694, 228}, {0xFB9B7CD9A4A7443C, 720, 236}, + {0xBB764C4CA7A44410, 747, 244}, {0x8BAB8EEFB6409C1A, 774, 252}, + {0xD01FEF10A657842C, 800, 260}, {0x9B10A4E5E9913129, 827, 268}, + {0xE7109BFBA19C0C9D, 853, 276}, {0xAC2820D9623BF429, 880, 284}, + {0x80444B5E7AA7CF85, 907, 292}, {0xBF21E44003ACDD2D, 933, 300}, + {0x8E679C2F5E44FF8F, 960, 308}, {0xD433179D9C8CB841, 986, 316}, + {0x9E19DB92B4E31BA9, 1013, 324}, + }}; + + // This computation gives exactly the same results for k as + // k = ceil((kAlpha - e - 1) * 0.30102999566398114) + // for |e| <= 1500, but doesn't require floating-point operations. + // NB: log_10(2) ~= 78913 / 2^18 + const int f = kAlpha - e - 1; + const int k = (f * 78913) / (1 << 18) + static_cast(f > 0); + + const int index = (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / + kCachedPowersDecStep; + + const cached_power cached = kCachedPowers[static_cast(index)]; + + return cached; +} + +/*! +For n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k. +For n == 0, returns 1 and sets pow10 := 1. +*/ +inline int find_largest_pow10(const std::uint32_t n, std::uint32_t &pow10) { + // LCOV_EXCL_START + if (n >= 1000000000) { + pow10 = 1000000000; + return 10; + } + // LCOV_EXCL_STOP + else if (n >= 100000000) { + pow10 = 100000000; + return 9; + } else if (n >= 10000000) { + pow10 = 10000000; + return 8; + } else if (n >= 1000000) { + pow10 = 1000000; + return 7; + } else if (n >= 100000) { + pow10 = 100000; + return 6; + } else if (n >= 10000) { + pow10 = 10000; + return 5; + } else if (n >= 1000) { + pow10 = 1000; + return 4; + } else if (n >= 100) { + pow10 = 100; + return 3; + } else if (n >= 10) { + pow10 = 10; + return 2; + } else { + pow10 = 1; + return 1; + } +} + +inline void grisu2_round(char *buf, int len, std::uint64_t dist, + std::uint64_t delta, std::uint64_t rest, + std::uint64_t ten_k) { + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // ten_k + // <------> + // <---- rest ----> + // --------------[------------------+----+--------------]-------------- + // w V + // = buf * 10^k + // + // ten_k represents a unit-in-the-last-place in the decimal representation + // stored in buf. + // Decrement buf by ten_k while this takes buf closer to w. + + // The tests are written in this order to avoid overflow in unsigned + // integer arithmetic. + + while (rest < dist && delta - rest >= ten_k && + (rest + ten_k < dist || dist - rest > rest + ten_k - dist)) { + buf[len - 1]--; + rest += ten_k; + } +} + +/*! +Generates V = buffer * 10^decimal_exponent, such that M- <= V <= M+. +M- and M+ must be normalized and share the same exponent -60 <= e <= -32. +*/ +inline void grisu2_digit_gen(char *buffer, int &length, int &decimal_exponent, + diyfp M_minus, diyfp w, diyfp M_plus) { + static_assert(kAlpha >= -60, "internal error"); + static_assert(kGamma <= -32, "internal error"); + + // Generates the digits (and the exponent) of a decimal floating-point + // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The diyfp's + // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= + // gamma. + // + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // Grisu2 generates the digits of M+ from left to right and stops as soon as + // V is in [M-,M+]. + + std::uint64_t delta = + diyfp::sub(M_plus, M_minus) + .f; // (significand of (M+ - M-), implicit exponent is e) + std::uint64_t dist = + diyfp::sub(M_plus, w) + .f; // (significand of (M+ - w ), implicit exponent is e) + + // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0): + // + // M+ = f * 2^e + // = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e + // = ((p1 ) * 2^-e + (p2 )) * 2^e + // = p1 + p2 * 2^e + + const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e); + + auto p1 = static_cast( + M_plus.f >> + -one.e); // p1 = f div 2^-e (Since -e >= 32, p1 fits into a 32-bit int.) + std::uint64_t p2 = M_plus.f & (one.f - 1); // p2 = f mod 2^-e + + // 1) + // + // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0] + + std::uint32_t pow10; + const int k = find_largest_pow10(p1, pow10); + + // 10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1) + // + // p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1)) + // = (d[k-1] ) * 10^(k-1) + (p1 mod 10^(k-1)) + // + // M+ = p1 + p2 * 2^e + // = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1)) + p2 * 2^e + // = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e + // = d[k-1] * 10^(k-1) + ( rest) * 2^e + // + // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0) + // + // p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0] + // + // but stop as soon as + // + // rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e + + int n = k; + while (n > 0) { + // Invariants: + // M+ = buffer * 10^n + (p1 + p2 * 2^e) (buffer = 0 for n = k) + // pow10 = 10^(n-1) <= p1 < 10^n + // + const std::uint32_t d = p1 / pow10; // d = p1 div 10^(n-1) + const std::uint32_t r = p1 % pow10; // r = p1 mod 10^(n-1) + // + // M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e + // = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e) + // + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(n-1) + (r + p2 * 2^e) + // + p1 = r; + n--; + // + // M+ = buffer * 10^n + (p1 + p2 * 2^e) + // pow10 = 10^n + // + + // Now check if enough digits have been generated. + // Compute + // + // p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e + // + // Note: + // Since rest and delta share the same exponent e, it suffices to + // compare the significands. + const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2; + if (rest <= delta) { + // V = buffer * 10^n, with M- <= V <= M+. + + decimal_exponent += n; + + // We may now just stop. But instead look if the buffer could be + // decremented to bring V closer to w. + // + // pow10 = 10^n is now 1 ulp in the decimal representation V. + // The rounding procedure works with diyfp's with an implicit + // exponent of e. + // + // 10^n = (10^n * 2^-e) * 2^e = ulp * 2^e + // + const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e; + grisu2_round(buffer, length, dist, delta, rest, ten_n); + + return; + } + + pow10 /= 10; + // + // pow10 = 10^(n-1) <= p1 < 10^n + // Invariants restored. + } + + // 2) + // + // The digits of the integral part have been generated: + // + // M+ = d[k-1]...d[1]d[0] + p2 * 2^e + // = buffer + p2 * 2^e + // + // Now generate the digits of the fractional part p2 * 2^e. + // + // Note: + // No decimal point is generated: the exponent is adjusted instead. + // + // p2 actually represents the fraction + // + // p2 * 2^e + // = p2 / 2^-e + // = d[-1] / 10^1 + d[-2] / 10^2 + ... + // + // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...) + // + // p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m + // + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...) + // + // using + // + // 10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e) + // = ( d) * 2^-e + ( r) + // + // or + // 10^m * p2 * 2^e = d + r * 2^e + // + // i.e. + // + // M+ = buffer + p2 * 2^e + // = buffer + 10^-m * (d + r * 2^e) + // = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e + // + // and stop as soon as 10^-m * r * 2^e <= delta * 2^e + + int m = 0; + for (;;) { + // Invariant: + // M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + ...) + // * 2^e + // = buffer * 10^-m + 10^-m * (p2 ) + // * 2^e = buffer * 10^-m + 10^-m * (1/10 * (10 * p2) ) * 2^e = + // buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + + // (10*p2 mod 2^-e)) * 2^e + // + p2 *= 10; + const std::uint64_t d = p2 >> -one.e; // d = (10 * p2) div 2^-e + const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e + // + // M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e)) + // = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + p2 = r; + m++; + // + // M+ = buffer * 10^-m + 10^-m * p2 * 2^e + // Invariant restored. + + // Check if enough digits have been generated. + // + // 10^-m * p2 * 2^e <= delta * 2^e + // p2 * 2^e <= 10^m * delta * 2^e + // p2 <= 10^m * delta + delta *= 10; + dist *= 10; + if (p2 <= delta) { + break; + } + } + + // V = buffer * 10^-m, with M- <= V <= M+. + + decimal_exponent -= m; + + // 1 ulp in the decimal representation is now 10^-m. + // Since delta and dist are now scaled by 10^m, we need to do the + // same with ulp in order to keep the units in sync. + // + // 10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e + // + const std::uint64_t ten_m = one.f; + grisu2_round(buffer, length, dist, delta, p2, ten_m); + + // By construction this algorithm generates the shortest possible decimal + // number (Loitsch, Theorem 6.2) which rounds back to w. + // For an input number of precision p, at least + // + // N = 1 + ceil(p * log_10(2)) + // + // decimal digits are sufficient to identify all binary floating-point + // numbers (Matula, "In-and-Out conversions"). + // This implies that the algorithm does not produce more than N decimal + // digits. + // + // N = 17 for p = 53 (IEEE double precision) + // N = 9 for p = 24 (IEEE single precision) +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +inline void grisu2(char *buf, int &len, int &decimal_exponent, diyfp m_minus, + diyfp v, diyfp m_plus) { + // --------(-----------------------+-----------------------)-------- (A) + // m- v m+ + // + // --------------------(-----------+-----------------------)-------- (B) + // m- v m+ + // + // First scale v (and m- and m+) such that the exponent is in the range + // [alpha, gamma]. + + const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e); + + const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k + + // The exponent of the products is = v.e + c_minus_k.e + q and is in the range + // [alpha,gamma] + const diyfp w = diyfp::mul(v, c_minus_k); + const diyfp w_minus = diyfp::mul(m_minus, c_minus_k); + const diyfp w_plus = diyfp::mul(m_plus, c_minus_k); + + // ----(---+---)---------------(---+---)---------------(---+---)---- + // w- w w+ + // = c*m- = c*v = c*m+ + // + // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and + // w+ are now off by a small amount. + // In fact: + // + // w - v * 10^k < 1 ulp + // + // To account for this inaccuracy, add resp. subtract 1 ulp. + // + // --------+---[---------------(---+---)---------------]---+-------- + // w- M- w M+ w+ + // + // Now any number in [M-, M+] (bounds included) will round to w when input, + // regardless of how the input rounding algorithm breaks ties. + // + // And digit_gen generates the shortest possible such number in [M-, M+]. + // Note that this does not mean that Grisu2 always generates the shortest + // possible number in the interval (m-, m+). + const diyfp M_minus(w_minus.f + 1, w_minus.e); + const diyfp M_plus(w_plus.f - 1, w_plus.e); + + decimal_exponent = -cached.k; // = -(-k) = k + + grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus); +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +template +void grisu2(char *buf, int &len, int &decimal_exponent, FloatType value) { + static_assert(diyfp::kPrecision >= std::numeric_limits::digits + 3, + "internal error: not enough precision"); + + // If the neighbors (and boundaries) of 'value' are always computed for + // double-precision numbers, all float's can be recovered using strtod (and + // strtof). However, the resulting decimal representations are not exactly + // "short". + // + // The documentation for 'std::to_chars' + // (https://en.cppreference.com/w/cpp/utility/to_chars) says "value is + // converted to a string as if by std::sprintf in the default ("C") locale" + // and since sprintf promotes float's to double's, I think this is exactly + // what 'std::to_chars' does. On the other hand, the documentation for + // 'std::to_chars' requires that "parsing the representation using the + // corresponding std::from_chars function recovers value exactly". That + // indicates that single precision floating-point numbers should be recovered + // using 'std::strtof'. + // + // NB: If the neighbors are computed for single-precision numbers, there is a + // single float + // (7.0385307e-26f) which can't be recovered using strtod. The resulting + // double precision value is off by 1 ulp. +#if 0 + const boundaries w = compute_boundaries(static_cast(value)); +#else + const boundaries w = compute_boundaries(value); +#endif + + grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus); +} + +/*! +@brief appends a decimal representation of e to buf +@return a pointer to the element following the exponent. +@pre -1000 < e < 1000 +*/ +inline char *append_exponent(char *buf, int e) { + if (e < 0) { + e = -e; + *buf++ = '-'; + } else { + *buf++ = '+'; + } + + auto k = static_cast(e); + if (k < 10) { + // Always print at least two digits in the exponent. + // This is for compatibility with printf("%g"). + *buf++ = '0'; + *buf++ = static_cast('0' + k); + } else if (k < 100) { + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } else { + *buf++ = static_cast('0' + k / 100); + k %= 100; + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } + + return buf; +} + +/*! +@brief prettify v = buf * 10^decimal_exponent +If v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point +notation. Otherwise it will be printed in exponential notation. +@pre min_exp < 0 +@pre max_exp > 0 +*/ +inline char *format_buffer(char *buf, int len, int decimal_exponent, + int min_exp, int max_exp) { + const int k = len; + const int n = len + decimal_exponent; + + // v = buf * 10^(n-k) + // k is the length of the buffer (number of decimal digits) + // n is the position of the decimal point relative to the start of the buffer. + + if (k <= n && n <= max_exp) { + // digits[000] + // len <= max_exp + 2 + + std::memset(buf + k, '0', static_cast(n) - static_cast(k)); + // Make it look like a floating-point number (#362, #378) + buf[n + 0] = '.'; + buf[n + 1] = '0'; + return buf + (static_cast(n)) + 2; + } + + if (0 < n && n <= max_exp) { + // dig.its + // len <= max_digits10 + 1 + std::memmove(buf + (static_cast(n) + 1), buf + n, + static_cast(k) - static_cast(n)); + buf[n] = '.'; + return buf + (static_cast(k) + 1U); + } + + if (min_exp < n && n <= 0) { + // 0.[000]digits + // len <= 2 + (-min_exp - 1) + max_digits10 + + std::memmove(buf + (2 + static_cast(-n)), buf, + static_cast(k)); + buf[0] = '0'; + buf[1] = '.'; + std::memset(buf + 2, '0', static_cast(-n)); + return buf + (2U + static_cast(-n) + static_cast(k)); + } + + if (k == 1) { + // dE+123 + // len <= 1 + 5 + + buf += 1; + } else { + // d.igitsE+123 + // len <= max_digits10 + 1 + 5 + + std::memmove(buf + 2, buf + 1, static_cast(k) - 1); + buf[1] = '.'; + buf += 1 + static_cast(k); + } + + *buf++ = 'e'; + return append_exponent(buf, n - 1); +} + +} // namespace dtoa_impl + +/*! +The format of the resulting decimal representation is similar to printf's %g +format. Returns an iterator pointing past-the-end of the decimal representation. +@note The input number must be finite, i.e. NaN's and Inf's are not supported. +@note The buffer must be large enough. +@note The result is NOT null-terminated. +*/ +char *to_chars(char *first, const char *last, double value) { + static_cast(last); // maybe unused - fix warning + + // bool negative = std::signbit(value); + bool negative = (*reinterpret_cast(&value)) & (1 << 31ull); + if (negative) { + value = -value; + *first++ = '-'; + } + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wfloat-equal" +#endif + + if (value == 0) // +-0 + { + *first++ = '0'; + // Make it look like a floating-point number (#362, #378) + *first++ = '.'; + *first++ = '0'; + return first; + } + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + // Compute v = buffer * 10^decimal_exponent. + // The decimal digits are stored in the buffer, which needs to be interpreted + // as an unsigned decimal integer. + // len is the length of the buffer, i.e. the number of decimal digits. + int len = 0; + int decimal_exponent = 0; + dtoa_impl::grisu2(first, len, decimal_exponent, value); + // Format the buffer like printf("%.*g", prec, value) + constexpr int kMinExp = -4; + constexpr int kMaxExp = std::numeric_limits::digits10; + + return dtoa_impl::format_buffer(first, len, decimal_exponent, kMinExp, + kMaxExp); +} +} // namespace internal +} // namespace simdjson +} // namespace minijson + +#endif // !MINIJSON_USE_STRTOD + +#endif // MINIJSON_IMPLEMENTATION + + +namespace safetensors { + +// Max header(JSON) size. 100 MB as done in original safetensors implementation. +constexpr size_t kMaxJSONSize = 1024ull * 1024ull * 100ull; + +namespace detail { + +#ifdef _WIN32 +std::wstring UTF8ToWchar(const std::string &str) { + int wstr_size = + MultiByteToWideChar(CP_UTF8, 0, str.data(), int(str.size()), nullptr, 0); + std::wstring wstr(size_t(wstr_size), 0); + MultiByteToWideChar(CP_UTF8, 0, str.data(), int(str.size()), &wstr[0], + int(wstr.size())); + return wstr; +} + +std::string WcharToUTF8(const std::wstring &wstr) { + int str_size = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), int(wstr.size()), + nullptr, 0, nullptr, nullptr); + std::string str(size_t(str_size), 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.data(), int(wstr.size()), &str[0], + int(str.size()), nullptr, nullptr); + return str; +} +#endif + +bool ReadWholeFile(std::vector *out, std::string *err, + const std::string &filepath, void *) { +#ifdef SAFETENSORS_CPP_ANDROID_LOAD_FROM_ASSETS + if (asset_manager) { + AAsset *asset = AAssetManager_open(asset_manager, filepath.c_str(), + AASSET_MODE_STREAMING); + if (!asset) { + if (err) { + (*err) += "File open error : " + filepath + "\n"; + } + return false; + } + size_t size = AAsset_getLength(asset); + if (size == 0) { + if (err) { + (*err) += "Invalid file size : " + filepath + + " (does the path point to a directory?)"; + } + return false; + } + out->resize(size); + AAsset_read(asset, reinterpret_cast(&out->at(0)), size); + AAsset_close(asset); + return true; + } else { + if (err) { + (*err) += "No asset manager specified : " + filepath + "\n"; + } + return false; + } +#else +#ifdef _WIN32 +#if defined(__GLIBCXX__) // mingw + int file_descriptor = + _wopen(UTF8ToWchar(filepath).c_str(), _O_RDONLY | _O_BINARY); + __gnu_cxx::stdio_filebuf wfile_buf(file_descriptor, std::ios_base::in); + std::istream f(&wfile_buf); +#elif defined(_MSC_VER) || defined(_LIBCPP_VERSION) + // For libcxx, assume _LIBCPP_HAS_OPEN_WITH_WCHAR is defined to accept + // `wchar_t *` + std::ifstream f(UTF8ToWchar(filepath).c_str(), std::ifstream::binary); +#else + // Unknown compiler/runtime + std::ifstream f(filepath.c_str(), std::ifstream::binary); +#endif +#else + std::ifstream f(filepath.c_str(), std::ifstream::binary); +#endif + if (!f) { + if (err) { + (*err) += "File open error : " + filepath + "\n"; + } + return false; + } + + // For directory(and pipe?), peek() will fail(Posix gnustl/libc++ only) + f.peek(); + if (!f) { + if (err) { + (*err) += + "File read error. Maybe empty file or invalid file : " + filepath + + "\n"; + } + return false; + } + + f.seekg(0, f.end); + size_t sz = static_cast(f.tellg()); + + // std::cout << "sz = " << sz << "\n"; + f.seekg(0, f.beg); + + if (int64_t(sz) < 0) { + if (err) { + (*err) += "Invalid file size : " + filepath + + " (does the path point to a directory?)"; + } + return false; + } else if (sz == 0) { + if (err) { + (*err) += "File is empty : " + filepath + "\n"; + } + return false; + } else if (sz >= (std::numeric_limits::max)()) { + if (err) { + (*err) += "Invalid file size : " + filepath + "\n"; + } + return false; + } + + out->resize(sz); + f.read(reinterpret_cast(&out->at(0)), + static_cast(sz)); + + return true; +#endif +} + +bool parse_metadata(const ::minijson::value &v, + ordered_dict &dst, std::string *err) { + if (auto po = v.as<::minijson::object>()) { + for (size_t i = 0; i < po->size(); i++) { + ::minijson::value ov; + if (!po->at(i, &ov)) { + if (err) { + (*err) += + "[Internal error] Invalid object found in __metadata__, at index " + std::to_string(i) + ".\n"; + } + return false; + } + + if (auto so = ov.as()) { + if (dst.count(po->keys()[i])) { + // This should not be happen though + if (err) { + (*err) += + "Duplicate key `" + po->keys()[i] + "` found in __metadata__.\n"; + } + return false; + } + + dst.insert(po->keys()[i], *so); + } else { + if (err) { + (*err) += "`" + po->keys()[i] + "` must be string value.\n"; + } + return false; + } + } + } else { + if (err) { + (*err) += "`__metadata__` value must be JSON object.\n"; + } + return false; + } + + return true; +} + +bool parse_dtype(const ::minijson::value &v, safetensors::dtype &dtype, + std::string *err) { + if (auto so = v.as()) { + if ((*so) == "BOOL") { + dtype = safetensors::dtype::kBOOL; + } else if ((*so) == "U8") { + dtype = safetensors::dtype::kUINT8; + } else if ((*so) == "I8") { + dtype = safetensors::dtype::kINT8; + } else if ((*so) == "U16") { + dtype = safetensors::dtype::kUINT16; + } else if ((*so) == "I16") { + dtype = safetensors::dtype::kINT16; + } else if ((*so) == "U32") { + dtype = safetensors::dtype::kUINT32; + } else if ((*so) == "I32") { + dtype = safetensors::dtype::kINT32; + } else if ((*so) == "U64") { + dtype = safetensors::dtype::kUINT64; + } else if ((*so) == "I64") { + dtype = safetensors::dtype::kINT64; + } else if ((*so) == "F16") { + dtype = safetensors::dtype::kFLOAT16; + } else if ((*so) == "BF16") { + dtype = safetensors::dtype::kBFLOAT16; + } else if ((*so) == "F32") { + dtype = safetensors::dtype::kFLOAT32; + } else if ((*so) == "F64") { + dtype = safetensors::dtype::kFLOAT64; + } else { + if (err) { + (*err) += "Unknown `dtype` string: " + *so + ".\n"; + } + return false; + } + } else { + if (err) { + (*err) += + "`dtype` item should be string type but got " + v.type_name() + ".\n"; + } + return false; + } + + return true; +} + +bool parse_shape(const ::minijson::value &v, std::vector &dst, + std::string *err) { + // NOTE: + // - Empty tensors (tensors with 1 dimension being 0) are allowed + // - [] is allowed(0-Rank tensor = merely a scalar) + if (auto pa = v.as<::minijson::array>()) { + ::minijson::array::const_iterator i; + + for (i = pa->begin(); i != pa->end(); i++) { + if (auto pn = i->as<::minijson::number>()) { + if (dst.size() >= kMaxDim) { + if (err) { + (*err) += "`shape` length must be less than " + + std::to_string(kMaxDim) + " but got " + + std::to_string(dst.size()) + ".\n"; + } + return false; + } + + dst.push_back(size_t(*pn)); + + } else { + if (err) { + (*err) += "Array item in `shape` must be number type, but got " + + i->type_name() + ".\n"; + } + return false; + } + } + } else { + if (err) { + (*err) += + "`shape` value must be JSON array, but got " + v.type_name() + ".\n"; + } + return false; + } + + return true; +} + +bool parse_data_offsets(const ::minijson::value &v, std::array &dst, + std::string *err) { + if (auto pa = v.as<::minijson::array>()) { + ::minijson::array::const_iterator i; + size_t cnt = 0; + + for (i = pa->begin(); i != pa->end(); i++) { + if (auto pn = i->as<::minijson::number>()) { + if (cnt >= 2) { + if (err) { + (*err) += "`data_offsets` length must be 2.\n"; + } + return false; + } + + dst[cnt] = size_t(*pn); + + cnt++; + + } else { + if (err) { + (*err) += + "Array item in `data_offsets` must be number type, but got " + + i->type_name() + ".\n"; + } + return false; + } + } + + if (cnt != 2) { + if (err) { + (*err) += "`data_offsets` length must be 2.\n"; + } + return false; + } + } else { + if (err) { + (*err) += "`data_offsets` value must be JSON array, but got " + + v.type_name() + ".\n"; + } + return false; + } + + return true; +} + +bool parse_tensor(const std::string &name, const ::minijson::value &v, + tensor_t &tensor, std::string *err) { + if (auto po = v.as<::minijson::object>()) { + + bool dtype_found{false}; + bool shape_found{false}; + bool data_offsets_found{false}; + + dtype dtype; + std::vector shape; + std::array data_offsets{}; + + for (size_t i = 0; i < po->size(); i++) { + std::string key = po->keys()[i]; + + if (key == "dtype") { + ::minijson::value value; + if (!po->at(i, &value)) { + if (err) { + (*err) += "Internal error. `dtype` has invalid object.\n"; + } + return false; + } + + if (!parse_dtype(value, dtype, err)) { + return false; + } + + dtype_found = true; + } else if (key == "shape") { + ::minijson::value value; + if (!po->at(i, &value)) { + if (err) { + (*err) += "Internal error. `shape` has invalid object.\n"; + } + return false; + } + + if (!parse_shape(value, shape, err)) { + return false; + } + + shape_found = true; + } else if (key == "data_offsets") { + ::minijson::value value; + if (!po->at(i, &value)) { + if (err) { + (*err) += "Internal error. `data_offsets` has invalid object.\n"; + } + return false; + } + if (!parse_data_offsets(value, data_offsets, err)) { + return false; + } + + data_offsets_found = true; + } else { + // Unknown key. Report error? + } + } + + if (!dtype_found) { + if (err) { + (*err) += "`" + name + "` does not have `dtype` item.\n"; + } + return false; + } + + if (!shape_found) { + if (err) { + (*err) += "`" + name + "` does not have `shape` item.\n"; + } + return false; + } + + bool is_empty_tensor{false}; + if ((shape.size() > 0)) { + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + is_empty_tensor = true; + break; + } + } + } + + if (is_empty_tensor) { + // They are not storing any data in the databuffer, yet retaining size in + // the header. So ignore data_offsets + if (data_offsets_found) { + // TODO: make this warn instead of err? + if (err) { + (*err) += + "`" + name + + "` is empty tensors(tensors with 1 dimension being 0), and no " + "data in databuffer, but `data_offsets` item is provided.\n"; + } + // DO NOT RETURN FALSE, JUST CONTINUE + } + } else { + if (!data_offsets_found) { + if (err) { + (*err) += "`" + name + "` does not have `data_offsets` item.\n"; + } + return false; + } + } + + tensor.dtype = dtype; + tensor.shape = shape; + tensor.data_offsets = data_offsets; + + } else { + if (err) { + (*err) += "`" + name + "` value must be JSON object.\n"; + } + return false; + } + + return true; +} + +// From llama.cpp +#if defined(_WIN32) +static std::string safetensors_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, + NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +struct safetensors_file { + // use FILE * so we don't have to re-open the file to mmap + FILE *fp{nullptr}; + size_t size{0}; + mutable bool _valid{false}; + std::string _err; + + safetensors_file(const char *fname, const char *mode) { + fp = std::fopen(fname, mode); + if (fp == nullptr) { + _err = "failed to open " + std::string(fname) + ":" + + std::string(strerror(errno)) + "\n"; + _valid = false; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + _valid = true; + } + } + + ~safetensors_file() { + if (fp) { + std::fclose(fp); + fp = nullptr; + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + // this really shouldn't fail + _valid = false; + return 0; + } + + return (size_t)ret; + } + + void seek(size_t offset, int whence) const { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64)offset, whence); +#else + int ret = std::fseek(fp, (long)offset, whence); +#endif + if (ret == 0) { + _valid = false; + } + } + + bool &is_valid() const { return _valid; } + + const std::string &get_error() const { return _err; } +}; + +struct safetensors_mmap { + uint8_t *addr{nullptr}; + size_t size{0}; + + bool _valid{false}; + std::string _warn; + std::string _err; + + const bool is_valid() const { return _valid; } + + const std::string &get_error() const { return _err; } + + const std::string &get_warning() const { return _warn; } + + safetensors_mmap(const safetensors_mmap &) = delete; + +#ifdef _POSIX_MAPPED_FILES + static constexpr bool SUPPORTED = true; + + safetensors_mmap(struct safetensors_file *file, + size_t prefetch = (size_t)-1 /* -1 = max value */, + bool numa = false) { + size = file->size; + int fd = fileno(file->fp); + int flags = MAP_SHARED; + // prefetch/readahead impairs performance on NUMA systems + if (numa) { + prefetch = 0; + } +#ifdef __linux__ + if (prefetch) { + flags |= MAP_POPULATE; + } +#endif + addr = reinterpret_cast( + mmap(NULL, file->size, PROT_READ, flags, fd, 0)); + if (addr == MAP_FAILED) { + _valid = false; + _err = "mmap failed: " + std::string(strerror(errno)) + "\n"; + + size = 0; + addr = nullptr; + + return; + } + + if (prefetch > 0) { + // Advise the kernel to preload the mapped memory + if (posix_madvise(addr, std::min(file->size, prefetch), + POSIX_MADV_WILLNEED)) { + _warn += "posix_madvise(.., POSIX_MADV_WILLNEED) failed: " + + std::string(strerror(errno)) + "\n"; + } + } + if (numa) { + // advise the kernel not to use readahead + // (because the next page might not belong on the same node) + if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { + _warn += "posix_madvise(.., POSIX_MADV_RANDOM) failed: " + + std::string(strerror(errno)) + "\n"; + } + } + + _valid = true; + } + + ~safetensors_mmap() { + if (_valid) { + munmap(addr, size); + } + size = 0; + addr = nullptr; + _valid = false; + } + +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + safetensors_mmap(struct safetensors_file *file, bool prefetch = true, + bool numa = false) { + (void)numa; + + size = file->size; + + HANDLE hFile = (HANDLE)_get_osfhandle(_fileno(file->fp)); + + HANDLE hMapping = + CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + DWORD error = GetLastError(); + + if (hMapping == NULL) { + // TODO: get error message + _err = "CreateFileMappingA failed: " + safetensors_format_win_err(error) + + "\n"; + _valid = false; + size = 0; + addr = nullptr; + return; + } + + addr = reinterpret_cast( + MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0)); + error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + _err = + "MapViewOfFile failed: " + safetensors_format_win_err(error) + "\n"; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + if (prefetch) { + // PrefetchVirtualMemory is only present on Windows 8 and above, so we + // dynamically load it + BOOL(WINAPI * pPrefetchVirtualMemory) + (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + // may fail on pre-Windows 8 systems + pPrefetchVirtualMemory = + reinterpret_cast( + GetProcAddress(hKernel32, "PrefetchVirtualMemory")); + + if (pPrefetchVirtualMemory) { + // advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + _warn += "PrefetchVirtualMemory failed: " + + safetensors_format_win_err(GetLastError()) + "\n"; + } + } + } +#endif + } + ~safetensors_mmap() { + if (!UnmapViewOfFile(addr)) { + _warn += "UnmapViewOfFile failed: " + + safetensors_format_win_err(GetLastError()) + "\n"; + } + } +#else + static constexpr bool SUPPORTED = false; + + safetensors_mmap(struct safetensors_file *file, bool prefetch = true, + bool numa = false) { + (void)file; + (void)prefetch; + (void)numa; + + _valid = false; + _err = "mmap not supported\n"; + addr = nullptr; + size = 0; + } +#endif +}; + +// Based on MIOPen bfloat16 +// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/master/src/kernels/bfloat16_dev.hpp + +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2019 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + *all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +typedef union cvt_bf16_fp32 { + uint32_t u32; + uint16_t ushortvec[2]; + + float f32; +} cvt_bf16_fp32_t; + +float bfloat16_to_float(uint16_t src_val) { + cvt_bf16_fp32_t target_val; + + target_val.ushortvec[0] = 0; + target_val.ushortvec[1] = src_val; + + return target_val.f32; +} + +uint16_t float_to_bfloat16(float src_val) { + cvt_bf16_fp32_t target_val; + target_val.f32 = src_val; + // BF16 round and NaN preservation code matches + // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h + if ((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + if ((target_val.u32 & 0xffff) != 0) { + target_val.u32 |= 0x10000; // Preserve signaling NaN + } + } else { +#if 1 // MIOPEN_USE_RNE_BFLOAT16 + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); +#endif // MIOPEN_USE_RNE_BFLOAT16 + } + + return target_val.ushortvec[1]; +} + +// half <-> float conversion based on: https://gist.github.com/rygorous/2156668 +// (CC0 license) +// + +// Little endian +union FP32le { + unsigned int u; + float f; + struct { + unsigned int Mantissa : 23; + unsigned int Exponent : 8; + unsigned int Sign : 1; + } s; +}; + +// Little endian +union float16le { + unsigned short u; + struct { + unsigned int Mantissa : 10; + unsigned int Exponent : 5; + unsigned int Sign : 1; + } s; +}; + +float half_to_float_le(float16le h) { + static const FP32le magic = {113 << 23}; + static const unsigned int shifted_exp = 0x7c00 + << 13; // exponent mask after shift + FP32le o; + + o.u = (h.u & 0x7fffU) << 13U; // exponent/mantissa bits + unsigned int exp_ = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp_ == shifted_exp) // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + else if (exp_ == 0) // Zero/Denormal? + { + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (h.u & 0x8000U) << 16U; // sign bit + return o.f; +} + +uint16_t float_to_half_full_le(float _f) { + FP32le f; + f.f = _f; + float16le o = {0}; + + // Based on ISPC reference code (with minor modifications) + if (f.s.Exponent == 0) // Signed zero/denormal (which will underflow) + o.s.Exponent = 0; + else if (f.s.Exponent == 255) // Inf or NaN (all exponent bits set) + { + o.s.Exponent = 31; + o.s.Mantissa = f.s.Mantissa ? 0x200 : 0; // NaN->qNaN and Inf->Inf + } else // Normalized number + { + // Exponent unbias the single, then bias the halfp + int newexp = f.s.Exponent - 127 + 15; + if (newexp >= 31) // Overflow, return signed infinity + o.s.Exponent = 31; + else if (newexp <= 0) // Underflow + { + if ((14 - newexp) <= 24) // Mantissa might be non-zero + { + unsigned int mant = f.s.Mantissa | 0x800000; // Hidden 1 bit + o.s.Mantissa = mant >> (14 - newexp); + if ((mant >> (13 - newexp)) & 1) // Check for rounding + o.u++; // Round, might overflow into exp bit, but this is OK + } + } else { + o.s.Exponent = static_cast(newexp); + o.s.Mantissa = f.s.Mantissa >> 13; + if (f.s.Mantissa & 0x1000) // Check for rounding + o.u++; // Round, might overflow to inf, this is OK + } + } + + o.s.Sign = f.s.Sign; + + return o.u; +} + +bool parse_safetensors_header(const uint8_t *addr, const size_t nbytes, + const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err) { + if (nbytes < 16) { + if (err) { + (*err) += "Size is too short.\n"; + } + return false; + } + + uint64_t header_size{0}; + memcpy(reinterpret_cast(&header_size), addr, + sizeof(uint64_t)); + + if (header_size < 4) { + if (err) { + (*err) += "Header size is too short.\n"; + } + return false; + } + + if ((8 + header_size) > nbytes) { + if (err) { + (*err) += "Header size " + std::to_string(header_size) + + " + 8 exceeds input size " + std::to_string(nbytes) + " .\n"; + } + return false; + } + + if (header_size > kMaxJSONSize) { + if (err) { + (*err) += "Header JSON size exceeds the limit(" + + std::to_string(kMaxJSONSize) + ").\n"; + } + return false; + } + + // assume JSON data is small enough. + std::string json_str(reinterpret_cast(&addr[8]), header_size); + const char *p = json_str.c_str(); + + ::minijson::value v; + ::minijson::error e = ::minijson::parse(p, v); + + if (e != ::minijson::no_error) { + if (err) { + std::string json_err(::minijson::errstr(e)); + (*err) += "JSON parse error: " + json_err + "\n"; + } + + return false; + } + + ordered_dict tensors; + ordered_dict metadata; + + // root element must be dict. + if (auto po = v.as<::minijson::object>()) { + for (size_t i = 0; i < po->size(); i++) { + std::string key = po->keys()[i]; + + if (key == "__metadata__") { + ::minijson::value value; + if (!po->at(i, &value)) { + if (err) { + (*err) += "Internal error. Invalid object in __metadata__.\n"; + } + return false; + } + + if (!detail::parse_metadata(value, metadata, err)) { + return false; + } + } else { + // tensor + + if (tensors.count(key)) { + if (err) { + (*err) += "Duplicate key `" + key + "` found.\n"; + } + return false; + } + + ::minijson::value value; + if (!po->at(i, &value)) { + if (err) { + (*err) += "Internal error. Invalid object in `" + key + "`.\n"; + } + return false; + } + + tensor_t tensor; + if (!detail::parse_tensor(key, value, tensor, err)) { + return false; + } + + tensors.insert(key, std::move(tensor)); + } + } + } else { + if (err) { + (*err) += "JSON root elements must be object(dict)\n"; + } + } + + st->tensors = std::move(tensors); + st->metadata = std::move(metadata); + st->header_size = header_size; + +#if 0 + size_t databuffer_size = nbytes - header_size - 8; + + st->storage.resize(nbytes); + memcpy(st->storage.data(), addr + 8 + header_size, nbytes); + + st->mmaped = false; + st->mmap_addr = addr + 8 + header_size; + st->mmap_size = 0; +#endif + + return true; +} + +} // namespace detail + +safetensors_t::~safetensors_t() { + if (st_mmap) { + detail::safetensors_mmap *p = + reinterpret_cast(st_mmap); + delete p; + st_mmap = nullptr; + } + + if (st_file) { + detail::safetensors_file *p = + reinterpret_cast(st_file); + delete p; + st_file = nullptr; + } +} + +// +// - 8byte: header_size +// - json data(header_size bytes) +// - tensor data(filesize - header_size) +// + +bool load_from_file(const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err) { + std::vector data; + if (!detail::ReadWholeFile(&data, err, filename, nullptr)) { + return false; + } + + return load_from_memory(reinterpret_cast(data.data()), + data.size(), filename, st, warn, err); +} + +bool load_from_memory(const uint8_t *addr, const size_t nbytes, + const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err) { + if (nbytes < 16) { + if (err) { + (*err) += "Size is too short.\n"; + } + return false; + } + + if (!detail::parse_safetensors_header(addr, nbytes, filename, st, warn, + err)) { + return false; + } + + size_t databuffer_size = nbytes - st->header_size - 8; + + st->storage.resize(databuffer_size); + memcpy(st->storage.data(), addr + 8 + st->header_size, databuffer_size); + + st->mmaped = false; + st->mmap_addr = nullptr; + st->mmap_size = 0; + st->databuffer_addr = nullptr; + st->databuffer_size = 0; + + return true; +} + +bool mmap_from_file(const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err) { + if (!st) { + return false; + } + + detail::safetensors_file *pf = + new detail::safetensors_file(filename.c_str(), "rb"); + if (!pf->is_valid()) { + if (err) { + (*err) += pf->get_error(); + } + delete pf; + return false; + } + + // TODO: prefetch, numa + detail::safetensors_mmap *pm = new detail::safetensors_mmap(pf); + + bool ret = mmap_from_memory(pm->addr, pm->size, filename, st, warn, err); + + if (!ret) { + delete pm; + delete pf; + + return false; + } + + st->mmap_addr = pm->addr; + st->mmap_size = pm->size; + + st->databuffer_addr = st->mmap_addr + 8 + st->header_size; + st->databuffer_size = st->mmap_size - (8 + st->header_size); + + // retain pointer + st->st_file = pf; + st->st_mmap = pm; + + st->mmaped = true; + + return true; +} + +bool mmap_from_memory(const uint8_t *addr, const size_t nbytes, + const std::string &filename, safetensors_t *st, + std::string *warn, std::string *err) { + if (!addr) { + return false; + } + + if (nbytes < 16) { + return false; + } + + if (!st) { + return false; + } + + if (!detail::parse_safetensors_header(addr, nbytes, filename, st, warn, + err)) { + return false; + } + + size_t databuffer_size = nbytes - st->header_size - 8; + + st->mmaped = true; + + st->mmap_addr = addr; + st->mmap_size = nbytes; + + st->databuffer_addr = st->mmap_addr + 8 + st->header_size; + st->databuffer_size = st->mmap_size - (8 + st->header_size); + + return true; +} + +float bfloat16_to_float(uint16_t x) { return detail::bfloat16_to_float(x); } + +uint16_t float_to_bfloat16(float x) { return detail::float_to_bfloat16(x); } + +float fp16_to_float(uint16_t x) { + detail::float16le src; + src.u = x; + return detail::half_to_float_le(src); +} + +uint16_t float_to_fp16(float x) { return detail::float_to_half_full_le(x); } + +size_t get_dtype_bytes(const safetensors::dtype dtype) { + size_t sz = 0; + + switch (dtype) { + case safetensors::dtype::kBOOL: + // Original Rust implementaion uses 1. + sz = 1; + break; + case safetensors::dtype::kUINT8: + sz = 1; + break; + case safetensors::dtype::kINT8: + sz = 1; + break; + case safetensors::dtype::kUINT16: + sz = 2; + break; + case safetensors::dtype::kINT16: + sz = 2; + break; + case safetensors::dtype::kINT32: + sz = 4; + break; + case safetensors::dtype::kUINT32: + sz = 4; + break; + case safetensors::dtype::kFLOAT16: + sz = 2; + break; + case safetensors::dtype::kBFLOAT16: + sz = 2; + break; + case safetensors::dtype::kFLOAT32: + sz = 4; + break; + case safetensors::dtype::kFLOAT64: + sz = 8; + break; + case safetensors::dtype::kINT64: + sz = 8; + break; + case safetensors::dtype::kUINT64: + sz = 8; + break; + } + + return sz; +} + +std::string get_dtype_str(const safetensors::dtype dtype) { + switch (dtype) { + case safetensors::dtype::kBOOL: + return "BOOL"; + case safetensors::dtype::kUINT8: + return "U8"; + case safetensors::dtype::kINT8: + return "I8"; + case safetensors::dtype::kUINT16: + return "U16"; + case safetensors::dtype::kINT16: + return "I16"; + case safetensors::dtype::kINT32: + return "I32"; + case safetensors::dtype::kUINT32: + return "U32"; + case safetensors::dtype::kFLOAT16: + return "F16"; + case safetensors::dtype::kBFLOAT16: + return "BF16"; + case safetensors::dtype::kFLOAT32: + return "F32"; + case safetensors::dtype::kFLOAT64: + return "F64"; + case safetensors::dtype::kINT64: + return "I64"; + case safetensors::dtype::kUINT64: + return "U64"; + } + return "???"; +} + +// Empty Tensor returns 0. +// Zero-rank Tensor reuturns 1(scalar) +size_t get_shape_size(const tensor_t &t) { + if (t.shape.empty()) { + return 1; + } + + if (t.shape.size() >= kMaxDim) { // invalid ndim + return 0; + } + + size_t sz = 1; + + for (size_t i = 0; i < t.shape.size(); i++) { + sz *= t.shape[i]; + } + + return sz; +} + +bool validate_data_offsets(const safetensors_t &st, std::string &err) { + bool valid{true}; + + std::stringstream ss; + + size_t databuffersize; + if (st.mmaped) { + databuffersize = st.databuffer_size; + } else { + databuffersize = st.storage.size(); + } + + size_t ntensors{0}; + // Iterate with key insertion order. + for (size_t i =0 ;i < st.tensors.size(); i++) { + + std::string key = st.tensors.keys()[i]; + + tensor_t tensor; + if (!st.tensors.at(i, &tensor)) { + ss << "Internal error: Failed to get tensor at [" << i << "]\n"; + valid = false; + continue; + } + + if (tensor.data_offsets[0] > tensor.data_offsets[1]) { + ss << key << ".data_offsets.BEGIN " << tensor.data_offsets[0] + << " must be less than or equal to data_offsets.END " + << tensor.data_offsets[1] << "\n"; + valid = false; + } + + size_t tensor_size = get_dtype_bytes(tensor.dtype) * get_shape_size(tensor); + + if (tensor_size == 0) { + // OK + continue; + } + + // data_offsets are absolute offset from the databuffer(file) + if (tensor.data_offsets[0] > databuffersize) { + ss << "Tensor `" << key << "`.data_offset.BEGIN " + << tensor.data_offsets[0] << " exceeds databuffer size " + << databuffersize << ".\n"; + valid = false; + } + + if (tensor.data_offsets[1] > databuffersize) { + ss << "Tensor `" << key << "`.data_offset.END " + << tensor.data_offsets[1] << " exceeds databuffer size " + << databuffersize << ".\n"; + valid = false; + } + + size_t data_size = tensor.data_offsets[1] - tensor.data_offsets[0]; + + if (tensor_size != data_size) { + ss << "Data size mismatch. The size in Tensor `" << key << "` is " + << tensor_size << ", but the size from data_offsets is " << data_size + << "\n"; + valid = false; + } + + ntensors++; + if (ntensors == st.tensors.size()) { + // Last element's data_offsets[1] must be equal to databuffer size. + if (tensor.data_offsets[1] != databuffersize) { + ss << "The last tensor's data_offset.END(" << tensor.data_offsets[1] + << ") must be equal to databufer size " << databuffersize << ".\n"; + valid = false; + } + } + } + + if (!valid) { + err = ss.str(); + } + + return valid; +} + +bool save_to_memory(const safetensors_t &st, std::vector *dst, + std::string *warn, std::string *err) { + // directly serialize JSON string. + std::stringstream ss; + + // NOTE: The last offset **must** be the end of the file, + // so write __metadata__ first(if metadata part exists) + + std::string _err; + if (!validate_data_offsets(st, _err)) { + if (err) { + (*err) += "Invalid safensors is provided.\n"; + (*err) += _err; + } + return false; + } + + ss << "{"; + if (st.metadata.size()) { + ss << "\"__metadata__\": {"; + size_t nmeta = 0; + for (size_t i = 0; i < st.metadata.size(); i++) { + std::string key = st.metadata.keys()[i]; + std::string value; + st.metadata.at(i, &value); + + if (nmeta > 0) { + ss << ", "; + } + ss << "\"" + key + "\": \"" << value << "\""; + nmeta++; + } + ss << "}"; + + if (st.tensors.size()) { + ss << ", "; + } + } + + size_t ntensors = 0; + { + for (size_t i = 0; i < st.tensors.size(); i++) { + + std::string key = st.tensors.keys()[i]; + safetensors::tensor_t tensor; + st.tensors.at(i, &tensor); + + if (tensor.shape.size() > safetensors::kMaxDim) { + if (err) { + (*err) += key + ".shape is too large.\n"; + (*err) += _err; + } + return false; + } + + if (ntensors > 0) { + ss << ", "; + } + ss << "\"" << key << "\": {"; + ss << "\"dtype\": \"" << safetensors::get_dtype_str(tensor.dtype) + << "\", "; + ss << "\"shape\": ["; + for (size_t i = 0; i < tensor.shape.size(); i++) { + if (i > 0) { + ss << ", "; + } + ss << tensor.shape[i]; + } + ss << "]"; + ss << ", \"data_offsets\": [" << tensor.data_offsets[0] << ", " + << tensor.data_offsets[1] << "]"; + ss << "}"; + ntensors++; + } + } + ss << "}"; + + std::string header_str = ss.str(); + + uint64_t header_size = header_str.size(); // do not include '\n' + + const void *databuffer_addr{nullptr}; + size_t databuffer_size{0}; + if (st.mmaped) { + databuffer_size = st.databuffer_size; + databuffer_addr = st.databuffer_addr; + } else { + databuffer_size = st.storage.size(); + databuffer_addr = reinterpret_cast(st.storage.data()); + } + + // make databuffer addr start from the multiple of 8. + size_t pad_bytes = 0; + if ((header_size % 8) != 0) { + pad_bytes = 8 - (header_size % 8); + } + //printf("header_size = %d\n", int(header_size)); + //printf("pad_bytes = %d\n", int(pad_bytes)); + size_t padded_header_size = header_size + pad_bytes; + dst->resize(8 + padded_header_size + databuffer_size); + + // write padded header_size + memcpy(dst->data(), &padded_header_size, 8); + + // write header + memcpy(dst->data() + 8, header_str.data(), header_size); + + // Use whitespace for trailing padding. + memset(dst->data() + 8 + header_size, 0x20, pad_bytes); + + memcpy(dst->data() + 8 + padded_header_size, databuffer_addr, + databuffer_size); + + return true; +} + +bool save_to_file(const safetensors_t &st, const std::string &filename, + std::string *warn, std::string *err) { + // TODO: Use more reliable io. + std::ofstream ofs(filename, std::ios::binary); + + if (!ofs) { + if (err) { + (*err) += "Failed to open `" + filename + + "` to write. File is either existing directory or " + "write-protected, or disk is full?\n"; + } + return false; + } + + std::vector buf; + if (!save_to_memory(st, &buf, warn, err)) { + return false; + } + + ofs.write(reinterpret_cast(buf.data()), buf.size()); + if (!ofs) { + if (err) { + (*err) += "Failed to write safetensor data to `" + filename + + "`. Maybe no disk space available?(Required bytes : " + + std::to_string(buf.size()) + "\n"; + } + return false; + } + + return true; +} + +} // namespace safetensors + +#endif diff --git a/transformers/README.md b/transformers/README.md index 24ff424ceb..01fd10b927 100644 --- a/transformers/README.md +++ b/transformers/README.md @@ -57,6 +57,26 @@ The directory structure is as follows: + Direct Conversion to MNN Model Use `--export mnn` to directly convert to an MNN model. Note that you need to either install pymnn or specify the path to the MNNConvert tool using the `--mnnconvert` option. At least one of these conditions must be met. If pymnn is not installed and the MNNConvert tool's path is not specified via --mnnconvert, the llmexport.py script will search for the MNNConvert tool in the directory "../../../build/". Ensure that the MNNConvert file exists in this directory. This method currently supports exporting 4-bit and 8-bit models. ++ Segment MNN Export +Use `--export mnn --segment` to export a segment-format MNN LLM directly from safetensors weights and a workflow JSON, without generating ONNX first. If `--workflow` is not specified, `llmexport.py` searches `resource/*.json` for a matching workflow. + +``` +cd transformers/llm/export +python3 llmexport.py \ + --path /path/to/Qwen3-0.6B \ + --export mnn \ + --segment \ + --dst_path ./model +``` + +The output directory contains `config.json` with `"mnn_llm_version": "segment"`, `llm_config.json`, `tokenizer.mtok`, `embed.mnn`, `decoder.mnn`, `decoder.mnn.weight`, `logit.mnn`, `logit.mnn.weight`, and `logit_topkv_1.mnn`. Run the segment model with the generated `config.json`: + +``` +./llm_demo transformers/llm/export/model/config.json /path/to/prompt.txt +``` + +The C++ runtime must be built with `MNN_BUILD_LLM=ON` and `MNN_LLM_SUPPORT_SEGMENT=ON` (enabled by default). Segment export currently supports `--export mnn` only. + + If you encounter issues with directly converting to an MNN model or require quantization with other bit depths (e.g., 5-bit/6-bit), you can first convert the model to an ONNX model using `--export onnx`. Then, use the MNNConvert tool to convert the ONNX model to an MNN model with the following command: ``` @@ -72,7 +92,7 @@ Use `--export mnn` to directly convert to an MNN model. Note that you need to ei ``` usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT] - [--mnnconvert MNNCONVERT] + [--mnnconvert MNNCONVERT] [--segment] [--workflow WORKFLOW] llm_exporter @@ -89,6 +109,8 @@ options: --dst_path DST_PATH export onnx/mnn model to path, default is `./model`. --test TEST test model inference with query `TEST`. --export EXPORT export model to an onnx/mnn model. + --segment export segment MNN LLM from safetensors workflow directly, without ONNX export. + --workflow WORKFLOW workflow json for --segment safetensors conversion. If absent, search resource/*.json. --quant_bit QUANT_BIT mnn quant bit, 4 or 8, default is 4. --quant_block QUANT_BLOCK diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt index 1b005a7858..8fbe34d874 100644 --- a/transformers/llm/engine/CMakeLists.txt +++ b/transformers/llm/engine/CMakeLists.txt @@ -1,6 +1,7 @@ option(BUILD_MLS "Build PC Commandline." OFF) option(MNN_LLM_BUILD_DEMO "Build LLM demo" ON) option(LLM_SUPPORT_HTTP_RESOURCE "Support HTTP resource download" ON) +option(MNN_LLM_SUPPORT_SEGMENT "Enable mnn_llm_version=segment runtime path." ON) set(LLM_DEPS ${MNN_DEPS}) if (MNN_BUILD_OPENCV) @@ -44,6 +45,9 @@ else() endif() # jinja.cpp template engine (always enabled, header-only) target_compile_definitions(llm PRIVATE LLM_USE_JINJA) +if (MNN_LLM_SUPPORT_SEGMENT) + target_compile_definitions(llm PRIVATE MNN_LLM_SUPPORT_SEGMENT) +endif() # Option to store MNN_PRINT/MNN_ERROR output into a string buffer. # Only enabled on Android. Pass -DLLM_LOG_TO_STRING=ON at cmake configure time. @@ -166,4 +170,4 @@ set_property(TARGET mls PROPERTY CXX_STANDARD_REQUIRED ON) # target_compile_options(mls PRIVATE -std=c++17) target_link_libraries(mls PRIVATE ${LLM_DEPS}) target_compile_definitions(mls PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) -endif() \ No newline at end of file +endif() diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index b7a06db562..ff0e52da0a 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -164,7 +164,7 @@ class MNN_PUBLIC Llm { void response(const ChatMessages& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1); void response(MNN::Express::VARP input_embeds, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1); virtual void generate_init(std::ostream* os = nullptr, const char* end_with = nullptr); - void generate(int max_token); + virtual void generate(int max_token); std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); std::vector generate(MNN::Express::VARP input_embeds, int max_tokens = -1); bool stoped(); @@ -278,4 +278,4 @@ class MNN_PUBLIC Embedding : public Llm { } } -#endif // LLM_hpp \ No newline at end of file +#endif // LLM_hpp diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 2f163a4906..4e2c21c51a 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -23,6 +23,9 @@ #include "diskembedding.hpp" #include "sampler.hpp" #include "omni.hpp" +#ifdef MNN_LLM_SUPPORT_SEGMENT +#include "segment.hpp" +#endif #include "speculative_decoding/generate.hpp" #include "core/MNNFileUtils.h" @@ -87,6 +90,16 @@ static inline void _llmOrigError(const char* msg) { Llm* Llm::createLLM(const std::string& config_path) { std::shared_ptr config(new LlmConfig(config_path)); Llm* llm = nullptr; + const auto llmVersion = config->mnn_llm_version(); +#ifdef MNN_LLM_SUPPORT_SEGMENT + if (llmVersion == "segment") { + return createSegmentLlm(config); + } +#else + if (llmVersion == "segment") { + MNN_ERROR("[Error]: mnn_llm_version=segment requires MNN_LLM_SUPPORT_SEGMENT.\n"); + } +#endif if (config->is_visual() || config->is_audio() || config->has_talker()) { llm = new Omni(config); } else { @@ -1617,4 +1630,4 @@ bool Llm::is_stop(int token_id) { return stop; } } // namespace Transformer -} // namespace MNN \ No newline at end of file +} // namespace MNN diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index e95c672352..3cad3af05e 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -81,6 +81,18 @@ class LlmConfig { } else { config_ = ujson::json::parse("{}"); base_dir_ = path; + if (!base_dir_.empty() && base_dir_.back() != '/' && base_dir_.back() != '\\') { + base_dir_ += "/"; + } + std::ifstream config_file(base_dir_ + "config.json"); + if (config_file.is_open()) { + std::ostringstream ostr; + ostr << config_file.rdbuf(); + auto model_config = ujson::json::parse(ostr.str()); + if (model_config.contains("mnn_llm_version")) { + config_.merge(model_config); + } + } } } // using config's base_dir @@ -146,6 +158,10 @@ class LlmConfig { std::string context_file() const { return base_dir_ + config_.value("context_file", "context.json"); } + + std::string mnn_llm_version() const { + return config_.value("mnn_llm_version", ""); + } // model file config end > // < generate config start @@ -634,4 +650,4 @@ class LlmConfig { } // Transformer } // MNN -#endif \ No newline at end of file +#endif diff --git a/transformers/llm/engine/src/segment.cpp b/transformers/llm/engine/src/segment.cpp new file mode 100644 index 0000000000..88ce7ccd4e --- /dev/null +++ b/transformers/llm/engine/src/segment.cpp @@ -0,0 +1,468 @@ +#ifdef MNN_LLM_SUPPORT_SEGMENT + +#include "segment.hpp" + +#include +#include +#include + +#include +#include "core/MNNFileUtils.h" +#include "kvmeta.hpp" +#include "llmconfig.hpp" +#include "tokenizer/tokenizer.hpp" + +namespace MNN { +namespace Transformer { +namespace { + +using namespace Express; +using RuntimeManager = Express::Executor::RuntimeManager; + +static bool segmentCheckFile(const std::string& path, const char* name) { + if (!MNNFileExist(path.c_str())) { + MNN_ERROR("[Error]: segment %s not found: %s\n", name, path.c_str()); + return false; + } + std::ifstream f(path); + if (!f.is_open()) { + MNN_ERROR("[Error]: failed to open segment %s: %s\n", name, path.c_str()); + return false; + } + return true; +} + +static std::string segmentPath(const LlmConfig& config, const std::string& name) { + return config.base_dir_ + name; +} + +static MNNForwardType segmentForwardType(std::shared_ptr config) { + if (config->config_.contains("forwardtype")) { + return static_cast(config->config_.value("forwardtype", 0)); + } + const auto type = config->backend_type(); + if (type == "metal") + return MNN_FORWARD_METAL; + if (type == "cuda") + return MNN_FORWARD_CUDA; + if (type == "opencl") + return MNN_FORWARD_OPENCL; + if (type == "opengl") + return MNN_FORWARD_OPENGL; + if (type == "vulkan") + return MNN_FORWARD_VULKAN; + if (type == "npu") + return MNN_FORWARD_NN; + return MNN_FORWARD_CPU; +} + +static void segmentApplyBackendConfig(std::shared_ptr config, BackendConfig* backend) { + if (backend == nullptr) { + return; + } + if (config->config_.contains("precision")) { + backend->precision = static_cast(config->config_.value("precision", 2)); + } else if (config->precision() == "high") { + backend->precision = BackendConfig::Precision_High; + } else if (config->precision() == "low") { + backend->precision = BackendConfig::Precision_Low; + } + if (config->config_.contains("memory")) { + backend->memory = static_cast(config->config_.value("memory", 2)); + } else if (config->memory() == "high") { + backend->memory = BackendConfig::Memory_High; + } else if (config->memory() == "low") { + backend->memory = BackendConfig::Memory_Low; + } +} + +static VARP segmentTakeLastHidden(VARP hidden) { + if (hidden == nullptr) { + return nullptr; + } + auto info = hidden->getInfo(); + if (info == nullptr || info->dim.size() < 3) { + return hidden; + } + const int seqLen = info->dim[1]; + const int hiddenSize = info->dim[2]; + if (seqLen <= 0 || hiddenSize <= 0 || seqLen == 1) { + return hidden; + } + const size_t bytes = static_cast(hiddenSize) * info->type.bytes(); + const uint8_t* src = hidden->readMap(); + if (src == nullptr || bytes == 0) { + return hidden; + } + auto out = _Input({1, 1, hiddenSize}, info->order, info->type); + ::memcpy(out->writeMap(), src + static_cast(seqLen - 1) * bytes, bytes); + out.fix(VARP::CONSTANT); + return out; +} + +static void segmentWait(VARP var) { + if (var == nullptr || var->getTensor() == nullptr) { + return; + } + ((MNN::Tensor*)var->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); +} + +} // namespace + +class SegmentLlm final : public Llm { +public: + explicit SegmentLlm(std::shared_ptr config) : Llm(config) { + mSeqLenIndex = 1; + mMeta->layer_nums = mConfig->config_.value("layer_nums", 0); + } + + bool load() override; + VARP embedding(const std::vector& input_ids) override; + VARP gen_attention_mask(int seq_len) override; + VARP gen_position_ids(int seq_len) override; + std::vector forwardRaw(VARP hiddenState, VARP mask, VARP inputPos, VARPS extraArgs = {}) override; + int sample(VARP logits, int offset = 0, int size = 0) override; + void response(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, + int max_new_tokens = -1) override; + void generate(int max_token) override; + +private: + bool loadTokenizer(); + bool loadModules(); + bool prefill(const std::vector& input_ids); + int sampleFromHidden(VARP hidden); + VARP decoderForward(VARP input, VARP mask = nullptr, VARP positionIds = nullptr); + void updateSegmentContext(int seqLen, int genLen); + +private: + std::shared_ptr mEmbedModule; + std::shared_ptr mDecoderModule; + std::shared_ptr mDecoderPrefillModule; + std::shared_ptr mLogitBaseModule; + std::shared_ptr mLogitModule; + VARP mLastHidden; + int mMaxDecodeTokens = 1024; +}; + +bool SegmentLlm::loadTokenizer() { + std::string tokenizerPath = mConfig->tokenizer_file(); + if (!segmentCheckFile(tokenizerPath, "tokenizer file")) { + return false; + } + mTokenizer.reset(Tokenizer::createTokenizer(tokenizerPath)); + if (mTokenizer == nullptr) { + MNN_ERROR("[Error]: failed to load segment tokenizer: %s\n", tokenizerPath.c_str()); + return false; + } + + if (mConfig->config_.contains("jinja")) { + setChatTemplate(); + return true; + } + + std::ifstream tokenConfig(segmentPath(*mConfig, "token_config.json")); + if (tokenConfig.is_open()) { + std::ostringstream ostr; + ostr << tokenConfig.rdbuf(); + auto json = ujson::json::parse(ostr.str()); + if (json.contains("chat_template")) { + mTokenizer->set_chat_template(json["chat_template"].get(), json.value("eos_token", "")); + } + } + return true; +} + +bool SegmentLlm::loadModules() { + const std::string embedPath = segmentPath(*mConfig, "embed.mnn"); + const std::string decoderPath = segmentPath(*mConfig, "decoder.mnn"); + const std::string decoderWeightPath = decoderPath + ".weight"; + const std::string logitPath = segmentPath(*mConfig, "logit.mnn"); + const std::string logitWeightPath = logitPath + ".weight"; + const std::string logitTopkPath = segmentPath(*mConfig, "logit_topkv_1.mnn"); + + if (!segmentCheckFile(embedPath, "embed model") || !segmentCheckFile(decoderPath, "decoder model") || + !segmentCheckFile(decoderWeightPath, "decoder weight") || !segmentCheckFile(logitPath, "logit model") || + !segmentCheckFile(logitWeightPath, "logit weight") || !segmentCheckFile(logitTopkPath, "logit topk model")) { + return false; + } + + BackendConfig backendConfig; + segmentApplyBackendConfig(mConfig, &backendConfig); + + ScheduleConfig decoderSchedule; + decoderSchedule.backendConfig = &backendConfig; + decoderSchedule.type = segmentForwardType(mConfig); + decoderSchedule.numThread = 1; + if (decoderSchedule.type == MNN_FORWARD_OPENCL) { + decoderSchedule.numThread |= 64; + } + + mRuntimeManager.reset(RuntimeManager::createRuntimeManager(decoderSchedule), RuntimeManager::destroy); + mRuntimeManager->setHintPtr(Interpreter::KVCACHE_INFO, mMeta.get()); + Module::Config decoderConfig; + decoderConfig.rearrange = true; + if (decoderSchedule.type == MNN_FORWARD_OPENCL || decoderSchedule.type == MNN_FORWARD_VULKAN) { + decoderConfig.shapeMutable = false; + } + + const std::vector decoderInputs = {"input_embedding", "mask", "position_ids"}; + mDecoderModule.reset( + Module::load(decoderInputs, {"last_hidden_state"}, decoderPath.c_str(), mRuntimeManager, &decoderConfig)); + if (!mDecoderModule) { + mDecoderModule.reset( + Module::load(decoderInputs, {"hidden_state"}, decoderPath.c_str(), mRuntimeManager, &decoderConfig)); + } + if (!mDecoderModule) { + MNN_ERROR("[Error]: load segment decoder.mnn failed\n"); + return false; + } + mDecoderPrefillModule.reset(Module::clone(mDecoderModule.get())); + if (!mDecoderPrefillModule) { + MNN_ERROR("[Error]: clone segment decoder prefill module failed\n"); + return false; + } + + BackendConfig otherBackendConfig = backendConfig; + if (mConfig->config_.contains("otherPrecision")) { + otherBackendConfig.precision = static_cast( + mConfig->config_.value("otherPrecision", (int)otherBackendConfig.precision)); + } + ScheduleConfig otherSchedule = decoderSchedule; + otherSchedule.backendConfig = &otherBackendConfig; + mProcessorRuntimeManager.reset(RuntimeManager::createRuntimeManager(otherSchedule), RuntimeManager::destroy); + mProcessorRuntimeManager->setHintPtr(Interpreter::KVCACHE_INFO, nullptr); + + Module::Config moduleConfig; + moduleConfig.rearrange = true; + mLogitBaseModule.reset(Module::load({}, {}, logitPath.c_str(), mProcessorRuntimeManager, &moduleConfig)); + if (!mLogitBaseModule) { + MNN_ERROR("[Error]: load segment logit.mnn failed\n"); + return false; + } + + Module::Config depConfig = moduleConfig; + depConfig.base = mLogitBaseModule.get(); + mLogitModule.reset(Module::load({}, {}, logitTopkPath.c_str(), mProcessorRuntimeManager, &depConfig)); + mEmbedModule.reset(Module::load({}, {}, embedPath.c_str(), mProcessorRuntimeManager, &depConfig)); + if (!mLogitModule || !mEmbedModule) { + MNN_ERROR("[Error]: load segment logit_topkv_1.mnn/embed.mnn failed\n"); + return false; + } + return true; +} + +bool SegmentLlm::load() { + MNN::Express::ExecutorScope s(mExecutor); + Timer _t; + mMaxDecodeTokens = mConfig->config_.value("max_decode_tokens", mConfig->max_new_tokens()); + if (!loadTokenizer() || !loadModules()) { + return false; + } + mContext->load_us += _t.durationInUs(); + mContext->status = LlmStatus::RUNNING; + return true; +} + +VARP SegmentLlm::embedding(const std::vector& input_ids) { + MNN::Express::ExecutorScope s(mExecutor); + if (input_ids.empty() || !mEmbedModule) { + return nullptr; + } + auto var = _Input({1, static_cast(input_ids.size())}, NCHW, halide_type_of()); + ::memcpy(var->writeMap(), input_ids.data(), input_ids.size() * sizeof(int)); + auto outputs = mEmbedModule->onForward({var}); + return outputs.empty() ? nullptr : outputs[0]; +} + +VARP SegmentLlm::gen_attention_mask(int seq_len) { + auto mask = _Input({}, NCHW, halide_type_of()); + *mask->writeMap() = 0.0f; + mask.fix(VARP::CONSTANT); + return mask; +} + +VARP SegmentLlm::gen_position_ids(int seq_len) { + auto positionIds = _Input({1, seq_len}, NCHW, halide_type_of()); + auto ptr = positionIds->writeMap(); + const int start = static_cast(mMeta->previous) - static_cast(mMeta->remove); + for (int i = 0; i < seq_len; ++i) { + ptr[i] = start + i; + } + positionIds.fix(VARP::CONSTANT); + return positionIds; +} + +VARP SegmentLlm::decoderForward(VARP input, VARP mask, VARP positionIds) { + if (input == nullptr) { + return nullptr; + } + auto info = input->getInfo(); + if (info == nullptr || info->dim.size() < 3) { + return nullptr; + } + const int seqLen = info->dim[1]; + if (mask == nullptr) { + mask = gen_attention_mask(seqLen); + } + if (positionIds == nullptr) { + positionIds = gen_position_ids(seqLen); + } + auto module = (seqLen == 1) ? mDecoderModule : mDecoderPrefillModule; + mMeta->add = seqLen; + auto outputs = module->onForward({input, mask, positionIds}); + mMeta->sync(); + if (outputs.empty()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + return nullptr; + } + segmentWait(outputs[0]); + return outputs[0]; +} + +std::vector SegmentLlm::forwardRaw(VARP hiddenState, VARP mask, VARP inputPos, VARPS extraArgs) { + auto hidden = decoderForward(hiddenState, mask, inputPos); + if (hidden == nullptr || !mLogitModule) { + mContext->status = LlmStatus::INTERNAL_ERROR; + return {}; + } + mLastHidden = segmentTakeLastHidden(hidden); + auto outputs = mLogitModule->onForward({hidden}); + if (outputs.empty()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + return {}; + } + return outputs; +} + +int SegmentLlm::sample(VARP logits, int offset, int size) { + if (logits == nullptr) { + return -1; + } + auto info = logits->getInfo(); + if (info == nullptr || info->size <= 0) { + return -1; + } + const int* topk = logits->readMap(); + return topk == nullptr ? -1 : topk[info->size - 1]; +} + +int SegmentLlm::sampleFromHidden(VARP hidden) { + if (hidden == nullptr || !mLogitModule) { + return -1; + } + auto outputs = mLogitModule->onForward({hidden}); + if (outputs.empty()) { + return -1; + } + return sample(outputs[0]); +} + +void SegmentLlm::updateSegmentContext(int seqLen, int genLen) { + mContext->all_seq_len += seqLen; + mContext->gen_seq_len += genLen; +} + +bool SegmentLlm::prefill(const std::vector& input_ids) { + if (input_ids.empty()) { + return false; + } + mContext->history_tokens.insert(mContext->history_tokens.end(), input_ids.begin(), input_ids.end()); + Timer _t; + auto emb = embedding(input_ids); + auto hidden = decoderForward(emb); + if (hidden == nullptr) { + mContext->status = LlmStatus::INTERNAL_ERROR; + return false; + } + mLastHidden = segmentTakeLastHidden(hidden); + if (mLastHidden.get() != nullptr) { + mLastHidden.fix(VARP::CONSTANT); + } + updateSegmentContext(static_cast(input_ids.size()), 0); + mContext->prompt_len = static_cast(input_ids.size()); + mContext->prefill_us += _t.durationInUs(); + return true; +} + +void SegmentLlm::generate(int max_token) { + CHECK_LLM_RUNNING(mContext); + MNN::Express::ExecutorScope s(mExecutor); + if (max_token < 0) { + max_token = mMaxDecodeTokens; + } + max_token = std::min(max_token, mMaxDecodeTokens); + int len = 0; + while (len < max_token) { + if (mContext->status == LlmStatus::USER_CANCEL || mContext->status == LlmStatus::INTERNAL_ERROR) { + break; + } + Timer _t; + int token = sampleFromHidden(mLastHidden); + if (token < 0) { + mContext->decode_us += _t.durationInUs(); + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + mContext->current_token = token; + if (is_stop(token)) { + mContext->decode_us += _t.durationInUs(); + if (mContext->os != nullptr) { + *mContext->os << mContext->end_with << std::flush; + } + break; + } + + mContext->history_tokens.push_back(token); + mContext->output_tokens.push_back(token); + auto decodeStr = tokenizer_decode(token); + mContext->generate_str += decodeStr; + if (mContext->os != nullptr) { + *mContext->os << decodeStr << std::flush; + } + + auto emb = embedding({token}); + auto hidden = decoderForward(emb); + if (hidden == nullptr) { + mContext->decode_us += _t.durationInUs(); + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + mLastHidden = segmentTakeLastHidden(hidden); + if (mLastHidden.get() != nullptr) { + mLastHidden.fix(VARP::CONSTANT); + } + updateSegmentContext(1, 1); + mContext->decode_us += _t.durationInUs(); + ++len; + } + if (len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } +} + +void SegmentLlm::response(const std::vector& input_ids, std::ostream* os, const char* end_with, + int max_new_tokens) { + MNN::Express::ExecutorScope s(mExecutor); + if (!end_with) { + end_with = "\n"; + } + generate_init(os, end_with); + if (!prefill(input_ids)) { + return; + } + if (max_new_tokens < 0) { + max_new_tokens = mMaxDecodeTokens; + } + if (max_new_tokens > 0) { + generate(max_new_tokens); + } +} + +Llm* createSegmentLlm(std::shared_ptr config) { + return new SegmentLlm(std::move(config)); +} + +} // namespace Transformer +} // namespace MNN + +#endif // MNN_LLM_SUPPORT_SEGMENT diff --git a/transformers/llm/engine/src/segment.hpp b/transformers/llm/engine/src/segment.hpp new file mode 100644 index 0000000000..8c4663399c --- /dev/null +++ b/transformers/llm/engine/src/segment.hpp @@ -0,0 +1,20 @@ +#ifdef MNN_LLM_SUPPORT_SEGMENT + +#ifndef LLM_SEGMENT_HPP +#define LLM_SEGMENT_HPP + +#include + +#include "llm/llm.hpp" + +namespace MNN { +namespace Transformer { + +Llm* createSegmentLlm(std::shared_ptr config); + +} // namespace Transformer +} // namespace MNN + +#endif // LLM_SEGMENT_HPP + +#endif // MNN_LLM_SUPPORT_SEGMENT diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py index 8835b9c6e9..ccf1f09296 100644 --- a/transformers/llm/export/llmexport.py +++ b/transformers/llm/export/llmexport.py @@ -21,6 +21,7 @@ from utils.smooth_quantizer import SmoothQuantizer from utils.omni_quantizer import OmniQuantizer from utils.torch_utils import onnx_export +import segment as segment_export class LlmExporter(torch.nn.Module): ''' @@ -29,7 +30,10 @@ class LlmExporter(torch.nn.Module): def __init__(self, args): super().__init__() self.init_from_args(args) - self.load_model(args.path) + if segment_export.enabled(args) and getattr(args, 'test', None) is None: + segment_export.load_metadata(self, args.path) + else: + self.load_model(args.path) def init_from_args(self, args): self.args = args @@ -47,7 +51,7 @@ def init_from_args(self, args): # init export dst dir if not os.path.exists(self.args.dst_path): os.makedirs(self.args.dst_path) - if not os.path.exists(self.onnx_path): + if not segment_export.enabled(self.args) and not os.path.exists(self.onnx_path): os.makedirs(self.onnx_path) @spinner_run(f'load pretrained model ', True) @@ -679,6 +683,9 @@ def export_language(self): self.onnx_load_param(onnx_model) def export(self, export_type): + if segment_export.enabled(self.args): + segment_export.export(self, export_type) + return if not self.args.skip_weight: if self.args.omni: self.omni_quant() @@ -877,6 +884,8 @@ def build_args(parser): parser.add_argument('--quant_config', type=str, default=None, help='path to the JSON file for op-wise quantization configuration.') parser.add_argument('--generate_for_npu', action='store_true', help='Whether or not to generate model for NPU deployment, default is False.') parser.add_argument('--skip_weight', action='store_true', help='Whether or not to skip loading model weights, useful for testing export flow.') + parser.add_argument('--segment', action='store_true', help='Export segment MNN LLM from safetensors workflow directly, without ONNX export.') + parser.add_argument('--workflow', type=str, default=None, help='workflow json for --segment safetensors conversion. If absent, search resource/*.json.') # omni quant parser.add_argument('--omni_epochs', type=int, default=20, help='OmniQuant 优化的轮数') parser.add_argument('--omni_lr', type=float, default=5e-3, help='OmniQuant 的学习率') @@ -916,4 +925,4 @@ def main(): llm_exporter.export(args.export) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/transformers/llm/export/segment.py b/transformers/llm/export/segment.py new file mode 100644 index 0000000000..582a0b0078 --- /dev/null +++ b/transformers/llm/export/segment.py @@ -0,0 +1,288 @@ +import glob +import json +import os + +from utils.config import LlmConfig +from utils.mnn_converter import MNNConverter +from utils.spinner import spinner_run +from utils.tokenizer import LlmTokenizer + + +def enabled(args): + return getattr(args, 'segment', False) + + +@spinner_run(f'load segment export metadata ', True) +def load_metadata(exporter, model_path): + model_path = os.path.abspath(os.path.expanduser(model_path)) + if exporter.args.tokenizer_path == exporter.args.path: + exporter.args.tokenizer_path = model_path + else: + tokenizer_path = os.path.expanduser(exporter.args.tokenizer_path) + if os.path.exists(tokenizer_path): + tokenizer_path = os.path.abspath(tokenizer_path) + exporter.args.tokenizer_path = tokenizer_path + + exporter.config = LlmConfig.from_pretrained(model_path) + exporter.model_type = exporter.config.model_type + exporter.tokenizer = LlmTokenizer.from_pretrained( + exporter.args.tokenizer_path, + model_type=exporter.model_type + ) + exporter.model = None + exporter.visual = None + exporter.audio = None + exporter.talker = None + exporter.mtp = None + exporter.scale_emb = None + exporter.llm_config = { + 'model_type': exporter.config.model_type, + 'hidden_size': exporter.config.hidden_size, + 'layer_nums': exporter.config.num_hidden_layers, + 'attention_mask': 'float', + 'attention_type': exporter.config.attention_type, + 'is_mrope': False + } + if exporter.config.sliding_window > 0: + exporter.llm_config['sliding_window'] = exporter.config.sliding_window + if hasattr(exporter.tokenizer, 'get_chat_template'): + chat_template = exporter.tokenizer.get_chat_template() + if chat_template is not None: + exporter.llm_config['jinja'] = { + 'chat_template': chat_template + } + if exporter.tokenizer.bos_token: + exporter.llm_config['jinja']['bos'] = exporter.tokenizer.bos_token + if exporter.tokenizer.eos_token: + exporter.llm_config['jinja']['eos'] = exporter.tokenizer.eos_token + if exporter.model_type == 'glm_ocr': + exporter.llm_config['jinja'] = { + 'chat_template': "[gMASK]{% for message in messages %}{% if message.role == \"user\" %}<|user|>\n{{ message.content }}{% elif message.role == \"assistant\" %}<|assistant|>\n{{ message.content }}{% elif message.role == \"system\" %}<|system|>\n{{ message.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}", + 'eos': '<|endoftext|>' + } + source_llm_config = os.path.join(model_path, 'llm_config.json') + if os.path.exists(source_llm_config): + with open(source_llm_config, 'r', encoding='utf-8') as f: + exporter.llm_config.update(json.load(f)) + return model_path + + +def _resource_dirs(): + export_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.abspath(os.path.join(export_dir, '../../..')) + candidates = [ + os.path.join(repo_root, 'resource'), + os.path.join(repo_root, 'transformers', 'llm', 'resource') + ] + return [path for path in candidates if os.path.isdir(path)] + + +def _workflow_score(exporter, workflow_path): + try: + with open(workflow_path, 'r', encoding='utf-8') as f: + workflow = json.load(f) + except Exception: + return None + models = workflow.get('models', []) + if not isinstance(models, list) or len(models) == 0: + return None + + model_names = [model.get('name', '') for model in models if isinstance(model, dict)] + lowered_names = [name.lower() for name in model_names] + score = 0 + if any(name in ('hf_decoder', 'decoder', 'gpt2_decoder') for name in lowered_names): + score += 20 + if any(name in ('logit', 'logit_mobile') for name in lowered_names): + score += 20 + if any(name in ('encoder', 'encoder_mobile', 'audio_proj', 'ntp1', 'wpe') for name in lowered_names): + score -= 30 + + filename = os.path.basename(workflow_path).lower() + model_type = str(getattr(exporter, 'model_type', '') or '').lower() + if model_type and model_type in filename: + score += 15 + if 'qwen' in model_type and 'qwen' in filename: + score += 10 + if 'hf' in filename and any(name == 'hf_decoder' for name in lowered_names): + score += 10 + + blocks = [] + for model in models: + if not isinstance(model, dict): + continue + for block in model.get('blocks', []): + if isinstance(block, dict): + blocks.append(block) + + cfg_pairs = { + 'hiddenSize': getattr(exporter.config, 'hidden_size', None), + 'number': getattr(exporter.config, 'num_hidden_layers', None), + 'headDim': getattr(exporter.config, 'head_dim', None), + 'numHead': getattr(exporter.config, 'num_attention_heads', None), + 'kvNumHead': getattr(exporter.config, 'num_key_value_heads', None), + 'max_position_embeddings': getattr(getattr(exporter.config, 'origin_config', None), 'max_position_embeddings', None) + } + weights = { + 'hiddenSize': 40, + 'number': 35, + 'headDim': 20, + 'numHead': 20, + 'kvNumHead': 20, + 'max_position_embeddings': 5 + } + for key, cfg_value in cfg_pairs.items(): + if cfg_value is None or isinstance(cfg_value, list): + continue + for block in blocks: + workflow_value = block.get(key) + if workflow_value is None and key == 'max_position_embeddings': + workflow_value = block.get('maxPositionEmbeddings') + if workflow_value == cfg_value: + score += weights[key] + break + return score + + +def _resolve_workflow(exporter): + workflow = getattr(exporter.args, 'workflow', None) + if workflow: + workflow = os.path.abspath(os.path.expanduser(workflow)) + if not os.path.exists(workflow): + raise FileNotFoundError(f'workflow json not found: {workflow}') + return workflow + + candidates = [] + for resource_dir in _resource_dirs(): + for path in glob.glob(os.path.join(resource_dir, '**', '*.json'), recursive=True): + score = _workflow_score(exporter, path) + if score is not None and score > 0: + candidates.append((score, os.path.abspath(path))) + candidates.sort(key=lambda item: (-item[0], item[1])) + if not candidates: + searched = ', '.join(_resource_dirs()) + raise RuntimeError(f'--workflow is not set and no suitable workflow json was found under: {searched}') + + best_score = candidates[0][0] + best = [path for score, path in candidates if score == best_score] + if len(best) > 1: + lines = '\n'.join([f' {path}' for path in best]) + raise RuntimeError(f'--workflow is not set and multiple suitable workflow json files were found:\n{lines}\nPlease pass --workflow explicitly.') + + workflow = candidates[0][1] + print(f'--workflow is not set, use workflow json: {workflow}') + return workflow + + +def _resolve_safetensors(model_path): + model_path = os.path.abspath(os.path.expanduser(model_path)) + if os.path.isfile(model_path): + if model_path.endswith('.safetensors'): + return [model_path] + raise RuntimeError(f'--segment expects --path to be a model directory or a .safetensors file, got: {model_path}') + + model_file = os.path.join(model_path, 'model.safetensors') + if os.path.exists(model_file): + return [model_file] + + index_files = sorted(glob.glob(os.path.join(model_path, '*.safetensors.index.json'))) + if index_files: + with open(index_files[0], 'r', encoding='utf-8') as f: + index = json.load(f) + ordered = [] + for filename in index.get('weight_map', {}).values(): + if filename not in ordered: + ordered.append(filename) + paths = [os.path.join(model_path, filename) for filename in ordered] + else: + paths = sorted(glob.glob(os.path.join(model_path, '*.safetensors'))) + + paths = [path for path in paths if os.path.exists(path)] + if not paths: + raise RuntimeError(f'no safetensors file found under: {model_path}') + if len(paths) > 1: + print(f'found {len(paths)} safetensors files, pass all of them to MNNConvert') + return paths + + +def _quant_args(exporter): + quant_bit = exporter.args.quant_bit + if quant_bit == 32: + return [] + if quant_bit == 16: + return ['--fp16'] + return [ + '--weightQuantBits', + str(quant_bit), + '--weightQuantBlock', + str(exporter.args.quant_block) + ] + + +@spinner_run(f'convert safetensors model to ') +def _convert_safetensors(exporter, workflow_path, safetensors_paths): + convert_args = [ + '', + '-f', + 'ST', + '-i', + str(workflow_path) + ] + for safetensors_path in safetensors_paths: + convert_args += ['-i', str(safetensors_path)] + convert_args += [ + '-o', + str(exporter.args.dst_path), + '--allowCustomOp' + ] + if exporter.args.transformer_fuse: + convert_args += ['--transformerFuse'] + if exporter.args.group_conv_native: + convert_args += ['--groupConvNative'] + if exporter.args.sym: + convert_args += ['--weightQuantAsymmetric=0'] + convert_args += ['--saveExternalData'] + if exporter.args.hqq: + convert_args += ['--hqq'] + convert_args += _quant_args(exporter) + MNNConverter(exporter).convert(convert_args) + return exporter.args.dst_path + + +def _export_config(exporter, tokenizer_file): + with open(f'{exporter.args.dst_path}/export_args.json', 'w', encoding='utf-8') as f: + json.dump(exporter.args.__dict__, f, ensure_ascii=False, indent=4) + config_json = f'{exporter.args.dst_path}/llm_config.json' + with open(config_json, 'w', encoding='utf-8') as f: + json.dump(exporter.llm_config, f, ensure_ascii=False, indent=4) + + stop_ids = getattr(exporter.tokenizer, 'stop_ids', []) + eos_token = getattr(exporter.tokenizer, 'eos_token_id', None) + if eos_token is None and len(stop_ids) > 0: + eos_token = stop_ids[0] + if isinstance(eos_token, list): + eos_token = eos_token[0] if len(eos_token) > 0 else None + config = { + 'forwardtype': 1, + 'precision': 2, + 'memory': 2, + 'speculative': 0, + 'draft_len': 1, + 'max_decode_tokens': exporter.max_new_tokens, + 'mnn_llm_version': 'segment', + 'tokenizer_file': os.path.basename(tokenizer_file) + } + if eos_token is not None: + config['eos_token'] = int(eos_token) + with open(f'{exporter.args.dst_path}/config.json', 'w', encoding='utf-8') as f: + json.dump(config, f, ensure_ascii=False, indent=4) + return config_json + + +def export(exporter, export_type): + if export_type != 'mnn': + raise RuntimeError('--segment only supports --export mnn') + workflow = _resolve_workflow(exporter) + safetensors_paths = _resolve_safetensors(exporter.args.path) + _convert_safetensors(exporter, workflow, safetensors_paths) + tokenizer_file = exporter.export_tokenizer() + _export_config(exporter, tokenizer_file)