V · 2023年12月19日 · 黑龙江

「X」Embedding in NLP|神经网络和语言模型 Embedding 向量入门

在「X」Embedding in NLP 进阶系列中,我们介绍了自然语言处理的基础知识——自然语言中的 Token、N-gram 和词袋语言模型。今天,我们将继续和大家一起“修炼”,深入探讨神经网络语言模型,特别是循环神经网络,并简要了解如何生成 Embedding 向量。

01.深入了解神经网络

首先,简要回顾一下神经网络的构成,即神经元、多层网络和反向传播算法。如果还想更详细深入了解这些基本概念可以参考其他资源,如 CS231n 课程笔记

在机器学习中,神经元是构成所有神经网络的基本单元。本质上,神经元是神经网络中的一个单元,它对其所有输入进行加权求和,并加上一个可选的偏置项。方程式表示如下所示:

<svg xmlns="http://www.w3.org/2000/svg" width="27.929ex" height="2.922ex" role="img" focusable="false" viewBox="0 -948 12344.5 1291.3" style="vertical-align: -0.777ex;"><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mi"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mo" transform="translate(737.8, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="munderover" transform="translate(1793.6, 0)"><g data-mml-node="mo"><path data-c="2211" d="M61 748Q64 750 489 750H913L954 640Q965 609 976 579T993 533T999 516H979L959 517Q936 579 886 621T777 682Q724 700 655 705T436 710H319Q183 710 183 709Q186 706 348 484T511 259Q517 250 513 244L490 216Q466 188 420 134T330 27L149 -187Q149 -188 362 -188Q388 -188 436 -188T506 -189Q679 -189 778 -162T936 -43Q946 -27 959 6H999L913 -249L489 -250Q65 -250 62 -248Q56 -246 56 -239Q56 -234 118 -161Q186 -81 245 -11L428 206Q428 207 242 462L57 717L56 728Q56 744 61 748Z"></path></g><g data-mml-node="TeXAtom" transform="translate(1056, 477.1) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><path data-c="6E" d="M21 287Q22 293 24 303T36 341T56 388T89 425T135 442Q171 442 195 424T225 390T231 369Q231 367 232 367L243 378Q304 442 382 442Q436 442 469 415T503 336T465 179T427 52Q427 26 444 26Q450 26 453 27Q482 32 505 65T540 145Q542 153 560 153Q580 153 580 145Q580 144 576 130Q568 101 554 73T508 17T439 -10Q392 -10 371 17T350 73Q350 92 386 193T423 345Q423 404 379 404H374Q288 404 229 303L222 291L189 157Q156 26 151 16Q138 -11 108 -11Q95 -11 87 -5T76 7T74 17Q74 30 112 180T152 343Q153 348 153 366Q153 405 129 405Q91 405 66 305Q60 285 60 284Q58 278 41 278H27Q21 284 21 287Z"></path></g><g data-mml-node="mo" transform="translate(600, 0)"><path data-c="2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path></g><g data-mml-node="mn" transform="translate(1378, 0)"><path data-c="31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"></path></g></g><g data-mml-node="TeXAtom" transform="translate(1056, -285.4) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><path data-c="69" d="M184 600Q184 624 203 642T247 661Q265 661 277 649T290 619Q290 596 270 577T226 557Q211 557 198 567T184 600ZM21 287Q21 295 30 318T54 369T98 420T158 442Q197 442 223 419T250 357Q250 340 236 301T196 196T154 83Q149 61 149 51Q149 26 166 26Q175 26 185 29T208 43T235 78T260 137Q263 149 265 151T282 153Q302 153 302 143Q302 135 293 112T268 61T223 11T161 -11Q129 -11 102 10T74 74Q74 91 79 106T122 220Q160 321 166 341T173 380Q173 404 156 404H154Q124 404 99 371T61 287Q60 286 59 284T58 281T56 279T53 278T49 278T41 278H27Q21 284 21 287Z"></path></g><g data-mml-node="mo" transform="translate(345, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="mn" transform="translate(1123, 0)"><path data-c="30" d="M96 585Q152 666 249 666Q297 666 345 640T423 548Q460 465 460 320Q460 165 417 83Q397 41 362 16T301 -15T250 -22Q224 -22 198 -16T137 16T82 83Q39 165 39 320Q39 494 96 585ZM321 597Q291 629 250 629Q208 629 178 597Q153 571 145 525T137 333Q137 175 145 125T181 46Q209 16 250 16Q290 16 318 46Q347 76 354 130T362 333Q362 478 354 524T321 597Z"></path></g></g></g><g data-mml-node="msub" transform="translate(4394.2, 0)"><g data-mml-node="mi"><path data-c="77" d="M580 385Q580 406 599 424T641 443Q659 443 674 425T690 368Q690 339 671 253Q656 197 644 161T609 80T554 12T482 -11Q438 -11 404 5T355 48Q354 47 352 44Q311 -11 252 -11Q226 -11 202 -5T155 14T118 53T104 116Q104 170 138 262T173 379Q173 380 173 381Q173 390 173 393T169 400T158 404H154Q131 404 112 385T82 344T65 302T57 280Q55 278 41 278H27Q21 284 21 287Q21 293 29 315T52 366T96 418T161 441Q204 441 227 416T250 358Q250 340 217 250T184 111Q184 65 205 46T258 26Q301 26 334 87L339 96V119Q339 122 339 128T340 136T341 143T342 152T345 165T348 182T354 206T362 238T373 281Q402 395 406 404Q419 431 449 431Q468 431 475 421T483 402Q483 389 454 274T422 142Q420 131 420 107V100Q420 85 423 71T442 42T487 26Q558 26 600 148Q609 171 620 213T632 273Q632 306 619 325T593 357T580 385Z"></path></g><g data-mml-node="mi" transform="translate(716, -150) scale(0.707)"><path data-c="69" d="M184 600Q184 624 203 642T247 661Q265 661 277 649T290 619Q290 596 270 577T226 557Q211 557 198 567T184 600ZM21 287Q21 295 30 318T54 369T98 420T158 442Q197 442 223 419T250 357Q250 340 236 301T196 196T154 83Q149 61 149 51Q149 26 166 26Q175 26 185 29T208 43T235 78T260 137Q263 149 265 151T282 153Q302 153 302 143Q302 135 293 112T268 61T223 11T161 -11Q129 -11 102 10T74 74Q74 91 79 106T122 220Q160 321 166 341T173 380Q173 404 156 404H154Q124 404 99 371T61 287Q60 286 59 284T58 281T56 279T53 278T49 278T41 278H27Q21 284 21 287Z"></path></g></g><g data-mml-node="msub" transform="translate(5404.1, 0)"><g data-mml-node="mi"><path data-c="78" d="M52 289Q59 331 106 386T222 442Q257 442 286 424T329 379Q371 442 430 442Q467 442 494 420T522 361Q522 332 508 314T481 292T458 288Q439 288 427 299T415 328Q415 374 465 391Q454 404 425 404Q412 404 406 402Q368 386 350 336Q290 115 290 78Q290 50 306 38T341 26Q378 26 414 59T463 140Q466 150 469 151T485 153H489Q504 153 504 145Q504 144 502 134Q486 77 440 33T333 -11Q263 -11 227 52Q186 -10 133 -10H127Q78 -10 57 16T35 71Q35 103 54 123T99 143Q142 143 142 101Q142 81 130 66T107 46T94 41L91 40Q91 39 97 36T113 29T132 26Q168 26 194 71Q203 87 217 139T245 247T261 313Q266 340 266 352Q266 380 251 392T217 404Q177 404 142 372T93 290Q91 281 88 280T72 278H58Q52 284 52 289Z"></path></g><g data-mml-node="mi" transform="translate(572, -150) scale(0.707)"><path data-c="69" d="M184 600Q184 624 203 642T247 661Q265 661 277 649T290 619Q290 596 270 577T226 557Q211 557 198 567T184 600ZM21 287Q21 295 30 318T54 369T98 420T158 442Q197 442 223 419T250 357Q250 340 236 301T196 196T154 83Q149 61 149 51Q149 26 166 26Q175 26 185 29T208 43T235 78T260 137Q263 149 265 151T282 153Q302 153 302 143Q302 135 293 112T268 61T223 11T161 -11Q129 -11 102 10T74 74Q74 91 79 106T122 220Q160 321 166 341T173 380Q173 404 156 404H154Q124 404 99 371T61 287Q60 286 59 284T58 281T56 279T53 278T49 278T41 278H27Q21 284 21 287Z"></path></g></g><g data-mml-node="mo" transform="translate(6492.3, 0)"><path data-c="2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path></g><g data-mml-node="mi" transform="translate(7492.5, 0)"><path data-c="62" d="M73 647Q73 657 77 670T89 683Q90 683 161 688T234 694Q246 694 246 685T212 542Q204 508 195 472T180 418L176 399Q176 396 182 402Q231 442 283 442Q345 442 383 396T422 280Q422 169 343 79T173 -11Q123 -11 82 27T40 150V159Q40 180 48 217T97 414Q147 611 147 623T109 637Q104 637 101 637H96Q86 637 83 637T76 640T73 647ZM336 325V331Q336 405 275 405Q258 405 240 397T207 376T181 352T163 330L157 322L136 236Q114 150 114 114Q114 66 138 42Q154 26 178 26Q211 26 245 58Q270 81 285 114T318 219Q336 291 336 325Z"></path></g><g data-mml-node="mo" transform="translate(8199.3, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(9255.1, 0)"><g data-mml-node="mi"><path data-c="77" d="M624 444Q636 441 722 441Q797 441 800 444H805V382H741L593 11Q592 10 590 8T586 4T584 2T581 0T579 -2T575 -3T571 -3T567 -4T561 -4T553 -4H542Q525 -4 518 6T490 70Q474 110 463 137L415 257L367 137Q357 111 341 72Q320 17 313 7T289 -4H277Q259 -4 253 -2T238 11L90 382H25V444H32Q47 441 140 441Q243 441 261 444H270V382H222L310 164L382 342L366 382H303V444H310Q322 441 407 441Q508 441 523 444H531V382H506Q481 382 481 380Q482 376 529 259T577 142L674 382H617V444H624Z"></path></g><g data-mml-node="mi" transform="translate(831, 0)"><path data-c="78" d="M227 0Q212 3 121 3Q40 3 28 0H21V62H117L245 213L109 382H26V444H34Q49 441 143 441Q247 441 265 444H274V382H246L281 339Q315 297 316 297Q320 297 354 341L389 382H352V444H360Q375 441 466 441Q547 441 559 444H566V382H471L355 246L504 63L545 62H586V0H578Q563 3 469 3Q365 3 347 0H338V62H366Q366 63 326 112T285 163L198 63L217 62H235V0H227Z"></path></g></g><g data-mml-node="mo" transform="translate(10915.3, 0)"><path data-c="2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path></g><g data-mml-node="mi" transform="translate(11915.5, 0)"><path data-c="62" d="M73 647Q73 657 77 670T89 683Q90 683 161 688T234 694Q246 694 246 685T212 542Q204 508 195 472T180 418L176 399Q176 396 182 402Q231 442 283 442Q345 442 383 396T422 280Q422 169 343 79T173 -11Q123 -11 82 27T40 150V159Q40 180 48 217T97 414Q147 611 147 623T109 637Q104 637 101 637H96Q86 637 83 637T76 640T73 647ZM336 325V331Q336 405 275 405Q258 405 240 397T207 376T181 352T163 330L157 322L136 236Q114 150 114 114Q114 66 138 42Q154 26 178 26Q211 26 245 58Q270 81 285 114T318 219Q336 291 336 325Z"></path></g></g></g></svg>

