Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model] Add deepseek model. #274

Merged
merged 3 commits into from
Mar 28, 2024
Merged

[model] Add deepseek model. #274

merged 3 commits into from
Mar 28, 2024

Conversation

marvin-Yu
Copy link
Contributor

No description provided.

@marvin-Yu marvin-Yu marked this pull request as ready for review March 21, 2024 13:07
@@ -123,7 +123,14 @@ def split_and_convert(self, input_dir, output_dir, dtype, processes):
config["llama"]["num_layer"] = str(hf_config["num_hidden_layers"])
config["llama"]["layernorm_eps"] = str(hf_config.get("rms_norm_eps", 1e-6))
config["llama"]["layernorm_type"] = "pre_layernorm"
config["llama"]["activation_type"] = "silu"
config["llama"]["activation_type"] = str(hf_config["hidden_act"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把LLaMa和deepseek的code分开来,llama的code尽量不要动,设计到deepseek的可以创建新的文件,但需要集成llama的code。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是一样的问题, deepseek模型结构就是复用了llama的"LlamaForCausalLM", 建议还是复用llama的代码. 添加对llama 其它RoPE类型的补充.

@@ -46,5 +48,14 @@ class LlamaRotaryEmbedding {
void llamaCalEmb(const float *inv_freq, const int max_position_embeddings);

private:
static bool initialized;
bool initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要static的,你这sin和cos会有多份相同的实例,但一个model只需要一份sin和cos

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sin / cos 指向的内存空间只会在第一次初始化. 这部分的buffer 由 ctx 上下文的内存池来维护.
emb_cos = ctx->getBuffer(emb_cos_str, max_position_embeddings * inv_freq_size);
emb_sin = ctx->getBuffer(emb_sin_str, max_position_embeddings * inv_freq_size);

@@ -46,5 +48,14 @@ class LlamaRotaryEmbedding {
void llamaCalEmb(const float *inv_freq, const int max_position_embeddings);

private:
static bool initialized;
bool initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要static的,你这sin和cos会有多份相同的实例,但一个model只需要一份sin和cos

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上.

#pragma omp parallel for
for (size_t i = 0; i < max_position_embeddings; i++) {
float *pcos = emb_cos + i * inv_freq_size;
float *psin = emb_sin + i * inv_freq_size;

for (size_t j = 0; j < inv_freq_size; j++) {
float tmp = i * inv_freq[j];
float tmp = i * inv_freq[j] / this->scaling_factor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把LLaMa和deepseek的code分开来,llama的code尽量不要动,涉及到deepseek的可以创建新的文件,但需要集成llama的code。这个你可以新建一个deepseek的rope文件。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek 类似 Yi 的方式, 直接复用的llama模型结构, LinearScaling rope的实现也是在llama model内支持的.
config.json: https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct/blob/main/config.json#L3
LinearScaling rope: https:/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148-L155

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This scaling_factor is LLaMa param.

@@ -117,7 +117,7 @@ def build_inputs_chatglm(tokenizer, query: List[str], padding, history: List[Tup
model_prompt = prompt_pool["chatglm2"]
if "chatglm3" in args.model_name.lower():
model_prompt = prompt_pool["chatglm3"]
if "llama" in args.model_name.lower():
if "llama" in args.model_name.lower() or "deepseek" in args.model_name.lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独 if deepseek

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek在结构层面和llama一致, 建议复用llama的path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add if

#pragma omp parallel for
for (size_t i = 0; i < max_position_embeddings; i++) {
float *pcos = emb_cos + i * inv_freq_size;
float *psin = emb_sin + i * inv_freq_size;

for (size_t j = 0; j < inv_freq_size; j++) {
float tmp = i * inv_freq[j];
float tmp = i * inv_freq[j] / this->scaling_factor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This scaling_factor is LLaMa param.

README.md Outdated
@@ -141,6 +142,7 @@ xFasterTransformer supports a different model format from Huggingface, but it's

Supported model convert list:
- LlamaConvert
- DeepseekConvert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put it down

@@ -117,7 +117,7 @@ def build_inputs_chatglm(tokenizer, query: List[str], padding, history: List[Tup
model_prompt = prompt_pool["chatglm2"]
if "chatglm3" in args.model_name.lower():
model_prompt = prompt_pool["chatglm3"]
if "llama" in args.model_name.lower():
if "llama" in args.model_name.lower() or "deepseek" in args.model_name.lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add if

@@ -34,6 +34,7 @@ def with_mpirun():
"automodel": ["AutoModel"],
"tools": [
"LlamaConvert",
"DeepseekConvert",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put it down

@marvin-Yu marvin-Yu merged commit b29259a into main Mar 28, 2024
1 check passed
@marvin-Yu marvin-Yu deleted the model/add_deepseek branch March 29, 2024 08:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants