1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| class StreamingLLM(Llama): pass def kv_cache_seq_trim(self): self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1): if n_past < 0: n_past = self.n_tokens self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard) self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard) self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past] self.n_tokens = n_past - n_discard
def _venv_init(self): self.venv = [0] self.venv_idx_map = []
def venv_create(self, name: str): self.venv.append(0) self.venv_idx_map.append(name) return name
def venv_disband(self, name_set): if len(self.venv) <= 1: return False name_set = {x for x in name_set if x in self.venv_idx_map} if not name_set: return False while self.venv_idx_map: if self.venv_idx_map[0] in name_set: self.venv_idx_map.pop(0) tmp = self.venv.pop(1) self.venv[0] += tmp else: break return True
def venv_revision(self, name: str): if len(self.venv) <= 1: return False if name not in self.venv_idx_map: return False _s = 0 while self.venv_idx_map: if self.venv_idx_map[-1] == name: break self.venv_idx_map.pop() _s += self.venv.pop() if _s: self.n_tokens -= min(_s, self.n_tokens) self.kv_cache_seq_trim() return True
def venv_remove(self, name: str): if len(self.venv) <= 1: return False if name not in self.venv_idx_map: return False venv_idx = self.venv_idx_map.index(name) + 1 while self.venv_idx_map: self.venv_idx_map.pop(venv_idx - 1) if venv_idx == len(self.venv) - 1: self.n_tokens -= min(self.venv.pop(), self.n_tokens) self.kv_cache_seq_trim() break else: n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv))) n_discard = self.venv.pop(venv_idx) self.kv_cache_seq_ltrim(n_keep, n_discard) try: venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1 except ValueError: break return True
def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None): if self._n_ctx < self.n_tokens + len(tokens): tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx) self.kv_cache_seq_ltrim(n_keep, tmp_n_discard) for i in range(0, len(tokens), self.n_batch): pass self.n_tokens += n_tokens self.venv[-1] += n_tokens
|