在这里,<math xmlns="http://www.w3.org/1998/Math/MathML">
<msub>

<mi>x</mi>
<mn>0</mn>

</msub>
<mo>,</mo>
<msub>

<mi>x</mi>
<mn>1</mn>

</msub>
<mo>,</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>,</mo>
<msub>

<mi>x</mi>
<mrow>
  <mi>n</mi>
  <mo>−</mo>
  <mn>1</mn>
</mrow>

</msub>
</math> 代表上一层神经元的输出,<math xmlns="http://www.w3.org/1998/Math/MathML">
<msub>

<mi>w</mi>
<mn>0</mn>

</msub>
<mo>,</mo>
<msub>

<mi>w</mi>
<mn>1</mn>

</msub>
<mo>,</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>,</mo>
<msub>

<mi>w</mi>
<mrow>
  <mi>n</mi>
  <mo>−</mo>
  <mn>1</mn>
</mrow>

</msub>
</math>代表这个神经元用来综合输出值的权重。

如果一个多层神经网络仅由上述方程中的加权和组成,我们可以将所有项合并为一个单一的线性层——这对于建模 Token 之间的关系或编码复杂文本并不是很理想。这就是为什么所有神经元在加权和之后都包含一个非线性激活函数,其中我们最熟知的例子就是修正线性单元(ReLU)函数:

<svg xmlns="http://www.w3.org/2000/svg" width="24.407ex" height="5.43ex" role="img" focusable="false" viewBox="0 -1450 10788.1 2400" style="vertical-align: -2.149ex;"><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mrow"><g data-mml-node="mo"></g><g data-mml-node="mi"><path data-c="52" d="M230 637Q203 637 198 638T193 649Q193 676 204 682Q206 683 378 683Q550 682 564 680Q620 672 658 652T712 606T733 563T739 529Q739 484 710 445T643 385T576 351T538 338L545 333Q612 295 612 223Q612 212 607 162T602 80V71Q602 53 603 43T614 25T640 16Q668 16 686 38T712 85Q717 99 720 102T735 105Q755 105 755 93Q755 75 731 36Q693 -21 641 -21H632Q571 -21 531 4T487 82Q487 109 502 166T517 239Q517 290 474 313Q459 320 449 321T378 323H309L277 193Q244 61 244 59Q244 55 245 54T252 50T269 48T302 46H333Q339 38 339 37T336 19Q332 6 326 0H311Q275 2 180 2Q146 2 117 2T71 2T50 1Q33 1 33 10Q33 12 36 24Q41 43 46 45Q50 46 61 46H67Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628Q287 635 230 637ZM630 554Q630 586 609 608T523 636Q521 636 500 636T462 637H440Q393 637 386 627Q385 624 352 494T319 361Q319 360 388 360Q466 361 492 367Q556 377 592 426Q608 449 619 486T630 554Z"></path></g><g data-mml-node="mi" transform="translate(759, 0)"><path data-c="65" d="M39 168Q39 225 58 272T107 350T174 402T244 433T307 442H310Q355 442 388 420T421 355Q421 265 310 237Q261 224 176 223Q139 223 138 221Q138 219 132 186T125 128Q125 81 146 54T209 26T302 45T394 111Q403 121 406 121Q410 121 419 112T429 98T420 82T390 55T344 24T281 -1T205 -11Q126 -11 83 42T39 168ZM373 353Q367 405 305 405Q272 405 244 391T199 357T170 316T154 280T149 261Q149 260 169 260Q282 260 327 284T373 353Z"></path></g><g data-mml-node="mi" transform="translate(1225, 0)"><path data-c="4C" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 217 683Q271 680 344 680Q485 680 506 683H518Q524 677 524 674T522 656Q517 641 513 637H475Q406 636 394 628Q387 624 380 600T313 336Q297 271 279 198T252 88L243 52Q243 48 252 48T311 46H328Q360 46 379 47T428 54T478 72T522 106T564 161Q580 191 594 228T611 270Q616 273 628 273H641Q647 264 647 262T627 203T583 83T557 9Q555 4 553 3T537 0T494 -1Q483 -1 418 -1T294 0H116Q32 0 32 10Q32 17 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path></g><g data-mml-node="mi" transform="translate(1906, 0)"><path data-c="55" d="M107 637Q73 637 71 641Q70 643 70 649Q70 673 81 682Q83 683 98 683Q139 681 234 681Q268 681 297 681T342 682T362 682Q378 682 378 672Q378 670 376 658Q371 641 366 638H364Q362 638 359 638T352 638T343 637T334 637Q295 636 284 634T266 623Q265 621 238 518T184 302T154 169Q152 155 152 140Q152 86 183 55T269 24Q336 24 403 69T501 205L552 406Q599 598 599 606Q599 633 535 637Q511 637 511 648Q511 650 513 660Q517 676 519 679T529 683Q532 683 561 682T645 680Q696 680 723 681T752 682Q767 682 767 672Q767 650 759 642Q756 637 737 637Q666 633 648 597Q646 592 598 404Q557 235 548 205Q515 105 433 42T263 -22Q171 -22 116 34T60 167V183Q60 201 115 421Q164 622 164 628Q164 635 107 637Z"></path></g><g data-mml-node="mo" transform="translate(2673, 0)"><path data-c="28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path></g><g data-mml-node="mi" transform="translate(3062, 0)"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mo" transform="translate(3522, 0)"><path data-c="29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path></g><g data-mml-node="mo" transform="translate(4188.8, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="mrow" transform="translate(5244.6, 0)"><g data-mml-node="mo"><path data-c="7B" d="M618 -943L612 -949H582L568 -943Q472 -903 411 -841T332 -703Q327 -682 327 -653T325 -350Q324 -28 323 -18Q317 24 301 61T264 124T221 171T179 205T147 225T132 234Q130 238 130 250Q130 255 130 258T131 264T132 267T134 269T139 272T144 275Q207 308 256 367Q310 436 323 519Q324 529 325 851Q326 1124 326 1154T332 1205Q369 1358 566 1443L582 1450H612L618 1444V1429Q618 1413 616 1411L608 1406Q599 1402 585 1393T552 1372T515 1343T479 1305T449 1257T429 1200Q425 1180 425 1152T423 851Q422 579 422 549T416 498Q407 459 388 424T346 364T297 318T250 284T214 264T197 254L188 251L205 242Q290 200 345 138T416 3Q421 -18 421 -48T423 -349Q423 -397 423 -472Q424 -677 428 -694Q429 -697 429 -699Q434 -722 443 -743T465 -782T491 -816T519 -845T548 -868T574 -886T595 -899T610 -908L616 -910Q618 -912 618 -928V-943Z"></path></g><g data-mml-node="mtable" transform="translate(750, 0)"><g data-mml-node="mtr" transform="translate(0, 700)"><g data-mml-node="mtd"></g><g data-mml-node="mtd" transform="translate(1000, 0)"><g data-mml-node="mn"><path data-c="30" d="M96 585Q152 666 249 666Q297 666 345 640T423 548Q460 465 460 320Q460 165 417 83Q397 41 362 16T301 -15T250 -22Q224 -22 198 -16T137 16T82 83Q39 165 39 320Q39 494 96 585ZM321 597Q291 629 250 629Q208 629 178 597Q153 571 145 525T137 333Q137 175 145 125T181 46Q209 16 250 16Q290 16 318 46Q347 76 354 130T362 333Q362 478 354 524T321 597Z"></path></g></g><g data-mml-node="mtd" transform="translate(2500, 0)"><g data-mml-node="mi"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mo" transform="translate(737.8, 0)"><path data-c="2264" d="M674 636Q682 636 688 630T694 615T687 601Q686 600 417 472L151 346L399 228Q687 92 691 87Q694 81 694 76Q694 58 676 56H670L382 192Q92 329 90 331Q83 336 83 348Q84 359 96 365Q104 369 382 500T665 634Q669 636 674 636ZM84 -118Q84 -108 99 -98H678Q694 -104 694 -118Q694 -130 679 -138H98Q84 -131 84 -118Z"></path></g><g data-mml-node="mn" transform="translate(1793.6, 0)"><path data-c="30" d="M96 585Q152 666 249 666Q297 666 345 640T423 548Q460 465 460 320Q460 165 417 83Q397 41 362 16T301 -15T250 -22Q224 -22 198 -16T137 16T82 83Q39 165 39 320Q39 494 96 585ZM321 597Q291 629 250 629Q208 629 178 597Q153 571 145 525T137 333Q137 175 145 125T181 46Q209 16 250 16Q290 16 318 46Q347 76 354 130T362 333Q362 478 354 524T321 597Z"></path></g></g></g><g data-mml-node="mtr" transform="translate(0, -700)"><g data-mml-node="mtd"></g><g data-mml-node="mtd" transform="translate(1020, 0)"><g data-mml-node="mi"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g></g><g data-mml-node="mtd" transform="translate(2520, 0)"><g data-mml-node="mi"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mo" transform="translate(737.8, 0)"><path data-c="2265" d="M83 616Q83 624 89 630T99 636Q107 636 253 568T543 431T687 361Q694 356 694 346T687 331Q685 329 395 192L107 56H101Q83 58 83 76Q83 77 83 79Q82 86 98 95Q117 105 248 167Q326 204 378 228L626 346L360 472Q291 505 200 548Q112 589 98 597T83 616ZM84 -118Q84 -108 99 -98H678Q694 -104 694 -118Q694 -130 679 -138H98Q84 -131 84 -118Z"></path></g><g data-mml-node="mi" transform="translate(1793.6, 0)"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g></g></g></g><g data-mml-node="mo" transform="translate(5543.6, 0)"></g></g><g data-mml-node="mo" transform="translate(10788.1, 0)"></g></g></g></g></svg>

对于大多数现代神经网络语言模型来说,高斯误差线性单元(GELU)激活函数更常见:

<svg xmlns="http://www.w3.org/2000/svg" width="18.578ex" height="2.262ex" role="img" focusable="false" viewBox="0 -750 8211.6 1000" style="vertical-align: -0.566ex;"><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mi"><path data-c="47" d="M50 252Q50 367 117 473T286 641T490 704Q580 704 633 653Q642 643 648 636T656 626L657 623Q660 623 684 649Q691 655 699 663T715 679T725 690L740 705H746Q760 705 760 698Q760 694 728 561Q692 422 692 421Q690 416 687 415T669 413H653Q647 419 647 422Q647 423 648 429T650 449T651 481Q651 552 619 605T510 659Q492 659 471 656T418 643T357 615T294 567T236 496T189 394T158 260Q156 242 156 221Q156 173 170 136T206 79T256 45T308 28T353 24Q407 24 452 47T514 106Q517 114 529 161T541 214Q541 222 528 224T468 227H431Q425 233 425 235T427 254Q431 267 437 273H454Q494 271 594 271Q634 271 659 271T695 272T707 272Q721 272 721 263Q721 261 719 249Q714 230 709 228Q706 227 694 227Q674 227 653 224Q646 221 643 215T629 164Q620 131 614 108Q589 6 586 3Q584 1 581 1Q571 1 553 21T530 52Q530 53 528 52T522 47Q448 -22 322 -22Q201 -22 126 55T50 252Z"></path></g><g data-mml-node="mi" transform="translate(786, 0)"><path data-c="45" d="M492 213Q472 213 472 226Q472 230 477 250T482 285Q482 316 461 323T364 330H312Q311 328 277 192T243 52Q243 48 254 48T334 46Q428 46 458 48T518 61Q567 77 599 117T670 248Q680 270 683 272Q690 274 698 274Q718 274 718 261Q613 7 608 2Q605 0 322 0H133Q31 0 31 11Q31 13 34 25Q38 41 42 43T65 46Q92 46 125 49Q139 52 144 61Q146 66 215 342T285 622Q285 629 281 629Q273 632 228 634H197Q191 640 191 642T193 659Q197 676 203 680H757Q764 676 764 669Q764 664 751 557T737 447Q735 440 717 440H705Q698 445 698 453L701 476Q704 500 704 528Q704 558 697 578T678 609T643 625T596 632T532 634H485Q397 633 392 631Q388 629 386 622Q385 619 355 499T324 377Q347 376 372 376H398Q464 376 489 391T534 472Q538 488 540 490T557 493Q562 493 565 493T570 492T572 491T574 487T577 483L544 351Q511 218 508 216Q505 213 492 213Z"></path></g><g data-mml-node="mi" transform="translate(1550, 0)"><path data-c="4C" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 217 683Q271 680 344 680Q485 680 506 683H518Q524 677 524 674T522 656Q517 641 513 637H475Q406 636 394 628Q387 624 380 600T313 336Q297 271 279 198T252 88L243 52Q243 48 252 48T311 46H328Q360 46 379 47T428 54T478 72T522 106T564 161Q580 191 594 228T611 270Q616 273 628 273H641Q647 264 647 262T627 203T583 83T557 9Q555 4 553 3T537 0T494 -1Q483 -1 418 -1T294 0H116Q32 0 32 10Q32 17 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path></g><g data-mml-node="mi" transform="translate(2231, 0)"><path data-c="55" d="M107 637Q73 637 71 641Q70 643 70 649Q70 673 81 682Q83 683 98 683Q139 681 234 681Q268 681 297 681T342 682T362 682Q378 682 378 672Q378 670 376 658Q371 641 366 638H364Q362 638 359 638T352 638T343 637T334 637Q295 636 284 634T266 623Q265 621 238 518T184 302T154 169Q152 155 152 140Q152 86 183 55T269 24Q336 24 403 69T501 205L552 406Q599 598 599 606Q599 633 535 637Q511 637 511 648Q511 650 513 660Q517 676 519 679T529 683Q532 683 561 682T645 680Q696 680 723 681T752 682Q767 682 767 672Q767 650 759 642Q756 637 737 637Q666 633 648 597Q646 592 598 404Q557 235 548 205Q515 105 433 42T263 -22Q171 -22 116 34T60 167V183Q60 201 115 421Q164 622 164 628Q164 635 107 637Z"></path></g><g data-mml-node="mo" transform="translate(2998, 0)"><path data-c="28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path></g><g data-mml-node="mi" transform="translate(3387, 0)"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mo" transform="translate(3847, 0)"><path data-c="29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path></g><g data-mml-node="mstyle" transform="translate(4236, 0)"><g data-mml-node="mspace"></g></g><g data-mml-node="mo" transform="translate(5513.8, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="mi" transform="translate(6569.6, 0)"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g><g data-mml-node="mi" transform="translate(7029.6, 0)"><path data-c="3A6" d="M312 622Q310 623 307 625T303 629T297 631T286 634T270 635T246 636T211 637H184V683H196Q220 680 361 680T526 683H538V637H511Q468 637 447 635T422 631T411 622V533L425 531Q525 519 595 466T665 342Q665 301 642 267T583 209T506 172T425 152L411 150V61Q417 55 421 53T447 48T511 46H538V0H526Q502 3 361 3T196 0H184V46H211Q231 46 245 46T270 47T286 48T297 51T303 54T307 57T312 61V150H310Q309 151 289 153T232 166T160 195Q149 201 136 210T103 238T69 284T56 342Q56 414 128 467T294 530Q309 532 310 533H312V622ZM170 342Q170 207 307 188H312V495H309Q301 495 282 491T231 469T186 423Q170 389 170 342ZM415 188Q487 199 519 236T551 342Q551 384 539 414T507 459T470 481T434 491T415 495H410V188H415Z"></path></g><g data-mml-node="mi" transform="translate(7751.6, 0)"><path data-c="71" d="M33 157Q33 258 109 349T280 441Q340 441 372 389Q373 390 377 395T388 406T404 418Q438 442 450 442Q454 442 457 439T460 434Q460 425 391 149Q320 -135 320 -139Q320 -147 365 -148H390Q396 -156 396 -157T393 -175Q389 -188 383 -194H370Q339 -192 262 -192Q234 -192 211 -192T174 -192T157 -193Q143 -193 143 -185Q143 -182 145 -170Q149 -154 152 -151T172 -148Q220 -148 230 -141Q238 -136 258 -53T279 32Q279 33 272 29Q224 -10 172 -10Q117 -10 75 30T33 157ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path></g></g></g></svg>

在这里,<math xmlns="http://www.w3.org/1998/Math/MathML">
<mi mathvariant="normal">Φ</mi>
<mi>q</mi>
</math> 代表高斯累积分布函数,可以用 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mi>G</mi>
<mi>E</mi>
<mi>L</mi>
<mi>U</mi>
<mo stretchy="false">(</mo>
<mi>q</mi>
<mo stretchy="false">)</mo>
<mo>≈</mo>
<mfrac>

<mi>q</mi>
<mrow>
  <mn>1</mn>
  <mo>+</mo>
  <msup>
    <mi>e</mi>
    <mrow>
      <mo>−</mo>
      <mn>1.702</mn>
      <mi>q</mi>
    </mrow>
  </msup>
</mrow>

</mfrac>
</math> 来表示。这个激活函数在上述的加权求和之后被应用。总而言之,一个单一的神经元看起来像这样:

为了学习更复杂的函数,我们可以将神经元堆叠起来——一个接一个地形成一个层。同一层中的所有神经元接收相同的输入;它们之间唯一的区别是权重 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mrow>

<mi mathvariant="bold">w</mi>

</mrow>
</math> 和偏置 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>b</mtext>
</math>。我们可以用矩阵符号将上述方程表示一个单层:

<svg xmlns="http://www.w3.org/2000/svg" width="21.208ex" height="2.262ex" role="img" focusable="false" viewBox="0 -750 9374 1000" style="vertical-align: -0.566ex;"><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="TeXAtom" data-mjx-texclass="ORD"><g data-mml-node="mi"><path data-c="79" d="M84 -102Q84 -110 87 -119T102 -138T133 -149Q148 -148 162 -143T186 -131T206 -114T222 -95T234 -76T243 -59T249 -45T252 -37L269 0L96 382H26V444H34Q49 441 146 441Q252 441 270 444H279V382H255Q232 382 232 380L337 151L442 382H394V444H401Q413 441 495 441Q568 441 574 444H580V382H510L406 152Q298 -84 297 -87Q269 -139 225 -169T131 -200Q85 -200 54 -172T23 -100Q23 -64 44 -50T87 -35Q111 -35 130 -50T152 -92V-100H84V-102Z"></path></g></g><g data-mml-node="mo" transform="translate(884.8, 0)"><path data-c="3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path></g><g data-mml-node="mi" transform="translate(1940.6, 0)"><path data-c="47" d="M50 252Q50 367 117 473T286 641T490 704Q580 704 633 653Q642 643 648 636T656 626L657 623Q660 623 684 649Q691 655 699 663T715 679T725 690L740 705H746Q760 705 760 698Q760 694 728 561Q692 422 692 421Q690 416 687 415T669 413H653Q647 419 647 422Q647 423 648 429T650 449T651 481Q651 552 619 605T510 659Q492 659 471 656T418 643T357 615T294 567T236 496T189 394T158 260Q156 242 156 221Q156 173 170 136T206 79T256 45T308 28T353 24Q407 24 452 47T514 106Q517 114 529 161T541 214Q541 222 528 224T468 227H431Q425 233 425 235T427 254Q431 267 437 273H454Q494 271 594 271Q634 271 659 271T695 272T707 272Q721 272 721 263Q721 261 719 249Q714 230 709 228Q706 227 694 227Q674 227 653 224Q646 221 643 215T629 164Q620 131 614 108Q589 6 586 3Q584 1 581 1Q571 1 553 21T530 52Q530 53 528 52T522 47Q448 -22 322 -22Q201 -22 126 55T50 252Z"></path></g><g data-mml-node="mi" transform="translate(2726.6, 0)"><path data-c="45" d="M492 213Q472 213 472 226Q472 230 477 250T482 285Q482 316 461 323T364 330H312Q311 328 277 192T243 52Q243 48 254 48T334 46Q428 46 458 48T518 61Q567 77 599 117T670 248Q680 270 683 272Q690 274 698 274Q718 274 718 261Q613 7 608 2Q605 0 322 0H133Q31 0 31 11Q31 13 34 25Q38 41 42 43T65 46Q92 46 125 49Q139 52 144 61Q146 66 215 342T285 622Q285 629 281 629Q273 632 228 634H197Q191 640 191 642T193 659Q197 676 203 680H757Q764 676 764 669Q764 664 751 557T737 447Q735 440 717 440H705Q698 445 698 453L701 476Q704 500 704 528Q704 558 697 578T678 609T643 625T596 632T532 634H485Q397 633 392 631Q388 629 386 622Q385 619 355 499T324 377Q347 376 372 376H398Q464 376 489 391T534 472Q538 488 540 490T557 493Q562 493 565 493T570 492T572 491T574 487T577 483L544 351Q511 218 508 216Q505 213 492 213Z"></path></g><g data-mml-node="mi" transform="translate(3490.6, 0)"><path data-c="4C" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 217 683Q271 680 344 680Q485 680 506 683H518Q524 677 524 674T522 656Q517 641 513 637H475Q406 636 394 628Q387 624 380 600T313 336Q297 271 279 198T252 88L243 52Q243 48 252 48T311 46H328Q360 46 379 47T428 54T478 72T522 106T564 161Q580 191 594 228T611 270Q616 273 628 273H641Q647 264 647 262T627 203T583 83T557 9Q555 4 553 3T537 0T494 -1Q483 -1 418 -1T294 0H116Q32 0 32 10Q32 17 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path></g><g data-mml-node="mi" transform="translate(4171.6, 0)"><path data-c="55" d="M107 637Q73 637 71 641Q70 643 70 649Q70 673 81 682Q83 683 98 683Q139 681 234 681Q268 681 297 681T342 682T362 682Q378 682 378 672Q378 670 376 658Q371 641 366 638H364Q362 638 359 638T352 638T343 637T334 637Q295 636 284 634T266 623Q265 621 238 518T184 302T154 169Q152 155 152 140Q152 86 183 55T269 24Q336 24 403 69T501 205L552 406Q599 598 599 606Q599 633 535 637Q511 637 511 648Q511 650 513 660Q517 676 519 679T529 683Q532 683 561 682T645 680Q696 680 723 681T752 682Q767 682 767 672Q767 650 759 642Q756 637 737 637Q666 633 648 597Q646 592 598 404Q557 235 548 205Q515 105 433 42T263 -22Q171 -22 116 34T60 167V183Q60 201 115 421Q164 622 164 628Q164 635 107 637Z"></path></g><g data-mml-node="mo" transform="translate(4938.6, 0)"><path data-c="28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(5327.6, 0)"><g data-mml-node="mi"><path data-c="57" d="M915 686L1052 683Q1142 683 1157 686H1164V624H1073L957 320Q930 249 900 170T855 52T839 10Q834 0 826 -5Q821 -7 799 -7H792Q777 -7 772 -5T759 10Q759 11 748 39T716 122T676 228L594 442L512 228Q486 159 455 78Q433 19 428 9T416 -5Q411 -7 389 -7H379Q356 -7 349 10Q349 12 334 51T288 170T231 320L116 624H24V686H35Q44 683 183 683Q331 683 355 686H368V624H323Q278 624 278 623L437 207L499 369L561 531L526 624H434V686H445Q454 683 593 683Q741 683 765 686H778V624H733Q688 624 688 623L847 207Q848 207 927 415T1006 624H905V686H915Z"></path></g></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(6516.6, 0)"><g data-mml-node="mi"><path data-c="78" d="M227 0Q212 3 121 3Q40 3 28 0H21V62H117L245 213L109 382H26V444H34Q49 441 143 441Q247 441 265 444H274V382H246L281 339Q315 297 316 297Q320 297 354 341L389 382H352V444H360Q375 441 466 441Q547 441 559 444H566V382H471L355 246L504 63L545 62H586V0H578Q563 3 469 3Q365 3 347 0H338V62H366Q366 63 326 112T285 163L198 63L217 62H235V0H227Z"></path></g></g><g data-mml-node="mo" transform="translate(7345.8, 0)"><path data-c="2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(8346, 0)"><g data-mml-node="mi"><path data-c="62" d="M32 686L123 690Q214 694 215 694H221V409Q289 450 378 450Q479 450 539 387T600 221Q600 122 535 58T358 -6H355Q272 -6 203 53L160 1L129 0H98V301Q98 362 98 435T99 525Q99 591 97 604T83 620Q69 624 42 624H29V686H32ZM227 105L232 99Q237 93 242 87T258 73T280 59T306 49T339 45Q380 45 411 66T451 131Q457 160 457 230Q457 264 456 284T448 329T430 367T396 389T343 398Q282 398 235 355L227 348V105Z"></path></g></g><g data-mml-node="mo" transform="translate(8985, 0)"><path data-c="29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path></g></g></g></svg>

在这里, <math xmlns="http://www.w3.org/1998/Math/MathML">
<mrow>

<mi mathvariant="bold">**w**</mi>

</mrow>
</math> 是一个二维矩阵,包含应用于输入 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>x</mtext>
</math> 的所有权重;矩阵的每一行对应一个神经元的权重。这种类型的层通常被称为密集层或全连接层,因为所有输入 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>x</mtext>
</math> 都连接到所有输出 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>y</mtext>
</math>。

我们可以将这两个层串联起来,创建一个基本的前馈网络:

<math xmlns="http://www.w3.org/1998/Math/MathML">
<mtable displaystyle="true" columnalign="right left right left right left right left right left right left" columnspacing="0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em" rowspacing="3pt">

<mtr>
  <mtd>
    <mrow>
      <mrow>
        <maligngroup/>
        <msub>
          <mrow>
            <mi mathvariant="bold">h</mi>
          </mrow>
          <mn>1</mn>
        </msub>
      </mrow>
      <mrow>
        <maligngroup/>
        <mi></mi>
        <mo>=</mo>
        <mi>G</mi>
        <mi>E</mi>
        <mi>L</mi>
        <mi>U</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mrow>
            <mi mathvariant="bold">W</mi>
          </mrow>
          <mn>1</mn>
        </msub>
        <mrow>
          <mi mathvariant="bold">x</mi>
        </mrow>
        <mo>+</mo>
        <msub>
          <mrow>
            <mi mathvariant="bold">b</mi>
          </mrow>
          <mn>1</mn>
        </msub>
        <mo stretchy="false">)</mo>
      </mrow>
    </mrow>
  </mtd>
</mtr>
<mtr>
  <mtd>
    <mrow>
      <mrow>
        <maligngroup/>
        <mrow>
          <mi mathvariant="bold">y</mi>
        </mrow>
      </mrow>
      <mrow>
        <maligngroup/>
        <mi></mi>
        <mo>=</mo>
        <mi>G</mi>
        <mi>E</mi>
        <mi>L</mi>
        <mi>U</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mrow>
            <mi mathvariant="bold">W</mi>
          </mrow>
          <mn>2</mn>
        </msub>
        <msub>
          <mrow>
            <mi mathvariant="bold">h</mi>
          </mrow>
          <mn>1</mn>
        </msub>
        <mo>+</mo>
        <msub>
          <mrow>
            <mi mathvariant="bold">b</mi>
          </mrow>
          <mn>2</mn>
        </msub>
        <mo stretchy="false">)</mo>
      </mrow>
    </mrow>
  </mtd>
</mtr>

</mtable>
</math>

这里我们引入了一个新的隐藏层 h1,它既没有直接连接到输入 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>x</mtext>
</math>,也没有直接连接到输出 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>y</mtext>
</math> 。这一层有效地增加了网络的深度,增加了总的参数数量(多个权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mrow>

<mi mathvariant="bold">**w**</mi>

</mrow>
</math> )。此时,需要注意:随着添加的隐藏层增多,靠近输入层的隐藏值(激活值)与 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>x</mtext>
</math> 更“相似”,而靠近输出的激活值则与 <math xmlns="http://www.w3.org/1998/Math/MathML">
<mtext>y</mtext>
</math> 更相似。

我们在后续的文章中将基于这个原则探讨 Embedding 向量。隐藏层的概念对理解向量搜索至关重要。

前馈网络中单个神经元的参数可以通过一个称为反向传播的过程进行更新,本质上就是微积分中链式法则的重复应用。大家可以搜索一些专门讲解反向传播的课程,这些课程会介绍反向传播为什么对训练神经网络如此有效。这里我们不多做赘述,其基本过程如下所示:

  1. 通过神经网络输入一批数据。
  2. 计算损失。这通常是回归的 L2 损失(平方差)和分类的交叉熵损失。
  3. 使用这个损失来计算与最后一个隐藏层权重的损失梯度 <math xmlns="http://www.w3.org/1998/Math/MathML">
    <mfrac>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mi mathvariant="normal">Λ</mi>
    </mrow>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mrow>

     <msub>
       <mi mathvariant="bold">W</mi>
       <mi mathvariant="bold">n</mi>
     </msub>

    </mrow>
    </mrow>
    </mfrac>
    </math>。

  4. 计算通过最后一个隐藏层的损失,即 <math xmlns="http://www.w3.org/1998/Math/MathML">
    <mfrac>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mi mathvariant="normal">Λ</mi>
    </mrow>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mrow>

     <msub>
       <mi mathvariant="bold">h</mi>
       <mrow>
         <mi mathvariant="bold">n</mi>
         <mo mathvariant="bold">−</mo>
         <mn mathvariant="bold">1</mn>
       </mrow>
     </msub>

    </mrow>
    </mrow>
    </mfrac>
    </math>。

  5. 将这个损失反向传播到倒数第二个隐藏层的权重 <math xmlns="http://www.w3.org/1998/Math/MathML">
    <mfrac>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mi mathvariant="normal">Λ</mi>
    </mrow>
    <mrow>
    <mi mathvariant="normal">∂</mi>
    <mrow>

     <msub>
       <mi mathvariant="bold">W</mi>
       <mrow>
         <mi mathvariant="bold">n</mi>
         <mo mathvariant="bold">−</mo>
         <mn mathvariant="bold">1</mn>
       </mrow>
     </msub>

    </mrow>
    </mrow>
    </mfrac>
    </math>。

  6. 重复步骤 4 和 5,直到计算出所有权重的偏导数。

在计算出与网络中所有权重相关的损失的偏导数后,可以根据优化器和学习率进行一次大规模的权重更新。这个过程会重复进行,直到模型达到收敛或所有轮次都完成。

02.循环神经网络

所有形式的文本和自然语言本质上都是顺序性的,也就是说单词 /Token 是一个接一个地处理的。看似简单的变化,比如增加一个单词、颠倒两个连续的 Token,或增加标点符号,都可能导致解释上的巨大差异。例如,“let's eat, Charles”和“let's eat Charles”两个短语完全是两回事。由于自然语言具备顺序性这一特性,因此循环神经网络(RNNs)是自然而然成为了语言建模的不二之选。

递归是一种独特的递归形式,其中函数是神经网络而不是代码。RNN 还有着生物学起源——人类大脑可以类比为一个(人工)神经网络,我们输入的单词或说出的话语都是生物学处理的结果。

RNN 由两个组成部分:1)一个标准的前馈网络和2)一个递归组件。前馈网络与我们在前一节中讨论的相同。对于递归组件,最后一个隐藏状态被反馈到输入中,以便网络可以保持先前的上下文。因此,先前的知识(以前一个时间步的隐藏层的形式)在每一个新的时间步被注入网络。

基于上述对 RNN 的宏观定义和解释,我们可以大致了解其实现方式以及为什么 RNN 在语义建模时表现良好。

首先,RNN 的循环结构使它们能够根据顺序捕捉和处理数据,其数据处理方式类似于人类说话、阅读和写作方式。此外,RNN 还可以有效访问来自较早时间的“信息”,比 n-gram 模型和纯前馈网络更能理解自然语言。

大家可以试试用 PyTorch 来实现一个 RNN。注意,这需要对 PyTorch 基础有深入的理解;如果对 PyTorch 还不太熟悉 ,建议大家先阅读该链接

首先定义一个简单的前馈网络,然后将其扩展为一个简单的 RNN,先定义层:

from torch import Tensor
import torch.nn as nn
class BasicNN(nn.Module):    
   def __init__(self, in_dims: int, hidden_dims: int, out_dims: int):        
       super(BasicNN, self).__init__()        
       self.w0 = nn.Linear(in_dims, hidden_dims)
       self.w1 = nn.Linear(hidden_dims, out_dims)

注意,由于我们仅仅输出原始的逻辑值,我们还没有定义损失的样式。在训练时,可以根据实际情况加上某种标准,比如 nn.CrossEntropyLoss

现在,我们可以实现前向传递:

    def forward(self, x: Tensor):
        h = self.w0(x)
        y = self.w1(h)
        return y

这两段代码片段结合在一起形成了一个非常基础的前馈神经网络。为了将其变成 RNN,我们需要从最后一个隐藏状态添加一个反馈回路回到输入:

    def forward(self, x: Tensor, h_p: Tensor):
        h = self.w0(torch.cat(x, h_p))        
        y = self.w1(h)        
        return (y, h)

上述基本上就是全部步骤。由于我们现在增加了由 w0 定义的神经元层的输入数量,我们需要在 __init__中更新它的定义。现在让我们来完成这个操作,并将所有内容整合到一个代码片段中:

import torch.nn as nn
from torch import Tensor

class SimpleRNN(nn.Module):
    def __init__(self, in_dims: int, hidden_dims: int, out_dims: int):
        super(RNN, self).__init__()
        self.w0 = nn.Linear(in_dims + hidden_dims, hidden_dims)
        self.w1 = nn.Linear(hidden_dims, out_dims)
        
    def forward(self, x: Tensor, h_p: Tensor):
        h = self.w0(torch.cat(x, h_p))
        y = self.w1(h)
        return (y, h)

在每次前向传递中,隐藏层h的激活值与输出一起返回。这些激活值随后可以与序列中的每个新 Token一起再次传回模型中。这样一个过程如下所示(以下代码仅作示意):

model = SimpleRNN(n_in, n_hidden, n_out)

...

h = torch.zeros(1, n_hidden)
for token in range(seq):
    (out, h) = model(token, )

至此,我们成功定义了一个简单的前馈网络,并将其扩展为一个简单的 RNN。

03.语言模型 Embedding

我们在上面例子中看到的隐藏层有效地将已经输入到 RNN 的所有内容(所有 Token)进行编码。更具体而言,所有解析 RNN 已看到的文本所需的信息应包含在激活值 h 中。换句话说,h 编码了输入序列的语义,而由 h 定义的有序浮点值集合就是 Embedding 向量,简称为 Embedding。

这些向量表示广泛构成了向量搜索和向量数据库的基础。尽管当今自然语言的 Embedding 是由另一类称为 Transformer 的机器学习模型生成的,而不是 RNN,但本质概念基本相同:将文本内容编码为计算机可理解的 Embedding 向量。我们将在下一篇博客文章中详细讨论如何使用 Embedding 向量。

04.总结

我们在 PyTorch 中实现了一个简单的循环神经网络,并简要介绍了语言模型Embedding。虽然循环神经网络是理解语言的强大工具,并且可以广泛应用于各种应用中(机器翻译、分类、问答等),但它们仍然不是用于生成 Embedding 向量的 ML 模型类型。

在接下来的教程中,我们将使用开源的 Transformer 模型来生成 Embedding 向量,并通过对它们进行向量搜索和运算来展示向量的强大功能。此外,我们也将回到词袋模型的概念,看看这两者如何一起用于编码词汇和语义。敬请期待!

本文由mdnice多平台发布

推荐阅读
关注数
4189
内容数
867
SegmentFault 思否旗下人工智能领域产业媒体,专注技术与产业,一起探索人工智能。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息